From ff4cb6917ac641584840b07648308ff4aed1a883 Mon Sep 17 00:00:00 2001 From: Seongmin Choi Date: Wed, 13 Nov 2019 12:31:07 +0900 Subject: [PATCH 1/6] added dice coefficient loss the VNet study. --- keras_contrib/losses/__init__.py | 1 + keras_contrib/losses/dice.py | 36 ++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) create mode 100644 keras_contrib/losses/dice.py diff --git a/keras_contrib/losses/__init__.py b/keras_contrib/losses/__init__.py index 37a47a804..d90790c8a 100644 --- a/keras_contrib/losses/__init__.py +++ b/keras_contrib/losses/__init__.py @@ -1,3 +1,4 @@ from .dssim import DSSIMObjective from .jaccard import jaccard_distance from .crf_losses import crf_loss, crf_nll +from .dice import dice_loss diff --git a/keras_contrib/losses/dice.py b/keras_contrib/losses/dice.py new file mode 100644 index 000000000..22e024442 --- /dev/null +++ b/keras_contrib/losses/dice.py @@ -0,0 +1,36 @@ +from keras import backend as K + + +def dice_loss(y_true, y_pred, smooth=1): + """Dice similarity coefficient (DSC) loss. + + Essentially 1 minus the Dice similarity coefficient (DSC). Here, the Dice + similarity coefficient is used as a metric to evaluate the performance of + image segmentation by comparing spatial overlap between the true and predicted + spaces. + + A smoothing factor, which is by default 1, is applied to avoid dividing by + zeros. + + Dice loss = 1 - (2 * |X & Y|)/ (X^2 + Y^2) + = 1 - 2 * sum(A*B) / sum(A^2 + B^2) + + # Arguments + y_true: The ground truth tensor. + y_pred: The predicted tensor + smooth: Smoothing factor. Default is 1. + + # Returns + The Dice coefficiet loss between the two tensors. + + # References + - [V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image + Segmentation](https://arxiv.org/pdf/1606.04797.pdf) + """ + y_true_flat, y_pred_flat = K.flatten(y_true), K.flatten(y_pred) + dice_nom = 2 * K.sum(y_true_flat * y_pred_flat) + dice_denom = K.sum(K.square(y_true_flat) + K.square(y_pred_flat)) + dice_coef = (dice_nom + smooth) / (dice_denom + smooth) + + return 1 - dice_coef + From 17dfe93eb092a53af07ee51abdd2c3cb8185db52 Mon Sep 17 00:00:00 2001 From: Seongmin Choi Date: Wed, 13 Nov 2019 14:05:39 +0900 Subject: [PATCH 2/6] added pytest functionality for dice_loss --- contrib_docs/_build/pydocmd/losses.md | 154 ++++++++++++++++++++++++ contrib_docs/pydocmd.yml | 1 + tests/keras_contrib/losses/dice_test.py | 39 ++++++ 3 files changed, 194 insertions(+) create mode 100644 contrib_docs/_build/pydocmd/losses.md create mode 100644 tests/keras_contrib/losses/dice_test.py diff --git a/contrib_docs/_build/pydocmd/losses.md b/contrib_docs/_build/pydocmd/losses.md new file mode 100644 index 000000000..65f58d442 --- /dev/null +++ b/contrib_docs/_build/pydocmd/losses.md @@ -0,0 +1,154 @@ +

keras_contrib.losses

+ + +

DSSIMObjective

+ +```python +DSSIMObjective(self, k1=0.01, k2=0.03, kernel_size=3, max_value=1.0) +``` +Difference of Structural Similarity (DSSIM loss function). +Clipped between 0 and 0.5 + +Note : You should add a regularization term like a l2 loss in addition to this one. +Note : In theano, the `kernel_size` must be a factor of the output size. So 3 could + not be the `kernel_size` for an output of 32. + +__Arguments__ + +- __k1__: Parameter of the SSIM (default 0.01) +- __k2__: Parameter of the SSIM (default 0.03) +- __kernel_size__: Size of the sliding window (default 3) +- __max_value__: Max value of the output (default 1.0) + +

jaccard_distance

