Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 37 additions & 23 deletions sru/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@
from torch import Tensor
from torch.nn.utils.rnn import PackedSequence

from sru.ops import (elementwise_recurrence_inference,
elementwise_recurrence_gpu,
elementwise_recurrence_naive)


class SRUCell(nn.Module):
"""
Expand All @@ -27,6 +23,11 @@ class SRUCell(nn.Module):
scale_x: Tensor
weight_proj: Optional[Tensor]

initialized = False
elementwise_recurrence_inference = None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a note about this function initialization to the SRUCell docstring or the docstring for sru/modules.py (it lacks a module docstring now, it probably should have one)

Copy link
Author

@dkasapp dkasapp Dec 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

elementwise_recurrence_gpu = None
elementwise_recurrence_naive = None

def __init__(self,
input_size: int,
hidden_size: int,
Expand Down Expand Up @@ -160,6 +161,19 @@ def __init__(self,
self.layer_norm = nn.LayerNorm(self.input_size)

self.reset_parameters()
SRUCell.init_elementwise_recurrence_funcs()

@classmethod
def init_elementwise_recurrence_funcs(cls):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a docstring to this method please

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

if cls.initialized:
return
from sru.ops import (elementwise_recurrence_inference,
elementwise_recurrence_gpu,
elementwise_recurrence_naive)
cls.elementwise_recurrence_inference = elementwise_recurrence_inference
cls.elementwise_recurrence_gpu = elementwise_recurrence_gpu
cls.elementwise_recurrence_naive = elementwise_recurrence_naive
cls.initialized = True

def reset_parameters(self):
"""Properly initialize the weights of SRU, following the same
Expand Down Expand Up @@ -295,27 +309,27 @@ def apply_recurrence(self,
"""
if not torch.jit.is_scripting():
if self.bias.is_cuda:
return elementwise_recurrence_gpu(U, residual, V, self.bias, c0,
self.activation_type,
self.hidden_size,
self.bidirectional,
self.has_skip_term,
scale_val, mask_c, mask_pad,
self.amp_recurrence_fp16)
return SRUCell.elementwise_recurrence_gpu(U, residual, V, self.bias, c0,
self.activation_type,
self.hidden_size,
self.bidirectional,
self.has_skip_term,
scale_val, mask_c, mask_pad,
self.amp_recurrence_fp16)
else:
return elementwise_recurrence_naive(U, residual, V, self.bias, c0,
self.activation_type,
self.hidden_size,
self.bidirectional,
self.has_skip_term,
scale_val, mask_c, mask_pad)
return SRUCell.elementwise_recurrence_naive(U, residual, V, self.bias, c0,
self.activation_type,
self.hidden_size,
self.bidirectional,
self.has_skip_term,
scale_val, mask_c, mask_pad)
else:
return elementwise_recurrence_inference(U, residual, V, self.bias, c0,
self.activation_type,
self.hidden_size,
self.bidirectional,
self.has_skip_term,
scale_val, mask_c, mask_pad)
return SRUCell.elementwise_recurrence_inference(U, residual, V, self.bias, c0,
self.activation_type,
self.hidden_size,
self.bidirectional,
self.has_skip_term,
scale_val, mask_c, mask_pad)


def compute_UV(self,
Expand Down
8 changes: 8 additions & 0 deletions test/test_1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import torch


# Run at the beginning of the test suites to ensure no previous use of SRUCells
def test_no_eager_cuda_init():
# Notice the test is expected to pass both with GPU available and without it
import sru
assert not torch.cuda.is_initialized()