-
Notifications
You must be signed in to change notification settings - Fork 1
Develop final #33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Develop final #33
Changes from 13 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
4119333
Update resnet_beast.py
rowem5 dbe8416
Update utils.py
rowem5 e333b0f
Create test_rnn.py
rowem5 5beda8c
Create test_tcn.py
rowem5 73a5052
Create test_temporalmlp.py
rowem5 fcbb644
Create test_resnet_beast.py
rowem5 bf52a3e
Create test_resnet.py
rowem5 59e3f3b
Create test_vitmae.py
rowem5 95f982c
Delete tests/models/backbones/test_temporalmlp.py
rowem5 1ba1d1c
Delete tests/models/backbones/test_rnn.py
rowem5 f5ebb12
Delete tests/models/backbones/test_tcn.py
rowem5 bcb8c5d
Update test_utils.py
rowem5 dadb6fa
Update test_datamodule.py
rowem5 6efac5e
Update test_resnet_beast.py
rowem5 86e98cb
Update test_resnet.py
rowem5 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| """Tests for ResNet backbone.""" | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from lightning_action.models.backbones.resnet import ResNetBackbone, RESNET_HIDDEN_SIZES | ||
|
|
||
|
|
||
| class TestResNetBackbone: | ||
| """Test the ResNetBackbone class.""" | ||
|
|
||
| @pytest.fixture | ||
| def default_config(self): | ||
| """Default config using resnet18 (smallest/fastest).""" | ||
| return {'backbone': 'resnet18'} | ||
|
|
||
| def test_default_initialization(self): | ||
| """Test initialization with default config (resnet50).""" | ||
| backbone = ResNetBackbone() | ||
| assert backbone.backbone_name == 'resnet50' | ||
| assert backbone.hidden_size == 2048 | ||
|
|
||
| def test_custom_initialization(self, default_config): | ||
| """Test initialization with custom config.""" | ||
| backbone = ResNetBackbone(default_config) | ||
| assert backbone.backbone_name == 'resnet18' | ||
| assert backbone.hidden_size == 512 | ||
|
|
||
| def test_all_supported_variants(self): | ||
| """Test that all supported variants can be initialized.""" | ||
| for name in RESNET_HIDDEN_SIZES: | ||
| backbone = ResNetBackbone({'backbone': name}) | ||
| assert backbone.backbone_name == name | ||
| assert backbone.hidden_size == RESNET_HIDDEN_SIZES[name] | ||
|
|
||
| def test_invalid_backbone(self): | ||
| """Test that unsupported backbone raises ValueError.""" | ||
| with pytest.raises(ValueError, match="Unsupported backbone"): | ||
| ResNetBackbone({'backbone': 'resnet999'}) | ||
|
|
||
| def test_properties(self, default_config): | ||
| """Test backbone properties.""" | ||
| backbone = ResNetBackbone(default_config) | ||
| assert backbone.hidden_size == 512 | ||
| assert backbone.num_channels == 3 | ||
| assert backbone.image_size == 224 | ||
| assert backbone.patch_size == 32 | ||
| assert backbone.backbone_name == 'resnet18' | ||
| assert backbone.backbone_type == 'resnet' | ||
|
|
||
| def test_hidden_size_mapping(self): | ||
| """Test hidden size mapping for different variants.""" | ||
| # resnet18/34 -> 512 | ||
| for name in ['resnet18', 'resnet34']: | ||
| backbone = ResNetBackbone({'backbone': name}) | ||
| assert backbone.hidden_size == 512 | ||
|
|
||
| # resnet50/101/152 -> 2048 | ||
| for name in ['resnet50', 'resnet101', 'resnet152']: | ||
| backbone = ResNetBackbone({'backbone': name}) | ||
| assert backbone.hidden_size == 2048 | ||
|
|
||
| def test_forward_pass_shape(self, default_config): | ||
themattinthehatt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Test forward pass produces correct output shape.""" | ||
| backbone = ResNetBackbone(default_config) | ||
| x = torch.randn(2, 3, 224, 224) | ||
| output = backbone(x) | ||
|
|
||
| assert output.shape == (2, 512, 7, 7) | ||
| assert torch.isfinite(output).all() | ||
|
|
||
| def test_forward_channel_validation(self, default_config): | ||
| """Test that wrong input channels raises ValueError.""" | ||
| backbone = ResNetBackbone(default_config) | ||
| x = torch.randn(2, 1, 224, 224) # wrong channels | ||
|
|
||
| with pytest.raises(ValueError, match="Input has 1 channels"): | ||
| backbone(x) | ||
|
|
||
| def test_forward_resnet50(self): | ||
themattinthehatt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Test forward pass with resnet50 (bottleneck variant).""" | ||
| backbone = ResNetBackbone({'backbone': 'resnet50'}) | ||
| x = torch.randn(2, 3, 224, 224) | ||
| output = backbone(x) | ||
|
|
||
| assert output.shape == (2, 2048, 7, 7) | ||
| assert torch.isfinite(output).all() | ||
|
|
||
| def test_gradient_flow(self, default_config): | ||
| """Test that gradients flow through the model.""" | ||
| backbone = ResNetBackbone(default_config) | ||
| x = torch.randn(1, 3, 224, 224, requires_grad=True) | ||
| output = backbone(x) | ||
|
|
||
| loss = output.sum() | ||
| loss.backward() | ||
|
|
||
| assert x.grad is not None | ||
| assert not torch.isnan(x.grad).any() | ||
|
|
||
| def test_load_pretrained_weights_file_not_found(self, default_config): | ||
| """Test that missing checkpoint raises FileNotFoundError.""" | ||
| backbone = ResNetBackbone(default_config) | ||
| with pytest.raises(FileNotFoundError, match="Checkpoint not found"): | ||
| backbone.load_pretrained_weights('/nonexistent/checkpoint.ckpt') | ||
|
|
||
| def test_get_last_layer_params(self, default_config): | ||
| """Test that get_last_layer_params returns parameters from layer4.""" | ||
| backbone = ResNetBackbone(default_config) | ||
| params = list(backbone.get_last_layer_params()) | ||
|
|
||
| assert len(params) > 0 | ||
| # All should be Parameters | ||
| for p in params: | ||
| assert isinstance(p, torch.nn.Parameter) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.