|
13 | 13 | # limitations under the License. |
14 | 14 | """Tests for tflite predict extractor.""" |
15 | 15 |
|
| 16 | +import itertools |
16 | 17 | import os |
17 | 18 | import tempfile |
18 | 19 |
|
|
31 | 32 | from google.protobuf import text_format |
32 | 33 | from tensorflow_metadata.proto.v0 import schema_pb2 |
33 | 34 |
|
| 35 | +_TF_MAJOR_VERSION = int(tf.version.VERSION.split('.')[0]) |
| 36 | + |
| 37 | +_MULTI_MODEL_CASES = [False, True] |
| 38 | +_MULTI_OUTPUT_CASES = [False, True] |
| 39 | +# Equality op not supported in TF1. See b/242088810 |
| 40 | +_BYTES_FEATURE_CASES = [False] if _TF_MAJOR_VERSION < 2 else [False, True] |
| 41 | + |
34 | 42 |
|
35 | 43 | class TFLitePredictExtractorTest(testutil.TensorflowModelAnalysisTest, |
36 | 44 | parameterized.TestCase): |
37 | 45 |
|
38 | | - @parameterized.named_parameters(('single_model_single_output', False, False), |
39 | | - ('single_model_multi_output', False, True), |
40 | | - ('multi_model_single_output', True, False), |
41 | | - ('multi_model_multi_output', True, True)) |
42 | | - def testTFlitePredictExtractorWithKerasModel(self, multi_model, multi_output): |
| 46 | + @parameterized.parameters( |
| 47 | + itertools.product(_MULTI_MODEL_CASES, _MULTI_OUTPUT_CASES, |
| 48 | + _BYTES_FEATURE_CASES)) |
| 49 | + def testTFlitePredictExtractorWithKerasModel(self, multi_model, multi_output, |
| 50 | + use_bytes_feature): |
43 | 51 | input1 = tf.keras.layers.Input(shape=(1,), name='input1') |
44 | 52 | input2 = tf.keras.layers.Input(shape=(1,), name='input2') |
45 | 53 | input3 = tf.keras.layers.Input(shape=(1,), name='input3', dtype=tf.string) |
46 | 54 | inputs = [input1, input2, input3] |
47 | | - input_layer = tf.keras.layers.concatenate( |
48 | | - [inputs[0], inputs[1], |
49 | | - tf.cast(inputs[2] == 'a', tf.float32)]) |
| 55 | + if use_bytes_feature: |
| 56 | + input_layer = tf.keras.layers.concatenate( |
| 57 | + [inputs[0], inputs[1], |
| 58 | + tf.cast(inputs[2] == 'a', tf.float32)]) |
| 59 | + else: |
| 60 | + input_layer = tf.keras.layers.concatenate([inputs[0], inputs[1]]) |
50 | 61 | output_layers = {} |
51 | 62 | output_layers['output1'] = ( |
52 | 63 | tf.keras.layers.Dense(1, activation=tf.nn.sigmoid, |
|
0 commit comments