Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/hyrax/hyrax_default_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ projection_dimension = 128
# The scalar temperature parameter for its loss function, NTXentLoss, for SimCLR
temperature = 0.5

# The number of the input channels.
input_channels = 3

# The probability of applying horizontal flip augmentation for SimCLR
horizontal_flip_probability = 0.5

Expand All @@ -120,6 +123,15 @@ gaussian_blur_kernel_size = 9
# The sigma range used in Gaussian blur augmentation for SimCLR
gaussian_blur_sigma_range = [0.1, 2.0]

# The maximum rotation angle (in degrees) augmentation for SimCLR
rotation_range = 180

# The mean of the distribution for Gaussian noise augmentation for SimCLR
gaussian_noise_mean = 0

# The standard deviation of the distribution for Gaussian noise augmentation for SimCLR
gaussian_noise_sigma = 0.2


[criterion]
# The name of the built-in criterion to use or the import path to an external criterion
Expand Down
3 changes: 2 additions & 1 deletion src/hyrax/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .hyrax_cnn import HyraxCNN
from .hyrax_loopback import HyraxLoopback
from .model_registry import hyrax_model
from .simclr import SimCLR
from .simclr import SimCLR, SimCLRv2

__all__ = [
"hyrax_model",
Expand All @@ -21,4 +21,5 @@
"HSCDCAE",
"ImageDCAE",
"SimCLR",
"SimCLRv2",
]
82 changes: 78 additions & 4 deletions src/hyrax/models/simclr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
import torch.nn as nn
import torch.nn.functional as F # noqa N812
import torchvision.models as models
import torchvision.transforms as T # noqa N812
import torchvision.transforms.v2 as T # noqa N812

from hyrax.models.model_registry import hyrax_model


class NTXentLoss(nn.Module):
"""Normalized Temperature-scaled Cross Entropy Loss. Based on Chen, 2020"""

def __init__(self, temperature=0.1):
def __init__(self, temperature):
super().__init__()
self.temperature = temperature
self.criterion = nn.CrossEntropyLoss(reduction="sum")
Expand All @@ -22,6 +22,7 @@ def forward(self, z_i, z_j):
"""Forward function of NTXentLoss. Based on Chen, 2020.
Loss is calculated from representations from two augmented views of the same batch.
"""

batch_size = z_i.shape[0]
device = z_i.device

Expand Down Expand Up @@ -82,7 +83,8 @@ def __init__(self, config, shape):
nn.ReLU(inplace=True),
nn.Linear(512, proj_dim),
)
self.criterion = NTXentLoss(temperature)
# TODO: Make sure to revisit this and properly implement custom criterion
self.the_criterion = NTXentLoss(temperature)

def forward(self, x):
feats = self.backbone(x)
Expand Down Expand Up @@ -111,7 +113,79 @@ def train_step(self, x):
z1 = self.forward(x1)
z2 = self.forward(x2)

loss = self.criterion(z1, z2)
loss = self.the_criterion(z1, z2)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return {"loss": loss.item()}


@hyrax_model
class SimCLRv2(nn.Module):
"""Modified SimCLR model with new augmentation routine and compatible
with an arbitrary number of input channels"""

def __init__(self, config, shape):
super().__init__()
self.config = config
self.shape = shape
proj_dim = config["model"]["SimCLR"]["projection_dimension"]
temperature = config["model"]["SimCLR"]["temperature"]
input_channels = config["model"]["SimCLR"]["input_channels"]

backbone = models.resnet18(pretrained=False)

# Create a new conv layer with same out_channels, kernel size, stride, padding, etc
# But with the new in_channels
old_conv = backbone.conv1
backbone.conv1 = nn.Conv2d(
in_channels=input_channels,
out_channels=old_conv.out_channels,
kernel_size=old_conv.kernel_size,
stride=old_conv.stride,
padding=old_conv.padding,
bias=old_conv.bias is not None,
)

backbone.fc = nn.Identity()
self.backbone = backbone

self.projection_head = nn.Sequential(
nn.Linear(512, 512),
nn.ReLU(inplace=True),
nn.Linear(512, proj_dim),
)
# TODO: Make sure to revisit this and properly implement custom criterion
self.the_criterion = NTXentLoss(temperature)

def forward(self, x):
feats = self.backbone(x)
return self.projection_head(feats)

def train_step(self, x):
aug = T.Compose(
[
T.RandomResizedCrop(size=x.shape[-1]),
T.RandomHorizontalFlip(self.config["model"]["SimCLR"]["horizontal_flip_probability"]),
T.RandomRotation(self.config["model"]["SimCLR"]["rotation_range"]),
T.GaussianBlur(
kernel_size=self.config["model"]["SimCLR"]["gaussian_blur_kernel_size"],
sigma=self.config["model"]["SimCLR"]["gaussian_blur_sigma_range"],
),
T.GaussianNoise(
mean=self.config["model"]["SimCLR"]["gaussian_noise_mean"],
sigma=self.config["model"]["SimCLR"]["gaussian_noise_sigma"],
),
]
)

x1 = torch.stack([aug(img) for img in x])
x2 = torch.stack([aug(img) for img in x])

z1 = self.forward(x1)
z2 = self.forward(x2)

loss = self.the_criterion(z1, z2)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
Expand Down