Skip to content

Formatted docstrings in files to pass with pydocstyle's criteria. #72

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Mar 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
a362c33
feat: added docstrings to functions in the metnet folder and split co…
andrewTheCommitter Mar 24, 2025
04de50b
feat: added more placeholder doc strings
andrewTheCommitter Mar 24, 2025
6385d2b
feat: added docstrings to the 'layers/ConvGRU.py' file
andrewTheCommitter Mar 25, 2025
411a9dd
feat: added docstrings to the 'layers/ConvGRU.py' file
andrewTheCommitter Mar 25, 2025
59be919
feat: added doc strings to the remaining functions to adhere to ruff …
andrewTheCommitter Mar 25, 2025
60f444d
feat: in 'layers/ConvGRU.py' changed 'l' variable name to 'layer_inde…
andrewTheCommitter Mar 25, 2025
be258e0
feat: added args to several docstrings
andrewTheCommitter Mar 25, 2025
dd99215
feat: adding periods to docstrings in the file
andrewTheCommitter Mar 25, 2025
f760349
feat: added periods to docstrings to comply with pydocstyle
andrewTheCommitter Mar 25, 2025
c00ff4d
feat: added the periods and adjusted single line docs in layers files…
andrewTheCommitter Mar 25, 2025
252e8d3
feat: changed the files in the repository to meet the requirements of…
andrewTheCommitter Mar 25, 2025
642202d
feat: added a period to docstring in 'setup.py' and reformated the fi…
andrewTheCommitter Mar 25, 2025
e98b053
feat: changed the summary text to be shorter and changed the import o…
andrewTheCommitter Mar 26, 2025
7b895f7
feat: explicitly added the remaning imports packages to 'metnet/__ini…
andrewTheCommitter Mar 26, 2025
9953384
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 26, 2025
4e52b1e
feat: added periods to the docstrings to match pydocstyle requirement…
andrewTheCommitter Mar 26, 2025
3615594
Applied the suggestions from code review for the docstrings
Averagenormaljoe Mar 27, 2025
8864293
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 27, 2025
13f9929
feat: reduced the summary text in the init function of the class 'Met…
andrewTheCommitter Mar 27, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion metnet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,27 @@
"""Modules for the MetNet package."""

from metnet.models.metnet import MetNet
from metnet.models.metnet2 import MetNet2
from metnet.models.metnet_pv import MetNetPV

from .layers import *
from .layers import (
ConditionTime,
ConditionWithTimeMetNet2,
ConvGRU,
ConvLSTM,
CoordConv,
DilatedCondConv,
DownSampler,
LeadTimeConditioner,
MaxViT,
MBConv,
MetNetPreprocessor,
MultiheadSelfAttention2D,
PartitionAttention,
Preprocessor,
RelativePositionBias,
SqueezeExcitation,
StochasticDepth,
TimeDistributed,
utils,
)
9 changes: 6 additions & 3 deletions metnet/layers/ConditionTime.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@
"""Condition time module."""

import torch
from torch import nn as nn


def condition_time(x, i=0, size=(12, 16), seq_len=15):
"Create one hot encoded time image-layers, i in [1, seq_len]"
"""Create one hot encoded time image-layers, i in [1, seq_len]."""
assert i < seq_len
times = (torch.eye(seq_len, dtype=torch.long, device=x.device)[i]).unsqueeze(-1).unsqueeze(-1)
ones = torch.ones(1, *size, dtype=x.dtype, device=x.device)
return times * ones


class ConditionTime(nn.Module):
"Condition Time on a stack of images, adds `horizon` channels to image"
"""Condition Time on a stack of images, adds `horizon` channels to image."""

def __init__(self, horizon, ch_dim=2, num_dims=5):
"""Set the horizontal channels and the dimensions for the channels and numbers."""
super().__init__()
self.horizon = horizon
self.ch_dim = ch_dim
self.num_dims = num_dims

def forward(self, x, fstep=0):
"X stack of images, fsteps"
"""X stack of images, fsteps."""
if self.num_dims == 5:
bs, seq_len, ch, h, w = x.shape
ct = condition_time(x, fstep, (h, w), seq_len=self.horizon).repeat(bs, seq_len, 1, 1, 1)
Expand Down
21 changes: 13 additions & 8 deletions metnet/layers/ConditionWithTimeMetNet2.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
"""Condition with time how MetNet-22 does it, with FiLM layers"""
"""Condition with time how MetNet-22 does it, with FiLM layers."""

import einops
import torch
from torch import nn as nn


class ConditionWithTimeMetNet2(nn.Module):
"""Compute Scale and bias for conditioning on time"""
"""Compute Scale and bias for conditioning on time."""

