Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/explosion/thinc
Browse files Browse the repository at this point in the history
  • Loading branch information
honnibal committed Jan 21, 2021
2 parents 0528e13 + 1233d0b commit fa80c06
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 9 deletions.
1 change: 1 addition & 0 deletions thinc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .layers import Dropout, Embed, expand_window, HashEmbed, LayerNorm, Linear
from .layers import Maxout, Mish, MultiSoftmax, Relu, softmax_activation, Softmax, LSTM
from .layers import CauchySimilarity, ParametricAttention, Logistic
from .layers import sigmoid_activation, Sigmoid
from .layers import SparseLinear
from .layers import PyTorchWrapper, PyTorchRNNWrapper, PyTorchLSTM
from .layers import TensorFlowWrapper, keras_subclass, MXNetWrapper
Expand Down
13 changes: 7 additions & 6 deletions thinc/backends/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

ArrayT = TypeVar("ArrayT", bound=ArrayXd)
FloatsT = TypeVar("FloatsT", bound=_Floats)
FloatsType = TypeVar("FloatsType", bound=FloatsXd)


class Ops:
Expand Down Expand Up @@ -558,16 +559,16 @@ def as_contig(self, data: ArrayT, dtype: Optional[DTypes] = None) -> ArrayT:
kwargs = {"dtype": dtype} if dtype is not None else {}
return self.xp.ascontiguousarray(data, **kwargs)

def sigmoid(self, X: FloatsT, *, inplace: bool = False) -> FloatsT:
def sigmoid(self, X: FloatsType, *, inplace: bool = False) -> FloatsType:
if inplace:
self.xp.exp(-X, out=X)
X += 1.0
X **= -1.0
return X
X += 1.0 # type: ignore
X **= -1.0 # type: ignore
return cast(FloatsType, X)
else:
return 1.0 / (1.0 + self.xp.exp(-X))
return cast(FloatsType, 1.0 / (1.0 + self.xp.exp(-X)))

def dsigmoid(self, Y: FloatsT, *, inplace: bool = False) -> FloatsT:
def dsigmoid(self, Y: FloatsType, *, inplace: bool = False) -> FloatsType:
if inplace:
Y *= 1 - Y
return Y
Expand Down
10 changes: 7 additions & 3 deletions thinc/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,19 @@
from .hashembed import HashEmbed
from .layernorm import LayerNorm
from .linear import Linear
from .lstm import LSTM, PyTorchLSTM
from .logistic import Logistic
from .maxout import Maxout
from .mish import Mish
from .multisoftmax import MultiSoftmax
from .parametricattention import ParametricAttention
from .pytorchwrapper import PyTorchWrapper, PyTorchRNNWrapper
from .relu import Relu
from .sigmoid_activation import sigmoid_activation
from .sigmoid import Sigmoid
from .softmax_activation import softmax_activation
from .softmax import Softmax
from .sparselinear import SparseLinear
from .lstm import LSTM, PyTorchLSTM
from .tensorflowwrapper import TensorFlowWrapper, keras_subclass
from .mxnetwrapper import MXNetWrapper

Expand Down Expand Up @@ -69,18 +71,20 @@
"expand_window",
"HashEmbed",
"LayerNorm",
"LSTM",
"Maxout",
"Mish",
"MultiSoftmax",
"ParametricAttention",
"PyTorchLSTM",
"PyTorchWrapper",
"PyTorchRNNWrapper",
"Relu",
"sigmoid_activation",
"Sigmoid"
"softmax_activation",
"Softmax",
"SparseLinear",
"LSTM",
"PyTorchLSTM",
"TensorFlowWrapper",
"add",
"bidirectional",
Expand Down
3 changes: 3 additions & 0 deletions thinc/layers/logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

@registry.layers("Logistic.v1")
def Logistic() -> Model[InT, OutT]:
"""Deprecated in favor of `sigmoid_activation` layer, for more consistent
naming.
"""
return Model("logistic", forward)


Expand Down
63 changes: 63 additions & 0 deletions thinc/layers/sigmoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import Tuple, Callable, Optional, cast

from ..model import Model
from ..config import registry
from ..types import Floats2d, Floats1d
from ..initializers import zero_init
from ..util import get_width, partial


InT = Floats2d
OutT = Floats2d


@registry.layers("Sigmoid.v1")
def Sigmoid(
nO: Optional[int] = None,
nI: Optional[int] = None,
*,
init_W: Callable = zero_init,
init_b: Callable = zero_init
) -> Model[InT, OutT]:
"""A dense layer, followed by a sigmoid (logistic) activation function. This
is usually used instead of the Softmax layer as an output for multi-label
classification.
"""
return Model(
"sigmoid",
forward,
init=partial(init, init_W, init_b),
dims={"nO": nO, "nI": nI},
params={"W": None, "b": None},
)


def forward(model: Model[InT, OutT], X: InT, is_train: bool) -> Tuple[OutT, Callable]:
W = cast(Floats2d, model.get_param("W"))
b = cast(Floats1d, model.get_param("b"))
Y = model.ops.affine(X, W, b)
Y = model.ops.sigmoid(Y)

def backprop(dY: InT) -> OutT:
dY = dY * model.ops.dsigmoid(Y, inplace=False)
model.inc_grad("b", dY.sum(axis=0))
model.inc_grad("W", model.ops.gemm(dY, X, trans1=True))
return model.ops.gemm(dY, W)

return Y, backprop


def init(
init_W: Callable,
init_b: Callable,
model: Model[InT, OutT],
X: Optional[InT] = None,
Y: Optional[OutT] = None,
) -> Model[InT, OutT]:
if X is not None and model.has_dim("nI") is None:
model.set_dim("nI", get_width(X))
if Y is not None and model.has_dim("nO") is None:
model.set_dim("nO", get_width(Y))
model.set_param("W", init_W(model.ops, (model.get_dim("nO"), model.get_dim("nI"))))
model.set_param("b", init_b(model.ops, (model.get_dim("nO"),)))
return model
22 changes: 22 additions & 0 deletions thinc/layers/sigmoid_activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import TypeVar, Tuple, Callable, cast

from ..model import Model
from ..config import registry
from ..types import FloatsXd


InT = TypeVar("InT", bound=FloatsXd)


@registry.layers("sigmoid_activation.v1")
def sigmoid_activation() -> Model[InT, InT]:
return Model("sigmoid_activation", forward)


def forward(model: Model[InT, InT], X: InT, is_train: bool) -> Tuple[InT, Callable]:
Y = model.ops.sigmoid(X, inplace=False)

def backprop(dY: InT) -> InT:
return dY * model.ops.dsigmoid(Y, inplace=False) # type: ignore

return Y, backprop
4 changes: 4 additions & 0 deletions thinc/tests/layers/test_layers_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ def assert_data_match(Y, out_data):
("Mish.v1", {"normalize": True, "dropout": 0.2}, array2d, array2d),
("Relu.v1", {}, array2d, array2d),
("Relu.v1", {"normalize": True, "dropout": 0.2}, array2d, array2d),
("Sigmoid.v1", {}, array2d, array2d),
("Sigmoid.v1", {"nO": 4, "nI": 4}, array2d, array2d),
("sigmoid_activation.v1", {}, array2d, array2d),
("softmax_activation.v1", {}, array2d, array2d),
("Softmax.v1", {}, array2d, array2d),
("Softmax.v1", {"nO": 4, "nI": 4}, array2d, array2d),
# fmt: off
Expand Down

0 comments on commit fa80c06

Please sign in to comment.