Skip to content

Commit 1ee7a10

Browse files
committed
* Include graph_runner.h and shape_refiner.h for TensorFlow
1 parent 906a8ae commit 1ee7a10

File tree

3 files changed

+198
-9
lines changed

3 files changed

+198
-9
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
* Bundle native resources (header files and import libraries) of MKL-DNN
99
* Make MSBuild compile more efficiently on multiple processors ([pull #599](https://github.com/bytedeco/javacpp-presets/pull/599))
1010
* Add samples for Clang ([pull #598](https://github.com/bytedeco/javacpp-presets/pull/598))
11-
* Include `python_api.h` and enable Python API for TensorFlow ([issue #602](https://github.com/bytedeco/javacpp-presets/issues/602))
11+
* Include `graph_runner.h`, `shape_refiner.h`, `python_api.h`, and enable Python API for TensorFlow ([issue #602](https://github.com/bytedeco/javacpp-presets/issues/602))
1212
* Add presets for Spinnaker 1.15.x ([pull #553](https://github.com/bytedeco/javacpp-presets/pull/553)), CPython 3.6.x, ONNX 1.2.2 ([pull #547](https://github.com/bytedeco/javacpp-presets/pull/547))
1313
* Define `std::vector<tensorflow::OpDef>` type to `OpDefVector` for TensorFlow
1414
* Link HDF5 with zlib on Windows also ([issue deeplearning4j/deeplearning4j#6017](https://github.com/deeplearning4j/deeplearning4j/issues/6017))

tensorflow/src/main/java/org/bytedeco/javacpp/presets/tensorflow.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@
149149
"tensorflow/core/common_runtime/process_function_library_runtime.h",
150150
"tensorflow/core/graph/graph.h",
151151
"tensorflow/core/graph/tensor_id.h",
152+
"tensorflow/core/common_runtime/graph_runner.h",
153+
"tensorflow/core/common_runtime/shape_refiner.h",
152154
"tensorflow/core/framework/node_def_builder.h",
153155
"tensorflow/core/framework/node_def_util.h",
154156
"tensorflow/core/framework/selective_registration.h",
@@ -315,6 +317,8 @@
315317
"tensorflow/core/common_runtime/process_function_library_runtime.h",
316318
"tensorflow/core/graph/graph.h",
317319
"tensorflow/core/graph/tensor_id.h",
320+
"tensorflow/core/common_runtime/graph_runner.h",
321+
"tensorflow/core/common_runtime/shape_refiner.h",
318322
"tensorflow/core/framework/node_def_builder.h",
319323
"tensorflow/core/framework/node_def_util.h",
320324
"tensorflow/core/framework/selective_registration.h",
@@ -600,7 +604,7 @@ public void map(InfoMap infoMap) {
600604
.put(new Info("tensorflow::gtl::FlatMap<TF_Session*,tensorflow::string>").pointerTypes("TF_SessionStringMap").define())
601605

602606
// Skip composite op scopes bc: call to implicitly-deleted default constructor of '::tensorflow::CompositeOpScopes'
603-
.put(new Info("tensorflow::CompositeOpScopes").skip())
607+
.put(new Info("tensorflow::CompositeOpScopes", "tensorflow::ExtendedInferenceContext").skip())
604608

605609
// Fixed shape inference
606610
.put(new Info("std::vector<const tensorflow::Tensor*>").pointerTypes("ConstTensorPtrVector").define())

tensorflow/src/main/java/org/bytedeco/javacpp/tensorflow.java

+192-7
Original file line numberDiff line numberDiff line change
@@ -21817,13 +21817,6 @@ private native void allocate(@Const KernelDef kernel_def, @StringPiece String ke
2181721817
// #include "tensorflow/core/lib/core/status.h"
2181821818
// #include "tensorflow/core/lib/gtl/inlined_vector.h"
2181921819
// #include "tensorflow/core/platform/macros.h"
21820-
21821-
@Namespace("tensorflow") @Opaque public static class ShapeRefiner extends Pointer {
21822-
/** Empty constructor. Calls {@code super((Pointer)null)}. */
21823-
public ShapeRefiner() { super((Pointer)null); }
21824-
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
21825-
public ShapeRefiner(Pointer p) { super(p); }
21826-
}
2182721820
@Namespace("tensorflow") @Opaque public static class ShapeRefinerTest extends Pointer {
2182821821
/** Empty constructor. Calls {@code super((Pointer)null)}. */
2182921822
public ShapeRefinerTest() { super((Pointer)null); }
@@ -30606,6 +30599,198 @@ public static class Hasher extends Pointer {
3060630599
// #endif // TENSORFLOW_GRAPH_TENSOR_ID_H_
3060730600

3060830601

30602+
// Parsed from tensorflow/core/common_runtime/graph_runner.h
30603+
30604+
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
30605+
30606+
Licensed under the Apache License, Version 2.0 (the "License");
30607+
you may not use this file except in compliance with the License.
30608+
You may obtain a copy of the License at
30609+
30610+
http://www.apache.org/licenses/LICENSE-2.0
30611+
30612+
Unless required by applicable law or agreed to in writing, software
30613+
distributed under the License is distributed on an "AS IS" BASIS,
30614+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30615+
See the License for the specific language governing permissions and
30616+
limitations under the License.
30617+
==============================================================================*/
30618+
30619+
// #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_RUNNER_H_
30620+
// #define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_RUNNER_H_
30621+
30622+
// #include <memory>
30623+
// #include <string>
30624+
// #include <vector>
30625+
30626+
// #include "tensorflow/core/common_runtime/device.h"
30627+
// #include "tensorflow/core/framework/function.h"
30628+
// #include "tensorflow/core/framework/tensor.h"
30629+
// #include "tensorflow/core/graph/graph.h"
30630+
// #include "tensorflow/core/lib/core/status.h"
30631+
// #include "tensorflow/core/platform/env.h"
30632+
30633+
// GraphRunner takes a Graph, some inputs to feed, and some outputs
30634+
// to fetch and executes the graph required to feed and fetch the
30635+
// inputs and outputs.
30636+
//
30637+
// This class is only meant for internal use where one needs to
30638+
// partially evaluate inexpensive nodes in a graph, such as for shape
30639+
// inference or for constant folding. Because of its limited, simple
30640+
// use-cases, it executes all computation on the given device (CPU by default)
30641+
// and is not meant to be particularly lightweight, fast, or efficient.
30642+
@Namespace("tensorflow") @NoOffset public static class GraphRunner extends Pointer {
30643+
static { Loader.load(); }
30644+
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
30645+
public GraphRunner(Pointer p) { super(p); }
30646+
30647+
// REQUIRES: `env` is not nullptr.
30648+
public GraphRunner(Env env) { super((Pointer)null); allocate(env); }
30649+
private native void allocate(Env env);
30650+
// REQUIRES: 'device' is not nullptr. Not owned.
30651+
public GraphRunner(Device device) { super((Pointer)null); allocate(device); }
30652+
private native void allocate(Device device);
30653+
30654+
// Function semantics for `inputs`, `output_names` and `outputs`
30655+
// matches those from Session::Run().
30656+
//
30657+
// NOTE: The output tensors share lifetime with the GraphRunner, and could
30658+
// be destroyed once the GraphRunner is destroyed.
30659+
//
30660+
// REQUIRES: `graph`, `env`, and `outputs` are not nullptr.
30661+
// `function_library` may be nullptr.
30662+
public native @ByVal Status Run(Graph graph, FunctionLibraryRuntime function_library,
30663+
@Cast("const tensorflow::GraphRunner::NamedTensorList*") @ByRef StringTensorPairVector inputs,
30664+
@Const @ByRef StringVector output_names,
30665+
TensorVector outputs);
30666+
}
30667+
30668+
// namespace tensorflow
30669+
30670+
// #endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_RUNNER_H_
30671+
30672+
30673+
// Parsed from tensorflow/core/common_runtime/shape_refiner.h
30674+
30675+
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
30676+
30677+
Licensed under the Apache License, Version 2.0 (the "License");
30678+
you may not use this file except in compliance with the License.
30679+
You may obtain a copy of the License at
30680+
30681+
http://www.apache.org/licenses/LICENSE-2.0
30682+
30683+
Unless required by applicable law or agreed to in writing, software
30684+
distributed under the License is distributed on an "AS IS" BASIS,
30685+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30686+
See the License for the specific language governing permissions and
30687+
limitations under the License.
30688+
==============================================================================*/
30689+
// #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_
30690+
// #define TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_
30691+
30692+
// #include <vector>
30693+
30694+
// #include "tensorflow/core/common_runtime/graph_runner.h"
30695+
// #include "tensorflow/core/framework/function.pb.h"
30696+
// #include "tensorflow/core/framework/shape_inference.h"
30697+
// #include "tensorflow/core/graph/graph.h"
30698+
// #include "tensorflow/core/lib/core/status.h"
30699+
// #include "tensorflow/core/platform/macros.h"
30700+
30701+
30702+
// This class stores extra inference information in addition to
30703+
// InferenceContext, such as inference tree for user-defined functions and node
30704+
// input and output types.
30705+
30706+
// ShapeRefiner performs shape inference for TensorFlow Graphs. It is
30707+
// responsible for instantiating InferenceContext objects for each
30708+
// Node in the Graph, and providing/storing the 'input_tensor' Tensors
30709+
// used by Shape Inference functions, when available at graph
30710+
// construction time.
30711+
@Namespace("tensorflow") @NoOffset public static class ShapeRefiner extends Pointer {
30712+
static { Loader.load(); }
30713+
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
30714+
public ShapeRefiner(Pointer p) { super(p); }
30715+
30716+
public ShapeRefiner(int graph_def_version, @Const OpRegistryInterface ops) { super((Pointer)null); allocate(graph_def_version, ops); }
30717+
private native void allocate(int graph_def_version, @Const OpRegistryInterface ops);
30718+
30719+
// Same as ShapeRefiner(versions.producer(), ops)
30720+
public ShapeRefiner(@Const @ByRef VersionDef versions, @Const OpRegistryInterface ops) { super((Pointer)null); allocate(versions, ops); }
30721+
private native void allocate(@Const @ByRef VersionDef versions, @Const OpRegistryInterface ops);
30722+
30723+
// Performs validation of 'node' and runs 'node's shape function,
30724+
// storing its shape outputs.
30725+
//
30726+
// All inputs of 'node' must be added to ShapeRefiner prior to
30727+
// adding 'node'.
30728+
//
30729+
// Returns an error if:
30730+
// - the shape function for 'node' was not registered.
30731+
// - 'node' was added before its inputs.
30732+
// - The shape inference function returns an error.
30733+
public native @ByVal Status AddNode(@Const Node node);
30734+
30735+
// Sets 'node's 'output_port' output to have shape 'shape'.
30736+
//
30737+
// Returns an error if 'node' was not previously added to this
30738+
// object, if 'output_port' is invalid, or if 'shape' is
30739+
// not compatible with the existing shape of the output.
30740+
public native @ByVal Status SetShape(@Const Node node, int output_port,
30741+
@ByVal ShapeHandle shape);
30742+
30743+
// Update the input shapes of node in case the shapes of the fan-ins of 'node'
30744+
// have themselves been modified (For example, in case of incremental shape
30745+
// refinement). If 'relax' is true, a new shape with the broadest set of
30746+
// information will be set as the new input (see InferenceContext::RelaxInput
30747+
// for full details and examples). Sets refined to true if any shapes have
30748+
// changed (in their string representations). Note that shapes may have been
30749+
// updated to newer versions (but with identical string representations) even
30750+
// if <*refined> is set to false.
30751+
public native @ByVal Status UpdateNode(@Const Node node, @Cast("bool") boolean relax, @Cast("bool*") BoolPointer refined);
30752+
public native @ByVal Status UpdateNode(@Const Node node, @Cast("bool") boolean relax, @Cast("bool*") boolean... refined);
30753+
30754+
// Returns the InferenceContext for 'node', if present.
30755+
public native InferenceContext GetContext(@Const Node node);
30756+
30757+
// Returns the ExtendedInferenceContext for 'node', if present.
30758+
30759+
// Getters and setters for graph_def_version_.
30760+
public native int graph_def_version();
30761+
public native void set_graph_def_version(int version);
30762+
30763+
public native void set_require_shape_inference_fns(@Cast("bool") boolean require_shape_inference_fns);
30764+
public native void set_disable_constant_propagation(@Cast("bool") boolean disable);
30765+
30766+
// Set function library to enable function shape inference.
30767+
// Without function library, function inference always yields unknown shapes.
30768+
// With this enabled, shape inference can take more time since it descends
30769+
// into all function calls. It doesn't do inference once for each function
30770+
// definition, but once for each function call.
30771+
// The function library must outlive the shape refiner.
30772+
public native void set_function_library_for_shape_inference(
30773+
@Const FunctionLibraryDefinition lib);
30774+
30775+
public native @Cast("bool") boolean function_shape_inference_supported();
30776+
30777+
// Call this to keep nested shapes information for user-defined functions:
30778+
// nested inferences will be available on the ExtendedInferenceContext for
30779+
// each function node, forming a tree of shape inferences corresponding to the
30780+
// tree of nested function calls. By default this setting is disabled, and
30781+
// only the shapes for the top-level function node will be reported on the
30782+
// InferenceContext for each function node, to reduce memory usage.
30783+
//
30784+
// This flag has no effect when the function inference is not enabled via
30785+
// set_function_library_for_shape_inference.
30786+
public native void set_keep_nested_shape_inferences();
30787+
}
30788+
30789+
// namespace tensorflow
30790+
30791+
// #endif // TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_
30792+
30793+
3060930794
// Parsed from tensorflow/core/framework/node_def_builder.h
3061030795

3061130796
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.

0 commit comments

Comments
 (0)