71
71
72
72
flags .DEFINE_bool ("do_eval" , False , "Whether to run eval on the dev set." )
73
73
74
- flags .DEFINE_bool ("do_predict" , False , "Whether to run the model in inference mode on the test set." )
74
+ flags .DEFINE_bool (
75
+ "do_predict" , False ,
76
+ "Whether to run the model in inference mode on the test set." )
75
77
76
78
flags .DEFINE_integer ("train_batch_size" , 32 , "Total batch size for training." )
77
79
@@ -248,8 +250,7 @@ def get_dev_examples(self, data_dir):
248
250
def get_test_examples (self , data_dir ):
249
251
"""See base class."""
250
252
return self ._create_examples (
251
- self ._read_tsv (os .path .join (data_dir , "test_matched.tsv" )),
252
- "test" )
253
+ self ._read_tsv (os .path .join (data_dir , "test_matched.tsv" )), "test" )
253
254
254
255
def get_labels (self ):
255
256
"""See base class."""
@@ -289,7 +290,7 @@ def get_dev_examples(self, data_dir):
289
290
def get_test_examples (self , data_dir ):
290
291
"""See base class."""
291
292
return self ._create_examples (
292
- self ._read_tsv (os .path .join (data_dir , "test.tsv" )), "test" )
293
+ self ._read_tsv (os .path .join (data_dir , "test.tsv" )), "test" )
293
294
294
295
def get_labels (self ):
295
296
"""See base class."""
@@ -329,7 +330,7 @@ def get_dev_examples(self, data_dir):
329
330
def get_test_examples (self , data_dir ):
330
331
"""See base class."""
331
332
return self ._create_examples (
332
- self ._read_tsv (os .path .join (data_dir , "test.tsv" )), "test" )
333
+ self ._read_tsv (os .path .join (data_dir , "test.tsv" )), "test" )
333
334
334
335
def get_labels (self ):
335
336
"""See base class."""
@@ -659,9 +660,7 @@ def metric_fn(per_example_loss, label_ids, logits):
659
660
scaffold_fn = scaffold_fn )
660
661
else :
661
662
output_spec = tf .contrib .tpu .TPUEstimatorSpec (
662
- mode = mode ,
663
- predictions = probabilities ,
664
- scaffold_fn = scaffold_fn )
663
+ mode = mode , predictions = probabilities , scaffold_fn = scaffold_fn )
665
664
return output_spec
666
665
667
666
return model_fn
@@ -874,7 +873,8 @@ def main(_):
874
873
predict_examples = processor .get_test_examples (FLAGS .data_dir )
875
874
predict_file = os .path .join (FLAGS .output_dir , "predict.tf_record" )
876
875
file_based_convert_examples_to_features (predict_examples , label_list ,
877
- FLAGS .max_seq_length , tokenizer , predict_file )
876
+ FLAGS .max_seq_length , tokenizer ,
877
+ predict_file )
878
878
879
879
tf .logging .info ("***** Running prediction*****" )
880
880
tf .logging .info (" Num examples = %d" , len (predict_examples ))
@@ -887,20 +887,22 @@ def main(_):
887
887
888
888
predict_drop_remainder = True if FLAGS .use_tpu else False
889
889
predict_input_fn = file_based_input_fn_builder (
890
- input_file = predict_file ,
891
- seq_length = FLAGS .max_seq_length ,
892
- is_training = False ,
893
- drop_remainder = predict_drop_remainder )
890
+ input_file = predict_file ,
891
+ seq_length = FLAGS .max_seq_length ,
892
+ is_training = False ,
893
+ drop_remainder = predict_drop_remainder )
894
894
895
895
result = estimator .predict (input_fn = predict_input_fn )
896
896
897
897
output_predict_file = os .path .join (FLAGS .output_dir , "test_results.tsv" )
898
898
with tf .gfile .GFile (output_predict_file , "w" ) as writer :
899
899
tf .logging .info ("***** Predict results *****" )
900
900
for prediction in result :
901
- output_line = "\t " .join (str (class_probability ) for class_probability in prediction ) + "\n "
901
+ output_line = "\t " .join (
902
+ str (class_probability ) for class_probability in prediction ) + "\n "
902
903
writer .write (output_line )
903
904
905
+
904
906
if __name__ == "__main__" :
905
907
flags .mark_flag_as_required ("data_dir" )
906
908
flags .mark_flag_as_required ("task_name" )
0 commit comments