Skip to content

Commit e8a8b64

Browse files
committed
[SW-51190] Enable HPU accelerator on pytorch lightning 1.4.0
Change-Id: Ia8128a48ae0239cfc371c7cd983339219d29b0e6
1 parent c7f8c8c commit e8a8b64

File tree

18 files changed

+286
-6
lines changed

18 files changed

+286
-6
lines changed
+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import os
2+
3+
import torch
4+
from torch import nn
5+
from torch.nn import functional as F
6+
from torch.utils.data import DataLoader, random_split
7+
from torchvision.datasets import MNIST
8+
from torchvision import transforms
9+
import pytorch_lightning as pl
10+
from pytorch_lightning.metrics.functional import accuracy
11+
import sys
12+
13+
from habana_frameworks.torch.utils.library_loader import load_habana_module
14+
load_habana_module()
15+
16+
class MNISTModel(pl.LightningModule):
17+
18+
def __init__(self):
19+
super(MNISTModel, self).__init__()
20+
self.l1 = torch.nn.Linear(28 * 28, 10)
21+
22+
def forward(self, x):
23+
return torch.relu(self.l1(x.view(x.size(0), -1)))
24+
25+
def training_step(self, batch, batch_nb):
26+
x, y = batch
27+
loss = F.cross_entropy(self(x), y)
28+
return loss
29+
30+
def configure_optimizers(self):
31+
return torch.optim.Adam(self.parameters(), lr=0.02)
32+
33+
# Init our model
34+
mnist_model = MNISTModel()
35+
36+
# Init DataLoader from MNIST Dataset
37+
train_ds = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
38+
train_loader = DataLoader(train_ds, batch_size=32)
39+
40+
# Initialize a trainer
41+
trainer = pl.Trainer(hpus=1, max_epochs=3, progress_bar_refresh_rate=20)
42+
43+
# Train the model ⚡
44+
trainer.fit(mnist_model, train_loader)

pytorch_lightning/accelerators/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@
1515
from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa F401
1616
from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa F401
1717
from pytorch_lightning.accelerators.tpu import TPUAccelerator # noqa F401
18+
from pytorch_lightning.accelerators.hpu import HPUAccelerator # noqa F401

pytorch_lightning/accelerators/accelerator.py

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class Accelerator:
4444
- CPU
4545
- GPU
4646
- TPU
47+
- HPU
4748
4849
Each Accelerator gets two plugins upon initialization:
4950
One to handle differences from the training routine and one to handle different precisions.

pytorch_lightning/accelerators/hpu.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import logging
15+
import os
16+
from typing import Any
17+
18+
import torch
19+
20+
import pytorch_lightning as pl
21+
from pytorch_lightning.accelerators.accelerator import Accelerator
22+
from pytorch_lightning.plugins import DataParallelPlugin
23+
from pytorch_lightning.plugins.training_type.single_hpu import SingleHPUPlugin
24+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
25+
26+
_log = logging.getLogger(__name__)
27+
28+
29+
30+
class HPUAccelerator(Accelerator):
31+
""" Accelerator for HPU devices. """
32+
33+
def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
34+
"""
35+
Raises:
36+
MisconfigurationException:
37+
If the selected device is not HPU.
38+
"""
39+
if "hpu" not in str(self.root_device):
40+
raise MisconfigurationException(f"Device should be HPU, got {self.root_device} instead")
41+
return super().setup(trainer, model)
42+
43+
def to_device(self, batch: Any) -> Any:
44+
# no need to transfer batch to device in DP mode
45+
# TODO: Add support to allow batch transfer to device in Lightning for DP mode.
46+
#if isinstance(self.training_type_plugin, SingleHPUPlugin):
47+
if not isinstance(self.training_type_plugin, DataParallelPlugin):
48+
batch = super().to_device(batch)
49+
50+
return batch

pytorch_lightning/core/lightning.py

+9
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,14 @@ def on_gpu(self):
242242
"""
243243
return self.device.type == "cuda"
244244

245+
@property
246+
def on_hpu(self):
247+
"""
248+
True if your model is currently running on HPUs.
249+
Useful to set flags around the LightningModule for different CPU vs GPU vs HPU behavior.
250+
"""
251+
return self.device.type == "hpu"
252+
245253
@property
246254
def automatic_optimization(self) -> bool:
247255
"""
@@ -1525,6 +1533,7 @@ def optimizer_step(
15251533
optimizer_idx: int = None,
15261534
optimizer_closure: Optional[Callable] = None,
15271535
on_tpu: bool = None,
1536+
on_hpu: bool = None,
15281537
using_native_amp: bool = None,
15291538
using_lbfgs: bool = None,
15301539
) -> None:

