Skip to content

Commit

Permalink
{set,use}_ops: switch PyTorch Tensor type (#553)
Browse files Browse the repository at this point in the history
{set,use}_ops did not change the default PyTorch Tensor type, whereas
require_{cpu, gpu} do. As a result, these set/use functions require more
work to use correctly with PyTorch

This change rectifies this by also changing the default PyTorch Tensor
type when {set,use}_ops are used.
  • Loading branch information
danieldk authored Oct 28, 2021
1 parent 5f7fb42 commit 6115392
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 9 deletions.
3 changes: 2 additions & 1 deletion thinc/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ._cupy_allocators import cupy_tensorflow_allocator, cupy_pytorch_allocator
from ._param_server import ParamServer
from ..util import assert_tensorflow_installed, assert_pytorch_installed
from ..util import is_cupy_array
from ..util import is_cupy_array, set_torch_tensor_type_for_ops
from .. import registry


Expand Down Expand Up @@ -129,6 +129,7 @@ def set_current_ops(ops: Ops) -> None:
"""Change the current backend object."""
context_ops.set(ops)
_get_thread_state().ops = ops
set_torch_tensor_type_for_ops(ops)


def contextvars_eq_thread_ops() -> bool:
Expand Down
28 changes: 20 additions & 8 deletions thinc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,9 @@ def require_cpu() -> bool: # pragma: no cover
"""Use CPU through best available backend."""
from .backends import set_current_ops, get_ops

try:
import torch

torch.set_default_tensor_type("torch.FloatTensor")
except ImportError:
pass

set_current_ops(get_ops("cpu"))
ops = get_ops("cpu")
set_current_ops(ops)
set_torch_tensor_type_for_ops(ops)

return True

Expand Down Expand Up @@ -480,6 +475,22 @@ def use_nvtx_range(message: int, id_color: int = -1):
yield


def set_torch_tensor_type_for_ops(ops):
"""Set the PyTorch default tensor type for the given ops. This is a
no-op if PyTorch is not available."""
from .backends.cupy_ops import CupyOps

try:
import torch

if CupyOps.xp is not None and isinstance(ops, CupyOps):
torch.set_default_tensor_type("torch.cuda.FloatTensor")
else:
torch.set_default_tensor_type("torch.FloatTensor")
except ImportError:
pass


__all__ = [
"get_array_module",
"fix_random_seed",
Expand All @@ -499,4 +510,5 @@ def use_nvtx_range(message: int, id_color: int = -1):
"DataValidationError",
"make_tempfile",
"use_nvtx_range",
"set_torch_tensor_type_for_ops",
]

0 comments on commit 6115392

Please sign in to comment.