Skip to content
7 changes: 0 additions & 7 deletions lightning_action/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,13 +301,6 @@ def split_sizes_from_probabilities(
train_number = int(np.ceil(train_probability * total_number))
val_number = total_number - train_number

# make sure that we have at least one validation sample
if val_number == 0:
train_number -= 1
val_number += 1
if train_number < 1:
raise ValueError('Must have at least two sequences, one train and one validation')

# assert that we're using all datapoints
assert train_number + val_number == total_number

Expand Down
8 changes: 4 additions & 4 deletions lightning_action/models/backbones/resnet_beast.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def __init__(
hidden_channels=hidden_channels,
downsample=False,
)
self.add_module(f'{i} Layer', layer)
self.add_module(f'{i} EncoderLayer', layer)

elif downsample_method == 'pool':
maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
Expand All @@ -342,7 +342,7 @@ def __init__(
hidden_channels=hidden_channels,
downsample=False,
)
self.add_module(f'{i + 1} Layer', layer)
self.add_module(f'{i + 1} EncoderLayer', layer)

def forward(self, x: torch.Tensor) -> torch.Tensor:
for name, layer in self.named_children():
Expand Down Expand Up @@ -375,7 +375,7 @@ def __init__(
in_channels=up_channels, hidden_channels=hidden_channels,
up_channels=up_channels, downsample=False,
)
self.add_module(f'{i} Layer', layer)
self.add_module(f'{i} EncoderLayer', layer)

elif downsample_method == 'pool':
maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
Expand All @@ -392,7 +392,7 @@ def __init__(
in_channels=up_channels, hidden_channels=hidden_channels,
up_channels=up_channels, downsample=False,
)
self.add_module(f'{i + 1} Layer', layer)
self.add_module(f'{i + 1} EncoderLayer', layer)

def forward(self, x: torch.Tensor) -> torch.Tensor:
for name, layer in self.named_children():
Expand Down
6 changes: 3 additions & 3 deletions tests/data/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_val_dataloader(self, create_test_feature_csv):
with tempfile.TemporaryDirectory() as tmpdir:
# create test files
feature_file = Path(tmpdir) / 'features.csv'
create_test_feature_csv(feature_file, n_frames=30, n_features=4)
create_test_feature_csv(feature_file, n_frames=100, n_features=4)

# create data config
data_config = {
Expand Down Expand Up @@ -139,8 +139,8 @@ def test_multiple_datasets(self, create_test_marker_csv):
marker_file1 = Path(tmpdir) / 'markers1.csv'
marker_file2 = Path(tmpdir) / 'markers2.csv'

create_test_marker_csv(marker_file1, n_frames=20, n_markers=2)
create_test_marker_csv(marker_file2, n_frames=25, n_markers=2)
create_test_marker_csv(marker_file1, n_frames=50, n_markers=2)
create_test_marker_csv(marker_file2, n_frames=50, n_markers=2)

# create data config
data_config = {
Expand Down
14 changes: 7 additions & 7 deletions tests/data/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,15 @@ def test_probabilities_sum_to_one(self):
split_sizes_from_probabilities(100, 0.6, 0.5)

def test_minimum_validation_samples(self):
"""Test that at least one validation sample is guaranteed."""
# case where val_probability would give 0 samples
"""Test that 0 validation samples is valid when val_probability is 0."""
# case where val_probability is 0 — all samples go to training
result = split_sizes_from_probabilities(10, 1.0, 0.0)
assert result == [9, 1] # should adjust to ensure 1 val sample
assert result == [10, 0]

def test_too_few_total_samples(self):
"""Test error when not enough samples for train and val."""
with pytest.raises(ValueError, match='Must have at least two sequences'):
split_sizes_from_probabilities(1, 1.0, 0.0)
"""Test that a single sample with 0 val probability returns [1, 0]."""
result = split_sizes_from_probabilities(1, 1.0, 0.0)
assert result == [1, 0]

def test_fractional_results(self):
"""Test handling of fractional results."""
Expand All @@ -140,7 +140,7 @@ def test_edge_case_small_numbers(self):
# slightly larger
result = split_sizes_from_probabilities(3, 0.67, 0.33)
assert sum(result) == 3
assert result[1] >= 1 # ensure at least 1 val
assert result[1] >= 0 # val can be 0 when val_probability rounds down


class TestComputeClassWeights:
Expand Down
109 changes: 109 additions & 0 deletions tests/models/backbones/test_resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Tests for ResNet backbone."""

import pytest
import torch

from lightning_action.models.backbones.resnet import ResNetBackbone, RESNET_HIDDEN_SIZES


class TestResNetBackbone:
"""Test the ResNetBackbone class."""

@pytest.fixture
def default_config(self):
"""Default config using resnet18 (smallest/fastest)."""
return {'backbone': 'resnet18'}

def test_default_initialization(self):
"""Test initialization with default config (resnet50)."""
backbone = ResNetBackbone()
assert backbone.backbone_name == 'resnet50'
assert backbone.hidden_size == 2048

def test_custom_initialization(self, default_config):
"""Test initialization with custom config."""
backbone = ResNetBackbone(default_config)
assert backbone.backbone_name == 'resnet18'
assert backbone.hidden_size == 512

def test_all_supported_variants(self):
"""Test that all supported variants can be initialized."""
for name in RESNET_HIDDEN_SIZES:
backbone = ResNetBackbone({'backbone': name})
assert backbone.backbone_name == name
assert backbone.hidden_size == RESNET_HIDDEN_SIZES[name]

def test_invalid_backbone(self):
"""Test that unsupported backbone raises ValueError."""
with pytest.raises(ValueError, match="Unsupported backbone"):
ResNetBackbone({'backbone': 'resnet999'})

@pytest.mark.parametrize("backbone_name,expected_hidden_size", RESNET_HIDDEN_SIZES.items())
def test_properties(self, backbone_name, expected_hidden_size):
"""Test backbone properties for all variants."""
backbone = ResNetBackbone({'backbone': backbone_name})
assert backbone.hidden_size == expected_hidden_size
assert backbone.num_channels == 3
assert backbone.image_size == 224
assert backbone.patch_size == 32
assert backbone.backbone_name == backbone_name
assert backbone.backbone_type == 'resnet'

def test_hidden_size_mapping(self):
Comment thread
themattinthehatt marked this conversation as resolved.
"""Test hidden size mapping for different variants."""
# resnet18/34 -> 512
for name in ['resnet18', 'resnet34']:
backbone = ResNetBackbone({'backbone': name})
assert backbone.hidden_size == 512

# resnet50/101/152 -> 2048
for name in ['resnet50', 'resnet101', 'resnet152']:
backbone = ResNetBackbone({'backbone': name})
assert backbone.hidden_size == 2048

@pytest.mark.parametrize("backbone_name,expected_hidden_size", RESNET_HIDDEN_SIZES.items())
def test_forward_pass_shape(self, backbone_name, expected_hidden_size):
"""Test forward pass produces correct output shape for all variants."""
backbone = ResNetBackbone({'backbone': backbone_name})
x = torch.randn(2, 3, 224, 224)
output = backbone(x)

assert output.shape == (2, expected_hidden_size, 7, 7)
assert torch.isfinite(output).all()

def test_forward_channel_validation(self, default_config):
"""Test that wrong input channels raises ValueError."""
backbone = ResNetBackbone(default_config)
x = torch.randn(2, 1, 224, 224) # wrong channels

with pytest.raises(ValueError, match="Input has 1 channels"):
backbone(x)

@pytest.mark.parametrize("backbone_name", RESNET_HIDDEN_SIZES.keys())
def test_gradient_flow(self, backbone_name):
"""Test that gradients flow through all variants."""
backbone = ResNetBackbone({'backbone': backbone_name})
x = torch.randn(1, 3, 224, 224, requires_grad=True)
output = backbone(x)

loss = output.sum()
loss.backward()

assert x.grad is not None
assert not torch.isnan(x.grad).any()

def test_load_pretrained_weights_file_not_found(self, default_config):
"""Test that missing checkpoint raises FileNotFoundError."""
backbone = ResNetBackbone(default_config)
with pytest.raises(FileNotFoundError, match="Checkpoint not found"):
backbone.load_pretrained_weights('/nonexistent/checkpoint.ckpt')

def test_get_last_layer_params(self, default_config):
"""Test that get_last_layer_params returns parameters from layer4."""
backbone = ResNetBackbone(default_config)
params = list(backbone.get_last_layer_params())

assert len(params) > 0
# All should be Parameters
for p in params:
assert isinstance(p, torch.nn.Parameter)
163 changes: 163 additions & 0 deletions tests/models/backbones/test_resnet_beast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
"""Tests for ResNet Beast backbone."""

import pytest
import torch

from lightning_action.models.backbones.resnet_beast import (
BEAST_RESNET_HIDDEN_SIZES,
BottleneckBlock,
ResidualBlock,
ResNetBeast,
ResNetBeastBackbone,
get_configs,
)


EXPECTED_CONFIGS = {
'resnet18': ([2, 2, 2, 2], False),
'resnet34': ([3, 4, 6, 3], False),
'resnet50': ([3, 4, 6, 3], True),
'resnet101': ([3, 4, 23, 3], True),
'resnet152': ([3, 8, 36, 3], True),
}


class TestGetConfigs:
"""Test the get_configs function."""

@pytest.mark.parametrize("arch,expected", EXPECTED_CONFIGS.items())
def test_configs(self, arch, expected):
"""Test config for each architecture variant."""
layers, bottleneck = get_configs(arch)
assert layers == expected[0]
assert bottleneck == expected[1]

def test_invalid_arch(self):
"""Test that invalid architecture raises ValueError."""
with pytest.raises(ValueError, match="not a valid ResNet architecture"):
get_configs('resnet999')


class TestResNetBeastBackbone:
"""Test the ResNetBeastBackbone class."""

@pytest.fixture
def default_config(self):
"""Default config using resnet18 (smallest/fastest)."""
return {'backbone': 'resnet18'}

def test_default_initialization(self):
"""Test initialization with default config (resnet50)."""
backbone = ResNetBeastBackbone()
assert backbone._backbone_name == 'resnet50'
assert backbone.hidden_size == 2048

def test_custom_initialization(self, default_config):
"""Test initialization with custom config."""
backbone = ResNetBeastBackbone(default_config)
assert backbone._backbone_name == 'resnet18'
assert backbone.hidden_size == 512

def test_invalid_backbone(self):
"""Test that unsupported backbone raises ValueError."""
with pytest.raises(ValueError, match="Unsupported backbone"):
ResNetBeastBackbone({'backbone': 'resnet999'})

@pytest.mark.parametrize(
"backbone_name,expected_hidden_size", BEAST_RESNET_HIDDEN_SIZES.items(),
)
def test_properties(self, backbone_name, expected_hidden_size):
"""Test backbone properties for each variant."""
backbone = ResNetBeastBackbone({'backbone': backbone_name})
assert backbone.hidden_size == expected_hidden_size
assert backbone.num_channels == 3
assert backbone.image_size == 224
assert backbone.patch_size == 32
assert backbone.backbone_type == 'resnet-beast'

@pytest.mark.parametrize(
"backbone_name,expected_hidden_size", BEAST_RESNET_HIDDEN_SIZES.items(),
)
def test_forward_pass_shape(self, backbone_name, expected_hidden_size):
"""Test forward pass produces correct output shape for each variant."""
backbone = ResNetBeastBackbone({'backbone': backbone_name})
x = torch.randn(2, 3, 224, 224)
output = backbone(x)

assert output.shape == (2, expected_hidden_size, 7, 7)
assert torch.isfinite(output).all()

@pytest.mark.parametrize("backbone_name", BEAST_RESNET_HIDDEN_SIZES.keys())
def test_gradient_flow(self, backbone_name):
"""Test that gradients flow through the model for each variant."""
backbone = ResNetBeastBackbone({'backbone': backbone_name})
x = torch.randn(1, 3, 224, 224, requires_grad=True)
output = backbone(x)

loss = output.sum()
loss.backward()

assert x.grad is not None
assert not torch.isnan(x.grad).any()

def test_load_pretrained_weights_file_not_found(self, default_config):
"""Test that missing checkpoint raises FileNotFoundError."""
backbone = ResNetBeastBackbone(default_config)
with pytest.raises(FileNotFoundError, match="Checkpoint not found"):
backbone.load_pretrained_weights('/nonexistent/checkpoint.ckpt')

def test_get_last_layer_params(self, default_config):
"""Test that get_last_layer_params returns parameters from conv5."""
backbone = ResNetBeastBackbone(default_config)
params = list(backbone.get_last_layer_params())

assert len(params) > 0
for p in params:
assert isinstance(p, torch.nn.Parameter)


class TestResNetBeastComponents:
"""Test internal ResNetBeast components."""

def test_resnet_beast_forward_no_bottleneck(self):
"""Test ResNetBeast forward pass without bottleneck (resnet18-style)."""
model = ResNetBeast(configs=[2, 2, 2, 2], bottleneck=False)
x = torch.randn(1, 3, 224, 224)
output = model(x)

assert output.shape == (1, 512, 7, 7)

def test_resnet_beast_forward_bottleneck(self):
"""Test ResNetBeast forward pass with bottleneck (resnet50-style)."""
model = ResNetBeast(configs=[3, 4, 6, 3], bottleneck=True)
x = torch.randn(1, 3, 224, 224)
output = model(x)

assert output.shape == (1, 2048, 7, 7)

def test_resnet_beast_invalid_configs(self):
"""Test that invalid config length raises ValueError."""
with pytest.raises(ValueError, match="Only 4 layers can be configured"):
ResNetBeast(configs=[2, 2, 2], bottleneck=False)

def test_residual_block_forward(self):
"""Test ResidualBlock forward pass."""
block = ResidualBlock(
in_channels=64, hidden_channels=128, layers=2, downsample_method='conv',
)
x = torch.randn(1, 64, 28, 28)
output = block(x)

assert output.shape == (1, 128, 14, 14)

def test_bottleneck_block_forward(self):
"""Test BottleneckBlock forward pass."""
block = BottleneckBlock(
in_channels=64, hidden_channels=64, up_channels=256,
layers=3, downsample_method='pool',
)
x = torch.randn(1, 64, 56, 56)
output = block(x)

assert output.shape == (1, 256, 28, 28)

Loading