Skip to content

Commit 3c67c1d

Browse files
committedNov 9, 2018
Running through pyformat to meet Google code standards
1 parent aefad12 commit 3c67c1d

File tree

3 files changed

+20
-17
lines changed

3 files changed

+20
-17
lines changed
 

‎extract_features.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,9 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
170170

171171
tvars = tf.trainable_variables()
172172
scaffold_fn = None
173-
(assignment_map, initialized_variable_names
174-
) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
173+
(assignment_map,
174+
initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(
175+
tvars, init_checkpoint)
175176
if use_tpu:
176177

177178
def tpu_scaffold():

‎run_classifier.py

+16-14
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@
7171

7272
flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
7373

74-
flags.DEFINE_bool("do_predict", False, "Whether to run the model in inference mode on the test set.")
74+
flags.DEFINE_bool(
75+
"do_predict", False,
76+
"Whether to run the model in inference mode on the test set.")
7577

7678
flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
7779

@@ -248,8 +250,7 @@ def get_dev_examples(self, data_dir):
248250
def get_test_examples(self, data_dir):
249251
"""See base class."""
250252
return self._create_examples(
251-
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")),
252-
"test")
253+
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")
253254

254255
def get_labels(self):
255256
"""See base class."""
@@ -289,7 +290,7 @@ def get_dev_examples(self, data_dir):
289290
def get_test_examples(self, data_dir):
290291
"""See base class."""
291292
return self._create_examples(
292-
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
293+
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
293294

294295
def get_labels(self):
295296
"""See base class."""
@@ -329,7 +330,7 @@ def get_dev_examples(self, data_dir):
329330
def get_test_examples(self, data_dir):
330331
"""See base class."""
331332
return self._create_examples(
332-
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
333+
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
333334

334335
def get_labels(self):
335336
"""See base class."""
@@ -659,9 +660,7 @@ def metric_fn(per_example_loss, label_ids, logits):
659660
scaffold_fn=scaffold_fn)
660661
else:
661662
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
662-
mode=mode,
663-
predictions=probabilities,
664-
scaffold_fn=scaffold_fn)
663+
mode=mode, predictions=probabilities, scaffold_fn=scaffold_fn)
665664
return output_spec
666665

667666
return model_fn
@@ -874,7 +873,8 @@ def main(_):
874873
predict_examples = processor.get_test_examples(FLAGS.data_dir)
875874
predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
876875
file_based_convert_examples_to_features(predict_examples, label_list,
877-
FLAGS.max_seq_length, tokenizer, predict_file)
876+
FLAGS.max_seq_length, tokenizer,
877+
predict_file)
878878

879879
tf.logging.info("***** Running prediction*****")
880880
tf.logging.info(" Num examples = %d", len(predict_examples))
@@ -887,20 +887,22 @@ def main(_):
887887

888888
predict_drop_remainder = True if FLAGS.use_tpu else False
889889
predict_input_fn = file_based_input_fn_builder(
890-
input_file=predict_file,
891-
seq_length=FLAGS.max_seq_length,
892-
is_training=False,
893-
drop_remainder=predict_drop_remainder)
890+
input_file=predict_file,
891+
seq_length=FLAGS.max_seq_length,
892+
is_training=False,
893+
drop_remainder=predict_drop_remainder)
894894

895895
result = estimator.predict(input_fn=predict_input_fn)
896896

897897
output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
898898
with tf.gfile.GFile(output_predict_file, "w") as writer:
899899
tf.logging.info("***** Predict results *****")
900900
for prediction in result:
901-
output_line = "\t".join(str(class_probability) for class_probability in prediction) + "\n"
901+
output_line = "\t".join(
902+
str(class_probability) for class_probability in prediction) + "\n"
902903
writer.write(output_line)
903904

905+
904906
if __name__ == "__main__":
905907
flags.mark_flag_as_required("data_dir")
906908
flags.mark_flag_as_required("task_name")

‎tokenization.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def convert_by_vocab(vocab, items):
8686
"""Converts a sequence of [tokens|ids] using the vocab."""
8787
output = []
8888
for item in items:
89-
output.append(vocab[item])
89+
output.append(vocab[item])
9090
return output
9191

9292

0 commit comments

Comments
 (0)
Please sign in to comment.