-
Notifications
You must be signed in to change notification settings - Fork 531
Open
Labels
enhancementNew feature or requestNew feature or requestgood first issueGood for newcomersGood for newcomershelp wantedExtra attention is neededExtra attention is needed
Description
We would like to forward a particular 'key' column which is part of the features to appear alongside the predictions - this is to be able to identify to which set of features a particular prediction belongs to. Here is an example of predictions output using the tensorflow.contrib.estimator.multi_class_head:
{"classes": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
"scores": [0.06819603592157364, 0.0864366963505745, 0.12838752567768097, 0.046013250946998596, 0.03129083290696144, 0.1518409103155136, 0.1248951405286789, 0.15043732523918152, 0.0821763351559639, 0.13032598793506622]}
We would therefore like to add a key attribute to this prediction.
estimator = tf.contrib.estimator.forward_features(estimator, ['key'])
gives the following error:
The adanet.Estimator's model_fn should not be called directly in TRAIN mode, because its behavior is undefined outside the context of its train method.
The current workaround is to subclass the head
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or requestgood first issueGood for newcomersGood for newcomershelp wantedExtra attention is neededExtra attention is needed