Skip to content

Commit 8cad46b

Browse files
davidstutzcopybara-github
authored andcommitted
Internal
PiperOrigin-RevId: 466673111 Change-Id: If18d6db1bb8065c92c2cb81636b3629df623ad0b
1 parent a098422 commit 8cad46b

30 files changed

+370
-70
lines changed

colab_utils.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,16 @@
1616
"""Utils for evaluation in Colabs or notebooks."""
1717
from typing import Tuple, Callable, Dict, Any, List
1818

19+
from absl import logging
1920
import jax
2021
import jax.numpy as jnp
2122
import numpy as np
2223
import pandas as pd
2324
import sklearn.metrics
2425

25-
import conformal_training.conformal_prediction as cp
26-
import conformal_training.evaluation as cpeval
27-
import conformal_training.open_source_utils as cpstaging
26+
import conformal_prediction as cp
27+
import evaluation as cpeval
28+
import open_source_utils as cpstaging
2829

2930

3031
_CalibrateFn = Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], float]
@@ -504,11 +505,10 @@ def evaluate_conformal_prediction(
504505
test_results_t = pd.concat([tau_t] + test_results_t, axis=1)
505506

506507
test_results = pd.concat((test_results, test_results_t), axis=0)
507-
print(f'\t trial {t}: {tau}', flush=True)
508+
logging.info('Trial %d: %f', t, tau)
508509

509510
results = {
510511
'mean': {'val': val_results.mean(0), 'test': test_results.mean(0)},
511512
'std': {'val': val_results.std(0), 'test': test_results.std(0)},
512513
}
513-
print('\t reduced', flush=True)
514514
return results

colab_utils_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
import ml_collections as collections
2323
import numpy as np
2424

25-
import conformal_training.colab_utils as cpcolab
26-
import conformal_training.data_utils as cpdatautils
27-
import conformal_training.test_utils as cptutils
25+
import colab_utils as cpcolab
26+
import data_utils as cpdatautils
27+
import test_utils as cptutils
2828

2929

3030
class ColabUtilsTest(parameterized.TestCase):

conformal_prediction_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
import jax.numpy as jnp
2424
import numpy as np
2525

26-
import conformal_training.conformal_prediction as cp
27-
import conformal_training.test_utils as cptutils
26+
import conformal_prediction as cp
27+
import test_utils as cptutils
2828

2929

3030
class ConformalPredictionTest(parameterized.TestCase):

data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import tensorflow as tf
2121
import tensorflow_datasets as tfds
2222

23-
import conformal_training.auto_augment as augment
23+
import auto_augment as augment
2424

2525

2626
def load_data_split(

data_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import numpy as np
2323
import tensorflow_datasets as tfds
2424

25-
import conformal_training.data as cpdata
25+
import data as cpdata
2626
DATA_DIR = './data/'
2727

2828

data_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import ml_collections as collections
2323
import tensorflow as tf
2424

25-
import conformal_training.data as cpdata
25+
import data as cpdata
2626

2727

2828
def apply_cifar_augmentation(

data_utils_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
import chex
2121
import ml_collections as collections
2222

23-
import conformal_training.data as cpdata
24-
import conformal_training.data_utils as cpdatautils
23+
import data as cpdata
24+
import data_utils as cpdatautils
2525
DATA_DIR = './data/'
2626

2727

environment.yml

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
name: conformal_training
2+
channels:
3+
- conda-forge
4+
- defaults
5+
dependencies:
6+
- _libgcc_mutex=0.1=conda_forge
7+
- _openmp_mutex=4.5=2_gnu
8+
- _tflow_select=2.3.0=mkl
9+
- abseil-cpp=20211102.0=h27087fc_1
10+
- absl-py=0.15.0=pyhd3eb1b0_0
11+
- aiohttp=3.8.1=py39hb9d737c_1
12+
- aiosignal=1.2.0=pyhd8ed1ab_0
13+
- astor=0.8.1=pyh9f0ad1d_0
14+
- astunparse=1.6.3=pyhd8ed1ab_0
15+
- async-timeout=4.0.2=pyhd8ed1ab_0
16+
- attrs=21.4.0=pyhd8ed1ab_0
17+
- blas=1.0=openblas
18+
- blinker=1.4=py_1
19+
- bottleneck=1.3.5=py39h7deecbd_0
20+
- brotlipy=0.7.0=py39hb9d737c_1004
21+
- bzip2=1.0.8=h7f98852_4
22+
- c-ares=1.18.1=h7f98852_0
23+
- ca-certificates=2022.6.15=ha878542_0
24+
- cachetools=4.2.4=pyhd8ed1ab_0
25+
- certifi=2022.6.15=py39hf3d152e_0
26+
- cffi=1.15.1=py39he91dace_0
27+
- charset-normalizer=2.1.0=pyhd8ed1ab_0
28+
- click=8.1.3=py39hf3d152e_0
29+
- cryptography=37.0.1=py39h9ce1e76_0
30+
- dataclasses=0.8=pyhc8e2a94_3
31+
- dm-haiku=0.0.7=pyhd8ed1ab_0
32+
- etils=0.6.0=pyhd8ed1ab_0
33+
- frozenlist=1.3.0=py39hb9d737c_1
34+
- gast=0.4.0=pyh9f0ad1d_0
35+
- google-auth=1.35.0=pyh6c4a22f_0
36+
- google-auth-oauthlib=0.4.6=pyhd8ed1ab_0
37+
- google-pasta=0.2.0=pyh8c360ce_0
38+
- grpc-cpp=1.46.3=h00ec82a_2
39+
- grpcio=1.46.3=py39h2edfe15_2
40+
- h5py=2.10.0=nompi_py39h98ba4bc_106
41+
- hdf5=1.10.6=h3ffc7dd_1
42+
- idna=3.3=pyhd8ed1ab_0
43+
- importlib-metadata=4.11.4=py39hf3d152e_0
44+
- importlib_resources=5.8.0=pyhd8ed1ab_0
45+
- jax=0.3.14=pyhd8ed1ab_1
46+
- jaxlib=0.3.14=cpu_py39h79d7c74_0
47+
- jmp=0.0.2=pyhd8ed1ab_0
48+
- joblib=1.1.0=pyhd3eb1b0_0
49+
- keras-preprocessing=1.1.2=pyhd8ed1ab_0
50+
- ld_impl_linux-64=2.36.1=hea4e1c9_2
51+
- libblas=3.9.0=15_linux64_openblas
52+
- libcblas=3.9.0=15_linux64_openblas
53+
- libffi=3.4.2=h7f98852_5
54+
- libgcc-ng=12.1.0=h8d9b700_16
55+
- libgfortran-ng=12.1.0=h69a702a_16
56+
- libgfortran5=12.1.0=hdcd56e2_16
57+
- libgomp=12.1.0=h8d9b700_16
58+
- liblapack=3.9.0=15_linux64_openblas
59+
- libnsl=2.0.0=h7f98852_0
60+
- libopenblas=0.3.20=pthreads_h78a6416_0
61+
- libprotobuf=3.20.1=h6239696_0
62+
- libstdcxx-ng=12.1.0=ha89aaad_16
63+
- libuuid=2.32.1=h7f98852_1000
64+
- libzlib=1.2.12=h166bdaf_1
65+
- markdown=3.3.7=pyhd8ed1ab_0
66+
- multidict=6.0.2=py39hb9d737c_1
67+
- ncurses=6.3=h27087fc_1
68+
- numexpr=2.8.3=py39hd2a5715_0
69+
- numpy=1.19.5=py39hd249d9e_3
70+
- oauthlib=3.2.0=pyhd8ed1ab_0
71+
- openssl=3.0.5=h166bdaf_0
72+
- opt_einsum=3.3.0=pyhd8ed1ab_1
73+
- packaging=21.3=pyhd3eb1b0_0
74+
- pandas=1.4.2=py39h295c915_0
75+
- pip=22.1.2=pyhd8ed1ab_0
76+
- protobuf=3.20.1=py39h5a03fae_0
77+
- pyasn1=0.4.8=py_0
78+
- pyasn1-modules=0.2.7=py_0
79+
- pycparser=2.21=pyhd8ed1ab_0
80+
- pyjwt=2.4.0=pyhd8ed1ab_0
81+
- pyopenssl=22.0.0=pyhd8ed1ab_0
82+
- pysocks=1.7.1=py39hf3d152e_5
83+
- python=3.9.13=h2660328_0_cpython
84+
- python-dateutil=2.8.2=pyhd3eb1b0_0
85+
- python-flatbuffers=2.0=pyhd8ed1ab_0
86+
- python_abi=3.9=2_cp39
87+
- pytz=2022.1=py39h06a4308_0
88+
- pyu2f=0.1.5=pyhd8ed1ab_0
89+
- re2=2022.06.01=h27087fc_0
90+
- readline=8.1.2=h0f457ee_0
91+
- requests=2.28.1=pyhd8ed1ab_0
92+
- requests-oauthlib=1.3.1=pyhd8ed1ab_0
93+
- rsa=4.8=pyhd8ed1ab_0
94+
- scikit-learn=1.0.2=py39h51133e4_1
95+
- scipy=1.8.1=py39he49c0e8_0
96+
- setuptools=63.1.0=py39hf3d152e_0
97+
- six=1.16.0=pyh6c4a22f_0
98+
- sqlite=3.39.0=h4ff8645_0
99+
- tabulate=0.8.10=pyhd8ed1ab_0
100+
- tensorboard=2.4.1=pyhd8ed1ab_1
101+
- tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0
102+
- tensorflow=2.4.1=mkl_py39h4683426_0
103+
- tensorflow-base=2.4.1=mkl_py39h43e0292_0
104+
- tensorflow-estimator=2.6.0=py39he80948d_0
105+
- termcolor=1.1.0=pyhd8ed1ab_3
106+
- threadpoolctl=2.2.0=pyh0d69192_0
107+
- tk=8.6.12=h27826a3_0
108+
- typing-extensions=4.3.0=hd8ed1ab_0
109+
- typing_extensions=4.3.0=pyha770c72_0
110+
- tzdata=2022a=h191b570_0
111+
- urllib3=1.26.9=pyhd8ed1ab_0
112+
- werkzeug=2.1.2=pyhd8ed1ab_1
113+
- wheel=0.37.1=pyhd8ed1ab_0
114+
- wrapt=1.14.1=py39hb9d737c_0
115+
- xz=5.2.5=h516909a_1
116+
- yarl=1.7.2=py39hb9d737c_2
117+
- zipp=3.8.0=pyhd8ed1ab_0
118+
- zlib=1.2.12=h166bdaf_1
119+
- pip:
120+
- chex==0.1.3
121+
- contextlib2==21.6.0
122+
- dill==0.3.5.1
123+
- dm-tree==0.1.7
124+
- googleapis-common-protos==1.56.3
125+
- install==1.3.5
126+
- ml-collections==0.1.1
127+
- optax==0.1.2
128+
- promise==2.3
129+
- pyparsing==3.0.9
130+
- pyyaml==6.0
131+
- tensorflow-addons==0.17.1
132+
- tensorflow-datasets==4.6.0
133+
- tensorflow-metadata==1.9.0
134+
- toml==0.10.2
135+
- toolz==0.11.2
136+
- tqdm==4.64.0
137+
- typeguard==2.13.3

eval.py

+118
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright 2022 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Evaluate experiment."""
17+
import os
18+
import sys
19+
20+
from absl import flags
21+
from absl import logging
22+
import jax
23+
24+
from absl import app
25+
import colab_utils as cbutils
26+
27+
FLAGS = flags.FLAGS
28+
flags.DEFINE_string('experiment_path', './', 'base path for experiments')
29+
flags.DEFINE_string('experiment_dataset', '', 'dataset to evaluate')
30+
flags.DEFINE_string(
31+
'experiment_method', 'thr', 'conformal predictor to use, thr or apr')
32+
flags.DEFINE_boolean('experiment_logfile', False,
33+
'log results to file in experiment_path')
34+
35+
36+
def main(argv):
37+
del argv
38+
39+
if FLAGS.experiment_logfile:
40+
logging.get_absl_handler().use_absl_log_file(
41+
f'eval_{FLAGS.experiment_method}', FLAGS.experiment_path)
42+
else:
43+
logging.get_absl_handler().python_handler.stream = sys.stdout
44+
45+
if not os.path.exists(FLAGS.experiment_path):
46+
logging.error('could not find experiment path %s', FLAGS.experiment_path)
47+
return
48+
49+
alpha = 0.01
50+
if FLAGS.experiment_method == 'thr':
51+
calibrate_fn, predict_fn = cbutils.get_threshold_fns(alpha)
52+
elif FLAGS.experiment_method == 'aps':
53+
calibrate_fn, predict_fn = cbutils.get_raps_fns(alpha, 0, 0)
54+
else:
55+
raise ValueError('Invalid conformal predictor, choose thr or aps.')
56+
57+
if FLAGS.experiment_dataset == 'mnist':
58+
num_classes = 10
59+
groups = ['singleton', 'groups']
60+
elif FLAGS.experiment_dataset == 'emnist_byclass':
61+
num_classes = 52
62+
groups = ['groups']
63+
elif FLAGS.experiment_dataset == 'fashion_mnist':
64+
num_classes = 10
65+
groups = ['singleton']
66+
elif FLAGS.experiment_dataset == 'cifar10':
67+
num_classes = 10
68+
groups = ['singleton', 'groups']
69+
elif FLAGS.experiment_dataset == 'cifar100':
70+
num_classes = 100
71+
groups = ['groups', 'hierarchy']
72+
else:
73+
raise ValueError('Invalid dataset %s.' % FLAGS.experiment_dataset)
74+
75+
model = cbutils.load_predictions(FLAGS.experiment_path, val_examples=5000)
76+
77+
for group in groups:
78+
model['data']['groups'][group] = cbutils.get_groups(
79+
FLAGS.experiment_dataset, group)
80+
81+
results = cbutils.evaluate_conformal_prediction(
82+
model, calibrate_fn, predict_fn, trials=10, rng=jax.random.PRNGKey(0))
83+
84+
logging.info('Accuracy: %f', results['mean']['test']['accuracy'])
85+
logging.info('Coverage: %f', results['mean']['test']['coverage'])
86+
logging.info('Size: %f', results['mean']['test']['size'])
87+
88+
for k in range(num_classes):
89+
logging.info(
90+
'Class size %d: %f', k, results['mean']['test'][f'class_size_{k}'])
91+
92+
for group in groups:
93+
k = 0
94+
key = f'{group}_size_{k}'
95+
while key in results['mean']['test'].keys():
96+
logging.info(
97+
'Group %s size %d: %f', group, k, results['mean']['test'][key])
98+
k += 1
99+
key = f'{group}_size_{k}'
100+
101+
logging.info(
102+
'Group %s miscoverage 0: %f',
103+
group, results['mean']['test'][f'{group}_miscoverage_0'])
104+
logging.info(
105+
'Group %s miscoverage 1: %f',
106+
group, results['mean']['test'][f'{group}_miscoverage_1'])
107+
108+
# Selected coverage confusion combinations:
109+
logging.info(
110+
'Coverage confusion 4-6: %f',
111+
results['mean']['test']['coverage_confusion_4_6'])
112+
logging.info(
113+
'Coverage confusion 6-4: %f',
114+
results['mean']['test']['coverage_confusion_6_4'])
115+
116+
117+
if __name__ == '__main__':
118+
app.run(main)

evaluation_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
import jax.numpy as jnp
2121
import numpy as np
2222

23-
import conformal_training.evaluation as cpeval
24-
import conformal_training.test_utils as cptutils
23+
import evaluation as cpeval
24+
import test_utils as cptutils
2525

2626

2727
class EvaluationTest(parameterized.TestCase):

experiments/__init__.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright 2022 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Experiments configuration."""

experiments/run_cifar10.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import ml_collections as collections
2020

21-
import conformal_training.experiments.experiment_utils as cpeutils
21+
import experiments.experiment_utils as cpeutils
2222

2323

2424
def get_parameters(
@@ -52,7 +52,7 @@ def get_parameters(
5252
else:
5353
config.epochs = 50
5454
config.finetune.enabled = True
55-
config.finetune.path = './cifar10_models_seed0/'
55+
config.finetune.path = 'cifar10_models_seed0/'
5656
config.finetune.model_state = False
5757
config.finetune.layers = 'res_net/~/logits'
5858
config.finetune.reinitialize = True

0 commit comments

Comments
 (0)