Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] add mypy & isort #45

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,31 @@
fail_fast: true
default_stages: [commit]
exclude: ".git"

repos:
- repo: https://github.com/pausan/cblack
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- id: check-merge-conflict

- repo: https://github.com/pausan/cblack
rev: release-22.3.0
hooks:
- id: cblack
name: cblack
description: "Black: The uncompromising Python code formatter - 2 space indent fork"
entry: cblack . -l 100
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
- id: cblack
name: cblack
description: "Black: The uncompromising Python code formatter - 2 space indent fork"
entry: cblack . -l 100

- repo: https://github.com/psf/black
rev: 23.1.0
hooks:
- id: black

- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- id: check-merge-conflict
- id: isort
46 changes: 32 additions & 14 deletions common/batch.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,40 @@
"""Extension of torchrec.dataset.utils.Batch to cover any dataset.
"""
# flake8: noqa
from __future__ import annotations
from typing import Dict
from __future__ import (
annotations,
)

import abc
from dataclasses import dataclass
import dataclasses
from collections import (
UserDict,
)
from dataclasses import (
dataclass,
)
from typing import (
Any,
Dict,
List,
TypeVar,
)

import torch
from torchrec.streamable import Pipelineable
from torchrec.streamable import (
Pipelineable,
)

_KT = TypeVar("_KT") # key type
_VT = TypeVar("_VT") # value type


class BatchBase(Pipelineable, abc.ABC):
@abc.abstractmethod
def as_dict(self) -> Dict:
def as_dict(self) -> Dict[str, Any]:
raise NotImplementedError

def to(self, device: torch.device, non_blocking: bool = False):
def to(self, device: torch.device, non_blocking: bool = False) -> BatchBase:
args = {}
for feature_name, feature_value in self.as_dict().items():
args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking)
Expand All @@ -26,14 +44,14 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
for feature_value in self.as_dict().values():
feature_value.record_stream(stream)

def pin_memory(self):
def pin_memory(self) -> BatchBase:
args = {}
for feature_name, feature_value in self.as_dict().items():
args[feature_name] = feature_value.pin_memory()
return self.__class__(**args)

def __repr__(self) -> str:
def obj2str(v):
def obj2str(v: Any) -> str:
return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}"

return "\n".join([f"{k}: {obj2str(v)}," for k, v in self.as_dict().items()])
Expand All @@ -52,18 +70,18 @@ def batch_size(self) -> int:
@dataclass
class DataclassBatch(BatchBase):
@classmethod
def feature_names(cls):
def feature_names(cls) -> List[str]:
return list(cls.__dataclass_fields__.keys())

def as_dict(self):
def as_dict(self) -> Dict[str, Any]:
return {
feature_name: getattr(self, feature_name)
for feature_name in self.feature_names()
if hasattr(self, feature_name)
}

@staticmethod
def from_schema(name: str, schema):
def from_schema(name: str, schema: Any) -> type:
"""Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor."""
return dataclasses.make_dataclass(
cls_name=name,
Expand All @@ -72,14 +90,14 @@ def from_schema(name: str, schema):
)

@staticmethod
def from_fields(name: str, fields: dict):
def from_fields(name: str, fields: Dict[str, Any]) -> type:
return dataclasses.make_dataclass(
cls_name=name,
fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()],
bases=(DataclassBatch,),
)


class DictionaryBatch(BatchBase, dict):
def as_dict(self) -> Dict:
class DictionaryBatch(BatchBase, UserDict[_KT, _VT]):
def as_dict(self) -> Dict[str, Any]:
return self
5 changes: 4 additions & 1 deletion common/checkpointing/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from tml.common.checkpointing.snapshot import get_checkpoint, Snapshot
from tml.common.checkpointing.snapshot import (
Snapshot,
get_checkpoint,
)
42 changes: 28 additions & 14 deletions common/checkpointing/snapshot.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
import os
import time
from typing import Any, Dict, List, Optional

from tml.ml_logging.torch_logging import logging
from tml.common.filesystem import infer_fs, is_gcs_fs
from typing import (
Any,
Dict,
Generator,
List,
Optional,
)

import torchsnapshot

from tml.common.filesystem import (
infer_fs,
is_gcs_fs,
)
from tml.ml_logging.torch_logging import (
logging,
)
from torch import (
FloatTensor,
)

DONE_EVAL_SUBDIR = "evaled_by"
GCS_PREFIX = "gs://"
Expand All @@ -25,22 +37,22 @@ def __init__(self, save_dir: str, state: Dict[str, Any]) -> None:
self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0)

@property
def step(self):
def step(self) -> int:
return self.state["extra_state"]["step"]

@step.setter
def step(self, step: int) -> None:
self.state["extra_state"]["step"] = step

@property
def walltime(self):
def walltime(self) -> float:
return self.state["extra_state"]["walltime"]

@walltime.setter
def walltime(self, walltime: float) -> None:
self.state["extra_state"]["walltime"] = walltime

