diff --git a/lungmask/mask.py b/lungmask/mask.py index 90df7c2..73ea08e 100644 --- a/lungmask/mask.py +++ b/lungmask/mask.py @@ -50,9 +50,6 @@ def apply( else: directions = np.asarray([1, 0, 0, 0, 1, 0, 0, 0, 1]) - if model is None: - model = get_model("unet", "R231") - if force_cpu: device = torch.device("cpu") else: @@ -62,7 +59,10 @@ def apply( logging.info("No GPU support available, will use CPU. Note, that this is significantly slower!") batch_size = 1 device = torch.device("cpu") - model.to(device) + + if model is None: + model = get_model("unet", "R231") + model.to(device) if not noHU: tvolslices, xnew_box = utils.preprocess(inimg_raw, resolution=[256, 256]) @@ -117,7 +117,7 @@ def apply( if len(directions) == 9: outmask = np.flip(outmask, np.where(directions[[0, 4, 8]][::-1] < 0)[0]) - return outmask.astype(np.uint8) + return outmask def get_model(modeltype, modelname): diff --git a/lungmask/utils.py b/lungmask/utils.py index 2d36506..a319fb3 100644 --- a/lungmask/utils.py +++ b/lungmask/utils.py @@ -13,33 +13,32 @@ import skimage.morphology from torch.utils.data import Dataset from tqdm import tqdm +import pandas as pd + ORDER2OCVINTER = {0: cv2.INTER_NEAREST, 1: cv2.INTER_LINEAR, 2: cv2.INTER_AREA, 3: cv2.INTER_CUBIC} -def preprocess(img, label=None, resolution=[192, 192]): - imgmtx = np.copy(img) - lblsmtx = np.copy(label) +def preprocess( + imgmtx: np.ndarray, resolution: list = [192, 192] +) -> Tuple[np.ndarray, np.ndarray]: + """Preprocesses the image by clipping, cropping and resizing. Clipping at -1024 and 600 HU, cropping to the body + + Args: + imgmtx (np.ndarray): Image to be preprocessed + resolution (list, optional): Target size after preprocessing. Defaults to [192, 192]. - imgmtx[imgmtx < -1024] = -1024 - imgmtx[imgmtx > 600] = 600 + Returns: + Tuple[np.ndarray, np.ndarray]: Preprocessed image and the cropping bounding box + """ + imgmtx = np.clip(imgmtx, -1024, 600) cip_xnew = [] cip_box = [] - cip_mask = [] - for i in range(imgmtx.shape[0]): - if label is None: - (im, m, box) = crop_and_resize(imgmtx[i, :, :], width=resolution[0], height=resolution[1]) - else: - (im, m, box) = crop_and_resize( - imgmtx[i, :, :], mask=lblsmtx[i, :, :], width=resolution[0], height=resolution[1] - ) - cip_mask.append(m) + for imslice in imgmtx: + im, _, box = crop_and_resize(imslice, width=resolution[0], height=resolution[1]) cip_xnew.append(im) cip_box.append(box) - if label is None: - return np.asarray(cip_xnew), cip_box - else: - return np.asarray(cip_xnew), cip_box, np.asarray(cip_mask) + return np.asarray(cip_xnew), cip_box def simple_bodymask(img): @@ -210,6 +209,21 @@ def get_input_image(path): return input_image +def speedup_numpy_unique(array: np.ndarray, return_counts: bool = False): + """ + Alternative accelerated version of numpy.unique for integer-based arrays + """ + + if return_counts is True: + counts = np.bincount(array.ravel()) + unique = np.where(counts != 0)[0] + counts = counts[counts != 0] + return unique, counts + else: + unique = np.sort(pd.unique(array.ravel())) + return unique + + def postrocessing(label_image, spare=[], verbose=True): """some post-processing mapping small label patches to the neighbout whith which they share the largest border. All connected components smaller than min_area will be removed @@ -217,7 +231,7 @@ def postrocessing(label_image, spare=[], verbose=True): # merge small components to neighbours regionmask = skimage.measure.label(label_image) - origlabels = np.unique(label_image) + origlabels = speedup_numpy_unique(label_image) origlabels_maxsub = np.zeros((max(origlabels) + 1,), dtype=np.uint32) # will hold the largest component for a label regions = skimage.measure.regionprops(regionmask, label_image) regions.sort(key=lambda x: x.area) @@ -237,7 +251,7 @@ def postrocessing(label_image, spare=[], verbose=True): bb = bbox_3D(regionmask == r.label) sub = regionmask[bb[0] : bb[1], bb[2] : bb[3], bb[4] : bb[5]] dil = ndimage.binary_dilation(sub == r.label) - neighbours, counts = np.unique(sub[dil], return_counts=True) + neighbours, counts = speedup_numpy_unique(sub[dil], return_counts=True) mapto = r.label maxmap = 0 myarea = 0 @@ -265,7 +279,7 @@ def postrocessing(label_image, spare=[], verbose=True): holefiller = fill_voids.fill outmask = np.zeros(outmask_mapped.shape, dtype=np.uint8) - for i in np.unique(outmask_mapped)[1:]: + for i in speedup_numpy_unique(outmask_mapped)[1:]: outmask[holefiller(keep_largest_connected_component(outmask_mapped == i))] = i return outmask @@ -335,7 +349,7 @@ def cv2_zoom( out_shape = tuple((np.asarray(img.shape[:2]) * np.asarray(scale)).round().astype(int)[::-1]) if pseudo_linear: - uniques = np.unique(img) + uniques = speedup_numpy_unique(img) out_shape_with_channels = list(out_shape[::-1]) + list(img.shape[2:]) out_img = np.zeros(out_shape_with_channels, dtype=img.dtype) for value in uniques[uniques != 0]: diff --git a/requirements.txt b/requirements.txt index e2679cc..ff748b7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ SimpleITK==1.2.4 skimage==0.0 fill_voids opencv-python-headless==4.5.3.56 +pandas \ No newline at end of file