@@ -288,13 +288,39 @@ def _ExtractOutput( # pylint: disable=invalid-name
288288 main = _ExtractOutputDoFn .OUTPUT_TAG_METRICS )
289289
290290
291+ def PredictExtractor (eval_saved_model_path , add_metrics_callbacks ,
292+ shared_handle , desired_batch_size ):
293+ # Map function which loads and runs the eval_saved_model against every
294+ # example, yielding an types.ExampleAndExtracts containing a
295+ # FeaturesPredictionsLabels value (where key is 'fpl').
296+ return types .Extractor (
297+ stage_name = 'Predict' ,
298+ ptransform = predict_extractor .TFMAPredict (
299+ eval_saved_model_path = eval_saved_model_path ,
300+ add_metrics_callbacks = add_metrics_callbacks ,
301+ shared_handle = shared_handle ,
302+ desired_batch_size = desired_batch_size ))
303+
304+
305+ @beam .ptransform_fn
306+ def Extract (examples , extractors ):
307+ """Performs Extractions serially in provided order."""
308+ augmented = examples
309+
310+ for extractor in extractors :
311+ augmented = augmented | extractor .stage_name >> extractor .ptransform
312+
313+ return augmented
314+
315+
291316@beam .ptransform_fn
292317# No typehint for output type, since it's a multi-output DoFn result that
293318# Beam doesn't support typehints for yet (BEAM-3280).
294319def Evaluate (
295320 # pylint: disable=invalid-name
296321 examples ,
297322 eval_saved_model_path ,
323+ extractors = None ,
298324 add_metrics_callbacks = None ,
299325 slice_spec = None ,
300326 desired_batch_size = None ,
@@ -309,6 +335,8 @@ def Evaluate(
309335 (e.g. string containing CSV row, TensorFlow.Example, etc).
310336 eval_saved_model_path: Path to EvalSavedModel. This directory should contain
311337 the saved_model.pb file.
338+ extractors: Optional list of Extractors to execute prior to slicing and
339+ aggregating the metrics. If not provided, a default set will be run.
312340 add_metrics_callbacks: Optional list of callbacks for adding additional
313341 metrics to the graph. The names of the metrics added by the callbacks
314342 should not conflict with existing metrics, or metrics added by other
@@ -349,24 +377,22 @@ def add_metrics_callback(features_dict, predictions_dict, labels):
349377
350378 shared_handle = shared .Shared ()
351379
380+ if not extractors :
381+ extractors = [
382+ PredictExtractor (eval_saved_model_path , add_metrics_callbacks ,
383+ shared_handle , desired_batch_size ),
384+ ]
385+
352386 # pylint: disable=no-value-for-parameter
353387 return (
354388 examples
355389 # Our diagnostic outputs, pass types.ExampleAndExtracts throughout,
356390 # however our aggregating functions do not use this interface.
357391 | 'ToExampleAndExtracts' >>
358392 beam .Map (lambda x : types .ExampleAndExtracts (example = x , extracts = {}))
393+ | Extract (extractors = extractors )
359394
360- # Map function which loads and runs the eval_saved_model against every
361- # example, yielding an types.ExampleAndExtracts containing a
362- # FeaturesPredictionsLabels value (where key is 'fpl').
363- | 'Predict' >> predict_extractor .TFMAPredict (
364- eval_saved_model_path = eval_saved_model_path ,
365- add_metrics_callbacks = add_metrics_callbacks ,
366- shared_handle = shared_handle ,
367- desired_batch_size = desired_batch_size )
368-
369- # Input: one example fpl at a time
395+ # Input: one example at a time
370396 # Output: one fpl example per slice key (notice that the example turns
371397 # into n, replicated once per applicable slice key)
372398 | 'Slice' >> slice_api .Slice (slice_spec )
@@ -395,6 +421,7 @@ def BuildDiagnosticTable(
395421 # pylint: disable=invalid-name
396422 examples ,
397423 eval_saved_model_path ,
424+ extractors = None ,
398425 desired_batch_size = None ):
399426 """Build diagnostics for the spacified EvalSavedModel and example collection.
400427
@@ -403,18 +430,24 @@ def BuildDiagnosticTable(
403430 (e.g. string containing CSV row, TensorFlow.Example, etc).
404431 eval_saved_model_path: Path to EvalSavedModel. This directory should contain
405432 the saved_model.pb file.
433+ extractors: Optional list of Extractors to execute prior to slicing and
434+ aggregating the metrics. If not provided, a default set will be run.
406435 desired_batch_size: Optional batch size for batching in Predict and
407436 Aggregate.
408437
409438 Returns:
410439 PCollection of ExampleAndExtracts
411440 """
441+
442+ if not extractors :
443+ extractors = [
444+ PredictExtractor (eval_saved_model_path , None , shared .Shared (),
445+ desired_batch_size ),
446+ types .Extractor (
447+ stage_name = 'ExtractFeatures' ,
448+ ptransform = feature_extractor .ExtractFeatures ()),
449+ ]
412450 return (examples
413451 | 'ToExampleAndExtracts' >>
414452 beam .Map (lambda x : types .ExampleAndExtracts (example = x , extracts = {}))
415- | 'Predict' >> predict_extractor .TFMAPredict (
416- eval_saved_model_path ,
417- add_metrics_callbacks = None ,
418- shared_handle = shared .Shared (),
419- desired_batch_size = desired_batch_size )
420- | 'ExtractFeatures' >> feature_extractor .ExtractFeatures ())
453+ | Extract (extractors = extractors ))
0 commit comments