@@ -6,7 +6,7 @@ from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
66from contextlib import contextmanager
77from enum import Enum
88from types import TracebackType
9- from typing import Any , Generic , NoReturn , TypeVar , overload
9+ from typing import Any , Generic , Literal , NoReturn , TypeVar , overload
1010from typing_extensions import ParamSpec , Self
1111
1212from google .protobuf .message import Message
@@ -20,7 +20,17 @@ from tensorflow import (
2020 math as math ,
2121 types as types ,
2222)
23- from tensorflow ._aliases import AnyArray , DTypeLike , ShapeLike , Slice , TensorCompatible
23+ from tensorflow ._aliases import (
24+ AnyArray ,
25+ DTypeLike ,
26+ IntArray ,
27+ ScalarTensorCompatible ,
28+ ShapeLike ,
29+ Slice ,
30+ SparseTensorCompatible ,
31+ TensorCompatible ,
32+ UIntTensorCompatible ,
33+ )
2434from tensorflow .autodiff import GradientTape as GradientTape
2535from tensorflow .core .protobuf import struct_pb2
2636from tensorflow .dtypes import *
@@ -56,6 +66,7 @@ from tensorflow.math import (
5666 reduce_min as reduce_min ,
5767 reduce_prod as reduce_prod ,
5868 reduce_sum as reduce_sum ,
69+ round as round ,
5970 sigmoid as sigmoid ,
6071 sign as sign ,
6172 sin as sin ,
@@ -403,4 +414,22 @@ def ones_like(
403414 input : RaggedTensor , dtype : DTypeLike | None = None , name : str | None = None , layout : Layout | None = None
404415) -> RaggedTensor : ...
405416def reshape (tensor : TensorCompatible , shape : ShapeLike | Tensor , name : str | None = None ) -> Tensor : ...
417+ def pad (
418+ tensor : TensorCompatible ,
419+ paddings : Tensor | IntArray | Iterable [Iterable [int ]],
420+ mode : Literal ["CONSTANT" , "constant" , "REFLECT" , "reflect" , "SYMMETRIC" , "symmectric" ] = "CONSTANT" ,
421+ constant_values : ScalarTensorCompatible = 0 ,
422+ name : str | None = None ,
423+ ) -> Tensor : ...
424+ def shape (input : SparseTensorCompatible , out_type : DTypeLike | None = None , name : str | None = None ) -> Tensor : ...
425+ def where (
426+ condition : TensorCompatible , x : TensorCompatible | None = None , y : TensorCompatible | None = None , name : str | None = None
427+ ) -> Tensor : ...
428+ def gather_nd (
429+ params : TensorCompatible ,
430+ indices : UIntTensorCompatible ,
431+ batch_dims : UIntTensorCompatible = 0 ,
432+ name : str | None = None ,
433+ bad_indices_policy : Literal ["" , "DEFAULT" , "ERROR" , "IGNORE" ] = "" ,
434+ ) -> Tensor : ...
406435def __getattr__ (name : str ) -> Incomplete : ...
0 commit comments