Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change python tested #22

Open
wants to merge 11 commits into
base: dev
Choose a base branch
from
8 changes: 1 addition & 7 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,8 @@ jobs:
strategy:
fail-fast: false
matrix:
python: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"]
python: ["3.8", "3.9", "3.10", "3.11", "3.12"]
platform: [ubuntu-latest, macos-latest, windows-latest]
exclude: # Python < v3.8 does not support Apple Silicon ARM64.
- python: "3.7"
platform: macos-latest
include: # So run those legacy versions on Intel CPUs.
- python: "3.7"
platform: macos-13
runs-on: ${{ matrix.platform }}
steps:
- uses: actions/checkout@v3
Expand Down
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ repos:
- id: isort

- repo: https://github.com/psf/black
rev: 24.4.2
rev: 24.10.0
hooks:
- id: black
language_version: python3
Expand All @@ -53,7 +53,7 @@ repos:
# additional_dependencies: [black]

- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
rev: 7.1.1
hooks:
- id: flake8
additional_dependencies: [flake8-docstrings]
Expand All @@ -66,7 +66,7 @@ repos:

# Check for type errors with mypy:
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.10.0'
rev: 'v1.13.0'
hooks:
- id: mypy
args: [--disallow-untyped-defs, --ignore-missing-imports]
Expand Down
2 changes: 1 addition & 1 deletion AUTHORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@

