Skip to content

[BUG] DeepARModel and TFTModel don't work on pytorch_lightning>=1.9.1 #1130

Open
1 task done
Mr-Geekman opened this issue Feb 27, 2023 · 1 comment
Open
1 task done
Labels
bug Something isn't working priority/medium Medium priority task

Comments

@Mr-Geekman
Copy link
Contributor

🐛 Bug Report

DeepARModel and TFTModel don't work on pytorch_lightning>=1.9.1.

Fitting fails with error:

AttributeError: 'tuple' object has no attribute 'items'

As I understand, it is connected to the issue in pytorch_forecasting library: 'tuple' object has no attribute 'items' in models.

Expected behavior

Everything works fine.

How To Reproduce

Script to check TFTModel (with DeepAR error is the same).

import pandas as pd
import numpy as np

from etna.datasets.tsdataset import TSDataset
from etna.pipeline import Pipeline
from etna.transforms import DateFlagsTransform
from etna.transforms import LagTransform
from etna.transforms import PytorchForecastingTransform
from pytorch_forecasting.data import GroupNormalizer
from etna.models.nn import TFTModel


original_df = pd.DataFrame(np.array([["2021-05-31", 1, 3],
                                     ["2021-06-07", 1, 6],
                                     ["2021-06-14", 1, 9],
                                     ["2021-06-21", 1, 12],
                                     ["2021-06-28", 1, 15]]),
                           columns=['timestamp', 'segment', 'target'])
original_df['timestamp'] = pd.to_datetime(original_df['timestamp'])
original_df['target'] = original_df['target'].astype(float)
df = TSDataset.to_dataset(original_df)
ts = TSDataset(df, freq="W-MON")

HORIZON = 1
transform_date = DateFlagsTransform(day_number_in_week=True, day_number_in_month=False, out_column="dateflag")
num_lags = 2
transform_lag = LagTransform(
    in_column="target",
    lags=[HORIZON + i for i in range(num_lags)],
    out_column="target_lag",
)

transform_tft = PytorchForecastingTransform(
    max_encoder_length=HORIZON,
    max_prediction_length=HORIZON,
    time_varying_known_reals=["time_idx"],
    time_varying_unknown_reals=["target"],
    time_varying_known_categoricals=["dateflag_day_number_in_week"],
    static_categoricals=["segment"],
    target_normalizer=GroupNormalizer(groups=["segment"]),
)
model_tft = TFTModel(max_epochs=5, learning_rate=[0.1], gpus=0, batch_size=64)

pipeline_tft = Pipeline(
    model=model_tft,
    horizon=HORIZON,
    transforms=[transform_lag, transform_date, transform_tft],
)

pipeline_tft.fit(ts)

Script fails on pipline_tft.fit(ts) with error:

AttributeError: 'tuple' object has no attribute 'items'

Environment

No response

Additional context

No response

Checklist

  • Bug appears at the latest library version
@Mr-Geekman Mr-Geekman added the bug Something isn't working label Feb 27, 2023
@Mr-Geekman
Copy link
Contributor Author

Package pytorch_forecasting was updated recently.
It looks like they probably solved the problem with pytorch_lightning there. But other packages requirements are very strict.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
bug Something isn't working priority/medium Medium priority task
Projects
Status: Todo
Development

No branches or pull requests

1 participant