diff --git a/keras_contrib/losses/jaccard.py b/keras_contrib/losses/jaccard.py index 671887b71..1fcf990d2 100644 --- a/keras_contrib/losses/jaccard.py +++ b/keras_contrib/losses/jaccard.py @@ -37,7 +37,8 @@ def jaccard_distance(y_true, y_pred, smooth=100): http://www.bmva.org/bmvc/2013/Papers/paper0032/paper0032.pdf) """ - intersection = K.sum(K.abs(y_true * y_pred), axis=-1) - sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1) + shape_dims = list(range(len(y_true.get_shape().as_list())-1)) + intersection = K.sum(K.abs(y_true * y_pred), axis=shape_dims) + sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=shape_dims) jac = (intersection + smooth) / (sum_ - intersection + smooth) return (1 - jac) * smooth