|
| 1 | +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). |
| 4 | +# You may not use this file except in compliance with the License. |
| 5 | +# A copy of the License is located at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# or in the "license" file accompanying this file. This file is distributed |
| 10 | +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either |
| 11 | +# express or implied. See the License for the specific language governing |
| 12 | +# permissions and limitations under the License. |
| 13 | + |
| 14 | +from typing import List, Optional, Iterable, Dict, Any |
| 15 | + |
| 16 | +import torch |
| 17 | +import pytorch_lightning as pl |
| 18 | + |
| 19 | +from gluonts.core.component import validated |
| 20 | +from gluonts.dataset.common import Dataset |
| 21 | +from gluonts.dataset.field_names import FieldName |
| 22 | +from gluonts.dataset.loader import as_stacked_batches |
| 23 | +from gluonts.dataset.stat import calculate_dataset_statistics |
| 24 | +from gluonts.itertools import Cyclic |
| 25 | +from gluonts.model.forecast_generator import DistributionForecastGenerator |
| 26 | +from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood |
| 27 | +from gluonts.transform import ( |
| 28 | + Transformation, |
| 29 | + AddObservedValuesIndicator, |
| 30 | + InstanceSampler, |
| 31 | + InstanceSplitter, |
| 32 | + ValidationSplitSampler, |
| 33 | + TestSplitSampler, |
| 34 | + ExpectedNumInstanceSampler, |
| 35 | + RemoveFields, |
| 36 | + SetField, |
| 37 | + AddTimeFeatures, |
| 38 | + AddAgeFeature, |
| 39 | + VstackFeatures, |
| 40 | +) |
| 41 | +from gluonts.time_feature import TimeFeature, time_features_from_frequency_str |
| 42 | +from gluonts.torch.model.estimator import PyTorchLightningEstimator |
| 43 | +from gluonts.torch.model.predictor import PyTorchPredictor |
| 44 | +from gluonts.torch.distributions import ( |
| 45 | + DistributionOutput, |
| 46 | + StudentTOutput, |
| 47 | +) |
| 48 | + |
| 49 | +from .lightning_module import LinearLightningModule |
| 50 | + |
| 51 | +PREDICTION_INPUT_NAMES = [ |
| 52 | + "feat_static_cat", |
| 53 | + "feat_static_real", |
| 54 | + "past_time_feat", |
| 55 | + "past_target", |
| 56 | + "past_observed_values", |
| 57 | + "future_time_feat", |
| 58 | +] |
| 59 | + |
| 60 | +TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [ |
| 61 | + "future_target", |
| 62 | + "future_observed_values", |
| 63 | +] |
| 64 | + |
| 65 | + |
| 66 | +class LinearEstimator(PyTorchLightningEstimator): |
| 67 | + """ |
| 68 | + An estimator training a Linear model for forecasting. |
| 69 | +
|
| 70 | + This class is uses the model defined in ``LinearModel``, |
| 71 | + and wraps it into a ``LinearLightningModule`` for training |
| 72 | + purposes: training is performed using PyTorch Lightning's ``pl.Trainer`` |
| 73 | + class. |
| 74 | +
|
| 75 | + Parameters |
| 76 | + ---------- |
| 77 | + prediction_length |
| 78 | + Length of the prediction horizon. |
| 79 | + context_length |
| 80 | + Number of time steps prior to prediction time that the model |
| 81 | + takes as inputs (default: ``10 * prediction_length``). |
| 82 | + hidden_dimensions |
| 83 | + Size of hidden layers in the feed-forward network |
| 84 | + (default: ``[20, 20]``). |
| 85 | + lr |
| 86 | + Learning rate (default: ``1e-3``). |
| 87 | + weight_decay |
| 88 | + Weight decay regularization parameter (default: ``1e-8``). |
| 89 | + distr_output |
| 90 | + Distribution to use to evaluate observations and sample predictions |
| 91 | + (default: StudentTOutput()). |
| 92 | + loss |
| 93 | + Loss to be optimized during training |
| 94 | + (default: ``NegativeLogLikelihood()``). |
| 95 | + batch_norm |
| 96 | + Whether to apply batch normalization. |
| 97 | + batch_size |
| 98 | + The size of the batches to be used for training (default: 32). |
| 99 | + num_batches_per_epoch |
| 100 | + Number of batches to be processed in each training epoch |
| 101 | + (default: 50). |
| 102 | + trainer_kwargs |
| 103 | + Additional arguments to provide to ``pl.Trainer`` for construction. |
| 104 | + train_sampler |
| 105 | + Controls the sampling of windows during training. |
| 106 | + validation_sampler |
| 107 | + Controls the sampling of windows during validation. |
| 108 | +
|
| 109 | + """ |
| 110 | + |
| 111 | + @validated() |
| 112 | + def __init__( |
| 113 | + self, |
| 114 | + freq: str, |
| 115 | + prediction_length: int, |
| 116 | + context_length: Optional[int] = None, |
| 117 | + hidden_dimensions: Optional[List[int]] = None, |
| 118 | + input_size: int = 1, |
| 119 | + scaling: Optional[str] = "mean", |
| 120 | + num_feat_dynamic_real: int = 0, |
| 121 | + num_feat_static_cat: int = 0, |
| 122 | + num_feat_static_real: int = 0, |
| 123 | + cardinality: Optional[List[int]] = None, |
| 124 | + embedding_dimension: Optional[List[int]] = None, |
| 125 | + time_features: Optional[List[TimeFeature]] = None, |
| 126 | + lr: float = 1e-3, |
| 127 | + weight_decay: float = 1e-8, |
| 128 | + distr_output: DistributionOutput = StudentTOutput(), |
| 129 | + loss: DistributionLoss = NegativeLogLikelihood(), |
| 130 | + batch_norm: bool = False, |
| 131 | + batch_size: int = 32, |
| 132 | + num_batches_per_epoch: int = 50, |
| 133 | + trainer_kwargs: Optional[Dict[str, Any]] = None, |
| 134 | + train_sampler: Optional[InstanceSampler] = None, |
| 135 | + validation_sampler: Optional[InstanceSampler] = None, |
| 136 | + ) -> None: |
| 137 | + default_trainer_kwargs = {"max_epochs": 100, "gradient_clip_val": 10.0} |
| 138 | + if trainer_kwargs is not None: |
| 139 | + default_trainer_kwargs.update(trainer_kwargs) |
| 140 | + super().__init__(trainer_kwargs=default_trainer_kwargs) |
| 141 | + |
| 142 | + self.scaling = scaling |
| 143 | + self.freq = freq |
| 144 | + self.input_size = input_size |
| 145 | + self.prediction_length = prediction_length |
| 146 | + self.context_length = context_length or 10 * prediction_length |
| 147 | + self.num_feat_dynamic_real = num_feat_dynamic_real |
| 148 | + self.num_feat_static_cat = num_feat_static_cat |
| 149 | + self.num_feat_static_real = num_feat_static_real |
| 150 | + self.cardinality = ( |
| 151 | + cardinality if cardinality and num_feat_static_cat > 0 else [1] |
| 152 | + ) |
| 153 | + self.embedding_dimension = embedding_dimension |
| 154 | + self.time_features = ( |
| 155 | + time_features |
| 156 | + if time_features is not None |
| 157 | + else time_features_from_frequency_str(self.freq) |
| 158 | + ) |
| 159 | + # TODO find way to enforce same defaults to network and estimator |
| 160 | + # somehow |
| 161 | + self.hidden_dimensions = hidden_dimensions or [20, 20] |
| 162 | + self.lr = lr |
| 163 | + self.weight_decay = weight_decay |
| 164 | + self.distr_output = distr_output |
| 165 | + self.loss = loss |
| 166 | + self.batch_norm = batch_norm |
| 167 | + self.batch_size = batch_size |
| 168 | + self.num_batches_per_epoch = num_batches_per_epoch |
| 169 | + |
| 170 | + self.train_sampler = train_sampler or ExpectedNumInstanceSampler( |
| 171 | + num_instances=1.0, min_future=prediction_length |
| 172 | + ) |
| 173 | + self.validation_sampler = validation_sampler or ValidationSplitSampler( |
| 174 | + min_future=prediction_length |
| 175 | + ) |
| 176 | + |
| 177 | + @classmethod |
| 178 | + def derive_auto_fields(cls, train_iter): |
| 179 | + stats = calculate_dataset_statistics(train_iter) |
| 180 | + |
| 181 | + return { |
| 182 | + "num_feat_dynamic_real": stats.num_feat_dynamic_real, |
| 183 | + "num_feat_static_cat": len(stats.feat_static_cat), |
| 184 | + "cardinality": [len(cats) for cats in stats.feat_static_cat], |
| 185 | + } |
| 186 | + |
| 187 | + def create_transformation(self) -> Transformation: |
| 188 | + remove_field_names = [] |
| 189 | + if self.num_feat_static_real == 0: |
| 190 | + remove_field_names.append(FieldName.FEAT_STATIC_REAL) |
| 191 | + if self.num_feat_dynamic_real == 0: |
| 192 | + remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL) |
| 193 | + |
| 194 | + return ( |
| 195 | + RemoveFields(field_names=remove_field_names) |
| 196 | + + ( |
| 197 | + SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0]) |
| 198 | + if not self.num_feat_static_cat > 0 |
| 199 | + else [] |
| 200 | + ) |
| 201 | + + ( |
| 202 | + SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0]) |
| 203 | + if not self.num_feat_static_real > 0 |
| 204 | + else [] |
| 205 | + ) |
| 206 | + + AddTimeFeatures( |
| 207 | + start_field=FieldName.START, |
| 208 | + target_field=FieldName.TARGET, |
| 209 | + output_field=FieldName.FEAT_TIME, |
| 210 | + time_features=self.time_features, |
| 211 | + pred_length=self.prediction_length, |
| 212 | + ) |
| 213 | + + AddAgeFeature( |
| 214 | + target_field=FieldName.TARGET, |
| 215 | + output_field=FieldName.FEAT_AGE, |
| 216 | + pred_length=self.prediction_length, |
| 217 | + log_scale=True, |
| 218 | + ) |
| 219 | + + VstackFeatures( |
| 220 | + output_field=FieldName.FEAT_TIME, |
| 221 | + input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE] |
| 222 | + + ( |
| 223 | + [FieldName.FEAT_DYNAMIC_REAL] |
| 224 | + if self.num_feat_dynamic_real > 0 |
| 225 | + else [] |
| 226 | + ), |
| 227 | + ) |
| 228 | + + AddObservedValuesIndicator( |
| 229 | + target_field=FieldName.TARGET, |
| 230 | + output_field=FieldName.OBSERVED_VALUES, |
| 231 | + ) |
| 232 | + ) |
| 233 | + |
| 234 | + def create_lightning_module(self) -> pl.LightningModule: |
| 235 | + return LinearLightningModule( |
| 236 | + loss=self.loss, |
| 237 | + lr=self.lr, |
| 238 | + weight_decay=self.weight_decay, |
| 239 | + model_kwargs={ |
| 240 | + "input_size": self.input_size, |
| 241 | + "prediction_length": self.prediction_length, |
| 242 | + "context_length": self.context_length, |
| 243 | + "hidden_dimensions": self.hidden_dimensions, |
| 244 | + "scaling": self.scaling, |
| 245 | + "distr_output": self.distr_output, |
| 246 | + "batch_norm": self.batch_norm, |
| 247 | + }, |
| 248 | + ) |
| 249 | + |
| 250 | + def _create_instance_splitter(self, module: LinearLightningModule, mode: str): |
| 251 | + assert mode in ["training", "validation", "test"] |
| 252 | + |
| 253 | + instance_sampler = { |
| 254 | + "training": self.train_sampler, |
| 255 | + "validation": self.validation_sampler, |
| 256 | + "test": TestSplitSampler(), |
| 257 | + }[mode] |
| 258 | + |
| 259 | + return InstanceSplitter( |
| 260 | + target_field=FieldName.TARGET, |
| 261 | + is_pad_field=FieldName.IS_PAD, |
| 262 | + start_field=FieldName.START, |
| 263 | + forecast_start_field=FieldName.FORECAST_START, |
| 264 | + instance_sampler=instance_sampler, |
| 265 | + past_length=self.context_length, |
| 266 | + future_length=self.prediction_length, |
| 267 | + time_series_fields=[FieldName.FEAT_TIME, FieldName.OBSERVED_VALUES], |
| 268 | + dummy_value=self.distr_output.value_in_support, |
| 269 | + ) |
| 270 | + |
| 271 | + def create_training_data_loader( |
| 272 | + self, |
| 273 | + data: Dataset, |
| 274 | + module: LinearLightningModule, |
| 275 | + shuffle_buffer_length: Optional[int] = None, |
| 276 | + **kwargs, |
| 277 | + ) -> Iterable: |
| 278 | + data = Cyclic(data).stream() |
| 279 | + instances = self._create_instance_splitter(module, "training").apply( |
| 280 | + data, is_train=True |
| 281 | + ) |
| 282 | + return as_stacked_batches( |
| 283 | + instances, |
| 284 | + batch_size=self.batch_size, |
| 285 | + shuffle_buffer_length=shuffle_buffer_length, |
| 286 | + field_names=TRAINING_INPUT_NAMES, |
| 287 | + output_type=torch.tensor, |
| 288 | + num_batches_per_epoch=self.num_batches_per_epoch, |
| 289 | + ) |
| 290 | + |
| 291 | + def create_validation_data_loader( |
| 292 | + self, |
| 293 | + data: Dataset, |
| 294 | + module: LinearLightningModule, |
| 295 | + **kwargs, |
| 296 | + ) -> Iterable: |
| 297 | + instances = self._create_instance_splitter(module, "validation").apply( |
| 298 | + data, is_train=True |
| 299 | + ) |
| 300 | + return as_stacked_batches( |
| 301 | + instances, |
| 302 | + batch_size=self.batch_size, |
| 303 | + field_names=TRAINING_INPUT_NAMES, |
| 304 | + output_type=torch.tensor, |
| 305 | + ) |
| 306 | + |
| 307 | + def create_predictor( |
| 308 | + self, |
| 309 | + transformation: Transformation, |
| 310 | + module, |
| 311 | + ) -> PyTorchPredictor: |
| 312 | + prediction_splitter = self._create_instance_splitter(module, "test") |
| 313 | + |
| 314 | + return PyTorchPredictor( |
| 315 | + input_transform=transformation + prediction_splitter, |
| 316 | + input_names=PREDICTION_INPUT_NAMES, |
| 317 | + prediction_net=module, |
| 318 | + forecast_generator=DistributionForecastGenerator(self.distr_output), |
| 319 | + batch_size=self.batch_size, |
| 320 | + prediction_length=self.prediction_length, |
| 321 | + device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), |
| 322 | + ) |
0 commit comments