diff --git a/tf_unet/unet.py b/tf_unet/unet.py index 0182597..f8ef61a 100644 --- a/tf_unet/unet.py +++ b/tf_unet/unet.py @@ -2,12 +2,12 @@ # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # tf_unet is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with tf_unet. If not, see . @@ -204,7 +204,8 @@ def __init__(self, channels=3, n_class=2, cost="cross_entropy", cost_kwargs={}, def _get_cost(self, logits, cost_name, cost_kwargs): """ - Constructs the cost function, either cross_entropy, weighted cross_entropy or dice_coefficient. + Constructs the cost function, either cross_entropy, weighted cross_entropy, + dice_coefficient, or iou (intersection over union). Optional arguments are: class_weights: weights for the different classes in case of multi-class imbalance regularizer: power of the L2 regularizers added to the loss function @@ -230,12 +231,17 @@ def _get_cost(self, logits, cost_name, cost_kwargs): else: loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=flat_logits, labels=flat_labels)) - elif cost_name == "dice_coefficient": + elif cost_name == "dice_coefficient" or cost_name == "iou": eps = 1e-5 prediction = pixel_wise_softmax_2(logits) - intersection = tf.reduce_sum(prediction * self.y) - union = eps + tf.reduce_sum(prediction) + tf.reduce_sum(self.y) - loss = -(2 * intersection / (union)) + A_intersect_B = tf.reduce_sum(prediction * self.y, axis=[0, 1, 2]) + A_plus_B = tf.reduce_sum(prediction, axis=[0, 1, 2]) + tf.reduce_sum(self.y, axis=[0, 1, 2]) + if cost_name == "dice_coefficient": + denominator = A_plus_B + else: # intersection over union + A_union_B = A_plus_B - A_intersect_B + denominator = A_union_B + loss = tf.reduce_sum(-(2 * A_intersect_B / (eps + denominator))) else: raise ValueError("Unknown cost function: " % cost_name)