Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions examples/multiple-trajectory-prediction/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# ZOD Multi-Trajectory Pipeline

This repository provides a pipeline for training, evaluating, and visualizing multi-trajectory prediction models on the ZOD dataset.

The method predicts multiple possible future paths that a driver may take, based on a single frontal camera image. Predictions are expressed in ground-plane coordinates (GPS) relative to the vehicle’s starting position.

Ground truth trajectories are generated from the GPS data of where the vehicle actually traveled after each image was captured.

For visualization, both predicted and ground-truth trajectories are projected into the **camera image plane** using calibration parameters. This overlays the paths directly on top of the corresponding camera frame, making it possible to compare predicted trajectories with the driver’s true path.


## Example Results

After training, the pipeline stores prediction visualizations in the `results/` directory.
Here is an example of a predicted trajectory overlayed on the ZOD dataset, where purple lines are predictions and the green line is the ground truth:
<p align="center">
<img src="results/example_prediction.png" alt="Example prediction" width="400"/>
</p>

---

## Setup

### 1. Create and activate a virtual environment
```bash
python3 -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt
```

### 2. Configure paths
Update dataset and model configuration in:

- `model/model_config.py`
- `dataset/zod_configs.py`

In particular, set for example:
```python
STORED_GROUND_TRUTH_PATH = "/mnt/ZOD/ground_truth.json"
DATASET_ROOT = "/mnt/ZOD"
```

Other parameters may be left as default or modified as needed.

---

## Data Preparation

### 1. Resize images
The pipeline has been tested with images resized to **256×256**.

Run the resize script:
```bash
python resize_and_save.py
```

Make sure `IMG_SIZE` and `DATASET_ROOT` in `dataset/zod_configs.py` are set correctly.

Alternatively, you can work with full-sized images by setting
`USE_PRE_RESIZED_IMGS = False` in `dataset/zod_configs.py` and skipping this step (not tested).

### 2. Generate ground truth
Ground truth is generated by looking ahead in future frames. Run the ground truth generation script:
```bash
python generate_ground_truth.py
```

This script calls utility functions defined in `dataset/groundtruth_utils.py`.

Ensure that the path specified in `STORED_GROUND_TRUTH_PATH` in `dataset/zod_configs.py` is valid.

---

## Training & Evaluation

Run the main training script:
`python main.py`

This will:
- Train the model
- Evaluate performance
- Store example predictions in the `results/` directory

---
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""Groundtruth utilities."""

import json
from typing import Set

import numpy as np
from dataset.zod_configs import ZodConfigs
from tqdm import tqdm

from zod import ZodFrames


def get_ground_truth(zod_frames: ZodFrames, frame_id: int, zod_configs: ZodConfigs) -> np.array:
"""Get true holistic path from future GPS locations.

Args:
zod_frames (ZodDataset): ZodDataset
frame_id (int): frame id
zod_configs (ZodConfigs): zod configs dataclass

Returns:
np.array: true path

"""
zod_frame = zod_frames[frame_id]
oxts = zod_frame.oxts
key_timestamp = zod_frame.info.keyframe_time.timestamp()

# get posses associated with frame timestamp
try:
current_pose = oxts.get_poses(key_timestamp)
all_poses = oxts.poses[oxts.timestamps >= key_timestamp]
transformed_poses = np.linalg.pinv(current_pose) @ all_poses
translations = transformed_poses[:, :3, 3]
distances = np.linalg.norm(np.diff(translations, axis=0), axis=1)
accumulated_distances = np.cumsum(distances).astype(int).tolist()

# get the poses that each have a point having a distance from TARGET_DISTANCES
pose_idx = [accumulated_distances.index(i) for i in zod_configs.TARGET_DISTANCES]
used_poses = transformed_poses[pose_idx]

except Exception as _:
print(f"detected invalid frame: {frame_id}")
return np.array([])

points = used_poses[:, :3, -1]
return points.flatten()


def save_ground_truth(
zod_frames: ZodFrames,
training_frames: Set[str],
validation_frames: Set[str],
zod_configs: ZodConfigs,
) -> None:
"""Write ground truth as json.

Args:
zod_frames (ZodDataset): _description_
training_frames (Set[str]): _description_
validation_frames (Set[str]): _description_
zod_configs (ZodConfigs): zod configs dataclass

"""
all_frames = validation_frames.copy()
all_frames.update(training_frames)

corrupted_frames = []
ground_truths = {}
for frame_id in tqdm(all_frames):
ground_truth = get_ground_truth(zod_frames, frame_id, zod_configs)

if ground_truth.shape[0] != zod_configs.NUM_OUTPUT:
corrupted_frames.append(frame_id)
continue

ground_truths[frame_id] = ground_truth.tolist()

# Serializing json
json_object = json.dumps(ground_truths, indent=4)

# Writing to sample.json
with open(zod_configs.STORED_GROUND_TRUTH_PATH, "w") as outfile:
outfile.write(json_object)

print(f"{corrupted_frames}")


def load_ground_truth(path: str) -> dict:
"""Load ground truth from file."""
with open(path) as json_file:
gt = json.load(json_file)

for f in gt:
gt[f] = np.array(gt[f])

return gt
43 changes: 43 additions & 0 deletions examples/multiple-trajectory-prediction/dataset/zod_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Global static parameters dataclass."""

