Skip to content

Commit 50bb492

Browse files
authored
[BUG] Ensure compatibility with lightning > 2.6 and above (#625)
Fixes #623, Fixes #621 Adds a wrapper around `load_from_checkpoint` to ensure compatibility with all versions
1 parent 7f312cb commit 50bb492

File tree

4 files changed

+96
-2
lines changed

4 files changed

+96
-2
lines changed

.github/workflows/testing.yml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,41 @@ jobs:
5353
- name: Run test-suite
5454
run: python -m pytest -v
5555

56+
test-deps-2024:
57+
runs-on: ${{ matrix.os }}
58+
strategy:
59+
fail-fast: false
60+
matrix:
61+
os: [ubuntu-latest]
62+
python-version: ["3.12"]
63+
64+
steps:
65+
- uses: actions/checkout@v6
66+
67+
- run: git remote set-branches origin 'main'
68+
69+
- run: git fetch --depth 1
70+
71+
- name: Install uv
72+
uses: astral-sh/setup-uv@v7
73+
with:
74+
enable-cache: true
75+
76+
- name: Set up Python ${{ matrix.python-version }}
77+
uses: actions/setup-python@v6
78+
with:
79+
python-version: ${{ matrix.python-version }}
80+
81+
- name: Install main package & dependencies
82+
run: uv pip install -e .[dev,dependencies_2024]
83+
env:
84+
UV_SYSTEM_PYTHON: 1
85+
86+
- name: Show installed packages
87+
run: uv pip list
88+
89+
- name: Run test-suite
90+
run: python -m pytest -v
5691

5792
test-all-extras:
5893
runs-on: ${{ matrix.os }}

pyproject.toml

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ dependencies = [
4343
"numpy<=3.0.0",
4444
"pandas>=1.1.5,<3.0.0",
4545
"scikit-learn>=1.3.0,<2.0",
46-
"pytorch-lightning>=2.0.0,<2.5.0",
46+
"pytorch-lightning>=2.0.0,<2.7.0",
4747
"scipy>=1.8,<2.0",
4848
"omegaconf>=2.3.0",
4949
"torchmetrics>=0.10.0,<1.9.0",
@@ -80,6 +80,19 @@ notebooks = [
8080
"matplotlib>3.1",
8181
]
8282

83+
# Core Dep set in Nov 2024
84+
dependencies_2024 = [
85+
"torch==2.5.0",
86+
"numpy==2.2.0",
87+
"pandas==2.2.3",
88+
"scikit-learn==1.6.0",
89+
"pytorch-lightning==2.4.0",
90+
"scipy==1.14.1",
91+
"omegaconf==2.3.0",
92+
"torchmetrics==1.5.2",
93+
"einops==0.8.0",
94+
]
95+
8396

8497
[project.urls]
8598
Homepage = "https://github.com/pytorch-tabular/pytorch_tabular"

src/pytorch_tabular/models/base_model.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from abc import ABCMeta, abstractmethod
99
from functools import partial
1010
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
11+
from pathlib import Path
1112

1213
import numpy as np
1314
import pytorch_lightning as pl
@@ -206,6 +207,28 @@ def embedding_layer(self):
206207
def head(self):
207208
raise NotImplementedError("head property needs to be implemented by inheriting classes")
208209

210+
@classmethod
211+
def load_from_checkpoint(
212+
cls,
213+
checkpoint_path: Union[str, Path],
214+
map_location=None,
215+
strict=True,
216+
**kwargs,
217+
):
218+
from skbase.utils.dependencies import _check_soft_dependencies
219+
220+
if not _check_soft_dependencies("pytorch_lightning<2.6", severity="none"):
221+
if "weights_only" not in kwargs:
222+
kwargs["weights_only"] = False
223+
else:
224+
kwargs.pop("weights_only", None)
225+
return super().load_from_checkpoint(
226+
checkpoint_path,
227+
map_location=map_location,
228+
strict=strict,
229+
**kwargs,
230+
)
231+
209232
def _check_and_verify(self):
210233
assert hasattr(self, "backbone"), "Model has no attribute called `backbone`"
211234
assert hasattr(self.backbone, "output_dim"), "Backbone needs to have attribute `output_dim`"

src/pytorch_tabular/ssl_models/base_model.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
import warnings
77
from abc import ABCMeta, abstractmethod
8-
from typing import Dict, Optional
8+
from typing import Dict, Optional, Union
9+
from pathlib import Path
910

1011
import pytorch_lightning as pl
1112
import torch
@@ -151,6 +152,28 @@ def forward(self, x: Dict):
151152
def featurize(self, x: Dict):
152153
pass
153154

155+
@classmethod
156+
def load_from_checkpoint(
157+
cls,
158+
checkpoint_path: Union[str, Path],
159+
map_location=None,
160+
strict=True,
161+
**kwargs,
162+
):
163+
from skbase.utils.dependencies import _check_soft_dependencies
164+
165+
if not _check_soft_dependencies("pytorch_lightning<2.6", severity="none"):
166+
if "weights_only" not in kwargs:
167+
kwargs["weights_only"] = False
168+
else:
169+
kwargs.pop("weights_only", None)
170+
return super().load_from_checkpoint(
171+
checkpoint_path,
172+
map_location=map_location,
173+
strict=strict,
174+
**kwargs,
175+
)
176+
154177
def predict(self, x: Dict, ret_model_output: bool = True): # ret_model_output only for compatibility
155178
assert ret_model_output, "ret_model_output must be True in case of SSL predict"
156179
return self.featurize(x)

0 commit comments

Comments
 (0)