def __init__(self, forecast_steps: int, hidden_dim: int, num_feature_maps: int):
"""
Compute the scale and bias factors for conditioning convolutional blocks on the forecast time
Compute the scale and bias factors for conditioning convolutional blocks on forecast time.

Args:
forecast_steps: Number of forecast steps
hidden_dim: Hidden dimension size
num_feature_maps: Max number of channels in the blocks, to generate enough scale+bias values
num_feature_maps: Max number of channels in the blocks, to generate enough
scale+bias values
This means extra values will be generated, but keeps implementation simpler
"""
super().__init__()
Expand All @@ -29,10 +31,10 @@ def __init__(self, forecast_steps: int, hidden_dim: int, num_feature_maps: int):

def forward(self, x: torch.Tensor, timestep: int) -> [torch.Tensor, torch.Tensor]:
"""
Get the scale and bias for the conditioning layers
Get the scale and bias for the conditioning layers.

From the FiLM paper, each feature map (i.e. channel) has its own scale and bias layer, so needs
a scale and bias for each feature map to be generated
From the FiLM paper, each feature map (i.e. channel) has its own scale and bias layer,
so needs a scale and bias for each feature map to be generated

Args:
x: The Tensor that is used
Expand All @@ -49,6 +51,9 @@ def forward(self, x: torch.Tensor, timestep: int) -> [torch.Tensor, torch.Tensor
timesteps = layer(timesteps)
scales_and_biases = timesteps
scales_and_biases = einops.rearrange(
scales_and_biases, "b (block sb) -> b block sb", block=self.num_feature_maps, sb=2
scales_and_biases,
"b (block sb) -> b block sb",
block=self.num_feature_maps,
sb=2,
)
return scales_and_biases[:, :, 0], scales_and_biases[:, :, 1]
46 changes: 36 additions & 10 deletions metnet/layers/ConvGRU.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""Implementation of Conv GRU and cell module."""

import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvGRUCell(nn.Module):
"""The Conv GRU Cell."""

