2323
2424def  test_config ():
2525    # mcc object 
26-     mcc1  =  MatthewsCorrelationCoefficient (num_classes = 1 )
27-     assert  mcc1 .num_classes  ==  1 
26+     mcc1  =  MatthewsCorrelationCoefficient (num_classes = 2 )
27+     assert  mcc1 .num_classes  ==  2 
2828    assert  mcc1 .dtype  ==  tf .float32 
2929    # check configure 
3030    mcc2  =  MatthewsCorrelationCoefficient .from_config (mcc1 .get_config ())
31-     assert  mcc2 .num_classes  ==  1 
31+     assert  mcc2 .num_classes  ==  2 
3232    assert  mcc2 .dtype  ==  tf .float32 
3333
3434
3535def  check_results (obj , value ):
3636    np .testing .assert_allclose (value , obj .result ().numpy (), atol = 1e-6 )
3737
3838
39+ def  test_binary_classes_sparse ():
40+     gt_label  =  tf .constant ([[1.0 ], [1.0 ], [1.0 ], [0.0 ]], dtype = tf .float32 )
41+     preds  =  tf .constant ([[1.0 ], [0.0 ], [1.0 ], [1.0 ]], dtype = tf .float32 )
42+     # Initialize 
43+     mcc  =  MatthewsCorrelationCoefficient (1 )
44+     # Update 
45+     mcc .update_state (gt_label , preds )
46+     # Check results 
47+     check_results (mcc , [- 0.33333334 ])
48+ 
49+ 
3950def  test_binary_classes ():
4051    gt_label  =  tf .constant (
4152        [[0.0 , 1.0 ], [0.0 , 1.0 ], [0.0 , 1.0 ], [1.0 , 0.0 ]], dtype = tf .float32 
@@ -91,6 +102,16 @@ def test_multiple_classes():
91102    sklearn_result  =  sklearn_matthew (gt_label .argmax (axis = 1 ), preds .argmax (axis = 1 ))
92103    check_results (mcc , sklearn_result )
93104
105+     gt_label_sparse  =  tf .constant (
106+         [[0.0 ], [2.0 ], [0.0 ], [2.0 ], [1.0 ], [1.0 ], [0.0 ], [0.0 ], [2.0 ], [1.0 ]]
107+     )
108+     preds_sparse  =  tf .constant (
109+         [[2.0 ], [0.0 ], [2.0 ], [2.0 ], [2.0 ], [2.0 ], [2.0 ], [0.0 ], [2.0 ], [2.0 ]]
110+     )
111+     mcc  =  MatthewsCorrelationCoefficient (3 )
112+     mcc .update_state (gt_label_sparse , preds_sparse )
113+     check_results (mcc , sklearn_result )
114+ 
94115
95116# Keras model API check 
96117def  test_keras_model ():
@@ -110,13 +131,9 @@ def test_keras_model():
110131
111132
112133def  test_reset_states_graph ():
113-     gt_label  =  tf .constant (
114-         [[0.0 , 1.0 ], [0.0 , 1.0 ], [0.0 , 1.0 ], [1.0 , 0.0 ]], dtype = tf .float32 
115-     )
116-     preds  =  tf .constant (
117-         [[0.0 , 1.0 ], [1.0 , 0.0 ], [0.0 , 1.0 ], [0.0 , 1.0 ]], dtype = tf .float32 
118-     )
119-     mcc  =  MatthewsCorrelationCoefficient (2 )
134+     gt_label  =  tf .constant ([[1.0 ], [1.0 ], [1.0 ], [0.0 ]], dtype = tf .float32 )
135+     preds  =  tf .constant ([[1.0 ], [0.0 ], [1.0 ], [1.0 ]], dtype = tf .float32 )
136+     mcc  =  MatthewsCorrelationCoefficient (1 )
120137    mcc .update_state (gt_label , preds )
121138
122139    @tf .function  
0 commit comments