Skip to content
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
version: 2.1

orbs:
python: circleci/python@0.2.1
python: circleci/python@1.5.0

jobs:
unit-tests:
Expand All @@ -11,12 +11,12 @@ jobs:
- run:
name: setup
command: |
virtualenv -p python3.7 .venv
virtualenv -p python3.8 .venv
source .venv/bin/activate
# pip install -q torch==1.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
# temporary workaround for https://github.com/pytorch/pytorch/issues/49560
wget https://download.pytorch.org/whl/cpu/torch-1.6.0%2Bcpu-cp37-cp37m-linux_x86_64.whl
pip install -q torch-1.6.0+cpu-cp37-cp37m-linux_x86_64.whl
wget https://download.pytorch.org/whl/cpu/torch-1.6.0%2Bcpu-cp38-cp38-linux_x86_64.whl
pip install -q torch-1.6.0+cpu-cp38-cp38-linux_x86_64.whl
pip install -q .
pip install -q -r requirements-test.txt
- run:
Expand All @@ -33,13 +33,14 @@ jobs:
- run:
name: setup
command: |
virtualenv -p python3.7 .venv
virtualenv -p python3.8 .venv
source .venv/bin/activate
# pip install -q torch==1.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
# temporary workaround for https://github.com/pytorch/pytorch/issues/49560
wget https://download.pytorch.org/whl/cpu/torch-1.6.0%2Bcpu-cp37-cp37m-linux_x86_64.whl
pip install -q torch-1.6.0+cpu-cp37-cp37m-linux_x86_64.whl
wget https://download.pytorch.org/whl/cpu/torch-1.6.0%2Bcpu-cp38-cp38-linux_x86_64.whl
pip install -q torch-1.6.0+cpu-cp38-cp38-linux_x86_64.whl
pip install -q .
sudo apt-get update -y
sudo apt-get -y install cmake
- run:
name: Test
Expand Down
73 changes: 50 additions & 23 deletions sru/modules.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
"""
This module implements core classes SRU and SRUCell.

Implementation note 1: We have postponed the import of sru.ops to the first SRUCell
instantiation in order to ensure CUDA init takes place in the process that will be
running the model. Please see the class method init_elementwise_recurrence_funcs.
"""

import copy
import warnings
import math
Expand All @@ -8,10 +16,6 @@
from torch import Tensor
from torch.nn.utils.rnn import PackedSequence

from sru.ops import (elementwise_recurrence_inference,
elementwise_recurrence_gpu,
elementwise_recurrence_naive)


class SRUCell(nn.Module):
"""
Expand All @@ -27,6 +31,11 @@ class SRUCell(nn.Module):
scale_x: Tensor
weight_proj: Optional[Tensor]

initialized = False
elementwise_recurrence_inference = None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a note about this function initialization to the SRUCell docstring or the docstring for sru/modules.py (it lacks a module docstring now, it probably should have one)

Copy link
Author

@dkasapp dkasapp Dec 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

elementwise_recurrence_gpu = None
elementwise_recurrence_naive = None

def __init__(self,
input_size: int,
hidden_size: int,
Expand Down Expand Up @@ -160,6 +169,24 @@ def __init__(self,
self.layer_norm = nn.LayerNorm(self.input_size)

self.reset_parameters()
SRUCell.init_elementwise_recurrence_funcs()

@classmethod
def init_elementwise_recurrence_funcs(cls):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a docstring to this method please

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

"""
Initializes the elementwise recurrence functions. This is postponed to the creation
of the first SRUCell instance because we want to avoid eager CUDA initialization and
ensure it takes place in the process running the model.
"""
if cls.initialized:
return
from sru.ops import (elementwise_recurrence_inference,
elementwise_recurrence_gpu,
elementwise_recurrence_naive)
cls.elementwise_recurrence_inference = elementwise_recurrence_inference
cls.elementwise_recurrence_gpu = elementwise_recurrence_gpu
cls.elementwise_recurrence_naive = elementwise_recurrence_naive
cls.initialized = True

def reset_parameters(self):
"""Properly initialize the weights of SRU, following the same
Expand Down Expand Up @@ -295,27 +322,27 @@ def apply_recurrence(self,
"""
if not torch.jit.is_scripting():
if self.bias.is_cuda:
return elementwise_recurrence_gpu(U, residual, V, self.bias, c0,
self.activation_type,
self.hidden_size,
self.bidirectional,
self.has_skip_term,
scale_val, mask_c, mask_pad,
self.amp_recurrence_fp16)
return SRUCell.elementwise_recurrence_gpu(U, residual, V, self.bias, c0,
self.activation_type,
self.hidden_size,
self.bidirectional,
self.has_skip_term,
scale_val, mask_c, mask_pad,
self.amp_recurrence_fp16)
else:
return elementwise_recurrence_naive(U, residual, V, self.bias, c0,
self.activation_type,
self.hidden_size,
self.bidirectional,
self.has_skip_term,
scale_val, mask_c, mask_pad)
return SRUCell.elementwise_recurrence_naive(U, residual, V, self.bias, c0,
self.activation_type,
self.hidden_size,
self.bidirectional,
self.has_skip_term,
scale_val, mask_c, mask_pad)
else:
return elementwise_recurrence_inference(U, residual, V, self.bias, c0,
self.activation_type,
self.hidden_size,
self.bidirectional,
self.has_skip_term,
scale_val, mask_c, mask_pad)
return SRUCell.elementwise_recurrence_inference(U, residual, V, self.bias, c0,
self.activation_type,
self.hidden_size,
self.bidirectional,
self.has_skip_term,
scale_val, mask_c, mask_pad)


def compute_UV(self,
Expand Down
8 changes: 8 additions & 0 deletions test/test_1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import torch


# Run at the beginning of the test suites to ensure no previous use of SRUCells
def test_no_eager_cuda_init():
# Notice the test is expected to pass both with GPU available and without it
import sru
assert not torch.cuda.is_initialized()