Skip to content
This repository was archived by the owner on Apr 25, 2023. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
d1629c6
scripts and modifications to train new cell type segmentations from f…
bryantChhun Aug 17, 2021
2d95acf
segmentation training using horovod. custom script for local training
bryantChhun Aug 31, 2021
1b67722
custom loss functions compatible with segmentaiton_models repo
bryantChhun Aug 31, 2021
c0b0e75
segmentation model adjustment: allow fully conv predictions, use of c…
bryantChhun Aug 31, 2021
7d366c4
input data range scaling and clipping
bryantChhun Aug 31, 2021
e82a2b3
enable RetardanceZavg in preprocessing
bryantChhun Aug 31, 2021
3de7632
added warning about patch before instance segmentation
bryantChhun Aug 31, 2021
ed150d9
moved instance segmentation methods from "segmentation" to "patch_VAE"
bryantChhun Aug 31, 2021
ae376de
moved CLI for instance segmentation from "segmentation" to "extract p…
bryantChhun Aug 31, 2021
011c6ea
modified config reader to reflect instance seg changes
bryantChhun Aug 31, 2021
31642a1
exposed instance segmentation parameters in the config file
bryantChhun Aug 31, 2021
b15f05d
added instance seg as CLI choice
bryantChhun Aug 31, 2021
ffd5f7a
file name bugfix
bryantChhun Aug 31, 2021
6ee53ab
added channel to select which NNProb channel to segment on.
bryantChhun Aug 31, 2021
d378e68
removed some horovod training scripts that will go in another branch
bryantChhun Aug 31, 2021
d95804e
revert back to master for segmentation scripts
bryantChhun Aug 31, 2021
8e59988
fix bug in channel slicing for patch extraction
bryantChhun Sep 1, 2021
d4cf519
ignore a dimension if segmentation shape is unusual
bryantChhun Sep 1, 2021
6cea870
added option to overwite instance segmentation results
bryantChhun Sep 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions NNsegmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,8 @@ def predict_whole_map(file_path,
outputs = []
for r in range(rows - 1):
for c in range(columns - 1):
patch_inp = inp[...,
(x_offset + r*x_size):(x_offset + (r+1)*x_size),
patch_inp = inp[...,
(x_offset + r*x_size):(x_offset + (r+1)*x_size),
(y_offset + c*y_size):(y_offset + (c+1)*y_size)]
if time_slices == 1:
patch_inp = patch_inp[0]
Expand Down
12 changes: 6 additions & 6 deletions NNsegmentation/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class weighted_binary_cross_entropy(object):
def __init__(self, n_classes=2):
self.n_classes = n_classes
self.__name__ = "weighted_binary_cross_entropy"

def __call__(self, y_true, y_pred):
"""
Args:
Expand All @@ -101,16 +101,16 @@ def __call__(self, y_true, y_pred):
last slice of the last dimension is weight
y_pred (tensor): in shape (batch_size, x_size, y_size, n_classes)
model predictions

"""

w = y_true[:, -1]
y_true = y_true[:, :-1]

# Switch to channel last form
y_true = keras.backend.permute_dimensions(y_true, (0, 2, 3, 1))
y_pred = keras.backend.permute_dimensions(y_pred, (0, 2, 3, 1))

loss = keras.backend.categorical_crossentropy(y_true, y_pred, from_logits=True) * w
return loss

Expand All @@ -120,7 +120,7 @@ class ValidMetrics(keras.callbacks.Callback):

Calculate ROC-AUC and F1 on validation data and test data (if applicable)
after each epoch

"""

def __init__(self, valid_data=None, test_data=None):
Expand All @@ -141,5 +141,5 @@ def on_epoch_end(self, epoch, logs={}):
f1 = f1_score(y_true.flatten(), y_pred.flatten()>0.5)
print('\r test-roc-auc: %f test-f1: %f\n' % (roc, f1))
return


30 changes: 15 additions & 15 deletions NNsegmentation/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(self,
self.valid_score_callback = ValidMetrics()
self.loss_func = weighted_binary_cross_entropy(n_classes=self.n_classes)
self.build_model()


def build_model(self):
""" Define model structure and compile """
Expand All @@ -77,7 +77,7 @@ def build_model(self):
self.pre_conv = keras.layers.Conv2D(3, (1, 1), activation=None, name='pre_conv')(self.input)

self.unet = segmentation_models.Unet(
backbone_name='resnet34',
backbone_name='resnet34',
input_shape=(3, self.x_size, self.y_size),
classes=self.n_classes,
activation='linear',
Expand All @@ -86,11 +86,11 @@ def build_model(self):
decoder_block_type='upsampling',
decoder_filters=(256, 128, 64, 32, 16),
decoder_use_batchnorm=True)

output = self.unet(self.pre_conv)

self.model = keras.models.Model(self.input, output)
self.model.compile(optimizer='Adam',
self.model.compile(optimizer='Adam',
loss=self.loss_func,
metrics=[])

Expand Down Expand Up @@ -127,14 +127,14 @@ def fit(self,
os.mkdir(self.model_path)
# `X` and `y` should originally be 5 dimensional: (batch, c, z, x, y),
# in default model z=1 will be neglected
X, y = preprocess(patches,
n_classes=self.n_classes,
label_input=label_input,
X, y = preprocess(patches,
n_classes=self.n_classes,
label_input=label_input,
class_weights=class_weights)
X = X.reshape(self.batch_input_shape)
y = y.reshape(self.batch_label_shape)
assert X.shape[0] == y.shape[0]

validation_data = None
if valid_patches is not None:
valid_X, valid_y = preprocess(valid_patches,
Expand All @@ -145,7 +145,7 @@ def fit(self,
assert valid_X.shape[0] == valid_y.shape[0]
self.valid_score_callback.valid_data = (valid_X, valid_y)
validation_data = (valid_X, valid_y)

self.model.fit(x=X,
y=y,
batch_size=batch_size,
Expand Down Expand Up @@ -212,13 +212,13 @@ def __init__(self,
""" Define model

Args:
unet_feat (int, optional): output dimension of unet (used as
unet_feat (int, optional): output dimension of unet (used as
hidden units)
**kwargs: keyword arguments for `Segment`
note that `input_shape` should have 4 dimensions

"""

self.unet_feat = unet_feat
super(SegmentWithMultipleSlice, self).__init__(**kwargs)
self.n_slices = self.input_shape[1] # Input shape (c, z, x, y)
Expand All @@ -233,9 +233,9 @@ def build_model(self):
# Combine time slice dimension and batch dimension
inp = SplitSlice(self.n_channels, self.x_size, self.y_size)(self.input)
self.pre_conv = keras.layers.Conv2D(3, (1, 1), activation=None, name='pre_conv')(inp)

self.unet = segmentation_models.Unet(
backbone_name='resnet34',
backbone_name='resnet34',
input_shape=(3, self.x_size, self.y_size),
classes=self.unet_feat,
activation='linear',
Expand All @@ -251,7 +251,7 @@ def build_model(self):
output = MergeSlices(self.n_slices, self.unet_feat)(output)
output = keras.layers.Conv2D(self.unet_feat, (1, 1), activation='relu', name='post_conv')(output)
output = keras.layers.Conv2D(self.n_classes, (1, 1), activation=None, name='pred_head')(output)

self.model = keras.models.Model(self.input, output)
self.model.compile(optimizer='Adam',
loss=self.loss_func,
Expand Down
12 changes: 8 additions & 4 deletions SingleCellPatch/extract_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,16 +189,20 @@ def process_site_extract_patches(site_path,
# Load data
image_stack = np.load(site_path)
if channels is None:
channels = list(range(len(image_stack)))
image_stack = image_stack[channels]
channels = list(range(len(image_stack[0])))
image_stack = image_stack[:, channels]
segmentation_stack = np.load(site_segmentation_path)
with open(os.path.join(site_supp_files_folder, 'cell_positions.pkl'), 'rb') as f:
cell_positions = pickle.load(f)
with open(os.path.join(site_supp_files_folder, 'cell_pixel_assignments.pkl'), 'rb') as f:
cell_pixel_assignments = pickle.load(f)

n_frames, n_channels, n_z, x_full_size, y_full_size = image_stack.shape
for t_point in range(n_frames):

# if the number of timepoints between images and predictions mismatch, choose the smaller one
endpt = image_stack.shape[0] if image_stack.shape[0] < segmentation_stack.shape[0] else segmentation_stack.shape[0]
print(f"total timepoints to process = {endpt}")
for t_point in range(endpt):
# for t_point in range(n_frames):
print(f"processing timepoint {t_point}")
stack_dat_path = os.path.join(site_supp_files_folder, 'stacks_%d.pkl' % t_point)
if reload and os.path.exists(stack_dat_path):
Expand Down
48 changes: 38 additions & 10 deletions SingleCellPatch/instance_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def check_segmentation_dim(segmentation):
segmentation: (np.array): segmentation mask for the frame

"""
# reshape if numpy file is too large, assume index=0 is redundant
if len(segmentation.shape) > 4:
shp = segmentation.shape
# skip index 1 which is blank
segmentation = segmentation.reshape((shp[1], shp[2], shp[3], shp[4]))

assert len(segmentation.shape) == 4, "Semantic segmentation should be formatted with dimension (c, z, x, y)"
n_channels, _, _, _ = segmentation.shape
Expand All @@ -60,7 +65,8 @@ def instance_clustering(cell_segmentation,
instance_map=True,
map_path=None,
fg_thr=0.3,
DBSCAN_thr=(10, 250)):
dbscan_thr=(10, 250),
channel=0):
""" Perform instance clustering on a static frame

Args:
Expand All @@ -75,7 +81,7 @@ def instance_clustering(cell_segmentation,
fg_thr (float, optional): threshold of foreground, any pixel with
predicted background prob less than this value would be regarded as
foreground (MG or Non-MG)
DBSCAN_thr (tuple, optional): parameters for DBSCAN, (eps, min_samples)
dbscan_thr (tuple, optional): parameters for DBSCAN, (eps, min_samples)

Returns:
(list * 3): 3 lists (MG, Non-MG, intermediate) of cell identifiers
Expand All @@ -85,14 +91,14 @@ def instance_clustering(cell_segmentation,

"""
cell_segmentation = check_segmentation_dim(cell_segmentation)
all_cells = np.mean(cell_segmentation[0], axis=0) < fg_thr
all_cells = np.mean(cell_segmentation[channel], axis=0) < fg_thr
positions = np.array(list(zip(*np.where(all_cells))))
if len(positions) < 1000:
# No cell detected
return [], np.zeros((0, 2), dtype=int), np.zeros((0,), dtype=int)

# DBSCAN clustering of cell pixels
clustering = DBSCAN(eps=DBSCAN_thr[0], min_samples=DBSCAN_thr[1]).fit(positions)
clustering = DBSCAN(eps=dbscan_thr[0], min_samples=dbscan_thr[1]).fit(positions)
positions_labels = clustering.labels_
cell_ids, point_cts = np.unique(positions_labels, return_counts=True)

Expand Down Expand Up @@ -140,7 +146,7 @@ def instance_clustering(cell_segmentation,
def process_site_instance_segmentation(raw_data,
raw_data_segmented,
site_supp_files_folder,
**kwargs):
config_):
"""
Wrapper method for instance segmentation

Expand All @@ -153,27 +159,49 @@ def process_site_instance_segmentation(raw_data,
:param raw_data: (str) path to image stack (.npy)
:param raw_data_segmented: (str) path to semantic segmentation stack (.npy)
:param site_supp_files_folder: (str) path to the folder where supplementary files will be saved
:param kwargs:
:param config_: config file parameters
:return:
"""

ct_thr = (config_.patch.count_threshold_low, config_.patch.count_threshold_high)
fg_thr = config_.patch.foreground_threshold
DBSCAN_thr = (config_.patch.dbscan_eps, config_.patch.dbscan_min_samples)
channel = config_.patch.channel

# TODO: Size is hardcoded here
# Should be of size (n_frame, n_channels, z(1), x(2048), y(2048)), uint16
print(f"\tLoading {raw_data}")
image_stack = np.load(raw_data)
print(f"\traw_data has shape {image_stack.shape}")
# Should be of size (n_frame, n_classes, z(1), x(2048), y(2048)), float
print(f"\tLoading {raw_data_segmented}")
segmentation_stack = np.load(raw_data_segmented)
print(f"\tsegmentation stack has shape {segmentation_stack.shape}")

# reshape if numpy file is too large, assume index=1 is redundant
if len(segmentation_stack.shape) > 4:
shp = segmentation_stack.shape
# skip index 1 which is blank
segmentation_stack = segmentation_stack.reshape((shp[0], shp[2], shp[3], shp[4], shp[5]))

cell_positions = {}
cell_pixel_assignments = {}
for t_point in range(image_stack.shape[0]):

# if the number of timepoints between images and predictions mismatch, choose the smaller one
endpt = image_stack.shape[0] if image_stack.shape[0] < segmentation_stack.shape[0] else segmentation_stack.shape[0]
for t_point in range(endpt):
print("\tClustering time %d" % t_point)
cell_segmentation = segmentation_stack[t_point]
instance_map_path = os.path.join(site_supp_files_folder, 'segmentation_%d.png' % t_point)
#TODO: expose instance clustering parameters in config
res = instance_clustering(cell_segmentation, instance_map=True, map_path=instance_map_path)
cell_positions[t_point] = res[0] # List of cell: (cell_id, mean_pos)
res = instance_clustering(cell_segmentation,
instance_map=True,
map_path=instance_map_path,
ct_thr=ct_thr,
fg_thr=fg_thr,
dbscan_thr=DBSCAN_thr,
channel=channel)
print(f"\tfound {len(res[0])} cells for timepoint {t_point}")
cell_positions[t_point] = res[0] # List of cell: (cell_id, mean_pos)
cell_pixel_assignments[t_point] = res[1:]
with open(os.path.join(site_supp_files_folder, 'cell_positions.pkl'), 'wb') as f:
pickle.dump(cell_positions, f)
Expand Down
12 changes: 12 additions & 0 deletions configs/config_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,18 @@ patch:
# True to skip patches whose edges exceed the image boundaries
# False to pad patches with mean background values

# instance segmentation parameters
channel: 0
# segmentation output can produce multiple channels. Select which one to cluster on
foreground_threshold: 0.6
# channel values under foreground_threshold will be considered
dbscan_eps: 10
dbscan_min_samples: 250
count_threshold_low: 100
count_threshold_high: 12000




latent_encoding:
raw_dirs: [
Expand Down
10 changes: 9 additions & 1 deletion configs/config_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,15 @@
'window_size',
'save_fig',
'reload',
'skip_boundary'
'skip_boundary',

'overwrite',
'channel',
'count_threshold_low',
'count_threshold_high',
'foreground_threshold',
'dbscan_eps',
'dbscan_min_samples'
}

# change this to "latent encoding" or similar
Expand Down
Loading