pytorch_lightning/plugins/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin # noqa: F401
2828
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin # noqa: F401
2929
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin # noqa: F401
30+
from pytorch_lightning.plugins.training_type.single_hpu import SingleHPUPlugin # noqa: F401
3031
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin # noqa: F401
3132
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin # noqa: F401
3233

@@ -49,6 +50,7 @@
4950
"FullyShardedNativeMixedPrecisionPlugin",
5051
"SingleDevicePlugin",
5152
"SingleTPUPlugin",
53+
"SingleHPUPlugin",
5254
"TPUHalfPrecisionPlugin",
5355
"TPUSpawnPlugin",
5456
"TrainingTypePlugin",

pytorch_lightning/plugins/training_type/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@
1010
from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin # noqa: F401
1111
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin # noqa: F401
1212
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin # noqa: F401
13+
from pytorch_lightning.plugins.training_type.single_hpu import SingleHPUPlugin # noqa: F401
1314
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin # noqa: F401
1415
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin # noqa: F401

pytorch_lightning/plugins/training_type/ddp.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def configure_ddp(self):
308308
self._register_ddp_hooks()
309309

310310
def determine_ddp_device_ids(self):
311-
if self.root_device.type == "cpu":
311+
if self.root_device.type == "cpu" or self.root_device.type == "hpu" :
312312
return None
313313
return [self.root_device.index]
314314

@@ -317,6 +317,11 @@ def init_ddp_connection(self, global_rank: Optional[int] = None, world_size: Opt
317317
world_size = world_size if world_size is not None else self.cluster_environment.world_size()
318318
os.environ["MASTER_ADDR"] = self.cluster_environment.master_address()
319319
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
320+
321+
if self.torch_distributed_backend == 'hcl':
322+
os.environ["ID"] = str(self.local_rank)
323+
import habana_torch_hcl
324+
320325
if torch.distributed.is_available() and not torch.distributed.is_initialized():
321326
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
322327
torch.distributed.init_process_group(

pytorch_lightning/plugins/training_type/single_device.py

+4
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ def __init__(self, device: torch.device):
3333
def on_tpu(self) -> bool:
3434
return self.root_device.type == "xla" and _XLA_AVAILABLE
3535

36+
@property
37+
def on_hpu(self) -> bool:
38+
return self.device.type == "hpu"
39+
3640
@property
3741
def on_gpu(self) -> bool:
3842
return self.root_device.type == "cuda" and torch.cuda.is_available()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import torch
15+
16+
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
17+
from pytorch_lightning.utilities import _HPU_AVAILABLE
18+
from pytorch_lightning.utilities.apply_func import move_data_to_device
19+
20+
21+
class SingleHPUPlugin(SingleDevicePlugin):
22+
23+
def __init__(self, device: int):
24+
25+
device = torch.device("hpu")
26+
super().__init__(device)
27+
28+
self.hpu_local_core_rank = 0
29+
self.hpu_global_core_rank = 0
30+
31+
@property
32+
def on_hpu(self) -> bool:
33+
return True
34+
35+
def connect(self, model: torch.nn.Module) -> torch.nn.Module:
36+
self._model = model
37+
self.model_to_device()
38+
return self._model
39+
40+
@property
41+
def is_distributed(self) -> bool:
42+
return False
43+
44+
def model_to_device(self) -> None:
45+
self._model.to(self.root_device)
46+
47+
def pre_dispatch(self) -> None:
48+
if isinstance(self.device, int):
49+
self.device = torch.device(self.device)
50+
51+
self.hpu_local_core_rank = 0
52+
self.hpu_global_core_rank = 0
53+
54+
def on_save(self, checkpoint: dict) -> dict:
55+
"""
56+
Need to check how this part works
57+
Move XLA tensors to CPU before saving
58+
Recommended on XLA Guide:
59+
https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors
60+
"""
61+
return move_data_to_device(checkpoint, torch.device("cpu"))

0 commit comments

Comments
 (0)