diff --git a/NNsegmentation/data.py b/NNsegmentation/data.py index ac0e6e5..ea161a7 100644 --- a/NNsegmentation/data.py +++ b/NNsegmentation/data.py @@ -443,8 +443,8 @@ def predict_whole_map(file_path, outputs = [] for r in range(rows - 1): for c in range(columns - 1): - patch_inp = inp[..., - (x_offset + r*x_size):(x_offset + (r+1)*x_size), + patch_inp = inp[..., + (x_offset + r*x_size):(x_offset + (r+1)*x_size), (y_offset + c*y_size):(y_offset + (c+1)*y_size)] if time_slices == 1: patch_inp = patch_inp[0] diff --git a/NNsegmentation/layers.py b/NNsegmentation/layers.py index a56e6c3..cb77dff 100644 --- a/NNsegmentation/layers.py +++ b/NNsegmentation/layers.py @@ -92,7 +92,7 @@ class weighted_binary_cross_entropy(object): def __init__(self, n_classes=2): self.n_classes = n_classes self.__name__ = "weighted_binary_cross_entropy" - + def __call__(self, y_true, y_pred): """ Args: @@ -101,16 +101,16 @@ def __call__(self, y_true, y_pred): last slice of the last dimension is weight y_pred (tensor): in shape (batch_size, x_size, y_size, n_classes) model predictions - + """ w = y_true[:, -1] y_true = y_true[:, :-1] - + # Switch to channel last form y_true = keras.backend.permute_dimensions(y_true, (0, 2, 3, 1)) y_pred = keras.backend.permute_dimensions(y_pred, (0, 2, 3, 1)) - + loss = keras.backend.categorical_crossentropy(y_true, y_pred, from_logits=True) * w return loss @@ -120,7 +120,7 @@ class ValidMetrics(keras.callbacks.Callback): Calculate ROC-AUC and F1 on validation data and test data (if applicable) after each epoch - + """ def __init__(self, valid_data=None, test_data=None): @@ -141,5 +141,5 @@ def on_epoch_end(self, epoch, logs={}): f1 = f1_score(y_true.flatten(), y_pred.flatten()>0.5) print('\r test-roc-auc: %f test-f1: %f\n' % (roc, f1)) return - + diff --git a/NNsegmentation/models.py b/NNsegmentation/models.py index 5de43c8..68f950d 100644 --- a/NNsegmentation/models.py +++ b/NNsegmentation/models.py @@ -68,7 +68,7 @@ def __init__(self, self.valid_score_callback = ValidMetrics() self.loss_func = weighted_binary_cross_entropy(n_classes=self.n_classes) self.build_model() - + def build_model(self): """ Define model structure and compile """ @@ -77,7 +77,7 @@ def build_model(self): self.pre_conv = keras.layers.Conv2D(3, (1, 1), activation=None, name='pre_conv')(self.input) self.unet = segmentation_models.Unet( - backbone_name='resnet34', + backbone_name='resnet34', input_shape=(3, self.x_size, self.y_size), classes=self.n_classes, activation='linear', @@ -86,11 +86,11 @@ def build_model(self): decoder_block_type='upsampling', decoder_filters=(256, 128, 64, 32, 16), decoder_use_batchnorm=True) - + output = self.unet(self.pre_conv) - + self.model = keras.models.Model(self.input, output) - self.model.compile(optimizer='Adam', + self.model.compile(optimizer='Adam', loss=self.loss_func, metrics=[]) @@ -127,14 +127,14 @@ def fit(self, os.mkdir(self.model_path) # `X` and `y` should originally be 5 dimensional: (batch, c, z, x, y), # in default model z=1 will be neglected - X, y = preprocess(patches, - n_classes=self.n_classes, - label_input=label_input, + X, y = preprocess(patches, + n_classes=self.n_classes, + label_input=label_input, class_weights=class_weights) X = X.reshape(self.batch_input_shape) y = y.reshape(self.batch_label_shape) assert X.shape[0] == y.shape[0] - + validation_data = None if valid_patches is not None: valid_X, valid_y = preprocess(valid_patches, @@ -145,7 +145,7 @@ def fit(self, assert valid_X.shape[0] == valid_y.shape[0] self.valid_score_callback.valid_data = (valid_X, valid_y) validation_data = (valid_X, valid_y) - + self.model.fit(x=X, y=y, batch_size=batch_size, @@ -212,13 +212,13 @@ def __init__(self, """ Define model Args: - unet_feat (int, optional): output dimension of unet (used as + unet_feat (int, optional): output dimension of unet (used as hidden units) **kwargs: keyword arguments for `Segment` note that `input_shape` should have 4 dimensions """ - + self.unet_feat = unet_feat super(SegmentWithMultipleSlice, self).__init__(**kwargs) self.n_slices = self.input_shape[1] # Input shape (c, z, x, y) @@ -233,9 +233,9 @@ def build_model(self): # Combine time slice dimension and batch dimension inp = SplitSlice(self.n_channels, self.x_size, self.y_size)(self.input) self.pre_conv = keras.layers.Conv2D(3, (1, 1), activation=None, name='pre_conv')(inp) - + self.unet = segmentation_models.Unet( - backbone_name='resnet34', + backbone_name='resnet34', input_shape=(3, self.x_size, self.y_size), classes=self.unet_feat, activation='linear', @@ -251,7 +251,7 @@ def build_model(self): output = MergeSlices(self.n_slices, self.unet_feat)(output) output = keras.layers.Conv2D(self.unet_feat, (1, 1), activation='relu', name='post_conv')(output) output = keras.layers.Conv2D(self.n_classes, (1, 1), activation=None, name='pred_head')(output) - + self.model = keras.models.Model(self.input, output) self.model.compile(optimizer='Adam', loss=self.loss_func, diff --git a/SingleCellPatch/extract_patches.py b/SingleCellPatch/extract_patches.py index 93f745d..164b2d4 100755 --- a/SingleCellPatch/extract_patches.py +++ b/SingleCellPatch/extract_patches.py @@ -189,16 +189,20 @@ def process_site_extract_patches(site_path, # Load data image_stack = np.load(site_path) if channels is None: - channels = list(range(len(image_stack))) - image_stack = image_stack[channels] + channels = list(range(len(image_stack[0]))) + image_stack = image_stack[:, channels] segmentation_stack = np.load(site_segmentation_path) with open(os.path.join(site_supp_files_folder, 'cell_positions.pkl'), 'rb') as f: cell_positions = pickle.load(f) with open(os.path.join(site_supp_files_folder, 'cell_pixel_assignments.pkl'), 'rb') as f: cell_pixel_assignments = pickle.load(f) - n_frames, n_channels, n_z, x_full_size, y_full_size = image_stack.shape - for t_point in range(n_frames): + + # if the number of timepoints between images and predictions mismatch, choose the smaller one + endpt = image_stack.shape[0] if image_stack.shape[0] < segmentation_stack.shape[0] else segmentation_stack.shape[0] + print(f"total timepoints to process = {endpt}") + for t_point in range(endpt): + # for t_point in range(n_frames): print(f"processing timepoint {t_point}") stack_dat_path = os.path.join(site_supp_files_folder, 'stacks_%d.pkl' % t_point) if reload and os.path.exists(stack_dat_path): diff --git a/SingleCellPatch/instance_clustering.py b/SingleCellPatch/instance_clustering.py index f570b09..424d2a5 100644 --- a/SingleCellPatch/instance_clustering.py +++ b/SingleCellPatch/instance_clustering.py @@ -44,6 +44,11 @@ def check_segmentation_dim(segmentation): segmentation: (np.array): segmentation mask for the frame """ + # reshape if numpy file is too large, assume index=0 is redundant + if len(segmentation.shape) > 4: + shp = segmentation.shape + # skip index 1 which is blank + segmentation = segmentation.reshape((shp[1], shp[2], shp[3], shp[4])) assert len(segmentation.shape) == 4, "Semantic segmentation should be formatted with dimension (c, z, x, y)" n_channels, _, _, _ = segmentation.shape @@ -60,7 +65,8 @@ def instance_clustering(cell_segmentation, instance_map=True, map_path=None, fg_thr=0.3, - DBSCAN_thr=(10, 250)): + dbscan_thr=(10, 250), + channel=0): """ Perform instance clustering on a static frame Args: @@ -75,7 +81,7 @@ def instance_clustering(cell_segmentation, fg_thr (float, optional): threshold of foreground, any pixel with predicted background prob less than this value would be regarded as foreground (MG or Non-MG) - DBSCAN_thr (tuple, optional): parameters for DBSCAN, (eps, min_samples) + dbscan_thr (tuple, optional): parameters for DBSCAN, (eps, min_samples) Returns: (list * 3): 3 lists (MG, Non-MG, intermediate) of cell identifiers @@ -85,14 +91,14 @@ def instance_clustering(cell_segmentation, """ cell_segmentation = check_segmentation_dim(cell_segmentation) - all_cells = np.mean(cell_segmentation[0], axis=0) < fg_thr + all_cells = np.mean(cell_segmentation[channel], axis=0) < fg_thr positions = np.array(list(zip(*np.where(all_cells)))) if len(positions) < 1000: # No cell detected return [], np.zeros((0, 2), dtype=int), np.zeros((0,), dtype=int) # DBSCAN clustering of cell pixels - clustering = DBSCAN(eps=DBSCAN_thr[0], min_samples=DBSCAN_thr[1]).fit(positions) + clustering = DBSCAN(eps=dbscan_thr[0], min_samples=dbscan_thr[1]).fit(positions) positions_labels = clustering.labels_ cell_ids, point_cts = np.unique(positions_labels, return_counts=True) @@ -140,7 +146,7 @@ def instance_clustering(cell_segmentation, def process_site_instance_segmentation(raw_data, raw_data_segmented, site_supp_files_folder, - **kwargs): + config_): """ Wrapper method for instance segmentation @@ -153,27 +159,49 @@ def process_site_instance_segmentation(raw_data, :param raw_data: (str) path to image stack (.npy) :param raw_data_segmented: (str) path to semantic segmentation stack (.npy) :param site_supp_files_folder: (str) path to the folder where supplementary files will be saved - :param kwargs: + :param config_: config file parameters :return: """ + ct_thr = (config_.patch.count_threshold_low, config_.patch.count_threshold_high) + fg_thr = config_.patch.foreground_threshold + DBSCAN_thr = (config_.patch.dbscan_eps, config_.patch.dbscan_min_samples) + channel = config_.patch.channel + # TODO: Size is hardcoded here # Should be of size (n_frame, n_channels, z(1), x(2048), y(2048)), uint16 print(f"\tLoading {raw_data}") image_stack = np.load(raw_data) + print(f"\traw_data has shape {image_stack.shape}") # Should be of size (n_frame, n_classes, z(1), x(2048), y(2048)), float print(f"\tLoading {raw_data_segmented}") segmentation_stack = np.load(raw_data_segmented) + print(f"\tsegmentation stack has shape {segmentation_stack.shape}") + + # reshape if numpy file is too large, assume index=1 is redundant + if len(segmentation_stack.shape) > 4: + shp = segmentation_stack.shape + # skip index 1 which is blank + segmentation_stack = segmentation_stack.reshape((shp[0], shp[2], shp[3], shp[4], shp[5])) cell_positions = {} cell_pixel_assignments = {} - for t_point in range(image_stack.shape[0]): + + # if the number of timepoints between images and predictions mismatch, choose the smaller one + endpt = image_stack.shape[0] if image_stack.shape[0] < segmentation_stack.shape[0] else segmentation_stack.shape[0] + for t_point in range(endpt): print("\tClustering time %d" % t_point) cell_segmentation = segmentation_stack[t_point] instance_map_path = os.path.join(site_supp_files_folder, 'segmentation_%d.png' % t_point) - #TODO: expose instance clustering parameters in config - res = instance_clustering(cell_segmentation, instance_map=True, map_path=instance_map_path) - cell_positions[t_point] = res[0] # List of cell: (cell_id, mean_pos) + res = instance_clustering(cell_segmentation, + instance_map=True, + map_path=instance_map_path, + ct_thr=ct_thr, + fg_thr=fg_thr, + dbscan_thr=DBSCAN_thr, + channel=channel) + print(f"\tfound {len(res[0])} cells for timepoint {t_point}") + cell_positions[t_point] = res[0] # List of cell: (cell_id, mean_pos) cell_pixel_assignments[t_point] = res[1:] with open(os.path.join(site_supp_files_folder, 'cell_positions.pkl'), 'wb') as f: pickle.dump(cell_positions, f) diff --git a/configs/config_example.yml b/configs/config_example.yml index 2197770..32150a2 100644 --- a/configs/config_example.yml +++ b/configs/config_example.yml @@ -84,6 +84,18 @@ patch: # True to skip patches whose edges exceed the image boundaries # False to pad patches with mean background values + # instance segmentation parameters + channel: 0 + # segmentation output can produce multiple channels. Select which one to cluster on + foreground_threshold: 0.6 + # channel values under foreground_threshold will be considered + dbscan_eps: 10 + dbscan_min_samples: 250 + count_threshold_low: 100 + count_threshold_high: 12000 + + + latent_encoding: raw_dirs: [ diff --git a/configs/config_reader.py b/configs/config_reader.py index 0c35155..db32e6d 100644 --- a/configs/config_reader.py +++ b/configs/config_reader.py @@ -60,7 +60,15 @@ 'window_size', 'save_fig', 'reload', - 'skip_boundary' + 'skip_boundary', + + 'overwrite', + 'channel', + 'count_threshold_low', + 'count_threshold_high', + 'foreground_threshold', + 'dbscan_eps', + 'dbscan_min_samples' } # change this to "latent encoding" or similar diff --git a/pipeline/patch_VAE.py b/pipeline/patch_VAE.py index 33d917c..03df005 100644 --- a/pipeline/patch_VAE.py +++ b/pipeline/patch_VAE.py @@ -6,11 +6,15 @@ import matplotlib.pyplot as plt import importlib import inspect +import logging +log = logging.getLogger(__name__) + from configs.config_reader import YamlReader from torch.utils.data import TensorDataset from SingleCellPatch.extract_patches import process_site_extract_patches, im_adjust from SingleCellPatch.generate_trajectories import process_site_build_trajectory, process_well_generate_trajectory_relations +from SingleCellPatch.instance_clustering import process_site_instance_segmentation from pipeline.train_utils import zscore, zscore_patch import HiddenStateExtractor.vae as vae @@ -19,6 +23,66 @@ NETWORK_MODULE = 'run_training' + +def instance_segmentation(raw_folder: str, + supp_folder: str, + # val_folder: str, + sites: list, + config_: YamlReader, + rerun=False, + + **kwargs): + """ Helper function for instance segmentation + + Wrapper method `process_site_instance_segmentation` will be called, which + loads "*_NNProbabilities.npy" files and performs instance segmentation. + + Results will be saved in the supplementary data folder, including: + "cell_positions.pkl": dict of cells in each frame (IDs and positions); + "cell_pixel_assignments.pkl": dict of pixel compositions of cells + in each frame; + "segmentation_*.png": image of instance segmentation results. + + Args: + raw_folder (str): folder for raw data, segmentation and summarized results + supp_folder (str): folder for supplementary data + sites (list of str): list of site names + config_ (YamlReader): + rerun: + + """ + + for site in sites: + site_path = os.path.join(raw_folder, '%s.npy' % site) + site_segmentation_path = os.path.join(raw_folder, + '%s_NNProbabilities.npy' % site) + if not os.path.exists(site_path) or not os.path.exists(site_segmentation_path): + log.info("Site not found %s" % site_path) + continue + + log.info("Clustering %s" % site_path) + site_supp_files_folder = os.path.join(supp_folder, + '%s-supps' % site[:2], + '%s' % site) + + if os.path.exists(os.path.join(site_supp_files_folder, 'cell_pixel_assignments.pkl')) and not rerun: + log.info('Found previously saved instance clustering output in {}.' + .format(site_supp_files_folder)) + if config_.patch.overwrite: + print(f"\toverwriting ...") + else: + print(f"\tskip processing ...") + continue + elif not os.path.exists(site_supp_files_folder): + os.makedirs(site_supp_files_folder, exist_ok=True) + + process_site_instance_segmentation(site_path, + site_segmentation_path, + site_supp_files_folder, + config_) + return + + def extract_patches(raw_folder: str, supp_folder: str, # channels: list, @@ -59,6 +123,7 @@ def extract_patches(raw_folder: str, print("Site data not found %s" % site_segmentation_path, flush=True) if not os.path.exists(site_supp_files_folder): print("Site supp folder not found %s" % site_supp_files_folder, flush=True) + print(" ... did you remember to run 'instance segmentation'? ") else: print("Building patches %s" % site_path, flush=True) diff --git a/pipeline/preprocess.py b/pipeline/preprocess.py index 14e284c..ba2bb31 100644 --- a/pipeline/preprocess.py +++ b/pipeline/preprocess.py @@ -60,16 +60,27 @@ def load_raw(fullpaths: list, if not multipage: log.info(f"single-page tiffs specified") # load singlepage tiffs. String parse assuming time series and z### format + + if "RetardanceZavg" and "Retardance" in chans: + raise ValueError("only one of Retardance or RetardanceZavg can be used") + for chan in chans: # files maps (key:value) = (z_index, t_y_x array) # files = [] # for z in z_indicies: # files.append([c for c in sorted(os.listdir(fullpath)) if chan in c and f"z{z:03d}" in c]) # files = np.array(files).flatten() - files = [c for c in fullpaths if chan in c.split('/')[-1] and f"z{z_slice:03d}" in c.split('/')[-1]] - files = sorted(files) + if chan == "RetardanceZavg": + print('') + + # RetardanceZavg does not take z-slicing + if chan != "RetardanceZavg": + files = sorted([c for c in fullpaths if chan in c.split('/')[-1] and f"z{z_slice:03d}" in c.split('/')[-1]]) + else: + files = sorted([c for c in fullpaths if chan in c.split('/')[-1]]) + if not files: - log.warning(f"no files with {chan} identified") + log.warning(f"no files of any type with {chan} identified") continue # resulting shapes are in (t, y, x) order @@ -77,14 +88,18 @@ def load_raw(fullpaths: list, phase = np.stack([read_image(f) for f in files]) # phase = phase.reshape((len(z_indicies), -1, phase.shape[-2], phase.shape[-1])) shapes.append(phase.shape) - elif "Retardance" in chan: + elif "Retardance" == chan: ret = np.stack([read_image(f) for f in files]) # ret = ret.reshape((len(z_indicies), -1, ret.shape[-2], ret.shape[-1])) shapes.append(ret.shape) - elif "Brightfield" in chan: + elif "Brightfield" == chan: bf = np.stack([read_image(f) for f in files]) # bf = bf.reshape((len(z_indicies), -1, bf.shape[-2], bf.shape[-1])) shapes.append(bf.shape) + elif "RetardanceZavg" == chan: + ret = np.stack([read_image(f) for f in files]) + # ret = ret.reshape((len(z_indicies), -1, ret.shape[-2], ret.shape[-1])) + shapes.append(ret.shape) else: log.warning(f'not implemented: {chan} parse from single page files') @@ -92,8 +107,13 @@ def load_raw(fullpaths: list, log.info(f"multi-page tiffs specified") # load stabilized multipage tiffs. for chan in chans: - files = [c for c in fullpaths if chan in c.split('/')[-1] and '.tif' in c.split('/')[-1]] - files = sorted(files) + + # RetardanceZavg does not take z-slicing + if chan != "RetardanceZavg": + files = sorted([c for c in fullpaths if chan in c.split('/')[-1] and '.tif' in c.split('/')[-1]]) + else: + files = sorted([c for c in fullpaths if chan in c.split('/')[-1]]) + if not files: log.warning(f"no files with {chan} identified") continue @@ -107,18 +127,22 @@ def load_raw(fullpaths: list, flags=cv2.IMREAD_ANYDEPTH) phase = np.array(phase) shapes.append(phase.shape) - if "Retardance" in chan: + if "Retardance" == chan: # multi_tif_retard = 'img__Retardance__stabilized.tif' _, ret = cv2.imreadmulti(files[0], flags=cv2.IMREAD_ANYDEPTH) ret = np.array(ret) shapes.append(ret.shape) - if "Brightfield" in chan: + if "Brightfield" == chan: # multi_tif_bf = 'img_Brightfield_computed_stabilized.tif' _, bf = cv2.imreadmulti(files[0], flags=cv2.IMREAD_ANYDEPTH) bf = np.array(bf) shapes.append(bf.shape) + elif "RetardanceZavg" == chan: + ret = np.stack([read_image(f) for f in files]) + # ret = ret.reshape((len(z_indicies), -1, ret.shape[-2], ret.shape[-1])) + shapes.append(ret.shape) # check that all shapes are the same assert shapes.count(shapes[0]) == len(shapes) @@ -131,7 +155,7 @@ def load_raw(fullpaths: list, try: if "Phase" in chan: out[:, 0, 0] = phase - if "Retardance" in chan: + if "Retardance" == chan or "RetardanceZavg" == chan: out[:, 1, 0] = ret if "Brightfield" in chan: out[:, 2, 0] = bf diff --git a/run_patch.py b/run_patch.py index 036e271..67e0dc7 100644 --- a/run_patch.py +++ b/run_patch.py @@ -1,11 +1,13 @@ # bchhun, {2020-02-21} -from pipeline.patch_VAE import extract_patches, build_trajectories +from pipeline.patch_VAE import extract_patches, build_trajectories, instance_segmentation from multiprocessing import Pool, Queue, Process import os import numpy as np import argparse from configs.config_reader import YamlReader +import logging +log = logging.getLogger(__name__) class Worker(Process): @@ -16,7 +18,10 @@ def __init__(self, inputs, cpu_id=0, method='extract_patches'): self.method = method def run(self): - if self.method == 'extract_patches': + if self.method == 'instance_segmentation': + log.info(f"running instance segmentation") + instance_segmentation(*self.inputs) + elif self.method == 'extract_patches': extract_patches(*self.inputs) elif self.method == 'build_trajectories': build_trajectories(*self.inputs) @@ -44,6 +49,13 @@ def main(method_, raw_dir_, supp_dir_, config_): if not supp: raise AttributeError("supplementary directory must be specified when method = extract_patches") + # instance segmentation requires raw (stack, NNprob), supp (to write outputs) to be defined + elif method == 'instance_segmentation': + pass + + else: + raise AttributeError(f"method flag {method} not implemented") + if fov: sites = fov else: @@ -51,13 +63,14 @@ def main(method_, raw_dir_, supp_dir_, config_): img_names = [file for file in os.listdir(raw) if (file.endswith(".npy")) & ('_NN' not in file)] sites = [os.path.splitext(img_name)[0] for img_name in img_names] sites = list(set(sites)) + # if probabilities and formatted stack exist - segment_sites = [site for site in sites if os.path.exists(os.path.join(raw, "%s.npy" % site)) and \ + segment_sites = [site for site in sites if os.path.exists(os.path.join(raw, "%s.npy" % site)) and os.path.exists(os.path.join(raw, "%s_NNProbabilities.npy" % site))] if len(segment_sites) == 0: raise AttributeError("no sites found in raw directory with preprocessed data and matching NNProbabilities") - # process each site on a different GPU if using multi-gpu + # process each site on a different cpu if using multi-cpu sep = np.linspace(0, len(segment_sites), n_cpus + 1).astype(int) # TARGET is never used in either extract_patches or build_trajectory @@ -84,7 +97,7 @@ def parse_args(): '-m', '--method', type=str, required=False, - choices=['extract_patches', 'build_trajectories'], + choices=['extract_patches', 'build_trajectories', 'instance_segmentation'], default='extract_patches', help="Method: one of 'extract_patches', 'build_trajectories'", ) diff --git a/run_segmentation.py b/run_segmentation.py index 4715d4e..feb167a 100644 --- a/run_segmentation.py +++ b/run_segmentation.py @@ -1,6 +1,6 @@ # bchhun, {2020-02-21} -from pipeline.segmentation import segmentation, instance_segmentation +from pipeline.segmentation import segmentation # instance_segmentation from pipeline.segmentation_validation import segmentation_validation_michael from multiprocessing import Process import os @@ -26,9 +26,9 @@ def run(self): if self.method == 'segmentation': log.info(f"running segmentation worker on {self.gpuid}") segmentation(*self.inputs) - elif self.method == 'instance_segmentation': - log.info(f"running instance segmentation") - instance_segmentation(*self.inputs) + # elif self.method == 'instance_segmentation': + # log.info(f"running instance segmentation") + # instance_segmentation(*self.inputs) elif self.method == 'segmentation_validation': segmentation_validation_michael(*self.inputs) @@ -55,11 +55,11 @@ def main(method_, raw_dir_, supp_dir_, val_dir_, config_): if config_.segmentation.inference.weights is None: raise AttributeError("Weights supp_dir must be specified when method=segmentation") - # instance segmentation requires raw (stack, NNprob), supp (to write outputs) to be defined - elif method == 'instance_segmentation': - TARGET = '' - else: - raise AttributeError(f"method flag {method} not implemented") + # # instance segmentation requires raw (stack, NNprob), supp (to write outputs) to be defined + # elif method == 'instance_segmentation': + # TARGET = '' + # else: + # raise AttributeError(f"method flag {method} not implemented") # all methods all require if config_.segmentation.inference.fov: