diff --git a/examples/multiple-trajectory-prediction/README.md b/examples/multiple-trajectory-prediction/README.md
new file mode 100644
index 0000000..9ffaad0
--- /dev/null
+++ b/examples/multiple-trajectory-prediction/README.md
@@ -0,0 +1,84 @@
+# ZOD Multi-Trajectory Pipeline
+
+This repository provides a pipeline for training, evaluating, and visualizing multi-trajectory prediction models on the ZOD dataset.
+
+The method predicts multiple possible future paths that a driver may take, based on a single frontal camera image. Predictions are expressed in ground-plane coordinates (GPS) relative to the vehicle’s starting position.
+
+Ground truth trajectories are generated from the GPS data of where the vehicle actually traveled after each image was captured.
+
+For visualization, both predicted and ground-truth trajectories are projected into the **camera image plane** using calibration parameters. This overlays the paths directly on top of the corresponding camera frame, making it possible to compare predicted trajectories with the driver’s true path.
+
+
+## Example Results
+
+After training, the pipeline stores prediction visualizations in the `results/` directory.
+Here is an example of a predicted trajectory overlayed on the ZOD dataset, where purple lines are predictions and the green line is the ground truth:
+
+
+
+
+---
+
+## Setup
+
+### 1. Create and activate a virtual environment
+```bash
+python3 -m venv .venv
+source .venv/bin/activate
+pip install -r requirements.txt
+```
+
+### 2. Configure paths
+Update dataset and model configuration in:
+
+- `model/model_config.py`
+- `dataset/zod_configs.py`
+
+In particular, set for example:
+```python
+STORED_GROUND_TRUTH_PATH = "/mnt/ZOD/ground_truth.json"
+DATASET_ROOT = "/mnt/ZOD"
+```
+
+Other parameters may be left as default or modified as needed.
+
+---
+
+## Data Preparation
+
+### 1. Resize images
+The pipeline has been tested with images resized to **256×256**.
+
+Run the resize script:
+```bash
+python resize_and_save.py
+```
+
+Make sure `IMG_SIZE` and `DATASET_ROOT` in `dataset/zod_configs.py` are set correctly.
+
+Alternatively, you can work with full-sized images by setting
+`USE_PRE_RESIZED_IMGS = False` in `dataset/zod_configs.py` and skipping this step (not tested).
+
+### 2. Generate ground truth
+Ground truth is generated by looking ahead in future frames. Run the ground truth generation script:
+```bash
+python generate_ground_truth.py
+```
+
+This script calls utility functions defined in `dataset/groundtruth_utils.py`.
+
+Ensure that the path specified in `STORED_GROUND_TRUTH_PATH` in `dataset/zod_configs.py` is valid.
+
+---
+
+## Training & Evaluation
+
+Run the main training script:
+`python main.py`
+
+This will:
+- Train the model
+- Evaluate performance
+- Store example predictions in the `results/` directory
+
+---
\ No newline at end of file
diff --git a/examples/multiple-trajectory-prediction/dataset/groundtruth_utils.py b/examples/multiple-trajectory-prediction/dataset/groundtruth_utils.py
new file mode 100644
index 0000000..a0d975a
--- /dev/null
+++ b/examples/multiple-trajectory-prediction/dataset/groundtruth_utils.py
@@ -0,0 +1,97 @@
+"""Groundtruth utilities."""
+
+import json
+from typing import Set
+
+import numpy as np
+from dataset.zod_configs import ZodConfigs
+from tqdm import tqdm
+
+from zod import ZodFrames
+
+
+def get_ground_truth(zod_frames: ZodFrames, frame_id: int, zod_configs: ZodConfigs) -> np.array:
+ """Get true holistic path from future GPS locations.
+
+ Args:
+ zod_frames (ZodDataset): ZodDataset
+ frame_id (int): frame id
+ zod_configs (ZodConfigs): zod configs dataclass
+
+ Returns:
+ np.array: true path
+
+ """
+ zod_frame = zod_frames[frame_id]
+ oxts = zod_frame.oxts
+ key_timestamp = zod_frame.info.keyframe_time.timestamp()
+
+ # get posses associated with frame timestamp
+ try:
+ current_pose = oxts.get_poses(key_timestamp)
+ all_poses = oxts.poses[oxts.timestamps >= key_timestamp]
+ transformed_poses = np.linalg.pinv(current_pose) @ all_poses
+ translations = transformed_poses[:, :3, 3]
+ distances = np.linalg.norm(np.diff(translations, axis=0), axis=1)
+ accumulated_distances = np.cumsum(distances).astype(int).tolist()
+
+ # get the poses that each have a point having a distance from TARGET_DISTANCES
+ pose_idx = [accumulated_distances.index(i) for i in zod_configs.TARGET_DISTANCES]
+ used_poses = transformed_poses[pose_idx]
+
+ except Exception as _:
+ print(f"detected invalid frame: {frame_id}")
+ return np.array([])
+
+ points = used_poses[:, :3, -1]
+ return points.flatten()
+
+
+def save_ground_truth(
+ zod_frames: ZodFrames,
+ training_frames: Set[str],
+ validation_frames: Set[str],
+ zod_configs: ZodConfigs,
+) -> None:
+ """Write ground truth as json.
+
+ Args:
+ zod_frames (ZodDataset): _description_
+ training_frames (Set[str]): _description_
+ validation_frames (Set[str]): _description_
+ zod_configs (ZodConfigs): zod configs dataclass
+
+ """
+ all_frames = validation_frames.copy()
+ all_frames.update(training_frames)
+
+ corrupted_frames = []
+ ground_truths = {}
+ for frame_id in tqdm(all_frames):
+ ground_truth = get_ground_truth(zod_frames, frame_id, zod_configs)
+
+ if ground_truth.shape[0] != zod_configs.NUM_OUTPUT:
+ corrupted_frames.append(frame_id)
+ continue
+
+ ground_truths[frame_id] = ground_truth.tolist()
+
+ # Serializing json
+ json_object = json.dumps(ground_truths, indent=4)
+
+ # Writing to sample.json
+ with open(zod_configs.STORED_GROUND_TRUTH_PATH, "w") as outfile:
+ outfile.write(json_object)
+
+ print(f"{corrupted_frames}")
+
+
+def load_ground_truth(path: str) -> dict:
+ """Load ground truth from file."""
+ with open(path) as json_file:
+ gt = json.load(json_file)
+
+ for f in gt:
+ gt[f] = np.array(gt[f])
+
+ return gt
diff --git a/examples/multiple-trajectory-prediction/dataset/zod_configs.py b/examples/multiple-trajectory-prediction/dataset/zod_configs.py
new file mode 100644
index 0000000..afe3740
--- /dev/null
+++ b/examples/multiple-trajectory-prediction/dataset/zod_configs.py
@@ -0,0 +1,43 @@
+"""Global static parameters dataclass."""
+
+from dataclasses import dataclass
+
+
+@dataclass
+class ZodConfigs:
+ """ZOD configs class."""
+
+ NUM_OUTPUT: int = 51
+ IMG_SIZE: int = 256
+ TARGET_DISTANCES: tuple = (
+ 5,
+ 10,
+ 15,
+ 20,
+ 25,
+ 30,
+ 35,
+ 40,
+ 50,
+ 60,
+ 70,
+ 80,
+ 95,
+ 110,
+ 125,
+ 145,
+ 165,
+ )
+
+ BATCH_SIZE: int = 32
+ TEST_SIZE: int = 0.1 # fraction of test data to use
+ VAL_SIZE: int = 0.1 # fraction of train data to use for validation
+
+ # File paths
+ STORED_GROUND_TRUTH_PATH: str = "/mnt/ZOD/ground_truth.json"
+ DATASET_ROOT: str = "/mnt/ZOD"
+
+ USE_PRE_RESIZED_IMGS: bool = True
+
+ NORMALIZE_MEAN: tuple = (0.485, 0.456, 0.406)
+ NORMALIZE_STD: tuple = (0.229, 0.224, 0.225)
diff --git a/examples/multiple-trajectory-prediction/dataset/zod_data_manager.py b/examples/multiple-trajectory-prediction/dataset/zod_data_manager.py
new file mode 100644
index 0000000..ec1015e
--- /dev/null
+++ b/examples/multiple-trajectory-prediction/dataset/zod_data_manager.py
@@ -0,0 +1,110 @@
+"""Zod dataset manager."""
+
+from typing import Tuple
+
+from dataset.groundtruth_utils import load_ground_truth
+from dataset.zod_configs import ZodConfigs
+from dataset.zod_dataset import ZodDataset
+from dataset.zod_image_generator import ZodImageGenerator
+from torch import Generator
+from torch.utils.data import DataLoader, RandomSampler, random_split
+from torchvision import transforms
+
+from zod import ZodFrames, constants
+
+
+class ZodDatasetManager:
+ """Zod dataset manager class."""
+
+ def __init__(self) -> None:
+ self.zod_configs = ZodConfigs()
+ self.zod_frames = ZodFrames(dataset_root=self.zod_configs.DATASET_ROOT, version="full")
+ self.transform = self._get_transform()
+ self.ground_truth = load_ground_truth(self.zod_configs.STORED_GROUND_TRUTH_PATH)
+ self.test_frames = None
+
+ def get_test_dataloader(self) -> DataLoader:
+ """Load the ZOD test dataset from the VAL partition. Server side test set."""
+ validation_frames_all = self.zod_frames.get_split(constants.VAL)
+
+ validation_frames_all = [idx for idx in validation_frames_all if self._is_valid_frame(idx, self.ground_truth)]
+
+ validation_frames = validation_frames_all[: int(len(validation_frames_all) * self.zod_configs.TEST_SIZE)]
+
+ self.test_frames = validation_frames
+
+ testset = ZodDataset(
+ zod_frames=self.zod_frames,
+ frames_id_set=validation_frames,
+ stored_ground_truth=self.ground_truth,
+ transform=self.transform,
+ zod_configs=self.zod_configs,
+ )
+ print(f"Test dataset loaded. Length: {len(testset)}")
+ return DataLoader(testset, batch_size=self.zod_configs.BATCH_SIZE)
+
+ def get_train_val_dataloader(self, seed: int = 42) -> Tuple[DataLoader, DataLoader]:
+ """Get train and validation dataloader for client side."""
+ train_frames = self.zod_frames.get_split(constants.TRAIN)
+
+ train_frames = [idx for idx in train_frames if self._is_valid_frame(idx, self.ground_truth)]
+ trainset = ZodDataset(
+ zod_frames=self.zod_frames,
+ frames_id_set=train_frames,
+ stored_ground_truth=self.ground_truth,
+ transform=self.transform,
+ zod_configs=self.zod_configs,
+ )
+
+ # Split into train/val and create DataLoader
+ len_val = int(len(trainset) * self.zod_configs.VAL_SIZE)
+ len_train = int(len(trainset) - len_val)
+
+ lengths = [len_train, len_val]
+ ds_train, ds_val = random_split(trainset, lengths, Generator().manual_seed(seed))
+ train_sampler = RandomSampler(ds_train)
+ trainloader = DataLoader(
+ ds_train,
+ batch_size=self.zod_configs.BATCH_SIZE,
+ shuffle=False,
+ num_workers=0,
+ sampler=train_sampler,
+ )
+ valloader = DataLoader(ds_val, batch_size=self.zod_configs.BATCH_SIZE, num_workers=0)
+
+ return trainloader, valloader
+
+ def get_image_generator(self) -> ZodImageGenerator:
+ """Get image generator for ZOD hollistic path."""
+ if self.test_frames is None:
+ self.get_test_dataloader()
+ return ZodImageGenerator(self.test_frames, self.zod_frames)
+
+ def _get_transform(self) -> transforms.Compose:
+ """Get transform to use."""
+ return (
+ transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize(self.zod_configs.NORMALIZE_MEAN, self.zod_configs.NORMALIZE_STD),
+ ]
+ )
+ if self.zod_configs.USE_PRE_RESIZED_IMGS
+ else transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Resize(
+ (self.zod_configs.IMG_SIZE, self.zod_configs.IMG_SIZE),
+ antialias=True,
+ ),
+ transforms.Normalize(self.zod_configs.NORMALIZE_MEAN, self.zod_configs.NORMALIZE_STD),
+ ]
+ )
+ )
+
+ def _is_valid_frame(self, frame_id: str, ground_truth: dict) -> bool:
+ """Check if frame is valid."""
+ if frame_id == "005350":
+ return False
+
+ return frame_id in ground_truth
diff --git a/examples/multiple-trajectory-prediction/dataset/zod_dataset.py b/examples/multiple-trajectory-prediction/dataset/zod_dataset.py
new file mode 100644
index 0000000..31eae17
--- /dev/null
+++ b/examples/multiple-trajectory-prediction/dataset/zod_dataset.py
@@ -0,0 +1,66 @@
+"""Zod dataset, dataset class."""
+
+import cv2
+from dataset.groundtruth_utils import get_ground_truth
+from dataset.zod_configs import ZodConfigs
+from torch.utils.data import Dataset
+from torchvision import transforms
+
+from zod import ZodFrames
+from zod.constants import Anonymization, Camera
+
+
+class ZodDataset(Dataset):
+ """Zod dataset class."""
+
+ def __init__(
+ self,
+ zod_frames: list,
+ frames_id_set: list,
+ stored_ground_truth: dict = None,
+ transform: transforms = None,
+ zod_configs: ZodConfigs = None,
+ ) -> None:
+ self.zod_frames: ZodFrames = zod_frames
+ self.frames_id_set = frames_id_set
+ self.transform = transform if transform is not None else transforms.ToTensor()
+ self.stored_ground_truth = stored_ground_truth
+ self.zod_configs = zod_configs
+
+ def __len__(self) -> int:
+ """Get number of frames."""
+ return len(self.frames_id_set)
+
+ def __getitem__(self, idx: int) -> tuple:
+ """Iterator."""
+ frame_idx = self.frames_id_set[idx]
+ frame = self.zod_frames[frame_idx]
+
+ if self.zod_configs.USE_PRE_RESIZED_IMGS:
+ original_path = frame.info.get_key_camera_frame(
+ camera=Camera.FRONT, anonymization=Anonymization.DNAT
+ ).filepath
+
+ resized_image_path = original_path.rsplit(".", 1)[0] + "_resized.jpg"
+ try:
+ image = cv2.imread(resized_image_path)
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ except (FileNotFoundError, TypeError):
+ image = None
+ print(f"Image {resized_image_path} not found.")
+
+ else:
+ image = frame.get_image(Anonymization.DNAT)
+ label = None
+
+ label = (
+ self.stored_ground_truth[frame_idx]
+ if self.stored_ground_truth
+ else get_ground_truth(self.zod_frames, frame_idx, self.zod_configs)
+ )
+
+ label = label.astype("float32")
+
+ if self.transform:
+ image = self.transform(image)
+ return image, label
diff --git a/examples/multiple-trajectory-prediction/dataset/zod_image_generator.py b/examples/multiple-trajectory-prediction/dataset/zod_image_generator.py
new file mode 100644
index 0000000..994bb86
--- /dev/null
+++ b/examples/multiple-trajectory-prediction/dataset/zod_image_generator.py
@@ -0,0 +1,261 @@
+"""Visualization tools for generating predicted path on images."""
+
+import random
+from pathlib import Path
+from typing import List
+
+import cv2
+import numpy as np
+import torch
+from dataset.groundtruth_utils import get_ground_truth
+from dataset.zod_configs import ZodConfigs
+from model.model_config import ModelConfig
+from torchvision import transforms
+from torchvision.utils import make_grid, save_image
+
+from zod import ZodFrames
+from zod.constants import Anonymization, Camera
+from zod.data_classes import Calibration
+from zod.utils.geometry import (
+ get_points_in_camera_fov,
+ project_3d_to_2d_kannala,
+ transform_points,
+)
+
+
+class ZodImageGenerator:
+ """Image generator."""
+
+ def __init__(
+ self,
+ validation_frames: list,
+ zod_frames: ZodFrames,
+ n_images: int = 30,
+ ) -> None:
+ """Generate a set of random images with true and predicted path."""
+ self.zod_configs = ZodConfigs()
+ self.model_config = ModelConfig()
+ self.zod_frames = zod_frames
+ self.validation_frames = validation_frames
+ self.n_available_samples = len(self.validation_frames)
+ self.n_images = n_images
+ self._select_image_subset()
+
+ def _select_image_subset(self) -> None:
+ """Randomly select images from ZOD dataset.
+
+ If a frame is inavlid, generate new random numbers until a valid frame is found.
+
+ """
+ self.images = []
+ self.frame_ids = []
+
+ transform = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Resize(
+ (self.zod_configs.IMG_SIZE, self.zod_configs.IMG_SIZE),
+ antialias=True,
+ ),
+ ]
+ )
+ device = torch.device("cuda" if torch.cuda.is_available() and self.model_config.USE_GPU else "cpu")
+
+ while len(self.images) < self.n_images:
+ # pick a random frame from dataset
+ random_index = random.randint(0, self.n_available_samples - 1) # noqa: S311
+ frame_id = self.validation_frames[random_index]
+
+ try:
+ zod_frame = self.zod_frames[frame_id]
+
+ image = zod_frame.get_image(Anonymization.DNAT)
+ image = transform(image)
+ image = image.to(device)
+ image = image.unsqueeze(0)
+
+ # if prediction is valid, ad frame to image subset
+ self.images.append(image)
+ self.frame_ids.append(frame_id)
+
+ # if frame invalid
+ except (TypeError, ValueError, FileNotFoundError):
+ pass
+
+ def visualize_prediction_on_image(self, model: torch.nn.Module, save_tag: str) -> torch.Tensor:
+ """Visualize true and predicted holistic path on each image.
+
+ Also combine the images into a grid.
+
+ Args:
+ model (torch.nn.Module): Model
+ save_tag (str): String to tag the saved images
+
+ Returns:
+ torch.Tensor: grid image with visualized predictions
+
+ """
+ device = torch.device("cuda" if torch.cuda.is_available() and self.model_config.USE_GPU else "cpu")
+ model.eval()
+ model.to(device)
+
+ transform = transforms.Compose([transforms.ToTensor()])
+ image_tensors = []
+
+ for frame_id, image in zip(self.frame_ids, self.images):
+ try:
+ with torch.no_grad():
+ predicted_path = model(image)[0, :]
+ zod_frame = self.zod_frames[frame_id]
+ raw_image = zod_frame.get_image(Anonymization.DNAT)
+ trajectories = list(predicted_path[: -self.model_config.NR_OF_MODES].reshape((-1, 51)).cpu().numpy())
+ mode_probabilities = list(predicted_path[-self.model_config.NR_OF_MODES :])
+
+ final_image = self._visualize_paths_on_image(
+ raw_image,
+ self.zod_frames,
+ frame_id,
+ trajectories,
+ mode_probabilities,
+ )
+
+ image_tensor = transform(final_image)
+ image_tensors.append(image_tensor)
+
+ # if frame invalid
+ except (TypeError, ValueError) as e:
+ print(f"skipping frame {frame_id} with error: {e}")
+ pass
+
+ stacked_tensors = torch.stack(image_tensors)
+ grid = make_grid(stacked_tensors, nrow=3, padding=5)
+
+ transform = transforms.Compose(
+ [
+ transforms.Resize(
+ (
+ round(self.zod_configs.IMG_SIZE * (self.n_images // 3)),
+ round(self.zod_configs.IMG_SIZE * 3),
+ ),
+ antialias=True,
+ ),
+ ]
+ )
+
+ grid = transform(grid)
+
+ grid.cpu().detach()
+ model.train()
+
+ # save to results/
+ results_dir = Path("results")
+ results_dir.mkdir(parents=True, exist_ok=True)
+ out_path = results_dir / f"predictions_grid_{save_tag}.png"
+ save_image(grid, out_path)
+
+ def _draw_line(self, image: np.array, line: np.array, color: tuple) -> np.array:
+ """Draw points on image.
+
+ Args:
+ image (np.array): image
+ line (np.array): points
+ color (tuple): color RGB tuple
+
+ Returns:
+ image(np.array): image with line
+
+ """
+ return cv2.polylines(
+ image.copy(),
+ [np.round(line).astype(np.int32)],
+ isClosed=False,
+ color=color,
+ thickness=60,
+ )
+
+ def _transform_absolute_to_relative_path(
+ self,
+ image: np.array,
+ points: np.array,
+ camera: str,
+ calibrations: Calibration,
+ ) -> np.array:
+ t_inv = np.linalg.pinv(calibrations.get_extrinsics(camera).transform)
+ points = points.reshape(((self.zod_configs.NUM_OUTPUT // 3), 3))
+ camerapoints = transform_points(points[:, :3], t_inv)
+
+ # filter points that are not in the camera field of view
+ points_in_fov = get_points_in_camera_fov(calibrations.cameras[camera].field_of_view, camerapoints)
+ points_in_fov = points_in_fov[0]
+
+ # project points to image plane
+ xy_array = project_3d_to_2d_kannala(
+ points_in_fov,
+ calibrations.cameras[camera].intrinsics[..., :3],
+ calibrations.cameras[camera].distortion,
+ )
+
+ points = []
+ for i in range(xy_array.shape[0]):
+ x, y = int(xy_array[i, 0]), int(xy_array[i, 1])
+ cv2.circle(image, (x, y), 2, (255, 0, 0), -1)
+ points.append([x, y])
+
+ return points, image
+
+ def _visualize_paths_on_image(
+ self,
+ image: np.ndarray,
+ zod_frames: ZodFrames,
+ frame_id: int,
+ predicted_paths: List = None,
+ probabilities: List = None,
+ ) -> np.array:
+ """Visualize oxts track on image plane."""
+ camera = Camera.FRONT
+ zod_frame = zod_frames[frame_id]
+ calibrations = zod_frame.calibration
+ true_path = get_ground_truth(zod_frames, frame_id, self.zod_configs)
+
+ # add true path to image
+ true_path_pov, image = self._transform_absolute_to_relative_path(image, true_path, camera, calibrations)
+
+ ground_truth_color = (20, 150, 61) # (19, 80, 41)
+ image = self._draw_line(image, true_path_pov, ground_truth_color)
+
+ # add predicted path to image
+ if predicted_paths is None:
+ return image
+ mid_points = []
+ predictions_color = (161, 65, 137)
+ for predicted_path in predicted_paths:
+ # transform point to camera coordinate system
+ predicted_path_pov, image = self._transform_absolute_to_relative_path(
+ image, predicted_path, camera, calibrations
+ )
+
+ image = self._draw_line(image, predicted_path_pov, predictions_color)
+ mid_points.append(self._get_mid_points(predicted_path_pov))
+ if probabilities is not None:
+ for i, probability in enumerate(probabilities):
+ image = self._add_probability_text(str(int(round(probability.item() * 100))), image, mid_points[i])
+ return image
+
+ def _get_mid_points(self, points: List) -> tuple:
+ return (
+ points[len(points) // 2][0],
+ points[len(points) // 2][1],
+ )
+
+ def _add_probability_text(self, prob: str, image: np.ndarray, point: tuple) -> np.ndarray:
+ """Adds the prob string on the image at the point."""
+ cv2.putText(
+ image,
+ prob + "%",
+ point,
+ cv2.FONT_HERSHEY_SIMPLEX,
+ 5,
+ (255, 255, 255),
+ 10,
+ )
+ return image
diff --git a/examples/multiple-trajectory-prediction/generate_ground_truth.py b/examples/multiple-trajectory-prediction/generate_ground_truth.py
new file mode 100644
index 0000000..3d55ae7
--- /dev/null
+++ b/examples/multiple-trajectory-prediction/generate_ground_truth.py
@@ -0,0 +1,17 @@
+"""Calls the save ground truth function with specified configs."""
+
+from dataset.groundtruth_utils import save_ground_truth
+from dataset.zod_configs import ZodConfigs
+
+from zod import constants
+from zod.zod_frames import ZodFrames
+
+
+def generate_ground_truth() -> None:
+ """Create ground truth."""
+ zod_configs = ZodConfigs()
+ zod_frames = ZodFrames(dataset_root=zod_configs.DATASET_ROOT, version="full")
+ training_frames_all = zod_frames.get_split(constants.TRAIN)
+ validation_frames_all = zod_frames.get_split(constants.VAL)
+
+ save_ground_truth(zod_frames, training_frames_all, validation_frames_all, zod_configs=zod_configs)
diff --git a/examples/multiple-trajectory-prediction/main.py b/examples/multiple-trajectory-prediction/main.py
new file mode 100644
index 0000000..132ff90
--- /dev/null
+++ b/examples/multiple-trajectory-prediction/main.py
@@ -0,0 +1,24 @@
+"""Main script."""
+
+from dataset.zod_data_manager import ZodDatasetManager
+from model.multi_trajectory_model_manager import MultiTrajectoryModelManager
+
+
+def main() -> None:
+ """Run experiment."""
+ # get data
+ data_manager = ZodDatasetManager()
+ test_loader = data_manager.get_test_dataloader()
+ train_loader, val_loader = data_manager.get_train_val_dataloader()
+ image_generator = data_manager.get_image_generator()
+
+ # create model and train
+ model_manager = MultiTrajectoryModelManager()
+ model_manager.train(train_loader, val_loader, image_generator)
+
+ # test
+ model_manager.test(test_loader)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/multiple-trajectory-prediction/model/model.py b/examples/multiple-trajectory-prediction/model/model.py
new file mode 100644
index 0000000..141df8a
--- /dev/null
+++ b/examples/multiple-trajectory-prediction/model/model.py
@@ -0,0 +1,108 @@
+"""Models."""
+
+from typing import Tuple
+
+import pytorch_lightning as pl
+import torch
+from model.model_config import ModelConfig
+from torch import nn
+from torch.nn import functional as f
+from torchvision import models
+
+
+class MultiTrajectoryLoss:
+ """Computes MultiTrajectoryLoss."""
+
+ def __init__(self, num_modes: int) -> None:
+ self.num_modes = num_modes
+
+ def __call__(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
+ """Computes the MultiTrajectoryLoss loss on a batch."""
+ batch_losses = torch.Tensor().requires_grad_(True).to(predictions.device)
+ trajectories, modes = self._get_trajectory_and_modes(predictions)
+ for sample_idx in range(predictions.shape[0]):
+ best_mode = self._compute_best_mode(target=targets[sample_idx], trajectories=trajectories[sample_idx])
+ best_mode_trajectory = trajectories[sample_idx, best_mode].reshape(-1)
+ regression_loss = f.smooth_l1_loss(best_mode_trajectory, targets[sample_idx])
+ mode_probabilities = modes[sample_idx].unsqueeze(0)
+ best_mode_target = torch.tensor([best_mode], device=predictions.device)
+ classification_loss = f.cross_entropy(mode_probabilities, best_mode_target)
+ loss = classification_loss + regression_loss
+ batch_losses = torch.cat((batch_losses, loss.unsqueeze(0)), 0)
+ return torch.mean(batch_losses)
+
+ def _compute_best_mode(self, target: torch.tensor, trajectories: torch.tensor) -> torch.tensor:
+ """Finds the index of the best mode based on l1 norm from the ground truth."""
+ l1_norms = torch.empty(trajectories.shape[0])
+
+ for i in range(trajectories.shape[0]):
+ l1_norm = torch.sum(torch.abs(trajectories[0, i] - target))
+ l1_norms[i] = l1_norm
+
+ return torch.argmin(l1_norms)
+
+ def _get_trajectory_and_modes(self, model_prediction: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Splits the predictions from the model into mode probabilities and trajectory."""
+ mode_probabilities = model_prediction[:, -self.num_modes :].clone()
+
+ desired_shape = (
+ model_prediction.shape[0],
+ self.num_modes,
+ -1,
+ )
+ trajectories = model_prediction[:, : -self.num_modes].clone().reshape(desired_shape)
+
+ return trajectories, mode_probabilities
+
+
+class Net(pl.LightningModule):
+ """Neural CNN model class."""
+
+ def __init__(self, model_configs: ModelConfig) -> None:
+ super(Net, self).__init__()
+ self.model_configs = model_configs
+ self.model = models.mobilenet_v3_large(weights=models.MobileNet_V3_Large_Weights.IMAGENET1K_V1)
+
+ self.loss_fn = MultiTrajectoryLoss(self.model_configs.NR_OF_MODES)
+
+ device = torch.device("cuda" if torch.cuda.is_available() and self.model_configs.USE_GPU else "cpu")
+
+ self.target_dists = torch.Tensor(self.model_configs.TARGET_DISTANCES).to(device)
+ self.num_target_distances = len(self.model_configs.TARGET_DISTANCES)
+
+ self.num_modes = self.model_configs.NR_OF_MODES
+
+ self._change_head_net(self.num_target_distances * 3, self.num_modes)
+
+ def forward(self, image: torch.Tensor) -> torch.Tensor:
+ """Forward propagation."""
+ model_output = self.model(image)
+
+ trajectories = model_output[:, : -self.num_modes].reshape((-1, self.num_target_distances, 3))
+ mode_probabilities = model_output[:, -self.num_modes :]
+ scaling_factors = self.target_dists.view(-1, 1)
+ trajectories *= scaling_factors
+
+ if not self.training:
+ mode_probabilities = f.softmax(mode_probabilities, dim=-1)
+
+ trajectories_reshaped = trajectories.reshape((-1, 3 * self.num_modes * self.num_target_distances))
+
+ return torch.cat([trajectories_reshaped, mode_probabilities], dim=-1)
+
+ def _change_head_net(self, num_points: int, num_modes: int) -> None:
+ """Change the last model classifier step."""
+ num_ftrs = self.model.classifier[-1].in_features
+ head_net = nn.Sequential(
+ nn.Linear(num_ftrs, 1024, bias=True),
+ nn.ReLU(inplace=True),
+ nn.Linear(1024, 512, bias=True),
+ nn.ReLU(inplace=True),
+ nn.Linear(
+ 512,
+ num_points * num_modes + num_modes,
+ bias=True,
+ ),
+ )
+
+ self.model.classifier[-1] = head_net
diff --git a/examples/multiple-trajectory-prediction/model/model_config.py b/examples/multiple-trajectory-prediction/model/model_config.py
new file mode 100644
index 0000000..5e4fef6
--- /dev/null
+++ b/examples/multiple-trajectory-prediction/model/model_config.py
@@ -0,0 +1,32 @@
+"""Model specific static parameters dataclass."""
+
+from dataclasses import dataclass
+
+
+@dataclass
+class ModelConfig:
+ """Global configs class."""
+
+ LEARNING_RATE: float = 0.001
+ TARGET_DISTANCES: list = (
+ 5,
+ 10,
+ 15,
+ 20,
+ 25,
+ 30,
+ 35,
+ 40,
+ 50,
+ 60,
+ 70,
+ 80,
+ 95,
+ 110,
+ 125,
+ 145,
+ 165,
+ )
+ USE_GPU: bool = True
+ NR_OF_MODES: int = 2
+ EPOCHS = 10
diff --git a/examples/multiple-trajectory-prediction/model/multi_trajectory_model_manager.py b/examples/multiple-trajectory-prediction/model/multi_trajectory_model_manager.py
new file mode 100644
index 0000000..331db81
--- /dev/null
+++ b/examples/multiple-trajectory-prediction/model/multi_trajectory_model_manager.py
@@ -0,0 +1,80 @@
+"""Model manager class for Holistic Path model trained with ZOD."""
+
+import time
+from typing import Tuple
+
+import numpy as np
+import torch
+from dataset.zod_image_generator import ZodImageGenerator
+from model.model import Net
+from model.model_config import ModelConfig
+from torch.utils.data import DataLoader
+
+
+class MultiTrajectoryModelManager:
+ """Zod dataset manager class."""
+
+ def __init__(self) -> None:
+ self.model_configs = ModelConfig()
+ device = torch.device("cuda" if torch.cuda.is_available() and self.model_configs.USE_GPU else "cpu")
+ self.net = Net(self.model_configs).to(device)
+
+ def train(
+ self, trainloader: DataLoader, valloader: DataLoader, image_generator: ZodImageGenerator
+ ) -> Tuple[list, list]:
+ """Trains data to update the model parameters.
+
+ Args:
+ trainloader (torch.utils.data.DataLoader): train loader
+ valloader (torch.utils.data.DataLoader): validaiton loader
+ image_generator (ZodImageGenerator): Image generator to visualize results
+
+ Returns:
+ tuple[list, list]: epoch_train_losses, epoch_val_losses
+
+ """
+ device = torch.device("cuda" if torch.cuda.is_available() and self.model_configs.USE_GPU else "cpu")
+ epochs = self.model_configs.EPOCHS
+ self.net.train()
+ opt = torch.optim.Adam(self.net.parameters(), lr=self.model_configs.LEARNING_RATE)
+ epoch_train_losses = []
+ epoch_val_losses = []
+ for epoch in range(1, epochs + 1):
+ tstart = time.time()
+ batch_train_losses = []
+ for data_, target_ in trainloader:
+ data, target = data_.to(device), target_.to(device)
+ opt.zero_grad()
+ output = self.net(data)
+ loss = self.net.loss_fn(output, target)
+ loss.backward()
+ opt.step()
+ batch_train_losses.append(loss.item())
+ epoch_train_losses.append(sum(batch_train_losses) / len(batch_train_losses))
+ val_loss, _ = self.test(valloader)
+ epoch_val_losses.append(val_loss)
+ print(
+ f"Epoch completed in {time.time() - tstart:.2f} seconds with "
+ + f"{len(trainloader)} batches of batch size {trainloader.batch_size}"
+ )
+ print(f"Train loss for epoch {epoch}: {epoch_train_losses[-1]:.2f}")
+ print(f"Validation loss for epoch {epoch}: {epoch_val_losses[-1]:.2f}")
+ image_generator.visualize_prediction_on_image(self.net, f"epoch_{str(epoch)}")
+ return epoch_train_losses, epoch_val_losses
+
+ def test(self, testloader: DataLoader) -> Tuple[float, float]:
+ """Test the model performance.
+
+ Returns a Tuple with [loss, accuracy], let accuracy be None on regression tasks.
+ """
+ device = torch.device("cuda" if torch.cuda.is_available() and self.model_configs.USE_GPU else "cpu")
+ criterion = self.net.loss_fn
+ self.net.eval()
+ loss = []
+ with torch.no_grad():
+ for images_, labels_ in testloader:
+ images, labels = images_.to(device), labels_.to(device)
+ outputs = self.net(images)
+ loss.append(criterion(outputs, labels).item())
+ self.net.train()
+ return np.mean(loss), None
diff --git a/examples/multiple-trajectory-prediction/requirements.txt b/examples/multiple-trajectory-prediction/requirements.txt
new file mode 100644
index 0000000..66d66af
--- /dev/null
+++ b/examples/multiple-trajectory-prediction/requirements.txt
@@ -0,0 +1,6 @@
+torch
+torchvision
+pypdf
+pytorch_lightning
+zod[cli]
+opencv-python
diff --git a/examples/multiple-trajectory-prediction/resize_and_save.py b/examples/multiple-trajectory-prediction/resize_and_save.py
new file mode 100644
index 0000000..4a2fb1a
--- /dev/null
+++ b/examples/multiple-trajectory-prediction/resize_and_save.py
@@ -0,0 +1,77 @@
+"""Script to resize images and save them (robust to corrupt/truncated files)."""
+
+import glob
+import os
+from pathlib import Path
+
+from dataset.zod_configs import ZodConfigs
+from PIL import Image, UnidentifiedImageError
+from tqdm import tqdm
+
+
+def looks_like_jpeg(path: str) -> bool:
+ """Quick magic-byte check for JPEG."""
+ try:
+ with open(path, "rb") as f:
+ return f.read(2) == b"\xff\xd8"
+ except Exception:
+ return False
+
+
+def resize_images(source_directory: str, size: int) -> None:
+ """Resize script to save a new ZOD single frames dataset to specified size.
+
+ Data is saved as original_name_resized.jpg.
+ """
+ pattern = os.path.join(source_directory, "single_frames", "*", "camera_front_dnat", "*.jpg")
+ files = glob.glob(pattern)
+
+ bad_list_path = Path(source_directory) / "bad_images.txt"
+ skipped = 0
+ written = 0
+ checked = 0
+
+ with bad_list_path.open("w") as badlog:
+ for file in tqdm(files, desc="Resizing images"):
+ checked += 1
+
+ # Skip already resized files
+ if file.endswith("_resized.jpg"):
+ continue
+
+ if not looks_like_jpeg(file):
+ badlog.write(f"not_jpeg_magic:{file}\n")
+ skipped += 1
+ continue
+
+ # Try to open/verify and then reopen for actual load (verify invalidates the fp)
+ try:
+ with Image.open(file) as probe:
+ probe.verify() # quick structural check
+
+ with Image.open(file) as img:
+ img_rgb = img.convert("RGB") # ensure JPEG-compatible
+ img_resized = img_rgb.resize((size, size), Image.Resampling.LANCZOS)
+
+ destination_file = file.rsplit(".", 1)[0] + "_resized.jpg"
+ img_resized.save(destination_file, "JPEG", quality=95, optimize=True)
+ written += 1
+
+ except (UnidentifiedImageError, OSError) as e:
+ # OSError can happen on truncated/invalid images
+ badlog.write(f"unreadable:{file} reason:{type(e).__name__}: {e}\n")
+ skipped += 1
+ continue
+
+ print(
+ f"Processed {checked} files. Wrote {written} resized images. "
+ f"Skipped {skipped} bad/unreadable files.\n"
+ f"See bad image list at: {bad_list_path}"
+ )
+
+
+if __name__ == "__main__":
+ configs = ZodConfigs()
+ directory = configs.DATASET_ROOT
+ size = configs.IMG_SIZE
+ resize_images(directory, size)
diff --git a/examples/multiple-trajectory-prediction/results/example_prediction.png b/examples/multiple-trajectory-prediction/results/example_prediction.png
new file mode 100644
index 0000000..7e66a2f
Binary files /dev/null and b/examples/multiple-trajectory-prediction/results/example_prediction.png differ