Skip to content
7 changes: 0 additions & 7 deletions lightning_action/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,13 +301,6 @@ def split_sizes_from_probabilities(
train_number = int(np.ceil(train_probability * total_number))
val_number = total_number - train_number

# make sure that we have at least one validation sample
if val_number == 0:
train_number -= 1
val_number += 1
if train_number < 1:
raise ValueError('Must have at least two sequences, one train and one validation')

# assert that we're using all datapoints
assert train_number + val_number == total_number

Expand Down
8 changes: 4 additions & 4 deletions lightning_action/models/backbones/resnet_beast.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def __init__(
hidden_channels=hidden_channels,
downsample=False,
)
self.add_module(f'{i} Layer', layer)
self.add_module(f'{i} EncoderLayer', layer)

elif downsample_method == 'pool':
maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
Expand All @@ -342,7 +342,7 @@ def __init__(
hidden_channels=hidden_channels,
downsample=False,
)
self.add_module(f'{i + 1} Layer', layer)
self.add_module(f'{i + 1} EncoderLayer', layer)

def forward(self, x: torch.Tensor) -> torch.Tensor:
for name, layer in self.named_children():
Expand Down Expand Up @@ -375,7 +375,7 @@ def __init__(
in_channels=up_channels, hidden_channels=hidden_channels,
up_channels=up_channels, downsample=False,
)
self.add_module(f'{i} Layer', layer)
self.add_module(f'{i} EncoderLayer', layer)

elif downsample_method == 'pool':
maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
Expand All @@ -392,7 +392,7 @@ def __init__(
in_channels=up_channels, hidden_channels=hidden_channels,
up_channels=up_channels, downsample=False,
)
self.add_module(f'{i + 1} Layer', layer)
self.add_module(f'{i + 1} EncoderLayer', layer)

def forward(self, x: torch.Tensor) -> torch.Tensor:
for name, layer in self.named_children():
Expand Down
6 changes: 3 additions & 3 deletions tests/data/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_val_dataloader(self, create_test_feature_csv):
with tempfile.TemporaryDirectory() as tmpdir:
# create test files
feature_file = Path(tmpdir) / 'features.csv'
create_test_feature_csv(feature_file, n_frames=30, n_features=4)
create_test_feature_csv(feature_file, n_frames=100, n_features=4)

# create data config
data_config = {
Expand Down Expand Up @@ -139,8 +139,8 @@ def test_multiple_datasets(self, create_test_marker_csv):
marker_file1 = Path(tmpdir) / 'markers1.csv'
marker_file2 = Path(tmpdir) / 'markers2.csv'

create_test_marker_csv(marker_file1, n_frames=20, n_markers=2)
create_test_marker_csv(marker_file2, n_frames=25, n_markers=2)
create_test_marker_csv(marker_file1, n_frames=50, n_markers=2)
create_test_marker_csv(marker_file2, n_frames=50, n_markers=2)

# create data config
data_config = {
Expand Down
14 changes: 7 additions & 7 deletions tests/data/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,15 @@ def test_probabilities_sum_to_one(self):
split_sizes_from_probabilities(100, 0.6, 0.5)

def test_minimum_validation_samples(self):
"""Test that at least one validation sample is guaranteed."""
# case where val_probability would give 0 samples
"""Test that 0 validation samples is valid when val_probability is 0."""
# case where val_probability is 0 — all samples go to training
result = split_sizes_from_probabilities(10, 1.0, 0.0)
assert result == [9, 1] # should adjust to ensure 1 val sample
assert result == [10, 0]

def test_too_few_total_samples(self):
"""Test error when not enough samples for train and val."""
with pytest.raises(ValueError, match='Must have at least two sequences'):
split_sizes_from_probabilities(1, 1.0, 0.0)
"""Test that a single sample with 0 val probability returns [1, 0]."""
result = split_sizes_from_probabilities(1, 1.0, 0.0)
assert result == [1, 0]

def test_fractional_results(self):
"""Test handling of fractional results."""
Expand All @@ -140,7 +140,7 @@ def test_edge_case_small_numbers(self):
# slightly larger
result = split_sizes_from_probabilities(3, 0.67, 0.33)
assert sum(result) == 3
assert result[1] >= 1 # ensure at least 1 val
assert result[1] >= 0 # val can be 0 when val_probability rounds down


class TestComputeClassWeights:
Expand Down
115 changes: 115 additions & 0 deletions tests/models/backbones/test_resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Tests for ResNet backbone."""

import pytest
import torch

from lightning_action.models.backbones.resnet import ResNetBackbone, RESNET_HIDDEN_SIZES


class TestResNetBackbone:
"""Test the ResNetBackbone class."""

@pytest.fixture
def default_config(self):
"""Default config using resnet18 (smallest/fastest)."""
return {'backbone': 'resnet18'}

def test_default_initialization(self):
"""Test initialization with default config (resnet50)."""
backbone = ResNetBackbone()
assert backbone.backbone_name == 'resnet50'
assert backbone.hidden_size == 2048

def test_custom_initialization(self, default_config):
"""Test initialization with custom config."""
backbone = ResNetBackbone(default_config)
assert backbone.backbone_name == 'resnet18'
assert backbone.hidden_size == 512

def test_all_supported_variants(self):
"""Test that all supported variants can be initialized."""
for name in RESNET_HIDDEN_SIZES:
backbone = ResNetBackbone({'backbone': name})
assert backbone.backbone_name == name
assert backbone.hidden_size == RESNET_HIDDEN_SIZES[name]

def test_invalid_backbone(self):
"""Test that unsupported backbone raises ValueError."""
with pytest.raises(ValueError, match="Unsupported backbone"):
ResNetBackbone({'backbone': 'resnet999'})

def test_properties(self, default_config):
"""Test backbone properties."""
backbone = ResNetBackbone(default_config)
assert backbone.hidden_size == 512
assert backbone.num_channels == 3
assert backbone.image_size == 224
assert backbone.patch_size == 32
assert backbone.backbone_name == 'resnet18'
assert backbone.backbone_type == 'resnet'

def test_hidden_size_mapping(self):
"""Test hidden size mapping for different variants."""
# resnet18/34 -> 512
for name in ['resnet18', 'resnet34']:
backbone = ResNetBackbone({'backbone': name})
assert backbone.hidden_size == 512

# resnet50/101/152 -> 2048
for name in ['resnet50', 'resnet101', 'resnet152']:
backbone = ResNetBackbone({'backbone': name})
assert backbone.hidden_size == 2048

def test_forward_pass_shape(self, default_config):
"""Test forward pass produces correct output shape."""
backbone = ResNetBackbone(default_config)
x = torch.randn(2, 3, 224, 224)
output = backbone(x)

assert output.shape == (2, 512, 7, 7)
assert torch.isfinite(output).all()

def test_forward_channel_validation(self, default_config):
"""Test that wrong input channels raises ValueError."""
backbone = ResNetBackbone(default_config)
x = torch.randn(2, 1, 224, 224) # wrong channels

with pytest.raises(ValueError, match="Input has 1 channels"):
backbone(x)

def test_forward_resnet50(self):
"""Test forward pass with resnet50 (bottleneck variant)."""
backbone = ResNetBackbone({'backbone': 'resnet50'})
x = torch.randn(2, 3, 224, 224)
output = backbone(x)

assert output.shape == (2, 2048, 7, 7)
assert torch.isfinite(output).all()

def test_gradient_flow(self, default_config):
"""Test that gradients flow through the model."""
backbone = ResNetBackbone(default_config)
x = torch.randn(1, 3, 224, 224, requires_grad=True)
output = backbone(x)

loss = output.sum()
loss.backward()

assert x.grad is not None
assert not torch.isnan(x.grad).any()

def test_load_pretrained_weights_file_not_found(self, default_config):
"""Test that missing checkpoint raises FileNotFoundError."""
backbone = ResNetBackbone(default_config)
with pytest.raises(FileNotFoundError, match="Checkpoint not found"):
backbone.load_pretrained_weights('/nonexistent/checkpoint.ckpt')

def test_get_last_layer_params(self, default_config):
"""Test that get_last_layer_params returns parameters from layer4."""
backbone = ResNetBackbone(default_config)
params = list(backbone.get_last_layer_params())

assert len(params) > 0
# All should be Parameters
for p in params:
assert isinstance(p, torch.nn.Parameter)
Loading