Skip to content

Commit 6fc35a9

Browse files
committed
added initial simple linear
1 parent 3858ad3 commit 6fc35a9

File tree

5 files changed

+2900
-0
lines changed

5 files changed

+2900
-0
lines changed

Linear/__init__.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
14+
from .module import LinearModel
15+
from .lightning_module import LinearLightningModule
16+
from .estimator import LinearEstimator
17+
18+
__all__ = [
19+
"LinearModel",
20+
"LinearLightningModule",
21+
"LinearEstimator",
22+
]

Linear/estimator.py

+322
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,322 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
14+
from typing import List, Optional, Iterable, Dict, Any
15+
16+
import torch
17+
import pytorch_lightning as pl
18+
19+
from gluonts.core.component import validated
20+
from gluonts.dataset.common import Dataset
21+
from gluonts.dataset.field_names import FieldName
22+
from gluonts.dataset.loader import as_stacked_batches
23+
from gluonts.dataset.stat import calculate_dataset_statistics
24+
from gluonts.itertools import Cyclic
25+
from gluonts.model.forecast_generator import DistributionForecastGenerator
26+
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
27+
from gluonts.transform import (
28+
Transformation,
29+
AddObservedValuesIndicator,
30+
InstanceSampler,
31+
InstanceSplitter,
32+
ValidationSplitSampler,
33+
TestSplitSampler,
34+
ExpectedNumInstanceSampler,
35+
RemoveFields,
36+
SetField,
37+
AddTimeFeatures,
38+
AddAgeFeature,
39+
VstackFeatures,
40+
)
41+
from gluonts.time_feature import TimeFeature, time_features_from_frequency_str
42+
from gluonts.torch.model.estimator import PyTorchLightningEstimator
43+
from gluonts.torch.model.predictor import PyTorchPredictor
44+
from gluonts.torch.distributions import (
45+
DistributionOutput,
46+
StudentTOutput,
47+
)
48+
49+
from .lightning_module import LinearLightningModule
50+
51+
PREDICTION_INPUT_NAMES = [
52+
"feat_static_cat",
53+
"feat_static_real",
54+
"past_time_feat",
55+
"past_target",
56+
"past_observed_values",
57+
"future_time_feat",
58+
]
59+
60+
TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [
61+
"future_target",
62+
"future_observed_values",
63+
]
64+
65+
66+
class LinearEstimator(PyTorchLightningEstimator):
67+
"""
68+
An estimator training a Linear model for forecasting.
69+
70+
This class is uses the model defined in ``LinearModel``,
71+
and wraps it into a ``LinearLightningModule`` for training
72+
purposes: training is performed using PyTorch Lightning's ``pl.Trainer``
73+
class.
74+
75+
Parameters
76+
----------
77+
prediction_length
78+
Length of the prediction horizon.
79+
context_length
80+
Number of time steps prior to prediction time that the model
81+
takes as inputs (default: ``10 * prediction_length``).
82+
hidden_dimensions
83+
Size of hidden layers in the feed-forward network
84+
(default: ``[20, 20]``).
85+
lr
86+
Learning rate (default: ``1e-3``).
87+
weight_decay
88+
Weight decay regularization parameter (default: ``1e-8``).
89+
distr_output
90+
Distribution to use to evaluate observations and sample predictions
91+
(default: StudentTOutput()).
92+
loss
93+
Loss to be optimized during training
94+
(default: ``NegativeLogLikelihood()``).
95+
batch_norm
96+
Whether to apply batch normalization.
97+
batch_size
98+
The size of the batches to be used for training (default: 32).
99+
num_batches_per_epoch
100+
Number of batches to be processed in each training epoch
101+
(default: 50).
102+
trainer_kwargs
103+
Additional arguments to provide to ``pl.Trainer`` for construction.
104+
train_sampler
105+
Controls the sampling of windows during training.
106+
validation_sampler
107+
Controls the sampling of windows during validation.
108+
109+
"""
110+
111+
@validated()
112+
def __init__(
113+
self,
114+
freq: str,
115+
prediction_length: int,
116+
context_length: Optional[int] = None,
117+
hidden_dimensions: Optional[List[int]] = None,
118+
input_size: int = 1,
119+
scaling: Optional[str] = "mean",
120+
num_feat_dynamic_real: int = 0,
121+
num_feat_static_cat: int = 0,
122+
num_feat_static_real: int = 0,
123+
cardinality: Optional[List[int]] = None,
124+
embedding_dimension: Optional[List[int]] = None,
125+
time_features: Optional[List[TimeFeature]] = None,
126+
lr: float = 1e-3,
127+
weight_decay: float = 1e-8,
128+
distr_output: DistributionOutput = StudentTOutput(),
129+
loss: DistributionLoss = NegativeLogLikelihood(),
130+
batch_norm: bool = False,
131+
batch_size: int = 32,
132+
num_batches_per_epoch: int = 50,
133+
trainer_kwargs: Optional[Dict[str, Any]] = None,
134+
train_sampler: Optional[InstanceSampler] = None,
135+
validation_sampler: Optional[InstanceSampler] = None,
136+
) -> None:
137+
default_trainer_kwargs = {"max_epochs": 100, "gradient_clip_val": 10.0}
138+
if trainer_kwargs is not None:
139+
default_trainer_kwargs.update(trainer_kwargs)
140+
super().__init__(trainer_kwargs=default_trainer_kwargs)
141+
142+
self.scaling = scaling
143+
self.freq = freq
144+
self.input_size = input_size
145+
self.prediction_length = prediction_length
146+
self.context_length = context_length or 10 * prediction_length
147+
self.num_feat_dynamic_real = num_feat_dynamic_real
148+
self.num_feat_static_cat = num_feat_static_cat
149+
self.num_feat_static_real = num_feat_static_real
150+
self.cardinality = (
151+
cardinality if cardinality and num_feat_static_cat > 0 else [1]
152+
)
153+
self.embedding_dimension = embedding_dimension
154+
self.time_features = (
155+
time_features
156+
if time_features is not None
157+
else time_features_from_frequency_str(self.freq)
158+
)
159+
# TODO find way to enforce same defaults to network and estimator
160+
# somehow
161+
self.hidden_dimensions = hidden_dimensions or [20, 20]
162+
self.lr = lr
163+
self.weight_decay = weight_decay
164+
self.distr_output = distr_output
165+
self.loss = loss
166+
self.batch_norm = batch_norm
167+
self.batch_size = batch_size
168+
self.num_batches_per_epoch = num_batches_per_epoch
169+
170+
self.train_sampler = train_sampler or ExpectedNumInstanceSampler(
171+
num_instances=1.0, min_future=prediction_length
172+
)
173+
self.validation_sampler = validation_sampler or ValidationSplitSampler(
174+
min_future=prediction_length
175+
)
176+
177+
@classmethod
178+
def derive_auto_fields(cls, train_iter):
179+
stats = calculate_dataset_statistics(train_iter)
180+
181+
return {
182+
"num_feat_dynamic_real": stats.num_feat_dynamic_real,
183+
"num_feat_static_cat": len(stats.feat_static_cat),
184+
"cardinality": [len(cats) for cats in stats.feat_static_cat],
185+
}
186+
187+
def create_transformation(self) -> Transformation:
188+
remove_field_names = []
189+
if self.num_feat_static_real == 0:
190+
remove_field_names.append(FieldName.FEAT_STATIC_REAL)
191+
if self.num_feat_dynamic_real == 0:
192+
remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL)
193+
194+
return (
195+
RemoveFields(field_names=remove_field_names)
196+
+ (
197+
SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])
198+
if not self.num_feat_static_cat > 0
199+
else []
200+
)
201+
+ (
202+
SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0])
203+
if not self.num_feat_static_real > 0
204+
else []
205+
)
206+
+ AddTimeFeatures(
207+
start_field=FieldName.START,
208+
target_field=FieldName.TARGET,
209+
output_field=FieldName.FEAT_TIME,
210+
time_features=self.time_features,
211+
pred_length=self.prediction_length,
212+
)
213+
+ AddAgeFeature(
214+
target_field=FieldName.TARGET,
215+
output_field=FieldName.FEAT_AGE,
216+
pred_length=self.prediction_length,
217+
log_scale=True,
218+
)
219+
+ VstackFeatures(
220+
output_field=FieldName.FEAT_TIME,
221+
input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]
222+
+ (
223+
[FieldName.FEAT_DYNAMIC_REAL]
224+
if self.num_feat_dynamic_real > 0
225+
else []
226+
),
227+
)
228+
+ AddObservedValuesIndicator(
229+
target_field=FieldName.TARGET,
230+
output_field=FieldName.OBSERVED_VALUES,
231+
)
232+
)
233+
234+
def create_lightning_module(self) -> pl.LightningModule:
235+
return LinearLightningModule(
236+
loss=self.loss,
237+
lr=self.lr,
238+
weight_decay=self.weight_decay,
239+
model_kwargs={
240+
"input_size": self.input_size,
241+
"prediction_length": self.prediction_length,
242+
"context_length": self.context_length,
243+
"hidden_dimensions": self.hidden_dimensions,
244+
"scaling": self.scaling,
245+
"distr_output": self.distr_output,
246+
"batch_norm": self.batch_norm,
247+
},
248+
)
249+
250+
def _create_instance_splitter(self, module: LinearLightningModule, mode: str):
251+
assert mode in ["training", "validation", "test"]
252+
253+
instance_sampler = {
254+
"training": self.train_sampler,
255+
"validation": self.validation_sampler,
256+
"test": TestSplitSampler(),
257+
}[mode]
258+
259+
return InstanceSplitter(
260+
target_field=FieldName.TARGET,
261+
is_pad_field=FieldName.IS_PAD,
262+
start_field=FieldName.START,
263+
forecast_start_field=FieldName.FORECAST_START,
264+
instance_sampler=instance_sampler,
265+
past_length=self.context_length,
266+
future_length=self.prediction_length,
267+
time_series_fields=[FieldName.FEAT_TIME, FieldName.OBSERVED_VALUES],
268+
dummy_value=self.distr_output.value_in_support,
269+
)
270+
271+
def create_training_data_loader(
272+
self,
273+
data: Dataset,
274+
module: LinearLightningModule,
275+
shuffle_buffer_length: Optional[int] = None,
276+
**kwargs,
277+
) -> Iterable:
278+
data = Cyclic(data).stream()
279+
instances = self._create_instance_splitter(module, "training").apply(
280+
data, is_train=True
281+
)
282+
return as_stacked_batches(
283+
instances,
284+
batch_size=self.batch_size,
285+
shuffle_buffer_length=shuffle_buffer_length,
286+
field_names=TRAINING_INPUT_NAMES,
287+
output_type=torch.tensor,
288+
num_batches_per_epoch=self.num_batches_per_epoch,
289+
)
290+
291+
def create_validation_data_loader(
292+
self,
293+
data: Dataset,
294+
module: LinearLightningModule,
295+
**kwargs,
296+
) -> Iterable:
297+
instances = self._create_instance_splitter(module, "validation").apply(
298+
data, is_train=True
299+
)
300+
return as_stacked_batches(
301+
instances,
302+
batch_size=self.batch_size,
303+
field_names=TRAINING_INPUT_NAMES,
304+
output_type=torch.tensor,
305+
)
306+
307+
def create_predictor(
308+
self,
309+
transformation: Transformation,
310+
module,
311+
) -> PyTorchPredictor:
312+
prediction_splitter = self._create_instance_splitter(module, "test")
313+
314+
return PyTorchPredictor(
315+
input_transform=transformation + prediction_splitter,
316+
input_names=PREDICTION_INPUT_NAMES,
317+
prediction_net=module,
318+
forecast_generator=DistributionForecastGenerator(self.distr_output),
319+
batch_size=self.batch_size,
320+
prediction_length=self.prediction_length,
321+
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
322+
)

0 commit comments

Comments
 (0)