Hi, here is the code we use to determine the best possible F1 Threshold on the toy dataset: ``` from collections import defaultdict import itertools import json import os import random import traceback import nibabel as nib import numpy as np from sklearn import metrics from scipy.ndimage import label as find_label from scipy.ndimage import center_of_mass from tqdm import tqdm def get_best_thres( toy_pred_dir=None, toy_label_dir=None, ): toy_pred_list = [] toy_label_list = [] for fl_ in os.listdir(toy_pred_dir): pred_file = os.path.join(toy_pred_dir, fl_) label_file = os.path.join(toy_label_dir, fl_) toy_pred_list.append(nib.load(pred_file).get_fdata(dtype=float)) toy_label_list.append(nib.load(label_file).get_fdata(dtype=float)) toy_pred_array = np.vstack(toy_pred_list).astype(float) toy_label_array = np.vstack(toy_label_list).astype(float) get_f1_score_clean_list(toy_pred_list, toy_label_list, 0.5) unique_preds = np.unique(toy_pred_array) if len(unique_preds) > 2: best_f1 = 0 best_thres = 0 for bin_thres in tqdm(np.linspace(toy_pred_array.min(), toy_pred_array.max(), 20)): f1s = get_f1_score_clean_list(toy_pred_list, toy_label_list, bin_thres) if f1s > best_f1: best_f1 = f1s best_thres = bin_thres print(bin_thres, f1s) step_size = (toy_pred_array.max() - toy_pred_array.min()) / 20 max_f1, reconst_thres = find_best_val( toy_pred_list, toy_label_list, get_f1_score_clean_list, max_steps=4, val_range=(best_thres-step_size, best_thres+step_size), ) else: reconst_thres = np.mean(unique_preds) return reconst_thres def get_f1_score_clean_list(anomaly_map_list, seg_objects_list, bin_thres, size_thres=600): tps , fps, fns = 0,0,0 for pred, label in zip(anomaly_map_list, seg_objects_list): _, tp, fp, fn = get_f1_score(pred, label, bin_thres, size_thres=size_thres) tps += tp fps += fp fns += fn if tps + fps + fns != 0: f1_score = 2 * tps / (2 * tps + fps + fns) else: f1_score = 0 return f1_score def get_f1_score(anomaly_map, seg_objects, bin_thres, size_thres): pred_thres = anomaly_map > bin_thres pred_labeled, n_labels = find_label(pred_thres, np.ones((3, 3, 3))) seg_labeled, seg_labels = find_label(seg_objects, np.ones((3, 3, 3))) label_counts = np.bincount(pred_labeled.flatten()) matched_dict = defaultdict(bool) fp = 0 fn = 0 tp = 0 for lbl_idx in range(1, n_labels + 1): if label_counts[lbl_idx] < size_thres: continue # print(label_counts[lbl_idx], np.sum(pred_labeled == lbl_idx)) pred_thres_copy = pred_thres.copy() pred_thres_copy[pred_labeled != lbl_idx] = 0 x, y, z = center_of_mass(pred_thres_copy) x, y, z = int(x), int(y), int(z) if seg_objects[x, y, z] != 0: gt_sum = np.sum(seg_labeled == seg_labeled[x, y, z]) pred_sum = np.sum(pred_thres_copy) up_thres = gt_sum * 2 low_thres = gt_sum // 2 if pred_sum < up_thres and pred_sum > low_thres: matched_dict[seg_labeled[x, y, z]] = True else: fp += 1 else: fp += 1 for seg_ob_id in np.unique(seg_labeled): if seg_ob_id != 0 and not matched_dict[seg_ob_id]: fn += 1 elif seg_ob_id != 0 and matched_dict[seg_ob_id]: tp += 1 if tp + fp + fn != 0: f1_score = 2 * tp / (2 * tp + fp + fn) else: f1_score = 0 return f1_score, tp, fp, fn def find_best_val(x, y, val_fn, val_range=(0, 1), max_steps=4, step=0, max_val=0, max_point=0): print(step, max_val, max_point) if step == max_steps: return max_val, max_point if val_range[0] == val_range[1]: val_range = (val_range[0], 1) bottom = val_range[0] top = val_range[1] center = bottom + (top - bottom) * 0.5 q_bottom = bottom + (top - bottom) * 0.25 q_top = bottom + (top - bottom) * 0.75 val_bottom = val_fn(x, y, q_bottom) val_top = val_fn(x, y, q_top) if val_bottom > val_top: if val_bottom > max_val: max_val = val_bottom max_point = q_bottom return find_best_val(x, y, val_fn, val_range=(bottom, center), step=step + 1, max_steps=max_steps, max_val=max_val, max_point=max_point) else: if val_top > max_val: max_val = val_bottom max_point = q_bottom return find_best_val(x, y, val_fn, val_range=(center, top), step=step + 1, max_steps=max_steps, max_val=max_val, max_point=max_point) ``` Cheers, David
Created by
David Zimmerer d.zimmerer
Drop files to upload
Your web browser must have JavaScript enabled in order for this application to display correctly.
If you are an automated web crawler from a search engine, follow this
AJAX application crawl link
Code to determine the best threshold
page is loading…