Skip to content

Commit 977d96e

Browse files
AakashKumarNainSquadrick
authored andcommitted
Add kappa (#267)
* Add Cohens Kappa Metric
1 parent e605b52 commit 977d96e

File tree

5 files changed

+337
-3
lines changed

5 files changed

+337
-3
lines changed

tensorflow_addons/metrics/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,23 @@ py_library(
66
name = "metrics",
77
srcs = [
88
"__init__.py",
9+
"cohens_kappa.py",
910
],
1011
srcs_version = "PY2AND3",
1112
deps = [
1213
"//tensorflow_addons/utils",
1314
],
1415
)
16+
17+
py_test(
18+
name = "cohens_kappa_test",
19+
size = "small",
20+
srcs = [
21+
"cohens_kappa_test.py",
22+
],
23+
main = "cohens_kappa_test.py",
24+
srcs_version = "PY2AND3",
25+
deps = [
26+
":metrics",
27+
],
28+
)

tensorflow_addons/metrics/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
## Maintainers
44
| Submodule | Maintainers | Contact Info |
55
|:---------- |:------------- |:--------------|
6-
| | | |
6+
| cohens_kappa| Aakash Nain | [email protected]|
77

88
## Contents
9-
| Submodule | Activation | Reference |
9+
| Submodule | Metric | Reference |
1010
|:----------------------- |:-------------------|:---------------|
11-
| | | |
11+
| cohens_kappa| CohenKappa|[Cohen's Kappa](https://en.wikipedia.org/wiki/Cohen%27s_kappa)|
1212

1313

1414
## Contribution Guidelines

tensorflow_addons/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,5 @@
1717
from __future__ import absolute_import
1818
from __future__ import division
1919
from __future__ import print_function
20+
21+
from tensorflow_addons.metrics.cohens_kappa import CohenKappa
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Implements Cohen's Kappa."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import tensorflow as tf
22+
import numpy as np
23+
import tensorflow.keras.backend as K
24+
from tensorflow.keras.metrics import Metric
25+
from tensorflow_addons.utils import keras_utils
26+
27+
28+
@keras_utils.register_keras_custom_object
29+
class CohenKappa(Metric):
30+
"""Computes Kappa score between two raters.
31+
32+
The score lies in the range [-1, 1]. A score of -1 represents
33+
complete disagreement between two raters whereas a score of 1
34+
represents complete agreement between the two raters.
35+
A score of 0 means agreement by chance.
36+
37+
Note: As of now, this implementation considers all labels
38+
while calculating the Cohen's Kappa score.
39+
40+
Usage:
41+
```python
42+
actuals = np.array([4, 4, 3, 4, 2, 4, 1, 1], dtype=np.int32)
43+
preds = np.array([4, 4, 3, 4, 4, 2, 1, 1], dtype=np.int32)
44+
45+
m = tf.keras.metrics.CohenKappa(num_classes=5)
46+
m.update_state(actuals, preds, "quadratic")
47+
print('Final result: ', m.result().numpy()) # Result: 0.68932
48+
```
49+
Usage with tf.keras API:
50+
```python
51+
model = keras.models.Model(inputs, outputs)
52+
model.add_metric(tf.keras.metrics.CohenKappa(num_classes=5)(outputs))
53+
model.compile('sgd', loss='mse')
54+
```
55+
56+
Args:
57+
num_classes : Number of unique classes in your dataset
58+
weightage : Weighting to be considered for calculating
59+
kappa statistics. A valid value is one of
60+
[None, 'linear', 'quadratic']. Defaults to None.
61+
62+
Returns:
63+
kappa_score : float
64+
The kappa statistic, which is a number between -1 and 1. The maximum
65+
value means complete agreement; zero or lower means chance agreement.
66+
67+
Raises:
68+
ValueError: If the value passed for `weightage` is invalid
69+
i.e. not any one of [None, 'linear', 'quadratic']
70+
"""
71+
72+
def __init__(self,
73+
num_classes,
74+
name='cohen_kappa',
75+
weightage=None,
76+
dtype=tf.float32):
77+
super(CohenKappa, self).__init__(name=name, dtype=dtype)
78+
79+
if weightage not in (None, 'linear', 'quadratic'):
80+
raise ValueError("Unknown kappa weighting type.")
81+
else:
82+
self.weightage = weightage
83+
84+
self.num_classes = num_classes
85+
self.conf_mtx = self.add_weight(
86+
'conf_mtx',
87+
shape=(self.num_classes, self.num_classes),
88+
initializer=tf.keras.initializers.zeros,
89+
dtype=tf.int32)
90+
91+
def update_state(self, y_true, y_pred, sample_weight=None):
92+
"""Accumulates the confusion matrix condition statistics.
93+
94+
Args:
95+
y_true : array, shape = [n_samples]
96+
Labels assigned by the first annotator.
97+
y_pred : array, shape = [n_samples]
98+
Labels assigned by the second annotator. The kappa statistic
99+
is symmetric, so swapping ``y_true`` and ``y_pred`` doesn't
100+
change the value.
101+
sample_weight(optional) : for weighting labels in confusion matrix
102+
Default is None. The dtype for weights should be the same
103+
as the dtype for confusion matrix. For more details,
104+
please check tf.math.confusion_matrix.
105+
106+
107+
Returns:
108+
Update op.
109+
"""
110+
y_true = tf.cast(y_true, dtype=tf.int32)
111+
y_pred = tf.cast(y_pred, dtype=tf.int32)
112+
113+
if y_true.shape != y_pred.shape:
114+
raise ValueError(
115+
"Number of samples in y_true and y_pred are different")
116+
117+
# compute the new values of the confusion matrix
118+
new_conf_mtx = tf.math.confusion_matrix(
119+
labels=y_true,
120+
predictions=y_pred,
121+
num_classes=self.num_classes,
122+
weights=sample_weight)
123+
124+
# update the values in the original confusion matrix
125+
return self.conf_mtx.assign_add(new_conf_mtx)
126+
127+
def result(self):
128+
nb_ratings = tf.shape(self.conf_mtx)[0]
129+
weight_mtx = tf.ones([nb_ratings, nb_ratings], dtype=tf.int32)
130+
131+
# 2. Create a weight matrix
132+
if self.weightage is None:
133+
diagonal = tf.zeros([nb_ratings], dtype=tf.int32)
134+
weight_mtx = tf.linalg.set_diag(weight_mtx, diagonal=diagonal)
135+
weight_mtx = tf.cast(weight_mtx, dtype=tf.float32)
136+
137+
else:
138+
weight_mtx += tf.range(nb_ratings, dtype=tf.int32)
139+
weight_mtx = tf.cast(weight_mtx, dtype=tf.float32)
140+
141+
if self.weightage == 'linear':
142+
weight_mtx = tf.abs(weight_mtx - tf.transpose(weight_mtx))
143+
else:
144+
weight_mtx = tf.pow((weight_mtx - tf.transpose(weight_mtx)), 2)
145+
weight_mtx = tf.cast(weight_mtx, dtype=tf.float32)
146+
147+
# 3. Get counts
148+
actual_ratings_hist = tf.reduce_sum(self.conf_mtx, axis=1)
149+
pred_ratings_hist = tf.reduce_sum(self.conf_mtx, axis=0)
150+
151+
# 4. Get the outer product
152+
out_prod = pred_ratings_hist[..., None] * \
153+
actual_ratings_hist[None, ...]
154+
155+
# 5. Normalize the confusion matrix and outer product
156+
conf_mtx = self.conf_mtx / tf.reduce_sum(self.conf_mtx)
157+
out_prod = out_prod / tf.reduce_sum(out_prod)
158+
159+
conf_mtx = tf.cast(conf_mtx, dtype=tf.float32)
160+
out_prod = tf.cast(out_prod, dtype=tf.float32)
161+
162+
# 6. Calculate Kappa score
163+
numerator = tf.reduce_sum(conf_mtx * weight_mtx)
164+
denominator = tf.reduce_sum(out_prod * weight_mtx)
165+
kp = 1 - (numerator / denominator)
166+
return kp
167+
168+
def get_config(self):
169+
"""Returns the serializable config of the metric."""
170+
171+
config = {
172+
"num_classes": self.num_classes,
173+
"weightage": self.weightage,
174+
}
175+
base_config = super(CohenKappa, self).get_config()
176+
return dict(list(base_config.items()) + list(config.items()))
177+
178+
def reset_states(self):
179+
"""Resets all of the metric state variables."""
180+
181+
for v in self.variables:
182+
K.set_value(
183+
v, np.zeros((self.num_classes, self.num_classes), np.int32))
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for Cohen's Kappa Metric."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import tensorflow as tf
22+
from tensorflow_addons.metrics import CohenKappa
23+
from tensorflow_addons.utils import test_utils
24+
25+
26+
@test_utils.run_all_in_graph_and_eager_modes
27+
class CohenKappaTest(tf.test.TestCase):
28+
def test_config(self):
29+
kp_obj = CohenKappa(name='cohen_kappa', num_classes=5)
30+
self.assertEqual(kp_obj.name, 'cohen_kappa')
31+
self.assertEqual(kp_obj.dtype, tf.float32)
32+
self.assertEqual(kp_obj.num_classes, 5)
33+
34+
# Check save and restore config
35+
kb_obj2 = CohenKappa.from_config(kp_obj.get_config())
36+
self.assertEqual(kb_obj2.name, 'cohen_kappa')
37+
self.assertEqual(kb_obj2.dtype, tf.float32)
38+
self.assertEqual(kp_obj.num_classes, 5)
39+
40+
def initialize_vars(self):
41+
kp_obj1 = CohenKappa(num_classes=5)
42+
kp_obj2 = CohenKappa(num_classes=5, weightage='linear')
43+
kp_obj3 = CohenKappa(num_classes=5, weightage='quadratic')
44+
45+
self.evaluate(tf.compat.v1.variables_initializer(kp_obj1.variables))
46+
self.evaluate(tf.compat.v1.variables_initializer(kp_obj2.variables))
47+
self.evaluate(tf.compat.v1.variables_initializer(kp_obj3.variables))
48+
return kp_obj1, kp_obj2, kp_obj3
49+
50+
def update_obj_states(self, obj1, obj2, obj3, actuals, preds, weights):
51+
update_op1 = obj1.update_state(actuals, preds, sample_weight=weights)
52+
update_op2 = obj2.update_state(actuals, preds, sample_weight=weights)
53+
update_op3 = obj3.update_state(actuals, preds, sample_weight=weights)
54+
55+
self.evaluate(update_op1)
56+
self.evaluate(update_op2)
57+
self.evaluate(update_op3)
58+
59+
def check_results(self, objs, values):
60+
obj1, obj2, obj3 = objs
61+
val1, val2, val3 = values
62+
63+
self.assertAllClose(val1, self.evaluate(obj1.result()), atol=1e-5)
64+
self.assertAllClose(val2, self.evaluate(obj2.result()), atol=1e-5)
65+
self.assertAllClose(val3, self.evaluate(obj3.result()), atol=1e-5)
66+
67+
def test_kappa_random_score(self):
68+
actuals = [4, 4, 3, 4, 2, 4, 1, 1]
69+
preds = [4, 4, 3, 4, 4, 2, 1, 1]
70+
actuals = tf.constant(actuals, dtype=tf.int32)
71+
preds = tf.constant(preds, dtype=tf.int32)
72+
73+
# Initialize
74+
kp_obj1, kp_obj2, kp_obj3 = self.initialize_vars()
75+
76+
# Update
77+
self.update_obj_states(kp_obj1, kp_obj2, kp_obj3, actuals, preds, None)
78+
79+
# Check results
80+
self.check_results([kp_obj1, kp_obj2, kp_obj3],
81+
[0.61904761, 0.62790697, 0.68932038])
82+
83+
def test_kappa_perfect_score(self):
84+
actuals = [4, 4, 3, 3, 2, 2, 1, 1]
85+
preds = [4, 4, 3, 3, 2, 2, 1, 1]
86+
actuals = tf.constant(actuals, dtype=tf.int32)
87+
preds = tf.constant(preds, dtype=tf.int32)
88+
89+
# Initialize
90+
kp_obj1, kp_obj2, kp_obj3 = self.initialize_vars()
91+
92+
# Update
93+
self.update_obj_states(kp_obj1, kp_obj2, kp_obj3, actuals, preds, None)
94+
95+
# Check results
96+
self.check_results([kp_obj1, kp_obj2, kp_obj3], [1.0, 1.0, 1.0])
97+
98+
def test_kappa_worse_than_random(self):
99+
actuals = [4, 4, 3, 3, 2, 2, 1, 1]
100+
preds = [1, 2, 4, 1, 3, 3, 4, 4]
101+
actuals = tf.constant(actuals, dtype=tf.int32)
102+
preds = tf.constant(preds, dtype=tf.int32)
103+
104+
# Initialize
105+
kp_obj1, kp_obj2, kp_obj3 = self.initialize_vars()
106+
107+
# Update
108+
self.update_obj_states(kp_obj1, kp_obj2, kp_obj3, actuals, preds, None)
109+
110+
# check results
111+
self.check_results([kp_obj1, kp_obj2, kp_obj3],
112+
[-0.3333333, -0.52380952, -0.72727272])
113+
114+
def test_kappa_with_sample_weights(self):
115+
actuals = [4, 4, 3, 3, 2, 2, 1, 1]
116+
preds = [1, 2, 4, 1, 3, 3, 4, 4]
117+
weights = [1, 1, 2, 5, 10, 2, 3, 3]
118+
actuals = tf.constant(actuals, dtype=tf.int32)
119+
preds = tf.constant(preds, dtype=tf.int32)
120+
weights = tf.constant(weights, dtype=tf.int32)
121+
122+
# Initialize
123+
kp_obj1, kp_obj2, kp_obj3 = self.initialize_vars()
124+
125+
# Update
126+
self.update_obj_states(kp_obj1, kp_obj2, kp_obj3, actuals, preds,
127+
weights)
128+
129+
# check results
130+
self.check_results([kp_obj1, kp_obj2, kp_obj3],
131+
[-0.25473321, -0.38992332, -0.60695344])
132+
133+
134+
if __name__ == '__main__':
135+
tf.test.main()

0 commit comments

Comments
 (0)