* Arnab Mondal [[email protected]](mailto:[email protected])
* [Siba Smarak Panigrahi](https://sibasmarak.github.io/) [[email protected]](mailto:[email protected])
* [Danielle Benesch](https://github.com/danibene) [[email protected]](mailto:[email protected])
* [Danielle Benesch](https://github.com/danibene) [[email protected]](mailto:[email protected])
* [Jikael Gagnon](https://github.com/jikaelgagnon) [[email protected]](mailto:[email protected])
* [Sékou-Oumar Kaba](https://oumarkaba.github.io) [[email protected]](mailto:[email protected])
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,19 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Added

### Fixed
- Initialization of padding parameters in `DiscreteGroupImageCanonicalization` class, allowing for multiple types of `resize_shape`.

### Changed
- Increased minimum Python version to 3.8.
- Specified maximum NumPy version as <2.0.

### Removed

## [0.1.2] - 2024-05-29

### Added
Expand Down
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ package:
```

3. Make sure to have a reliable [tox] installation that uses the correct
Python version (e.g., 3.7+). When in doubt you can run:
Python version (e.g., 3.8+). When in doubt you can run:

```
tox --version
Expand Down
9 changes: 5 additions & 4 deletions equiadapt/common/basecanonicalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,10 @@ def invert_canonicalization(
# self.device
# )
# return torch.nn.CrossEntropyLoss()(group_activations, dataset_prior)

def get_prior_regularization_loss(self, dataset_prior: Optional[torch.Tensor] = None) -> torch.Tensor:

def get_prior_regularization_loss(
self, dataset_prior: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Gets the prior regularization loss.

Expand All @@ -322,7 +324,7 @@ def get_prior_regularization_loss(self, dataset_prior: Optional[torch.Tensor] =
log_group_activations = F.log_softmax(group_activations, dim=1)

# KL Divergence
return F.kl_div(log_group_activations, dataset_prior, reduction='batchmean')
return F.kl_div(log_group_activations, dataset_prior, reduction="batchmean")

def get_identity_metric(self) -> torch.Tensor:
"""
Expand Down Expand Up @@ -430,7 +432,6 @@ def get_prior_regularization_loss(self) -> torch.Tensor:
.to(self.device)
)
return torch.nn.MSELoss()(group_elements_rep, dataset_prior)


def get_identity_metric(self) -> torch.Tensor:
"""
Expand Down
84 changes: 52 additions & 32 deletions equiadapt/images/canonicalization/discrete_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import kornia as K
import torch
from omegaconf import DictConfig
from omegaconf import DictConfig, ListConfig
from torch.nn import functional as F
from torchvision import transforms

Expand Down Expand Up @@ -90,22 +90,31 @@ def __init__(
if is_grayscale
else transforms.Resize(size=canonicalization_hyperparams.resize_shape)
)

# group augment specific cropping and padding (required for group_augment())
group_augment_in_shape = canonicalization_hyperparams.resize_shape
self.crop_group_augment = (
torch.nn.Identity()
if in_shape[0] == 1
else transforms.CenterCrop(group_augment_in_shape)
)
self.pad_group_augment = (
torch.nn.Identity()
if in_shape[0] == 1
else transforms.Pad(
math.ceil(group_augment_in_shape * 0.5), padding_mode="edge"
)
)

self._set_pad_group_augment(in_shape, group_augment_in_shape)

def _set_pad_group_augment(
self, in_shape: tuple, group_augment_in_shape: Union[ListConfig, float]
) -> None:
if in_shape[0] == 1:
self.pad_group_augment = torch.nn.Identity()
else:
padding = []
if isinstance(group_augment_in_shape, ListConfig):
for i in range(len(group_augment_in_shape)):
padding.append(math.ceil(group_augment_in_shape[i] * 0.5))
else:
padding.append(math.ceil(group_augment_in_shape * 0.5))

self.pad_group_augment = transforms.Pad(padding, padding_mode="edge")

def rotate_and_maybe_reflect(
self, x: torch.Tensor, degrees: torch.Tensor, reflect: bool = False
) -> List[torch.Tensor]:
Expand Down Expand Up @@ -133,7 +142,8 @@ def rotate_and_maybe_reflect(
def group_augment(self, x: torch.Tensor) -> torch.Tensor:
"""
Augment the input images by applying group transformations (rotations and reflections).
This function is used both for the energy based optimization method for the discrete rotation

This function is used both for the energy based optimization method for the discrete rotation.

Args:
x (torch.Tensor): The input image.
Expand Down Expand Up @@ -315,15 +325,15 @@ def invert_canonicalization(
group_element_dict=self.canonicalization_info_dict["group_element"], # type: ignore
induced_rep_type=induced_rep_type,
)

def get_prior(
self,
x: torch.Tensor,
self,
x: torch.Tensor,
model: torch.nn.Module,
targets: torch.Tensor,
metric_function: torch.nn.Module,
tau: float = 1.0,
) -> torch.Tensor:
) -> torch.Tensor:
"""
Get the prior for the input images.

Expand All @@ -339,30 +349,36 @@ def get_prior(
"""
with torch.no_grad():
batch_size = x.shape[0]
x_augmented = self.group_augment(x) # size (group_size * batch_size, in_channels, height, width)
x_augmented = self.group_augment(
x
) # size (group_size * batch_size, in_channels, height, width)
# If a self.group_augment_target is defined, apply the same transformation to the targets
# Or else just repeat the targets for each group element in the first dimension
if hasattr(self, "group_augment_target"):
targets_augmented = self.group_augment_target(targets)
else:
targets_augmented = targets.repeat(self.num_group, 1).flatten() # size (group_size * batch_size)

targets_augmented = targets.repeat(
self.num_group, 1
).flatten() # size (group_size * batch_size)

# Get the output of the model for the augmented images
model_output = model(x_augmented) # size eg (group_size * batch_size, num_classes)

model_output = model(
x_augmented
) # size eg (group_size * batch_size, num_classes)

# Get the unnormalized probability masses for each group element
unnormalized_prob_masses = metric_function(
model_output, targets_augmented
).reshape(self.num_group, batch_size).transpose(0, 1) # size (batch_size, group_size)

unnormalized_prob_masses = (
metric_function(model_output, targets_augmented)
.reshape(self.num_group, batch_size)
.transpose(0, 1)
) # size (batch_size, group_size)

# Get the prior for the input images
prior = F.softmax(unnormalized_prob_masses / tau, dim=-1) # size (batch_size, group_size)

prior = F.softmax(
unnormalized_prob_masses / tau, dim=-1
) # size (batch_size, group_size)

return prior






class GroupEquivariantImageCanonicalization(DiscreteGroupImageCanonicalization):
Expand Down Expand Up @@ -486,8 +502,12 @@ def get_group_activations(self, x: torch.Tensor) -> torch.Tensor:
torch.Tensor: The group activations.
"""
x = self.transformations_before_canonicalization_network_forward(x)
x_augmented = self.group_augment(x) # size (batch_size * group_size, in_channels, height, width)
vector_out = self.canonicalization_network(x_augmented) # size (batch_size * group_size, reference_vector_size)
x_augmented = self.group_augment(
x
) # size (batch_size * group_size, in_channels, height, width)
vector_out = self.canonicalization_network(
x_augmented
) # size (batch_size * group_size, reference_vector_size)
self.canonicalization_info_dict = {"vector_out": vector_out}

if self.artifact_err_wt:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Tuple
from typing import Any, Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -51,8 +51,10 @@ def __init__(
) -> None:
super().__init__()
self.device: str = device
self.learning_rate: float = (
hyperparams.learning_rate if hasattr(hyperparams, "learning_rate") else None
self.learning_rate: Optional[float] = (
float(hyperparams.learning_rate)
if hasattr(hyperparams, "learning_rate")
else None
)
self.weight_decay: float = (
hyperparams.weight_decay if hasattr(hyperparams, "weight_decay") else 0.0
Expand Down
17 changes: 12 additions & 5 deletions examples/images/classification/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from math import tau

import pytorch_lightning as pl
import torch
from inference_utils import get_inference_method
from model_utils import get_dataset_specific_info, get_prediction_network
from omegaconf import DictConfig
from torch.optim.lr_scheduler import MultiStepLR
from torch.nn import functional as F
from torch.optim.lr_scheduler import MultiStepLR

from examples.images.common.utils import get_canonicalization_network, get_canonicalizer

Expand Down Expand Up @@ -104,13 +105,19 @@ def training_step(self, batch: torch.Tensor):
# Add prior regularization loss if the prior weight is non-zero
if self.hyperparams.experiment.training.loss.prior_weight:
if self.hyperparams.experiment.training.loss.automated_prior:

def metric_function(model_predictions, targets):
return -F.cross_entropy(model_predictions, targets, reduction='none')
prior = self.canonicalizer.get_prior(x, self.prediction_network, y, metric_function, tau=0.01)
prior_loss = self.canonicalizer.get_prior_regularization_loss(prior) # type: ignore
return -F.cross_entropy(
model_predictions, targets, reduction="none"
)

prior = self.canonicalizer.get_prior(
x, self.prediction_network, y, metric_function, tau=0.01
)
prior_loss = self.canonicalizer.get_prior_regularization_loss(prior) # type: ignore
else:
prior_loss = self.canonicalizer.get_prior_regularization_loss()

loss += prior_loss * self.hyperparams.experiment.training.loss.prior_weight
metric_identity = self.canonicalizer.get_identity_metric()
training_metrics.update(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ network_hyperparams:
num_rotations: 4 # Number of rotations for the canonization network
beta: 1.0 # Beta parameter for the canonization network
input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization
resize_shape: 64 # Resize shape for the input
resize_shape: 64 # Resize shape for the input
Original file line number Diff line number Diff line change
@@ -1 +1 @@
canonicalization_type: identity
canonicalization_type: identity
2 changes: 1 addition & 1 deletion examples/images/reinforcementlearning/configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ defaults:
- env: default
- experiment: default
- canonicalization: identity
- wandb: default
- wandb: default
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ replay_memory_size: 100000
end_score: 200
training_stop: 142
num_episodes: 50000
last_episodes_num: 20
last_episodes_num: 20
15 changes: 8 additions & 7 deletions examples/images/reinforcementlearning/network.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import random


class DQN(nn.Module):
def __init__(self, input_shape, num_actions, dueling_DQN=False):
Expand All @@ -20,7 +22,7 @@ def __init__(self, input_shape, num_actions, dueling_DQN=False):
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=5, stride=2),
nn.BatchNorm2d(64),
nn.ReLU()
nn.ReLU(),
)

feature_size = self._get_feature_size()
Expand All @@ -30,20 +32,20 @@ def __init__(self, input_shape, num_actions, dueling_DQN=False):
nn.Linear(feature_size, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, self.num_actions)
nn.Linear(512, self.num_actions),
)
self.value = nn.Sequential(
nn.Linear(feature_size, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, 1)
nn.Linear(512, 1),
)
else:
self.action_value = nn.Sequential(
nn.Linear(feature_size, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, self.num_actions)
nn.Linear(512, self.num_actions),
)

def forward(self, x):
Expand All @@ -57,11 +59,10 @@ def forward(self, x):
q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
else:
q_values = self.action_value(x)

return q_values

def _get_feature_size(self):
self.features.eval()
with torch.no_grad():
return self.features(torch.zeros(1, *self.input_shape)).view(1, -1).size(1)

Loading
Loading