diff --git a/training/data.py b/training/data.py index 963b0f01..43b898d2 100644 --- a/training/data.py +++ b/training/data.py @@ -15,16 +15,24 @@ # This file is heavily inspired by https://github.com/mlfoundations/open_clip/blob/main/src/training/data.py +import io import itertools import json +import logging import math +import os import random import re +import sys +import time from typing import List, Optional, Union +import pyarrow as pa +import s3fs import webdataset as wds from braceexpand import braceexpand -from torch.utils.data import default_collate +from PIL import Image +from torch.utils.data import DataLoader, IterableDataset, default_collate from torchvision import transforms from transformers import PreTrainedTokenizer from webdataset.tariterators import ( @@ -34,6 +42,8 @@ valid_sample, ) +logger = logging.Logger(__name__) + person_token = ["a person", "someone", "somebody"] @@ -286,7 +296,13 @@ def __init__( vae_checkpoint: Optional[str] = None, text_encoder_checkpoint: Optional[str] = None, use_filtered_dataset: bool = False, + use_m4_laion_text_2_image_dataset: bool = False, ): + num_batches = math.ceil(num_train_examples / global_batch_size) + num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + transform = ImageNetTransform(resolution, center_crop, random_flip) def tokenize(text): @@ -296,61 +312,99 @@ def tokenize(text): ).input_ids return input_ids[0] - if not isinstance(train_shards_path_or_url, str): - train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url] - # flatten list using itertools - train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url)) - - if not isinstance(eval_shards_path_or_url, str): - eval_shards_path_or_url = [list(braceexpand(urls)) for urls in eval_shards_path_or_url] - # flatten list using itertools - eval_shards_path_or_url = list(itertools.chain.from_iterable(eval_shards_path_or_url)) + if not use_m4_laion_text_2_image_dataset: + if not isinstance(train_shards_path_or_url, str): + train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url] + # flatten list using itertools + train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url)) - if not is_pre_encoded: - processing_pipeline = [ - wds.decode("pil", handler=wds.ignore_and_continue), - wds.rename(image="jpg;png;jpeg;webp", input_ids="text;txt;caption", handler=wds.warn_and_continue), - wds.map(filter_keys(set(["image", "input_ids"]))), - wds.map_dict(image=transform.train_transform, input_ids=tokenize), - wds.to_tuple("image", "input_ids"), - ] + if not isinstance(eval_shards_path_or_url, str): + eval_shards_path_or_url = [list(braceexpand(urls)) for urls in eval_shards_path_or_url] + # flatten list using itertools + eval_shards_path_or_url = list(itertools.chain.from_iterable(eval_shards_path_or_url)) else: - # lowercase and replace / with . - vae_checkpoint = vae_checkpoint.lower().replace("/", ".") - text_encoder_checkpoint = text_encoder_checkpoint.lower().replace("/", ".") - processing_pipeline = [ - wds.decode(wds.handle_extension("pth", wds.autodecode.torch_loads), handler=wds.ignore_and_continue), - wds.rename( - input_ids=f"{vae_checkpoint}.pth", - encoder_hidden_states=f"{text_encoder_checkpoint}.pth", - handler=wds.warn_and_continue, + if train_shards_path_or_url is not None or eval_shards_path_or_url is not None: + raise ValueError( + "`train_shards_path_or_url` and `eval_shards_path_or_url` must both be None when" + " `use_m4_laion_text_2_image_dataset` is set" + ) + + if not use_m4_laion_text_2_image_dataset: + if not is_pre_encoded: + processing_pipeline = [ + wds.decode("pil", handler=wds.ignore_and_continue), + wds.rename(image="jpg;png;jpeg;webp", input_ids="text;txt;caption", handler=wds.warn_and_continue), + wds.map(filter_keys(set(["image", "input_ids"]))), + wds.map_dict(image=transform.train_transform, input_ids=tokenize), + wds.to_tuple("image", "input_ids"), + ] + else: + # lowercase and replace / with . + vae_checkpoint = vae_checkpoint.lower().replace("/", ".") + text_encoder_checkpoint = text_encoder_checkpoint.lower().replace("/", ".") + processing_pipeline = [ + wds.decode( + wds.handle_extension("pth", wds.autodecode.torch_loads), handler=wds.ignore_and_continue + ), + wds.rename( + input_ids=f"{vae_checkpoint}.pth", + encoder_hidden_states=f"{text_encoder_checkpoint}.pth", + handler=wds.warn_and_continue, + ), + wds.map(filter_keys(set(["input_ids", "encoder_hidden_states"]))), + wds.to_tuple("input_ids", "encoder_hidden_states"), + ] + + # Create train dataset and loader + pipeline = [ + wds.ResampledShards(train_shards_path_or_url), + tarfile_to_samples_nothrow, + wds.select( + WebdatasetFilter(min_size=256, max_pwatermark=0.5, aesthetic_threshold=4.9) + if use_filtered_dataset + else lambda x: True ), - wds.map(filter_keys(set(["input_ids", "encoder_hidden_states"]))), - wds.to_tuple("input_ids", "encoder_hidden_states"), + wds.shuffle(shuffle_buffer_size), + *processing_pipeline, + wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate), ] - # Create train dataset and loader - pipeline = [ - wds.ResampledShards(train_shards_path_or_url), - tarfile_to_samples_nothrow, - wds.select( - WebdatasetFilter(min_size=256, max_pwatermark=0.5, aesthetic_threshold=4.9) - if use_filtered_dataset - else lambda x: True - ), - wds.shuffle(shuffle_buffer_size), - *processing_pipeline, - wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate), - ] - - num_batches = math.ceil(num_train_examples / global_batch_size) - num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker - num_batches = num_worker_batches * num_workers - num_samples = num_batches * global_batch_size + # each worker is iterating over this + train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches) - # each worker is iterating over this - self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches) - self._train_dataloader = wds.WebLoader( + eval_dataset_pipeline = [ + wds.SimpleShardList(eval_shards_path_or_url), + wds.split_by_worker, + wds.tarfile_to_samples(handler=wds.ignore_and_continue), + *processing_pipeline, + wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate), + ] + eval_dataset = wds.DataPipeline(*eval_dataset_pipeline) + else: + train_dataset = M4LaionShards(type="train") + train_dataset = M4LaionResampledShards(train_dataset) + train_dataset = M4LaionDatasetStream(train_dataset) + + if use_filtered_dataset: + train_dataset = M4LaionDatasetFilter(train_dataset, min_size=256) + + train_dataset = M4LaionDatasetShuffle(train_dataset, shuffle_buffer_size) + train_dataset = M4LaionDatasetProcessingPipeline(train_dataset, transform.train_transform, tokenize) + train_dataset = M4LaionDatasetBatched( + train_dataset, per_gpu_batch_size, partial=False, collation_fn=default_collate + ) + train_dataset = M4LaionDatasetWithEpoch(train_dataset, num_worker_batches) + + eval_dataset = M4LaionShards(type="eval") + eval_dataset = M4LaionDatasetStream(eval_dataset) + eval_dataset = M4LaionDatasetSplitByWorker(eval_dataset) + eval_dataset = M4LaionDatasetProcessingPipeline(eval_dataset, transform.train_transform, tokenize) + eval_dataset = M4LaionDatasetBatched( + eval_dataset, per_gpu_batch_size, partial=False, collation_fn=default_collate + ) + + self._train_dataset = train_dataset + self._train_dataloader = DataLoader( self._train_dataset, batch_size=None, shuffle=False, @@ -363,15 +417,8 @@ def tokenize(text): self._train_dataloader.num_samples = num_samples # Create eval dataset and loader - pipeline = [ - wds.SimpleShardList(eval_shards_path_or_url), - wds.split_by_worker, - wds.tarfile_to_samples(handler=wds.ignore_and_continue), - *processing_pipeline, - wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate), - ] - self._eval_dataset = wds.DataPipeline(*pipeline) - self._eval_dataloader = wds.WebLoader( + self._eval_dataset = eval_dataset + self._eval_dataloader = DataLoader( self._eval_dataset, batch_size=None, shuffle=False, @@ -395,3 +442,220 @@ def eval_dataset(self): @property def eval_dataloader(self): return self._eval_dataloader + + +# We must run the `s3fs.ls` call in the dataloader subprocess. +# s3fs requires non-fork based multi processing if we create an instance in the parent process. +class M4LaionShards(IterableDataset): + def __init__(self, type): + self.type = type + + def __iter__(self): + # s3 handle is not fork safe, just create a new handle when the stream is created + s3 = s3fs.S3FileSystem() + + split_urls = braceexpand("s3://m4-datasets/LAION_data/laion_dataset_filtered_dedup/{0..199}") + + shard_urls = [] + + for split_url in split_urls: + for shard_url in s3.ls(split_url): + shard_urls.append(shard_url) + + if self.type == "train": + shard_urls = shard_urls[:-4] + elif self.type == "eval": + # This is really choosing two shards, of the last 4, 2 are misformatted + shard_urls = shard_urls[-4:] + else: + assert False + + for shard in shard_urls: + yield shard + + +class M4LaionResampledShards(IterableDataset): + def __init__( + self, + urls_iterable, + nshards=sys.maxsize, + worker_seed=None, + deterministic=False, + ): + self.urls = None + self.urls_iterable = urls_iterable + self.nshards = nshards + self.worker_seed = wds.utils.pytorch_worker_seed if worker_seed is None else worker_seed + self.deterministic = deterministic + self.epoch = -1 + + def __iter__(self): + if self.urls is None: + self.urls = [x for x in self.urls_iterable] + + self.epoch += 1 + + if self.deterministic: + seed = wds.utils.make_seed(self.worker_seed(), self.epoch) + else: + seed = wds.utils.make_seed( + self.worker_seed(), + self.epoch, + os.getpid(), + time.time_ns(), + os.urandom(4), + ) + + self.rng = random.Random(seed) + + for _ in range(self.nshards): + index = self.rng.randint(0, len(self.urls) - 1) + + yield self.urls[index] + + +class M4LaionDatasetStream(IterableDataset): + def __init__(self, iterable): + self.iterable = iterable + + def __iter__(self): + # s3 handle is not fork safe, just create a new handle when the stream is created + s3 = s3fs.S3FileSystem() + + for shard_url in self.iterable: + with s3.open(shard_url, "rb") as f: + in_memory_stream = pa.input_stream(f) + try: + opened_stream = pa.ipc.open_stream(in_memory_stream) + except pa.lib.ArrowInvalid as e: + logger.warning(str(e)) + continue + pa_table = opened_stream.read_all() + + table = pa_table.to_pydict() + + for i in range(len(table["text"])): + image_bytes = table["image"][i]["bytes"] + image_bytes = io.BytesIO(image_bytes) + try: + image = Image.open(image_bytes) + image = image.convert("RGB") + except Exception as e: + logger.warning(str(e)) + continue + + text = table["text"][i] + + meta = table["meta"][i] + + yield {"image": image, "text": text, "meta": meta} + + +class M4LaionDatasetFilter(IterableDataset): + def __init__(self, iterable, min_size=256): + self.iterable = iterable + self.min_size = min_size + + def __iter__(self): + for sample in self.iterable: + original_width = sample["meta"].get("original_width", 0.0) or 0.0 + original_height = sample["meta"].get("original_height", 0.0) or 0.0 + + filter_size = original_width >= self.min_size and original_height >= self.min_size + + if filter_size: + yield sample + + +class M4LaionDatasetShuffle(IterableDataset): + def __init__(self, iterable, bufsize=1000, initial=100): + self.iterable = iterable + self.bufsize = bufsize + self.initial = initial + + def __iter__(self): + data = iter(self.iterable) + + rng = random.Random(int((os.getpid() + time.time()) * 1e9)) + initial = min(self.initial, self.bufsize) + buf = [] + for sample in data: + buf.append(sample) + if len(buf) < self.bufsize: + try: + buf.append(next(data)) # skipcq: PYL-R1708 + except StopIteration: + pass + if len(buf) >= initial: + yield pick_random(buf, rng) + while len(buf) > 0: + yield pick_random(buf, rng) + + +class M4LaionDatasetProcessingPipeline(IterableDataset): + def __init__(self, iterable, transform, tokenize): + self.iterable = iterable + self.transform = transform + self.tokenize = tokenize + + def __iter__(self): + for sample in self.iterable: + image = sample["image"] + text = sample["text"] + + image = self.transform(image) + text = self.tokenize(text) + + yield image, text + + +class M4LaionDatasetBatched(IterableDataset): + def __init__(self, iterable, batchsize=20, collation_fn=default_collate, partial=True): + self.iterable = iterable + self.batchsize = batchsize + self.collation_fn = collation_fn + self.partial = partial + + def __iter__(self): + batch = [] + for sample in self.iterable: + if len(batch) >= self.batchsize: + if self.collation_fn is not None: + batch = self.collation_fn(batch) + yield batch + batch = [] + batch.append(sample) + if len(batch) == 0: + return + elif len(batch) == self.batchsize or self.partial: + if self.collation_fn is not None: + batch = self.collation_fn(batch) + yield batch + + +class M4LaionDatasetWithEpoch(IterableDataset): + def __init__(self, iterable, num_epochs): + self.iterable = iterable + self.num_epochs = num_epochs + + def __iter__(self): + for _ in range(self.num_epochs): + for sample in self.iterable: + yield sample + + +class M4LaionDatasetSplitByWorker(IterableDataset): + def __init__(self, iterable): + self.iterable = iterable + + def __iter__(self): + for x in wds.split_by_worker(self.iterable): + yield x + + +def pick_random(buf, rng): + k = rng.randint(0, len(buf) - 1) + sample = buf[k] + buf[k] = buf[-1] + buf.pop() + return sample diff --git a/training/train_muse.py b/training/train_muse.py index 1fe7f8c0..9bf547e9 100644 --- a/training/train_muse.py +++ b/training/train_muse.py @@ -27,7 +27,6 @@ import plotly.express as px import torch import torch.nn.functional as F -import wandb from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import DistributedType, set_seed @@ -46,6 +45,7 @@ import muse import muse.training_utils +import wandb from muse import ( MOVQ, EMAModel, @@ -435,9 +435,18 @@ def save_model_hook(models, weights, output_dir): else: dataset_cls = Text2ImageDataset + use_m4_laion_text_2_image_dataset = dataset_config.get("use_m4_laion_text_2_image_dataset", False) + + if use_m4_laion_text_2_image_dataset: + train_shards_path_or_url = None + eval_shards_path_or_url = None + else: + train_shards_path_or_url = dataset_config.train_shards_path_or_url + eval_shards_path_or_url = dataset_config.eval_shards_path_or_url + dataset = dataset_cls( - train_shards_path_or_url=dataset_config.train_shards_path_or_url, - eval_shards_path_or_url=dataset_config.eval_shards_path_or_url, + train_shards_path_or_url=train_shards_path_or_url, + eval_shards_path_or_url=eval_shards_path_or_url, tokenizer=tokenizer, max_seq_length=preproc_config.max_seq_length, num_train_examples=config.experiment.max_train_examples, @@ -454,6 +463,7 @@ def save_model_hook(models, weights, output_dir): vae_checkpoint=config.model.vq_model.pretrained, text_encoder_checkpoint=config.model.text_encoder.pretrained, use_filtered_dataset=dataset_config.get("use_filtered_dataset", False), + use_m4_laion_text_2_image_dataset=use_m4_laion_text_2_image_dataset, ) train_dataloader, eval_dataloader = dataset.train_dataloader, dataset.eval_dataloader