from dataclasses import dataclass


@dataclass
class ZodConfigs:
"""ZOD configs class."""

NUM_OUTPUT: int = 51
IMG_SIZE: int = 256
TARGET_DISTANCES: tuple = (
5,
10,
15,
20,
25,
30,
35,
40,
50,
60,
70,
80,
95,
110,
125,
145,
165,
)

BATCH_SIZE: int = 32
TEST_SIZE: int = 0.1 # fraction of test data to use
VAL_SIZE: int = 0.1 # fraction of train data to use for validation

# File paths
STORED_GROUND_TRUTH_PATH: str = "/mnt/ZOD/ground_truth.json"
DATASET_ROOT: str = "/mnt/ZOD"

USE_PRE_RESIZED_IMGS: bool = True

NORMALIZE_MEAN: tuple = (0.485, 0.456, 0.406)
NORMALIZE_STD: tuple = (0.229, 0.224, 0.225)
110 changes: 110 additions & 0 deletions examples/multiple-trajectory-prediction/dataset/zod_data_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""Zod dataset manager."""

from typing import Tuple

from dataset.groundtruth_utils import load_ground_truth
from dataset.zod_configs import ZodConfigs
from dataset.zod_dataset import ZodDataset
from dataset.zod_image_generator import ZodImageGenerator
from torch import Generator
from torch.utils.data import DataLoader, RandomSampler, random_split
from torchvision import transforms

from zod import ZodFrames, constants


class ZodDatasetManager:
"""Zod dataset manager class."""

def __init__(self) -> None:
self.zod_configs = ZodConfigs()
self.zod_frames = ZodFrames(dataset_root=self.zod_configs.DATASET_ROOT, version="full")
self.transform = self._get_transform()
self.ground_truth = load_ground_truth(self.zod_configs.STORED_GROUND_TRUTH_PATH)
self.test_frames = None

def get_test_dataloader(self) -> DataLoader:
"""Load the ZOD test dataset from the VAL partition. Server side test set."""
validation_frames_all = self.zod_frames.get_split(constants.VAL)

validation_frames_all = [idx for idx in validation_frames_all if self._is_valid_frame(idx, self.ground_truth)]

validation_frames = validation_frames_all[: int(len(validation_frames_all) * self.zod_configs.TEST_SIZE)]

self.test_frames = validation_frames

testset = ZodDataset(
zod_frames=self.zod_frames,
frames_id_set=validation_frames,
stored_ground_truth=self.ground_truth,
transform=self.transform,
zod_configs=self.zod_configs,
)
print(f"Test dataset loaded. Length: {len(testset)}")
return DataLoader(testset, batch_size=self.zod_configs.BATCH_SIZE)

def get_train_val_dataloader(self, seed: int = 42) -> Tuple[DataLoader, DataLoader]:
"""Get train and validation dataloader for client side."""
train_frames = self.zod_frames.get_split(constants.TRAIN)

train_frames = [idx for idx in train_frames if self._is_valid_frame(idx, self.ground_truth)]
trainset = ZodDataset(
zod_frames=self.zod_frames,
frames_id_set=train_frames,
stored_ground_truth=self.ground_truth,
transform=self.transform,
zod_configs=self.zod_configs,
)

# Split into train/val and create DataLoader
len_val = int(len(trainset) * self.zod_configs.VAL_SIZE)
len_train = int(len(trainset) - len_val)

lengths = [len_train, len_val]
ds_train, ds_val = random_split(trainset, lengths, Generator().manual_seed(seed))
train_sampler = RandomSampler(ds_train)
trainloader = DataLoader(
ds_train,
batch_size=self.zod_configs.BATCH_SIZE,
shuffle=False,
num_workers=0,
sampler=train_sampler,
)
valloader = DataLoader(ds_val, batch_size=self.zod_configs.BATCH_SIZE, num_workers=0)

return trainloader, valloader

def get_image_generator(self) -> ZodImageGenerator:
"""Get image generator for ZOD hollistic path."""
if self.test_frames is None:
self.get_test_dataloader()
return ZodImageGenerator(self.test_frames, self.zod_frames)

def _get_transform(self) -> transforms.Compose:
"""Get transform to use."""
return (
transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(self.zod_configs.NORMALIZE_MEAN, self.zod_configs.NORMALIZE_STD),
]
)
if self.zod_configs.USE_PRE_RESIZED_IMGS
else transforms.Compose(
[
transforms.ToTensor(),
transforms.Resize(
(self.zod_configs.IMG_SIZE, self.zod_configs.IMG_SIZE),
antialias=True,
),
transforms.Normalize(self.zod_configs.NORMALIZE_MEAN, self.zod_configs.NORMALIZE_STD),
]
)
)

def _is_valid_frame(self, frame_id: str, ground_truth: dict) -> bool:
"""Check if frame is valid."""
if frame_id == "005350":
return False

return frame_id in ground_truth
Loading