+ +```python +jaccard_distance(y_true, y_pred, smooth=100) +``` +Jaccard distance for semantic segmentation. + +Also known as the intersection-over-union loss. + +This loss is useful when you have unbalanced numbers of pixels within an image +because it gives all classes equal weight. However, it is not the defacto +standard for image segmentation. + +For example, assume you are trying to predict if +each pixel is cat, dog, or background. +You have 80% background pixels, 10% dog, and 10% cat. +If the model predicts 100% background +should it be be 80% right (as with categorical cross entropy) +or 30% (with this loss)? + +The loss has been modified to have a smooth gradient as it converges on zero. +This has been shifted so it converges on 0 and is smoothed to avoid exploding +or disappearing gradient. + +Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|) + = sum(|A*B|)/(sum(|A|)+sum(|B|)-sum(|A*B|)) + +__Arguments__ + +- __y_true__: The ground truth tensor. +- __y_pred__: The predicted tensor +- __smooth__: Smoothing factor. Default is 100. + +__Returns__ + + The Jaccard distance between the two tensors. + +__References__ + + - [What is a good evaluation measure for semantic segmentation?]( + http://www.bmva.org/bmvc/2013/Papers/paper0032/paper0032.pdf) + + +

crf_loss

+ +```python +crf_loss(y_true, y_pred) +``` +General CRF loss function depending on the learning mode. + +__Arguments__ + +- __y_true__: tensor with true targets. +- __y_pred__: tensor with predicted targets. + +__Returns__ + + If the CRF layer is being trained in the join mode, returns the negative + log-likelihood. Otherwise returns the categorical crossentropy implemented + by the underlying Keras backend. + +__About GitHub__ + + If you open an issue or a pull request about CRF, please + add `cc @lzfelix` to notify Luiz Felix. + +

crf_nll

+ +```python +crf_nll(y_true, y_pred) +``` +The negative log-likelihood for linear chain Conditional Random Field (CRF). + +This loss function is only used when the `layers.CRF` layer +is trained in the "join" mode. + +__Arguments__ + +- __y_true__: tensor with true targets. +- __y_pred__: tensor with predicted targets. + +__Returns__ + + A scalar representing corresponding to the negative log-likelihood. + +__Raises__ + +- `TypeError`: If CRF is not the last layer. + +__About GitHub__ + + If you open an issue or a pull request about CRF, please + add `cc @lzfelix` to notify Luiz Felix. + +

dice_loss

+ +```python +dice_loss(y_true, y_pred, smooth=1) +``` +Dice similarity coefficient (DSC) loss. + +Essentially 1 minus the Dice similarity coefficient (DSC). Here, the Dice +similarity coefficient is used as a metric to evaluate the performance of +image segmentation by comparing spatial overlap between the true and predicted +spaces. + +A smoothing factor, which is by default 1, is applied to avoid dividing by +zeros. + +Dice loss = 1 - (2 * |X & Y|)/ (X^2 + Y^2) + = 1 - 2 * sum(A*B) / sum(A^2 + B^2) + +__Arguments__ + +- __y_true__: The ground truth tensor. +- __y_pred__: The predicted tensor +- __smooth__: Smoothing factor. Default is 1. + +__Returns__ + + The Dice coefficiet loss between the two tensors. + +__References__ + + - [V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation]( + https://arxiv.org/pdf/1606.04797.pdf) + +__About GitHub__ + + If you open an issue or a pull request about CRF, please + add `cc @alexbmp` to notify Seongmin Choi. + diff --git a/contrib_docs/pydocmd.yml b/contrib_docs/pydocmd.yml index eca7ba534..67b7524fd 100644 --- a/contrib_docs/pydocmd.yml +++ b/contrib_docs/pydocmd.yml @@ -23,6 +23,7 @@ generate: - keras_contrib.losses.jaccard_distance - keras_contrib.losses.crf_loss - keras_contrib.losses.crf_nll + - keras_contrib.losses.dice_loss - optimizers.md: - keras_contrib.optimizers: - keras_contrib.optimizers.FTML diff --git a/tests/keras_contrib/losses/dice_test.py b/tests/keras_contrib/losses/dice_test.py new file mode 100644 index 000000000..8c3fc3e5a --- /dev/null +++ b/tests/keras_contrib/losses/dice_test.py @@ -0,0 +1,39 @@ +import pytest + +from keras_contrib.losses import dice_loss +from keras_contrib.utils.test_utils import is_tf_keras +from keras import backend as K +import numpy as np + + +def test_dice_loss_shapes_scalar(): + y_true = np.random.randn(3, 4) + y_pred = np.random.randn(3, 4) + + L = dice_loss( + K.variable(y_true), + K.variable(y_pred), ) + assert K.is_tensor(L), 'should be a Tensor' + assert L.shape == () + assert K.eval(L).shape == () + + +def test_dice_loss_for_same_array(): + y_true = np.random.randn(3, 4) + y_pred = y_true.copy() + + L = dice_loss( + K.variable(y_true), + K.variable(y_pred), ) + assert K.eval(L) == 0, 'loss should be zero' + + +def test_dice_loss_for_zero_array(): + y_true = np.array([1]) + y_pred = np.array([0]) + + L = dice_loss( + K.variable(y_true), + K.variable(y_pred), ) + assert K.eval(L) == 0.5, 'loss should equal 0.5' + From dc39fbe175279fa0fedaff74ca5fb3cc1c2968b7 Mon Sep 17 00:00:00 2001 From: Seongmin Choi Date: Wed, 13 Nov 2019 14:12:37 +0900 Subject: [PATCH 3/6] typo fix --- contrib_docs/_build/pydocmd/losses.md | 7 +++---- keras_contrib/losses/dice.py | 5 +++++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/contrib_docs/_build/pydocmd/losses.md b/contrib_docs/_build/pydocmd/losses.md index 65f58d442..027d013de 100644 --- a/contrib_docs/_build/pydocmd/losses.md +++ b/contrib_docs/_build/pydocmd/losses.md @@ -144,11 +144,10 @@ __Returns__ __References__ - - [V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation]( - https://arxiv.org/pdf/1606.04797.pdf) + - [V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image + Segmentation](https://arxiv.org/pdf/1606.04797.pdf) __About GitHub__ - If you open an issue or a pull request about CRF, please + If you open an issue or a pull request about Dice loss, please add `cc @alexbmp` to notify Seongmin Choi. - diff --git a/keras_contrib/losses/dice.py b/keras_contrib/losses/dice.py index 22e024442..db467aebd 100644 --- a/keras_contrib/losses/dice.py +++ b/keras_contrib/losses/dice.py @@ -26,6 +26,11 @@ def dice_loss(y_true, y_pred, smooth=1): # References - [V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation](https://arxiv.org/pdf/1606.04797.pdf) + + # About GitHub + If you open an issue or a pull request about Dice loss, please + add `cc @alexbmp` to notify Seongmin Choi. + """ y_true_flat, y_pred_flat = K.flatten(y_true), K.flatten(y_pred) dice_nom = 2 * K.sum(y_true_flat * y_pred_flat) From 76ed324e49252b980554dd28315cf8151c3c7c0c Mon Sep 17 00:00:00 2001 From: Seongmin Choi Date: Wed, 13 Nov 2019 14:18:11 +0900 Subject: [PATCH 4/6] self install test done --- contrib_docs/_build/pydocmd/losses.md | 2 ++ contrib_docs/mkdocs.yml | 15 +++++++++++++++ 2 files changed, 17 insertions(+) create mode 100644 contrib_docs/mkdocs.yml diff --git a/contrib_docs/_build/pydocmd/losses.md b/contrib_docs/_build/pydocmd/losses.md index 027d013de..942beda20 100644 --- a/contrib_docs/_build/pydocmd/losses.md +++ b/contrib_docs/_build/pydocmd/losses.md @@ -151,3 +151,5 @@ __About GitHub__ If you open an issue or a pull request about Dice loss, please add `cc @alexbmp` to notify Seongmin Choi. + + diff --git a/contrib_docs/mkdocs.yml b/contrib_docs/mkdocs.yml new file mode 100644 index 000000000..8d8670f93 --- /dev/null +++ b/contrib_docs/mkdocs.yml @@ -0,0 +1,15 @@ +docs_dir: _build/pydocmd +pages: +- Home: index.md +- layers: + - Core layers: layers/core.md + - Convolutional layers: layers/convolutional.md + - normalization layers: layers/normalization.md + - Advanced activations layers: layers/advanced-activations.md + - CRF layers: layers/crf.md +- Losses: losses.md +- Optimizers: optimizers.md +- Callbacks: callbacks.md +site_dir: _build/site +site_name: Keras-contrib Documentation +theme: readthedocs From 3b1b02e3fa18651a23b6c1cd2e98209347304bd0 Mon Sep 17 00:00:00 2001 From: Seongmin Choi Date: Wed, 13 Nov 2019 14:22:26 +0900 Subject: [PATCH 5/6] commit after self install-test and pytest --- contrib_docs/mkdocs.yml | 15 --------------- 1 file changed, 15 deletions(-) delete mode 100644 contrib_docs/mkdocs.yml diff --git a/contrib_docs/mkdocs.yml b/contrib_docs/mkdocs.yml deleted file mode 100644 index 8d8670f93..000000000 --- a/contrib_docs/mkdocs.yml +++ /dev/null @@ -1,15 +0,0 @@ -docs_dir: _build/pydocmd -pages: -- Home: index.md -- layers: - - Core layers: layers/core.md - - Convolutional layers: layers/convolutional.md - - normalization layers: layers/normalization.md - - Advanced activations layers: layers/advanced-activations.md - - CRF layers: layers/crf.md -- Losses: losses.md -- Optimizers: optimizers.md -- Callbacks: callbacks.md -site_dir: _build/site -site_name: Keras-contrib Documentation -theme: readthedocs From 6c6c991956420ba30faeb6170133fca9369b5084 Mon Sep 17 00:00:00 2001 From: Seongmin Choi Date: Wed, 13 Nov 2019 14:42:07 +0900 Subject: [PATCH 6/6] fixed scripts to match PEP8 --- keras_contrib/losses/dice.py | 21 ++++++++++----------- tests/keras_contrib/losses/dice_test.py | 1 - 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/keras_contrib/losses/dice.py b/keras_contrib/losses/dice.py index db467aebd..514c8ee39 100644 --- a/keras_contrib/losses/dice.py +++ b/keras_contrib/losses/dice.py @@ -4,12 +4,12 @@ def dice_loss(y_true, y_pred, smooth=1): """Dice similarity coefficient (DSC) loss. - Essentially 1 minus the Dice similarity coefficient (DSC). Here, the Dice - similarity coefficient is used as a metric to evaluate the performance of - image segmentation by comparing spatial overlap between the true and predicted - spaces. + Essentially 1 minus the Dice similarity coefficient (DSC). Here, the Dice + similarity coefficient is used as a metric to evaluate the performance of + image segmentation by comparing spatial overlap between the true and + predicted spaces. - A smoothing factor, which is by default 1, is applied to avoid dividing by + A smoothing factor, which is by default 1, is applied to avoid dividing by zeros. Dice loss = 1 - (2 * |X & Y|)/ (X^2 + Y^2) @@ -24,18 +24,17 @@ def dice_loss(y_true, y_pred, smooth=1): The Dice coefficiet loss between the two tensors. # References - - [V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image - Segmentation](https://arxiv.org/pdf/1606.04797.pdf) + - [V-Net: Fully Convolutional Neural Networks for Volumetric Medical + Image Segmentation](https://arxiv.org/pdf/1606.04797.pdf) # About GitHub If you open an issue or a pull request about Dice loss, please - add `cc @alexbmp` to notify Seongmin Choi. + add `cc @alexbmp` to notify Seongmin Choi """ y_true_flat, y_pred_flat = K.flatten(y_true), K.flatten(y_pred) dice_nom = 2 * K.sum(y_true_flat * y_pred_flat) - dice_denom = K.sum(K.square(y_true_flat) + K.square(y_pred_flat)) + dice_denom = K.sum(K.square(y_true_flat) + K.square(y_pred_flat)) dice_coef = (dice_nom + smooth) / (dice_denom + smooth) - - return 1 - dice_coef + return 1 - dice_coef diff --git a/tests/keras_contrib/losses/dice_test.py b/tests/keras_contrib/losses/dice_test.py index 8c3fc3e5a..8b82b5444 100644 --- a/tests/keras_contrib/losses/dice_test.py +++ b/tests/keras_contrib/losses/dice_test.py @@ -36,4 +36,3 @@ def test_dice_loss_for_zero_array(): K.variable(y_true), K.variable(y_pred), ) assert K.eval(L) == 0.5, 'loss should equal 0.5' -