def save(self, global_step: int) -> "PendingSnapshot":
def save(self, global_step: int) -> "PendingSnapshot": # type: ignore
"""Saves checkpoint with given global_step."""
path = os.path.join(self.save_dir, str(global_step))
logging.info(f"Saving snapshot global_step {global_step} to {path}.")
Expand Down Expand Up @@ -98,7 +110,7 @@ def load_snapshot_to_weight(
cls,
embedding_snapshot: torchsnapshot.Snapshot,
snapshot_emb_name: str,
weight_tensor,
weight_tensor: FloatTensor,
) -> None:
"""Loads pretrained embedding from the snapshot to the model.
Utilise partial lodaing meachanism from torchsnapshot.
Expand Down Expand Up @@ -128,19 +140,21 @@ def _eval_done_path(checkpoint_path: str, eval_partition: str) -> str:
return os.path.join(_eval_subdir(checkpoint_path), f"{eval_partition}_DONE")


def is_done_eval(checkpoint_path: str, eval_partition: str):
return get_checkpoint(checkpoint_path).exists(_eval_done_path(checkpoint_path, eval_partition))
def is_done_eval(checkpoint_path: str, eval_partition: str) -> bool:
return get_checkpoint(checkpoint_path).exists(_eval_done_path(checkpoint_path, eval_partition)) # type: ignore[attr-defined]


def mark_done_eval(checkpoint_path: str, eval_partition: str):
def mark_done_eval(checkpoint_path: str, eval_partition: str) -> Any:
infer_fs(checkpoint_path).touch(_eval_done_path(checkpoint_path, eval_partition))


def step_from_checkpoint(checkpoint: str) -> int:
return int(os.path.basename(checkpoint))


def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, timeout: int = 1800):
def checkpoints_iterator(
save_dir: str, seconds_to_sleep: int = 30, timeout: int = 1800
) -> Generator[str, None, None]:
"""Simplified equivalent of tf.train.checkpoints_iterator.

Args:
Expand All @@ -149,7 +163,7 @@ def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, timeout: int

"""

def _poll(last_checkpoint: Optional[str] = None):
def _poll(last_checkpoint: Optional[str] = None) -> Optional[str]:
stop_time = time.time() + timeout
while True:
_checkpoint_path = get_checkpoint(save_dir, missing_ok=True)
Expand Down
2 changes: 1 addition & 1 deletion common/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.distributed as dist


def maybe_setup_tensorflow():
def maybe_setup_tensorflow() -> None:
try:
import tensorflow as tf
except ImportError:
Expand Down
6 changes: 5 additions & 1 deletion common/filesystem/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from tml.common.filesystem.util import infer_fs, is_gcs_fs, is_local_fs
from tml.common.filesystem.util import (
infer_fs,
is_gcs_fs,
is_local_fs,
)
4 changes: 3 additions & 1 deletion common/filesystem/test_infer_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

Mostly a test that it returns an object
"""
from tml.common.filesystem import infer_fs
from tml.common.filesystem import (
infer_fs,
)


def test_infer_fs():
Expand Down
15 changes: 10 additions & 5 deletions common/filesystem/util.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
"""Utilities for interacting with the file systems."""
from fsspec.implementations.local import LocalFileSystem
import gcsfs
from typing import (
Union,
)

import gcsfs
from fsspec.implementations.local import (
LocalFileSystem,
)

GCS_FS = gcsfs.GCSFileSystem(cache_timeout=-1)
LOCAL_FS = LocalFileSystem()


def infer_fs(path: str):
def infer_fs(path: str) -> Union[LocalFileSystem, gcsfs.core.GCSFileSystem, NotImplementedError]:
if path.startswith("gs://"):
return GCS_FS
elif path.startswith("hdfs://"):
Expand All @@ -17,9 +22,9 @@ def infer_fs(path: str):
return LOCAL_FS


def is_local_fs(fs):
def is_local_fs(fs: LocalFileSystem) -> bool:
return fs == LOCAL_FS


def is_gcs_fs(fs):
def is_gcs_fs(fs: gcsfs.core.GCSFileSystem) -> bool:
return fs == GCS_FS
29 changes: 20 additions & 9 deletions common/log_weights.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,28 @@
"""For logging model weights."""
import itertools
from typing import Callable, Dict, List, Optional, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Union,
)

from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
import torch
import torch.distributed as dist
from torchrec.distributed.model_parallel import DistributedModelParallel
from tml.ml_logging.torch_logging import (
logging,
)
from torchrec.distributed.model_parallel import (
DistributedModelParallel,
)


def weights_to_log(
model: torch.nn.Module,
how_to_log: Optional[Union[Callable, Dict[str, Callable]]] = None,
):
how_to_log: Optional[Union[Callable[[Any], Any], Dict[str, Callable[[Any], Any]]]] = None,
) -> Optional[Dict[str, Any]]:
"""Creates dict of reduced weights to log to give sense of training.

Args:
Expand All @@ -21,7 +32,7 @@ def weights_to_log(

"""
if not how_to_log:
return
return None

to_log = dict()
named_parameters = model.named_parameters()
Expand All @@ -38,14 +49,14 @@ def weights_to_log(
how = how_to_log
else:
how = how_to_log.get(param_name) # type: ignore[assignment]
if not how:
continue # type: ignore
if how is None:
continue
to_log[f"model/{how.__name__}/{param_name}"] = how(params.detach()).cpu().numpy()
return to_log


def log_ebc_norms(
model_state_dict,
model_state_dict: Dict[str, Any],
ebc_keys: List[str],
sample_size: int = 4_000_000,
) -> Dict[str, torch.Tensor]:
Expand Down
11 changes: 7 additions & 4 deletions common/modules/embedding/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import List
from enum import Enum

import tml.core.config as base_config
from tml.optimizers.config import OptimizerConfig
from typing import (
List,
)

import pydantic
import tml.core.config as base_config
from tml.optimizers.config import (
OptimizerConfig,
)


class DataType(str, Enum):
Expand Down
Loading