diff --git a/examples/06_pytorch_oxe_dataloader.py b/examples/06_pytorch_oxe_dataloader.py index 046d0802..ca68c676 100644 --- a/examples/06_pytorch_oxe_dataloader.py +++ b/examples/06_pytorch_oxe_dataloader.py @@ -2,16 +2,22 @@ This example shows how to use the `octo.data` dataloader with PyTorch by wrapping it in a simple PyTorch dataloader. The config below also happens to be our exact pretraining config (except for the batch size and shuffle buffer size, which are reduced for demonstration purposes). + +Before running the script, please download the OXE data (e.g. via the script here: +https://github.com/kpertsch/rlds_dataset_mod/blob/main/prepare_open_x.sh) + +Then modify the `DATA_PATH` variable below to point to the directory where you downloaded the data to. """ import numpy as np -from octo.data.dataset import make_interleaved_dataset -from octo.data.oxe import make_oxe_dataset_kwargs_and_weights import tensorflow as tf import torch from torch.utils.data import DataLoader import tqdm -DATA_PATH = "gs://rail-orca-central2/resize_256_256" +from octo.data.dataset import make_interleaved_dataset +from octo.data.oxe import make_oxe_dataset_kwargs_and_weights + +DATA_PATH = "/path/to/your/oxe/directory" tf.config.set_visible_devices([], "GPU")