def __init__(
self,
input_dim,
Expand All @@ -15,6 +19,7 @@ def __init__(
):
"""
Initialize ConvGRU cell.

Parameters
----------
input_dim: int
Expand Down Expand Up @@ -64,7 +69,8 @@ def __init__(

self.reset_parameters()

def forward(self, input, h_prev=None):
def forward(self, input: torch.Tensor, h_prev=None):
"""Get the current hidden layer of the input layer."""
# init hidden on forward
if h_prev is None:
h_prev = self.init_hidden(input)
Expand All @@ -81,11 +87,13 @@ def forward(self, input, h_prev=None):

return h_cur

def init_hidden(self, input):
def init_hidden(self, input: torch.Tensor):
"""Create and return a hidden layer."""
bs, ch, h, w = input.shape
return one_param(self).new_zeros(bs, self.hidden_dim, h, w)

def reset_parameters(self):
"""Reset the weights and bias of the ConvGRU cell."""
# self.conv.reset_parameters()
nn.init.xavier_uniform_(self.conv_zr.weight, gain=nn.init.calculate_gain("tanh"))
self.conv_zr.bias.data.zero_()
Expand All @@ -100,29 +108,38 @@ def reset_parameters(self):


def one_param(m):
"First parameter in `m`"
"""First parameter in `m`."""
return next(m.parameters())


def dropout_mask(x, sz, p):
"Return a dropout mask of the same type as `x`, size `sz`, with probability `p` to cancel an element."
"""
Get the dropout mask of x.

Return a dropout mask of the same type as `x`, size `sz`, with probability
`p` to cancel an element.
"""
return x.new_empty(*sz).bernoulli_(1 - p).div_(1 - p)


class RNNDropout(nn.Module):
"Dropout with probability `p` that is consistent on the seq_len dimension."
"""Dropout with probability `p` that is consistent on the seq_len dimension."""

def __init__(self, p=0.5):
"""Initialize the RNN dropout layer."""
super().__init__()
self.p = p

def forward(self, x):
"""Calculate the dropout mask."""
if not self.training or self.p == 0.0:
return x
return x * dropout_mask(x.data, (x.size(0), 1, *x.shape[2:]), self.p)


class ConvGRU(nn.Module):
"""Conv GRU."""

def __init__(
self,
input_dim,
Expand All @@ -136,6 +153,7 @@ def __init__(
hidden_p=0.1,
batchnorm=False,
):
"""Initialize the configurations of the conv GRU."""
super(ConvGRU, self).__init__()

self._check_kernel_size_consistency(kernel_size)
Expand Down Expand Up @@ -174,16 +192,20 @@ def __init__(

self.cell_list = nn.ModuleList(cell_list)
self.input_dp = RNNDropout(input_p)
self.hidden_dps = nn.ModuleList([nn.Dropout(hidden_p) for l in range(n_layers)])
self.hidden_dps = nn.ModuleList([nn.Dropout(hidden_p) for layer_index in range(n_layers)])
self.reset_parameters()

def __repr__(self):
"""Return a string representation of the configuration options of the conv gru."""
s = f"ConvGru(in={self.input_dim}, out={self.hidden_dim[0]}, ks={self.kernel_size[0]}, "
s += f"n_layers={self.n_layers}, input_p={self.input_p}, hidden_p={self.hidden_p})"
return s

def forward(self, input, hidden_state=None):
"""

Pass the input tensor into a sequence of models.

Parameters
----------
input_tensor:
Expand All @@ -203,15 +225,15 @@ def forward(self, input, hidden_state=None):

last_state_list = []

for l, (gru_cell, hid_dp) in enumerate(zip(self.cell_list, self.hidden_dps)):
h = hidden_state[l]
for layer_index, (gru_cell, hid_dp) in enumerate(zip(self.cell_list, self.hidden_dps)):
h = hidden_state[layer_index]
output_inner = []
for t in range(seq_len):
h = gru_cell(input=cur_layer_input[t], h_prev=h)
output_inner.append(h)

cur_layer_input = torch.stack(output_inner) # list to array
if l != self.n_layers:
if layer_index != self.n_layers:
cur_layer_input = hid_dp(cur_layer_input)
last_state_list.append(h)

Expand All @@ -220,17 +242,20 @@ def forward(self, input, hidden_state=None):
return layer_output, last_state_list

def reset_parameters(self):
"""Reset the parameters of each of the conv gru cells in the list."""
for c in self.cell_list:
c.reset_parameters()

def get_init_states(self, input):
"""Collect the init states from the cell list."""
init_states = []
for gru_cell in self.cell_list:
init_states.append(gru_cell.init_hidden(input))
return init_states

@staticmethod
def _check_kernel_size_consistency(kernel_size):
"""Check if kernel size is a tuple or is a list of tuples."""
if not (
isinstance(kernel_size, tuple)
or (
Expand All @@ -241,7 +266,8 @@ def _check_kernel_size_consistency(kernel_size):
raise ValueError("`kernel_size` must be tuple or list of tuples")

@staticmethod
def _extend_for_multilayer(param, num_layers):
def _extend_for_multilayer(param, num_layers: int):
"""Convert the param into a list."""
if not isinstance(param, list):
param = [param] * num_layers
return param
28 changes: 15 additions & 13 deletions metnet/layers/ConvLSTM.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Originally adapted from https://github.com/aserdega/convlstmgru, MIT License Andriy Serdega"""
"""Originally adapted from https://github.com/aserdega/convlstmgru, MIT License Andriy Serdega."""

from typing import Any, List, Optional

import torch
Expand All @@ -7,7 +8,7 @@


class ConvLSTMCell(nn.Module):
"""ConvLSTM Cell"""
"""ConvLSTM Cell."""

def __init__(
self,
Expand All @@ -19,7 +20,7 @@ def __init__(
batchnorm=False,
):
"""
ConLSTM Cell
ConvLSTM Cell.

Args:
input_dim: Number of input channels
Expand Down Expand Up @@ -52,7 +53,7 @@ def __init__(

def forward(self, x: torch.Tensor, prev_state: list) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute forward pass
Compute forward pass.

Args:
x: Input tensor of [Batch, Channel, Height, Width]
Expand Down Expand Up @@ -82,7 +83,8 @@ def forward(self, x: torch.Tensor, prev_state: list) -> tuple[torch.Tensor, torc

def init_hidden(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Initializes the hidden state
Initialize the hidden state.

Args:
x: Input tensor to initialize for

Expand All @@ -97,7 +99,7 @@ def init_hidden(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return state

def reset_parameters(self) -> None:
"""Resets parameters"""
"""Reset parameters."""
nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain("tanh"))
self.conv.bias.data.zero_()

Expand All @@ -107,6 +109,8 @@ def reset_parameters(self) -> None:


class ConvLSTM(nn.Module):
"""Creates a convolution LSTM layer."""

def __init__(
self,
input_dim: int,
Expand All @@ -118,7 +122,7 @@ def __init__(
batchnorm=False,
):
"""
ConvLSTM module
ConvLSTM module.

Args:
input_dim: Input dimension size
Expand Down Expand Up @@ -169,7 +173,7 @@ def forward(
self, x: torch.Tensor, hidden_state: Optional[list] = None
) -> tuple[Tensor, list[tuple[Any, Any]]]:
"""
Computes the output of the ConvLSTM
Compute the output of the ConvLSTM.

Args:
x: Input Tensor of shape [Batch, Time, Channel, Width, Height]
Expand Down Expand Up @@ -202,15 +206,13 @@ def forward(
return layer_output, last_state_list

def reset_parameters(self) -> None:
"""
Reset parameters
"""
"""Reset parameters."""
for c in self.cell_list:
c.reset_parameters()

def get_init_states(self, x: torch.Tensor) -> List[torch.Tensor]:
"""
Constructs the initial hidden states
Construct the initial hidden states.

Args:
x: Tensor to use for constructing state
Expand All @@ -226,7 +228,7 @@ def get_init_states(self, x: torch.Tensor) -> List[torch.Tensor]:
@staticmethod
def _extend_for_multilayer(param, num_layers):
"""
Extends a parameter for multiple layers
Extend a parameter for multiple layers.

Args:
param: Parameter to copy
Expand Down
Loading