diff --git a/examples/stardoc_config.yaml b/examples/stardoc_config.yaml
new file mode 100644
index 000000000..0c170dea5
--- /dev/null
+++ b/examples/stardoc_config.yaml
@@ -0,0 +1,78 @@
+training:
+ train_iters: 1000
+ num_workers: 2
+ logs:
+ interval: 10
+ checkpoint:
+ interval: 1000
+ keep: 10
+ export:
+ interval: 1000
+ validation:
+ iterations: null
+ test_iters: 0
+pretrained:
+ path: ".../stardoc_checkpoint"
+ format: huggingface
+batch:
+ sequence_length: 8192
+ micro_batch_size: 1
+ batch_size: 8
+data:
+ split: [0.9, 0.1, 0]
+ path: ".../stardoc_data_config.json"
+ tokenizer:
+ format: TokenzierFromFile
+ path: ".../Mistral-7B-v0.3/tokenizer.json"
+ special_tokens:
+ eos_token: ""
+ bos_token: ""
+ pad_token: "[control_8]"
+ image_placeholder_token: "[control_9]"
+optimizer:
+ learning_rate:
+ base: 1.0e-05
+ decay_style: constant
+ warmup_iterations: 0
+ weight_decay: 0.1
+ beta_1: 0.9
+ beta_2: 0.95
+model:
+ base_model:
+ transformer:
+ normalization:
+ type: rms_norm
+ epsilon: 1.0e-05
+ num_layers: 32
+ hidden_size: 4096
+ ffn_hidden_size: 14336
+ num_attention_heads: 32
+ head_groups: 8
+ add_linear_biases: false
+ use_rotary_embeddings: true
+ gated: true
+ activation_type: silu
+ triton_rotary: true
+ kv_channels: 128
+ rotary_embedding_scale: -9.210340371976184
+ window_size: 4096
+ init_method_std: 0.009021
+ attention_dropout: 0.0
+ hidden_dropout: 0.0
+ multimodal_model:
+ image_encoder_hidden_size: 1024
+ num_image_tokens: 256
+ max_num_images: 10
+ image_resolution: 448
+ image_encoder_type: clip
+ vocab_size: 32000
+ tie_word_embeddings: false
+ multi_stage:
+ zero_stage: 3
+ distributed:
+ training_dtype: bf16
+ distributed_timeout: 3600
+ seed: 984059
+
+run:
+ experiment_dir: stardoc
\ No newline at end of file
diff --git a/examples/train_stardoc.sh b/examples/train_stardoc.sh
new file mode 100755
index 000000000..5607f1537
--- /dev/null
+++ b/examples/train_stardoc.sh
@@ -0,0 +1,124 @@
+# Required or optional environment variables
+# export PROJECT_DIR=
+# export PROJECT_NAME=
+# export PROJECT_VERSION=
+# export DATA_PATH=
+# export PRETRAINED_STARDOC_PATH=
+# export TOKENIZER_PATH=
+
+# export HF_HOME=
+# export HF_TOKEN=
+
+export CMD_ARGS="fast-llm train stardoc"
+
+export MODEL_ARGS_PRETRAINED="\
+--pretrained_checkpoint_type=huggingface \
+--pretrained_checkpoint_path=$PRETRAINED_STARDOC_PATH \
+--use_pretrained_config=1 \
+"
+
+export MODEL_ARGS_ARCHITECTURE="\
+--num_layers=32 \
+--hidden_size=4096 \
+--vocab_size=32000 \
+--num_attention_heads=32 \
+--head_groups=8 \
+--add_linear_biases=0 \
+--ffn_hidden_size=14336 \
+--kv_channels=128 \
+--use_rotary_embeddings=1 \
+--rotary_embedding_scale=-9.210340371976184 \
+--gated=1 \
+--activation_type=silu \
+--normalization_type=rms_norm \
+--tie_word_embeddings=0 \
+--window_size=8192 \
+"
+
+export MULTIMODAL_ARGS="\
+--image_encoder_hidden_size=1024 \
+--num_image_tokens=256 \
+--max_num_images=10 \
+--image_encoder_type=clip \
+"
+
+export DATA_ARGS="\
+--split=9998,2,0 \
+--dataset_type=stardoc \
+--dataset_source=multimodal \
+--data_path=$DATA_PATH \
+--tokenizer_type=PreTrainedTokenizer \
+--tokenizer_path=$TOKENIZER_PATH \
+"
+
+export TRAINING_ARGS="\
+--batch_size=8 \
+--sequence_length=8192 \
+--train_iters=500000 \
+--weight_decay=0.1 \
+--adam_beta1=0.9 \
+--adam_beta2=0.95 \
+--clip_grad=1.0 \
+--lr=0.0001 \
+--lr_warmup_iters=1000 \
+--lr_decay_style=cosine \
+--lr_decay_iters=500000 \
+--min_lr=0.000003 \
+"
+
+export PERFORMANCE_ARGS="\
+--micro_batch_size=1 \
+--training_dtype=bf16 \
+--zero_stage=3 \
+--num_workers=8 \
+"
+
+export MONITORING_ARGS="\
+--validation_iters=25 \
+--validation_interval=1000 \
+--log_interval=10 \
+--log_offset=0 \
+--checkpoint_interval=500 \
+--max_checkpoints=5 \
+--export_interval=25000 \
+--wandb_status_interval=25000 \
+--wandb_entity_name=$WANDB_ENTITY_NAME \
+--wandb_project_name=$PROJECT_NAME \
+--wandb_group_name=$PROJECT_VERSION \
+"
+
+export ALL_ARGS="\
+$CMD_ARGS \
+$MODEL_ARGS_PRETRAINED \
+$MODEL_ARGS_ARCHITECTURE \
+$MULTIMODAL_ARGS \
+$DATA_ARGS \
+$TRAINING_ARGS \
+$PERFORMANCE_ARGS \
+$MONITORING_ARGS \
+"
+
+export PROFILE_ARGS="\
+--profile_cuda=1 \
+--profile_skip=10 \
+--profile_wait=95 \
+--profile_warmup=2 \
+--profile_cycles=3 \
+--profile_export=1 \
+"
+
+run_local () { # run(name, num_gpus, base_cmd)
+ echo $1 $2 $3
+ export TORCHRUN="torchrun --nproc-per-node=$2 --nnodes=1 --no-python"
+ $TORCHRUN $3 --experiment_dir=$PROJECT_DIR/$PROJECT_NAME_$PROJECT_VERSION/$1
+}
+
+run_c10d () { # run(name, num_nodes, base_cmd)
+ echo $1 $2 $3
+ export TORCHRUN="torchrun --nproc-per-node=8 --nnodes=$2 --no-python --rdzv-backend=c10d --rdzv-endpoint=$HOST_NODE_ADDR"
+ $TORCHRUN $3 --experiment_dir=$PROJECT_DIR/$PROJECT_NAME_$PROJECT_VERSION/$1
+}
+
+run_local stardoc_example 8 "$ALL_ARGS"
+# run_c10d stardoc_example 16 "$ALL_ARGS"
+# run_c10d stardoc_example 16 "$ALL_ARGS $MIXTRAL_ARGS --train_iters=50"
diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py
index f973a32c6..62c741f26 100644
--- a/fast_llm/data/config.py
+++ b/fast_llm/data/config.py
@@ -23,6 +23,7 @@ class DatasetSource(str, enum.Enum):
file = "file"
sample = "sample"
random = "random"
+ multimodal = "multimodal"
class MultiprocessingContext(str, enum.Enum):
@@ -99,10 +100,48 @@ def _validate(self):
Assert.in_range_incl(self.rate, 0, 1)
-EOD = "<|endoftext|>"
TokenizerFromFile = "TokenizerFromFile"
+@config_class()
+class SpecialTokensConfig(Config):
+ """
+ Define special tokens like EOS, BOS, PAD and image_placeholder tokens
+ """
+
+ bos_token: str | None = Field(
+ default=None,
+ desc="Beginning of sequence token",
+ hint=FieldHint.core,
+ )
+ eos_token: str | None = Field(
+ default="<|endoftext|>",
+ desc="End of sequence token",
+ hint=FieldHint.core,
+ )
+ pad_token: str | None = Field(
+ default=None,
+ desc="Pad token",
+ hint=FieldHint.core,
+ )
+ image_placeholder_token: str | None = Field(
+ default=None,
+ desc="Placeholder token for images. Used only in multi-modal models",
+ hint=FieldHint.core,
+ )
+
+ def get_special_tokens(self):
+ special_tokens = [
+ self.bos_token,
+ self.eos_token,
+ self.pad_token,
+ self.image_placeholder_token,
+ ]
+
+ # Only return special tokens that are set
+ return [token for token in special_tokens if token is not None]
+
+
@config_class()
class TokenizerConfig(Config):
"""
@@ -114,13 +153,17 @@ class TokenizerConfig(Config):
default="TokenizerFromFile",
desc="Unused.",
hint=FieldHint.deprecated,
- valid=check_field(Assert.eq, TokenizerFromFile),
)
path: str | None = Field(
default=None,
desc="Path to the tokenizer file.",
hint=FieldHint.core,
)
+ special_tokens: SpecialTokensConfig = Field(
+ default_factory=SpecialTokensConfig,
+ desc="Define special tokens.",
+ hint=FieldHint.core,
+ )
@config_class
diff --git a/fast_llm/data/gpt/data.py b/fast_llm/data/gpt/data.py
index 4c3ffcbc7..1c183a977 100644
--- a/fast_llm/data/gpt/data.py
+++ b/fast_llm/data/gpt/data.py
@@ -112,7 +112,7 @@ def __init__(
for name, prefix in zip(dataset_names, dataset_prefixes)
}
self._dataset_weights = {name: weight for name, weight in zip(dataset_names, dataset_weights)}
-
+
def setup(self, distributed: Distributed, samples_per_phase: dict[PhaseType, int]):
"""
Load the datasets, and prepare or load the samplings.
diff --git a/fast_llm/data/stardoc_data_utils/constants.py b/fast_llm/data/stardoc_data_utils/constants.py
new file mode 100644
index 000000000..db134ee6a
--- /dev/null
+++ b/fast_llm/data/stardoc_data_utils/constants.py
@@ -0,0 +1,9 @@
+CONTROLLER_HEART_BEAT_EXPIRATION = 30
+WORKER_HEART_BEAT_INTERVAL = 15
+
+LOGDIR = "./demo_logs"
+
+# Model Constants
+IGNORE_INDEX = -100
+IMAGE_TOKEN_INDEX = -200
+DEFAULT_IMAGE_TOKEN = "<|image|>"
\ No newline at end of file
diff --git a/fast_llm/data/stardoc_data_utils/conversation.py b/fast_llm/data/stardoc_data_utils/conversation.py
new file mode 100644
index 000000000..c5a9ef4b2
--- /dev/null
+++ b/fast_llm/data/stardoc_data_utils/conversation.py
@@ -0,0 +1,303 @@
+import dataclasses
+from enum import auto, Enum
+from typing import List, Tuple
+from fast_llm.data.stardoc_data_utils.constants import DEFAULT_IMAGE_TOKEN
+
+class SeparatorStyle(Enum):
+ """Different separator style."""
+ SINGLE = auto()
+ TWO = auto()
+ TWO_NO_SYS = auto()
+ MPT = auto()
+ PLAIN = auto()
+ LLAMA_2 = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+ """A class that keeps all conversation history."""
+ system: str
+ roles: List[str]
+ messages: List[List[str]]
+ offset: int
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
+ sep: str = "###"
+ sep2: str = None
+ version: str = "Unknown"
+
+ skip_next: bool = False
+
+ def get_prompt(self):
+ messages = self.messages
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
+ messages = self.messages.copy()
+ init_role, init_msg = messages[0].copy()
+ # init_msg = init_msg[0].replace("", "").strip()
+ # if 'mmtag' in self.version:
+ # messages[0] = (init_role, init_msg)
+ # messages.insert(0, (self.roles[0], ""))
+ # messages.insert(1, (self.roles[1], "Received."))
+ # else:
+ # messages[0] = (init_role, "\n" + init_msg)
+ init_msg = init_msg[0].replace(DEFAULT_IMAGE_TOKEN, "").strip()
+ messages[0] = (init_role, DEFAULT_IMAGE_TOKEN + init_msg)
+
+ if self.sep_style == SeparatorStyle.SINGLE:
+ ret = self.system + self.sep
+ for role, message in messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + self.sep
+ else:
+ ret += role + ":"
+ elif self.sep_style == SeparatorStyle.TWO:
+ seps = [self.sep, self.sep2]
+ ret = self.system + seps[0]
+ for i, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + seps[i % 2]
+ else:
+ ret += role + ":"
+ elif self.sep_style == SeparatorStyle.TWO_NO_SYS:
+ seps = [self.sep, self.sep2]
+ ret = ""
+ for i, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + seps[i % 2]
+ else:
+ ret += role + ":"
+ elif self.sep_style == SeparatorStyle.MPT:
+ ret = self.system + self.sep
+ for role, message in messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + message + self.sep
+ else:
+ ret += role
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
+ wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n"
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
+ ret = ""
+
+ for i, (role, message) in enumerate(messages):
+ if i == 0:
+ assert message, "first message should not be none"
+ assert role == self.roles[0], "first message should come from user"
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ if i == 0: message = wrap_sys(self.system) + message
+ if i % 2 == 0:
+ message = wrap_inst(message)
+ ret += self.sep + message
+ else:
+ ret += " " + message + " " + self.sep2
+ else:
+ ret += ""
+ ret = ret.lstrip(self.sep)
+ elif self.sep_style == SeparatorStyle.PLAIN:
+ seps = [self.sep, self.sep2]
+ ret = self.system
+ for i, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += message + seps[i % 2]
+ else:
+ ret += ""
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ return ret
+
+ def append_message(self, role, message):
+ self.messages.append([role, message])
+
+ def get_images(self, return_pil=False):
+ images = []
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
+ if i % 2 == 0:
+ if type(msg) is tuple:
+ import base64
+ from io import BytesIO
+ from PIL import Image
+ msg, image, image_process_mode = msg
+ if image_process_mode == "Pad":
+ def expand2square(pil_img, background_color=(122, 116, 104)):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+ image = expand2square(image)
+ elif image_process_mode in ["Default", "Crop"]:
+ pass
+ elif image_process_mode == "Resize336":
+ image = image.resize((336, 336))
+ elif image_process_mode == "Resize":
+ max_hw, min_hw = max(image.size), min(image.size)
+ aspect_ratio = max_hw / min_hw
+ max_len, min_len = 800, 400
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
+ longest_edge = int(shortest_edge * aspect_ratio)
+ W, H = image.size
+ if longest_edge != max(image.size):
+ if H > W:
+ H, W = longest_edge, shortest_edge
+ else:
+ H, W = shortest_edge, longest_edge
+ image = image.resize((W, H))
+ else:
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
+
+ if return_pil:
+ images.append(image)
+ else:
+ buffered = BytesIO()
+ image.save(buffered, format="PNG")
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
+ images.append(img_b64_str)
+ return images
+
+ def to_gradio_chatbot(self):
+ ret = []
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
+ if i % 2 == 0:
+ if type(msg) is tuple:
+ import base64
+ from io import BytesIO
+ msg, image, image_process_mode = msg
+ max_hw, min_hw = max(image.size), min(image.size)
+ aspect_ratio = max_hw / min_hw
+ max_len, min_len = 800, 400
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
+ longest_edge = int(shortest_edge * aspect_ratio)
+ W, H = image.size
+ if H > W:
+ H, W = longest_edge, shortest_edge
+ else:
+ H, W = shortest_edge, longest_edge
+ image = image.resize((W, H))
+ buffered = BytesIO()
+ image.save(buffered, format="JPEG")
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
+ img_str = f'
'
+ msg = img_str + msg.replace('<|image|>', '').strip()
+ ret.append([msg, None])
+ else:
+ ret.append([msg, None])
+ else:
+ ret[-1][-1] = msg
+ return ret
+
+ def copy(self):
+ return Conversation(
+ system=self.system,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ offset=self.offset,
+ sep_style=self.sep_style,
+ sep=self.sep,
+ sep2=self.sep2,
+ version=self.version)
+
+ def dict(self):
+ if len(self.get_images()) > 0:
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ }
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": self.messages,
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ }
+
+
+conv_vicuna_v0 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
+ ("Assistant",
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
+ "renewable and non-renewable energy sources:\n"
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
+ "energy sources are finite and will eventually run out.\n"
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
+ "and other negative effects.\n"
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
+ "have lower operational costs than non-renewable sources.\n"
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
+ "locations than non-renewable sources.\n"
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
+ ),
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+conv_vicuna_v1 = Conversation(
+ system="A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
+ roles=("USER", "ASSISTANT"),
+ version="v1",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+)
+
+conv_mplug_owl2 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("USER", "ASSISTANT"),
+ version="v1",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO_NO_SYS,
+ sep=" ",
+ sep2="",
+)
+
+# default_conversation = conv_vicuna_v1
+default_conversation = conv_mplug_owl2
+conv_templates = {
+ "default": conv_vicuna_v0,
+ "v0": conv_vicuna_v0,
+ "v1": conv_vicuna_v1,
+ "vicuna_v1": conv_vicuna_v1,
+ "mplug_owl2": conv_mplug_owl2,
+}
+
+
+if __name__ == "__main__":
+ print(default_conversation.get_prompt())
\ No newline at end of file
diff --git a/fast_llm/data/stardoc_data_utils/docowl_processor.py b/fast_llm/data/stardoc_data_utils/docowl_processor.py
new file mode 100644
index 000000000..dfc71b99e
--- /dev/null
+++ b/fast_llm/data/stardoc_data_utils/docowl_processor.py
@@ -0,0 +1,218 @@
+from einops import rearrange, repeat
+import torch
+from torchvision import transforms
+from PIL import Image, ImageFile
+import random
+from torchvision.ops.boxes import box_area
+
+from torchvision.transforms.transforms import InterpolationMode
+from torchvision.transforms import functional as F
+import numpy as np
+#from icecream import ic
+
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+ImageFile.MAX_IMAGE_PIXELS = None
+Image.MAX_IMAGE_PIXELS = None
+
+def box_iou(boxes1, area1, boxes2, eps=1e-5):
+ area2 = box_area(boxes2)
+
+ lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
+ rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
+
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
+ inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
+
+ union = area1[:, None] + area2 - inter
+
+ iou = inter / (union+eps)
+ return iou, union
+
+def anchor_rank(anchors, anchors_areas, input_image_size, eps=1e-5):
+ # anchors x1 y1 x2 y2
+
+ # image_size: (h, w)
+ # xyxy
+ input_image_bbox = torch.tensor([0, 0, input_image_size[1], input_image_size[0]]).unsqueeze(0)
+
+ boxes1 = anchors
+ boxes2 = input_image_bbox
+ boxes3 = anchors.clone()
+ # y2
+ boxes3[:,3] = input_image_size[0]/input_image_size[1]*anchors[:,2] # 用于算分辨率无关的iou
+
+ area1 = anchors_areas
+
+ iou, _ = box_iou(boxes1, area1, boxes2)
+ iou = iou.squeeze(1)
+ shape_iou, _ = box_iou(boxes1, area1, boxes3)
+ shape_iou = shape_iou.diag()
+ # 优先匹配形状接近 再匹配分辨率接近
+ index = torch.argmax(shape_iou*100+iou,dim=0)
+ return index
+
+class AnchorResize(torch.nn.Module):
+
+ def __init__(self, image_size, anchors, interpolation=InterpolationMode.BILINEAR, antialias=None):
+ super().__init__()
+ # xyxy
+ self.anchors = torch.tensor(
+ [[0, 0, _[1]*image_size[1], _[0]*image_size[0]]
+ for _ in anchors], requires_grad=False
+ )
+
+ self.anchor_areas = box_area(self.anchors)
+
+ self.interpolation = interpolation
+ self.antialias = antialias
+
+ def forward(self, img, skip_resize=False):
+ """
+ Args:
+ img (PIL Image or Tensor): Image to be scaled.
+
+ Returns:
+ PIL Image or Tensor: Rescaled image.
+ """
+ selected_anchor = anchor_rank(self.anchors, self.anchor_areas, (img.size[1], img.size[0]))
+ target_size = self.anchors[selected_anchor][2:].tolist() # w,h
+ if skip_resize:
+ # for debug
+ return selected_anchor
+ return F.resize(img, [target_size[1],target_size[0]], self.interpolation, max_size=None, antialias=self.antialias), selected_anchor
+
+ def __repr__(self) -> str:
+ detail = f"(size={self.image_size}, anchor={self.anchors}, interpolation={self.interpolation.value}, antialias={self.antialias})"
+ return f"{self.__class__.__name__}{detail}"
+
+grid_dict = {
+ 'grid_1':[
+ (1,1)],
+ 'grid_4':[
+ (1,1),
+ (1,2),(2,1),
+ (1,3),(3,1),
+ (2,2),(1,4),(4,1)],
+ 'grid_9':[
+ (1,1),
+ (1,2),(2,1),
+ (1,3),(3,1),
+ (2,2),(1,4),(4,1),
+ (1,5),(5,1),
+ (1,6),(6,1),(2,3),(3,2),
+ (1,7),(7,1),
+ (4,2),(2,4),(1,8),(8,1),
+ (3,3),(1,9),(9,1)],
+ 'grid_3x3':[
+ (3,3)],
+ 'grid_20':[
+ (1, 1),
+ (1, 2), (2, 1),
+ (1, 3), (3, 1), (1, 4), (2, 2), (4, 1),
+ (1, 5), (5, 1),
+ (1, 6), (2, 3), (3, 2), (6, 1),
+ (1, 7), (7, 1),
+ (1, 8), (2, 4), (4, 2), (8, 1),
+ (1, 9), (3, 3), (9, 1),
+ (1, 10), (2, 5), (5, 2), (10, 1),
+ (1, 11), (11, 1),
+ (2, 6), (3, 4), (4, 3), (6, 2),
+ (2, 7), (7, 2),
+ (3, 5), (5, 3),
+ (2, 8), (4, 4), (8, 2),
+ (2, 9), (3, 6), (6, 3), (9, 2),
+ (2, 10), (4, 5), (5, 4), (10, 2)]
+}
+
+class DocProcessor():
+ def __init__(self, image_size=224, anchors='grid_9', add_global_img=True, add_textual_crop_indicator=False, media_token=""):
+ self.add_global_img = add_global_img
+ self.add_textual_crop_indicator = add_textual_crop_indicator
+ self.media_token= media_token
+ # h,w
+ if isinstance(image_size, int):
+ image_size = (image_size, image_size)
+ self.image_size = image_size
+ # h,w
+ anchors = grid_dict[anchors]
+ self.anchors = [tuple(_) for _ in anchors]
+ self.anchor_max = max([max(_) for _ in self.anchors])
+ # xywh -> xyxy
+ self.resizer = AnchorResize(image_size=image_size, anchors=anchors, interpolation=InterpolationMode.BICUBIC)
+ self.old_resizer = transforms.Resize(image_size,interpolation=InterpolationMode.BICUBIC)
+ self.image_transform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
+ ])
+
+ def _process_image(self, images):
+ new_images = []
+ new_patch_position = []
+ num_image_mult = []
+ for image in images:
+ if self.add_global_img:
+ nocut_image = self.image_transform(self.old_resizer(image)).unsqueeze(0)
+
+ image, selected_anchor = self.resizer(image)
+ image_input = self.image_transform(image) # h,w,3 -> 3,h,w
+ # rearrange(x,'B C (n1 h) (n2 w) -> (B n1 n2) C h w', n1=self.down_sample[0], n2=self.down_sample[1])
+ image_input = rearrange(image_input, 'C (num_h h) (num_w w) -> (num_h num_w) C h w', h=self.image_size[0], w=self.image_size[1])
+
+ if self.add_global_img:
+ image_input = torch.cat([nocut_image, image_input], dim=0)
+
+ anchor = self.anchors[selected_anchor] # w,h
+ patch_position = torch.cat([
+ repeat(torch.arange(anchor[0]), 'num_h -> num_h num_w 1', num_w=anchor[1]),
+ repeat(torch.arange(anchor[1]), 'num_w -> num_h num_w 1', num_h=anchor[0])],dim=2)
+ patch_position = rearrange(patch_position, 'num_h num_w p-> (num_h num_w) p', p=2) # num_patch, (ph,pw)
+
+ if self.add_global_img:
+ patch_position = torch.cat([torch.ones(1,2).long()*self.anchor_max, patch_position], dim=0)
+
+ new_images.append(image_input)
+ new_patch_position.append(patch_position)
+ num_image_mult.append(patch_position.shape[0])
+
+ new_images = torch.cat(new_images,dim=0)
+ new_patch_position = torch.cat(new_patch_position, dim=0)
+ return new_images, new_patch_position, num_image_mult
+
+ def __call__(self, images=None, query=None):
+ assert images is not None
+
+ if not isinstance(images, list):
+ images = [images]
+ image_pils = []
+ for image in images:
+ if isinstance(image, str):
+ image = Image.open(image).convert('RGB')
+ else:
+ image = image.convert('RGB')
+ # ic(image.size)
+ image_pils.append(image)
+
+ image_data, patch_position, num_image_mult = self._process_image(image_pils)
+
+ assert self.media_token in query
+ text_list = query.split(self.media_token)
+ text = text_list[0]
+ image_token_ptr = 0
+ for next_text in text_list[1:]:
+ if self.add_textual_crop_indicator:
+ # generate image placeholders with interleaved texutual crop indicator
+ # e.g. <|image|><|image|><|image|>...
+ for patch_pos in patch_position.tolist():
+ # global non-crop image
+ if patch_pos[0] == self.anchor_max and patch_pos[1] == self.anchor_max:
+ text += '' + self.media_token
+ else:
+ row_col = 'row'+str(patch_pos[0])+'_col'+str(patch_pos[1])
+ text += '' + self.media_token
+ else:
+ # generate successive image placeholders for a image, 1 crop img == 1 <|image|>
+ text += self.media_token*num_image_mult[image_token_ptr]
+ text += next_text
+ image_token_ptr += 1
+
+ return image_data, patch_position, text
\ No newline at end of file
diff --git a/fast_llm/data/stardoc_data_utils/docowl_stardoc_processor.py b/fast_llm/data/stardoc_data_utils/docowl_stardoc_processor.py
new file mode 100644
index 000000000..88c6b1023
--- /dev/null
+++ b/fast_llm/data/stardoc_data_utils/docowl_stardoc_processor.py
@@ -0,0 +1,110 @@
+import torch
+import transformers
+from fast_llm.data.stardoc_data_utils import conversation as conversation_lib
+from fast_llm.data.stardoc_data_utils.mm_utils import tokenizer_image_token
+from fast_llm.data.stardoc_data_utils.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
+from typing import Dict
+
+
+def docowl_text_preprocess_v1(
+ source,
+ tokenizer: transformers.PreTrainedTokenizer,
+ has_image: bool = False,
+ split: str = "train",
+) -> Dict:
+ """
+ source: list of {'role':'user'/'assistant', 'content':xxxx}
+ """
+ conv = conversation_lib.conv_mplug_owl2.copy()
+ # conv.roles: ("USER", "ASSISTANT")
+ roles = {"user": conv.roles[0], "assistant": conv.roles[1]}
+
+ if split == "train" or split == "val" or split == "test":
+
+ # Apply prompt templates
+ conversations = []
+
+ # Skip the first one if it is not from human
+ if roles[source[0]["role"]] != conv.roles[0]:
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["role"]]
+ assert role == conv.roles[j % 2]
+ conv.append_message(role, sentence["content"])
+
+ # conv.get_prompt(): USER: {content} ASSISTANT: {content}USER: {content} ASSISTANT: {content}...
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+ if has_image:
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
+ else:
+ input_ids = tokenizer.tokenize(
+ conversations,
+ return_tensors="pt",
+ padding="longest",
+ truncation=True,
+ ).input_ids
+
+ targets = input_ids.clone()
+
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO or conv.sep_style == conversation_lib.SeparatorStyle.TWO_NO_SYS
+
+ # Mask targets
+ sep = conv.sep + conv.roles[1] + ": " # ' ASSISTANT: '
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds = conversation.split(conv.sep2) # split by
+ cur_len = 1
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep) # split each round by ' ASSISTANT: '
+ if len(parts) != 2:
+ break
+ parts[0] += sep # input query, ignore for loss
+
+ if has_image:
+ round_len = len(tokenizer_image_token(rou, tokenizer))
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
+ else:
+ round_len = len(tokenizer(rou).input_ids)
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
+
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
+
+ cur_len += round_len
+ target[cur_len:] = IGNORE_INDEX
+
+ if cur_len < tokenizer.max_seq_length: # ignore padding
+ if cur_len != total_len:
+ target[:] = IGNORE_INDEX
+ print(
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
+ f" (ignored)"
+ )
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+ else:
+ text = source[0]["content"]
+ roles = conv.roles # ("USER", "ASSISTANT")
+ conv.append_message(conv.roles[0], text)
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt()
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)
+ stop_str = conv.sep2
+ keywords = [stop_str]
+ return dict(
+ input_ids=input_ids,
+ labels=input_ids,
+ stop_str=stop_str,
+ keywords=keywords,
+ )
\ No newline at end of file
diff --git a/fast_llm/data/stardoc_data_utils/mm_utils.py b/fast_llm/data/stardoc_data_utils/mm_utils.py
new file mode 100644
index 000000000..01cb71053
--- /dev/null
+++ b/fast_llm/data/stardoc_data_utils/mm_utils.py
@@ -0,0 +1,111 @@
+from PIL import Image
+from io import BytesIO
+import base64
+
+import torch
+from transformers import StoppingCriteria
+from fast_llm.data.stardoc_data_utils.constants import IMAGE_TOKEN_INDEX,DEFAULT_IMAGE_TOKEN
+
+
+def load_image_from_base64(image):
+ return Image.open(BytesIO(base64.b64decode(image)))
+
+
+def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+
+def process_images(images, image_processor, model_cfg=None):
+ if model_cfg is not None:
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
+ else:
+ image_aspect_ratio = 'resize'
+ new_images = []
+ if image_aspect_ratio == 'pad':
+ for image in images:
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
+ new_images.append(image)
+ elif image_aspect_ratio == 'resize':
+ for image in images:
+ max_edge = max(image.size)
+ image = image.resize((max_edge, max_edge))
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
+ new_images.append(image)
+ else:
+ return image_processor(images, return_tensors='pt')['pixel_values']
+ if all(x.shape == new_images[0].shape for x in new_images):
+ new_images = torch.stack(new_images, dim=0)
+ return new_images
+
+
+def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
+ prompt_chunks = [tokenizer.tokenize(chunk, max_length=tokenizer.max_seq_length, truncation=True) if len(chunk) > 0 else [] for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)]
+
+ def insert_separator(X, sep):
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
+
+ input_ids = []
+ offset = 0
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
+ offset = 1
+ input_ids.append(prompt_chunks[0][0])
+
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
+ input_ids.extend(x[offset:])
+
+ if return_tensors is not None:
+ if return_tensors == 'pt':
+ return torch.tensor(input_ids, dtype=torch.long)
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
+ return input_ids
+
+
+def get_model_name_from_path(model_path):
+ model_path = model_path.strip("/")
+ model_paths = model_path.split("/")
+ if model_paths[-1].startswith('checkpoint-'):
+ return model_paths[-2] + "_" + model_paths[-1]
+ else:
+ return model_paths[-1]
+
+
+
+
+class KeywordsStoppingCriteria(StoppingCriteria):
+ def __init__(self, keywords, tokenizer, input_ids):
+ self.keywords = keywords
+ self.keyword_ids = []
+ self.max_keyword_len = 0
+ for keyword in keywords:
+ cur_keyword_ids = tokenizer(keyword).input_ids
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
+ cur_keyword_ids = cur_keyword_ids[1:]
+ if len(cur_keyword_ids) > self.max_keyword_len:
+ self.max_keyword_len = len(cur_keyword_ids)
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
+ self.tokenizer = tokenizer
+ self.start_len = input_ids.shape[1]
+
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
+ assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
+ for keyword_id in self.keyword_ids:
+ if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
+ return True
+ outputs = self.tokenizer.detokenize_batch(output_ids[:, -offset:], skip_special_tokens=True)[0]
+ for keyword in self.keywords:
+ if keyword in outputs:
+ return True
+ return False
\ No newline at end of file
diff --git a/fast_llm/data/stardoc_data_utils/utils.py b/fast_llm/data/stardoc_data_utils/utils.py
new file mode 100644
index 000000000..15a9a46d2
--- /dev/null
+++ b/fast_llm/data/stardoc_data_utils/utils.py
@@ -0,0 +1,33 @@
+from PIL import Image
+import io
+
+def convert_queries_and_annotations_to_messages(queries, annotations):
+ messages = []
+ # Add each query and annotation as a user-assistant pair
+ for i, (q, a) in enumerate(zip(queries, annotations)):
+ if i == 0:
+ # Prepend "<|image|>" to the first query
+ q = f"<|image|>{q}"
+ messages.append({"role": "user", "content": q})
+ messages.append({"role": "assistant", "content": a})
+ return messages
+
+def image_loading_function(images):
+ """
+ Load an image from a file path
+ """
+ assert images is not None
+ if not isinstance(images, list):
+ images = [images]
+ image_pils = []
+ for image in images:
+ if isinstance(image, bytes):
+ image = Image.open(io.BytesIO(image))
+ elif isinstance(image, str):
+ image = Image.open(image)
+ elif isinstance(image, Image.Image):
+ pass
+ else:
+ raise ValueError(f"Unsupported image type: {type(image)}")
+ image_pils.append(image)
+ return image_pils
\ No newline at end of file
diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py
index 2061d6b6f..811e2d173 100644
--- a/fast_llm/data/tokenizer.py
+++ b/fast_llm/data/tokenizer.py
@@ -1,6 +1,6 @@
from transformers import PreTrainedTokenizerFast
-from fast_llm.data.config import EOD, TokenizerConfig
+from fast_llm.data.config import TokenizerConfig
from fast_llm.engine.config_utils.run import log_main_rank
@@ -9,15 +9,17 @@ class Tokenizer:
A wrapper around Huggingface (transformers) tokenizer.
"""
- def __init__(self, config: TokenizerConfig):
+ def __init__(self, config: TokenizerConfig, max_sequence_length=None):
log_main_rank(f"> loading tokenizer from {config.path} ...")
- special_tokens = [EOD]
- self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=config.path, errors="replace", max_len=None)
+ self._config = config
+ self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=config.path, errors="replace", max_len=max_sequence_length)
+ special_tokens = config.special_tokens.get_special_tokens()
self.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
- self.eod_id = self.tokenizer.vocab[EOD]
+
# Token->id mapping for additional special-tokens
self.special_tokens = {tok: self.tokenizer.vocab[tok] for tok in special_tokens}
self._inv_vocab = {v: k for k, v in self.tokenizer.vocab.items()}
+ self._max_sequence_length = max_sequence_length
@property
def vocab_size(self):
@@ -31,6 +33,42 @@ def vocab(self):
def inv_vocab(self):
return self._inv_vocab
+ @property
+ def max_sequence_length(self):
+ return self._max_sequence_length
+
+ @property
+ def bos_token_id(self):
+ bos_token = self._config.special_tokens.bos_token
+ if bos_token is not None:
+ return self.special_tokens[bos_token]
+ else:
+ raise ValueError("BOS token not set in tokenizer")
+
+ @property
+ def eos_token_id(self):
+ eos_token = self._config.special_tokens.eos_token
+ if eos_token is not None:
+ return self.special_tokens[eos_token]
+ else:
+ raise ValueError("EOS token not set in tokenizer")
+
+ @property
+ def pad_token_id(self):
+ pad_token = self._config.special_tokens.pad_token
+ if pad_token is not None:
+ return self.special_tokens[pad_token]
+ else:
+ raise ValueError("PAD token not set in tokenizer")
+
+ @property
+ def image_placeholder_token_id(self):
+ image_placeholder_token = self._config.special_tokens.image_placeholder_token
+ if image_placeholder_token is not None:
+ return self.special_tokens[image_placeholder_token]
+ else:
+ raise ValueError("Image placeholder token not set in tokenizer")
+
def tokenize(self, text):
return self.tokenizer.encode(text)
@@ -39,4 +77,4 @@ def detokenize(self, token_ids):
@property
def eod(self):
- return self.eod_id
+ return self.eos_token_id
\ No newline at end of file
diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py
index c0f899144..b16e22a29 100644
--- a/fast_llm/engine/multi_stage/stage_base.py
+++ b/fast_llm/engine/multi_stage/stage_base.py
@@ -239,6 +239,8 @@ def setup(
def _replace(module: torch.nn.Module):
nonlocal i
for key in module._parameters: # noqa
+ if module._parameters[key] is None:
+ continue
meta = typing.cast(ParameterMeta, module._parameters[key]) # noqa
module._parameters[key] = self._parameter_buffers[self._parameter_index[meta.tensor_name]] # noqa
i += 1
diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py
index b88f45540..b48d72ba8 100644
--- a/fast_llm/layers/language_model/config.py
+++ b/fast_llm/layers/language_model/config.py
@@ -28,6 +28,7 @@ class LanguageModelKwargs:
# TODO: These are generic
labels = "labels"
phase = "phase"
+ tokens = "tokens"
@config_class()
diff --git a/fast_llm/layers/multimodal_model/adapter.py b/fast_llm/layers/multimodal_model/adapter.py
new file mode 100644
index 000000000..b70d21391
--- /dev/null
+++ b/fast_llm/layers/multimodal_model/adapter.py
@@ -0,0 +1,58 @@
+import logging
+import copy
+import torch
+from torch import nn
+
+from fast_llm.layers.common.linear import Linear
+from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs
+from fast_llm.layers.multimodal_model.config import MultimodalModelBaseConfig, MultimodalModelDimNames, MultimodalModelKwargs
+from fast_llm.layers.language_model.config import LanguageModelBaseConfig
+from fast_llm.tensor import ParameterMeta, TensorMeta, TensorSpace, TensorDim, init_normal_
+
+logger = logging.getLogger(__name__)
+
+class Adapter(torch.nn.Module):
+
+ # Ensure the layer is on its own stage.
+ layer_count: float = 1000.0
+
+ def __init__(
+ self,
+ config: LanguageModelBaseConfig,
+ tensor_space: TensorSpace,
+ ):
+ super(Adapter, self).__init__()
+ self._distributed_config = tensor_space.distributed_config
+ self._tensor_space = tensor_space
+ self._residual_dtype = (
+ self._distributed_config.optimization_dtype
+ if config.transformer.full_precision_residual
+ else self._distributed_config.training_dtype
+ ).torch
+
+ in_dim = self._tensor_space.get_tensor_dim(MultimodalModelDimNames.image_encoder_hidden_size)
+ out_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden)
+
+ self.dropout = nn.Dropout(p=0.1)
+ self.adapter_fc = Linear(
+ in_dim,
+ out_dim,
+ bias=True,
+ weight_init_method=init_normal_(std=config.transformer.init_method_std),
+ )
+
+ def _forward(self, input_: torch.Tensor, losses: dict | None = None, metrics: dict | None = None):
+ hidden_states = self.dropout(input_)
+ out = self.adapter_fc(hidden_states)
+
+ return out.to(dtype=self._residual_dtype)
+
+ def forward(self, input_, kwargs, losses: dict | None = None, metrics: dict | None = None):
+ if isinstance(input_, TensorMeta):
+ return TensorMeta.from_dims(
+ kwargs[MultimodalModelKwargs.adapter_hidden_dims],
+ tensor_name="Adapter output",
+ dtype=self._residual_dtype,
+ )
+
+ return self._forward(input_)
\ No newline at end of file
diff --git a/fast_llm/layers/multimodal_model/config.py b/fast_llm/layers/multimodal_model/config.py
new file mode 100644
index 000000000..be689fab4
--- /dev/null
+++ b/fast_llm/layers/multimodal_model/config.py
@@ -0,0 +1,76 @@
+import enum
+from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none
+from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig
+from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace
+
+from fast_llm.utils import Assert
+
+class MultimodalModelDimNames:
+ # Image encoder dimensions
+ max_num_images = "max_num_images"
+ image_pixel_count = "image_pixel_count"
+ num_image_tokens = "num_image_tokens"
+ image_encoder_hidden_size = "image_encoder_hidden_size"
+
+class MultimodalModelKwargs:
+ image_encoder_hidden_dims = "image_encoder_hidden_dims"
+ adapter_hidden_dims = "adapter_hidden_dims"
+
+class ImageEncoderType(str, enum.Enum):
+ clip = "clip"
+ docowl = "docowl"
+
+@config_class()
+class MultimodalModelArchitectureConfig(BaseModelArchitectureConfig):
+ _abstract = False
+
+ image_encoder_hidden_size: int = Field(
+ default=1024,
+ desc="Hidden size of image encoder.",
+ hint=FieldHint.core,
+ valid=check_field(Assert.gt, 0),
+ )
+ num_image_tokens: int = Field(
+ default=256,
+ desc="Number of image tokens.",
+ hint=FieldHint.core,
+ valid=check_field(Assert.gt, 0),
+ )
+ max_num_images: int = Field(
+ default=10,
+ desc="Max. number of images in a sample. We pad to ensure shapes are consistent.",
+ hint=FieldHint.core,
+ valid=check_field(Assert.gt, 0),
+ )
+ image_resolution: int = Field(
+ default=448,
+ desc="Resolution of image",
+ hint=FieldHint.core,
+ valid=check_field(Assert.gt, 0),
+ )
+
+ def _validate(self):
+ super()._validate()
+
+ def setup_tensor_space(self, tensor_space: TensorSpace):
+ tensor_space.add_tensor_dim(TensorDim(MultimodalModelDimNames.max_num_images, self.max_num_images))
+ tensor_space.add_tensor_dim(TensorDim(MultimodalModelDimNames.num_image_tokens, self.num_image_tokens))
+ tensor_space.add_tensor_dim(TensorDim(MultimodalModelDimNames.image_pixel_count, self.image_resolution * self.image_resolution))
+ tensor_space.add_tensor_dim(TensorDim(MultimodalModelDimNames.image_encoder_hidden_size, self.image_encoder_hidden_size))
+
+
+@config_class()
+class MultimodalModelBaseConfig(MultimodalModelArchitectureConfig, BaseModelConfig):
+ """
+ A configuration class for defining the model configuration of encoder and adapter components of multi-modal model.
+ """
+ _abstract = False
+
+ image_encoder_type: ImageEncoderType = Field(
+ default=ImageEncoderType.clip,
+ desc="Type of image encoder",
+ hint=FieldHint.feature,
+ )
+
+ def _validate(self):
+ super()._validate()
\ No newline at end of file
diff --git a/fast_llm/layers/multimodal_model/image_encoder.py b/fast_llm/layers/multimodal_model/image_encoder.py
new file mode 100644
index 000000000..6e5b1191d
--- /dev/null
+++ b/fast_llm/layers/multimodal_model/image_encoder.py
@@ -0,0 +1,100 @@
+import logging
+import copy
+import torch
+
+from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelDimNames, LanguageModelKwargs
+from fast_llm.layers.multimodal_model.config import MultimodalModelKwargs, MultimodalModelBaseConfig
+from fast_llm.engine.config_utils.tensor_space import TensorSpace
+from fast_llm.tensor import ParameterMeta, TensorMeta, TensorDim, init_normal_
+
+logger = logging.getLogger(__name__)
+
+class ImageEncoder(torch.nn.Module):
+
+ # Ensure the layer is on its own stage.
+ layer_count: float = 1000.0
+
+ def __init__(
+ self,
+ config: LanguageModelBaseConfig,
+ tensor_space: TensorSpace,
+ ):
+ super(ImageEncoder, self).__init__()
+ self._distributed_config = tensor_space.distributed_config
+ self._tensor_space = tensor_space
+ self._residual_dtype = (
+ self._distributed_config.optimization_dtype
+ if config.transformer.full_precision_residual
+ else self._distributed_config.training_dtype
+ ).torch
+ self.image_encoder_type = config.multimodal_model.image_encoder_type
+
+ if self.image_encoder_type.lower() == "clip":
+ import open_clip
+
+ model, _, _ = open_clip.create_model_and_transforms(
+ "ViT-L-14", pretrained="laion2b_s32b_b82k"
+ )
+
+ self.visual_encoder = model.visual
+ self.visual_encoder.output_tokens = True
+ self.ln_vision = copy.deepcopy(self.visual_encoder.ln_post)
+ else:
+ logger.error(f'Unknown image encoder specified: {self.image_encoder_type.lower()}')
+
+ # Replace all parameters with Parameter(MetaParameter(...))
+ with torch.no_grad():
+ for name, param in self.named_parameters():
+ module = self
+ name_parts = name.split('.')
+ # We have to traverse to the correct parent module and change the parameter there
+ for part in name_parts[:-1]:
+ module = getattr(module, part)
+
+ # Replace prameter with FastLLM meta parameter
+ setattr(module, name_parts[-1], self.get_fastllm_parameter(name, param))
+
+ def get_fastllm_parameter(self, param_name, param):
+ param_dims = tuple([TensorDim(name=f'{param_name}_{idx}', global_size=x, parallel_dim=None) for idx, x in enumerate(param.shape)])
+ return ParameterMeta(param.to("meta"), tensor_name=param_name, dims=param_dims, init_method=init_normal_(std=0.02), allow_no_grad=True)
+
+ def _forward(self, input_: tuple[torch.Tensor], losses: dict | None = None, metrics: dict | None = None):
+ if not self.image_encoder_type.lower() == "clip":
+ raise ValueError(f'clip is the only image encoder type currrently supported')
+
+ # TODO: Remove padding images
+ # _bsz_im, num_img, ch, im_width, im_height = image_input
+ # image_input = image_input.view(_bsz_im * num_img, *image_input.shape[2:])
+ # num_values_per_image = image_input.shape[1:].numel()
+ # real_images_inds = (image_input == 0.0).sum(dim=(-1, -2, -3)) != num_values_per_image
+ # image_input = image_input[real_images_inds].contiguous()
+
+ # (bsz, num_img, ch, im_h, im_w) -> (bsz*num_img, ch, im_h, im_w)
+
+
+ # Convert the input images tensor to residual dtype. This is torch.float32 by default
+ input_ = input_.to(self._residual_dtype)
+
+ _bsz_im, num_img, ch, im_width, im_height = input_.shape
+ input_ = input_.view(_bsz_im * num_img, *input_.shape[2:]).contiguous()
+
+ out = self.visual_encoder(input_)[1]
+ out = self.ln_vision(out)
+
+ # (bsz*num_img, im_tokens, h) -> (bsz, num_img, im_tokens, h)
+ out = out.view(_bsz_im, num_img, *out.shape[1:]).contiguous()
+
+ return out.to(dtype=self._residual_dtype)
+
+ def forward(self, input_, kwargs, losses: dict | None = None, metrics: dict | None = None):
+ if input_ is None:
+ raise ValueError(f'You must define a max_num_images > 0 if image_encoder is enabled')
+
+ if isinstance(input_, TensorMeta):
+ return TensorMeta.from_dims(
+ kwargs[MultimodalModelKwargs.image_encoder_hidden_dims],
+ tensor_name="Image encoder output",
+ dtype=self._residual_dtype,
+ )
+
+ return self._forward(input_)
\ No newline at end of file
diff --git a/fast_llm/layers/multimodal_model/multimodal_language_embedding.py b/fast_llm/layers/multimodal_model/multimodal_language_embedding.py
new file mode 100644
index 000000000..9d5275715
--- /dev/null
+++ b/fast_llm/layers/multimodal_model/multimodal_language_embedding.py
@@ -0,0 +1,98 @@
+import torch
+
+from fast_llm.core.distributed import set_generator
+from fast_llm.core.ops import reduce_forward, split
+from fast_llm.engine.base_model.base_model import Layer
+from fast_llm.engine.config_utils.tensor_space import TensorSpace
+from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelDimNames, LanguageModelKwargs
+from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs
+from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_
+from fast_llm.utils import Assert
+
+WORD_EMBEDDINGS_WEIGHT = "word_embeddings_weight"
+
+
+class MultiModalLanguageModelEmbedding(Layer):
+ """
+ An embedding layer that fuses multi-modal features with language embedding.
+ Consists of multi-modal embeddings (tensor-parallel),
+ together with optional absolute position embeddings and dropout.
+ """
+
+ # Ensure the layer is on its own stage.
+ layer_count: float = 1000.0
+
+ def __init__(
+ self,
+ config: LanguageModelBaseConfig,
+ tensor_space: TensorSpace,
+ ):
+ super().__init__()
+ config.validate()
+ self._distributed_config = tensor_space.distributed_config
+ self._tensor_space = tensor_space
+ self._residual_dtype = (
+ self._distributed_config.optimization_dtype
+ if config.transformer.full_precision_residual
+ else self._distributed_config.training_dtype
+ ).torch
+ self._group_size = self._distributed_config.tensor_parallel
+ self._sequence_parallel = self._distributed_config.sequence_tensor_parallel
+ self._parallel_embeddings = tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings
+ self._dropout_p = config.transformer.hidden_dropout
+ self._use_absolute_position_embeddings = config.use_absolute_position_embeddings
+
+ hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden)
+ vocab_dim = tensor_space.get_tensor_dim(
+ LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab
+ )
+
+ if self._parallel_embeddings:
+ self._vocab_start_index = self._distributed_config.tensor_rank * vocab_dim.size
+ self._vocab_end_index = (self._distributed_config.tensor_rank + 1) * vocab_dim.size
+
+ self.word_embeddings_weight = ParameterMeta.from_dims(
+ (vocab_dim, hidden_dim),
+ init_method=init_normal_(std=config.init_method_std_embed),
+ )
+ if self._use_absolute_position_embeddings:
+ self.position_embeddings_weight = ParameterMeta.from_dims(
+ (tensor_space.get_tensor_dim(LanguageModelDimNames.position_embed), hidden_dim),
+ init_method=init_normal_(std=config.init_method_std_embed),
+ allow_sequence_tensor_parallel=not config.parallel_embeddings,
+ )
+
+ @torch.compile
+ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, tokens: torch.Tensor | None):
+ Assert.eq(position_ids is not None, self._use_absolute_position_embeddings)
+ Assert.eq(tokens is not None)
+
+ text_embeddings = torch.embedding(self.word_embeddings_weight, tokens)
+
+ bsz, num_imgs, _, hidden_size = input_.shape
+
+ # TODO: Hardcoded image token
+ image_token_mask = tokens == 11
+
+ embeddings = text_embeddings.clone()
+ embeddings[image_token_mask] = input_.view(-1, hidden_size)
+
+ if self._use_absolute_position_embeddings:
+ embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight)
+
+ with set_generator(
+ self._tensor_space.distributed.tp_generator
+ if self._sequence_parallel
+ else self._tensor_space.distributed.pp_generator
+ ):
+ embeddings = torch.dropout(embeddings, self._dropout_p, self.training)
+ return embeddings.to(dtype=self._residual_dtype)
+
+ def forward(self, input_, kwargs, losses: dict | None = None, metrics: dict | None = None):
+ if isinstance(input_, TensorMeta):
+ return TensorMeta.from_dims(
+ kwargs[TransformerKwargs.hidden_dims],
+ tensor_name="Embedding output",
+ dtype=self._residual_dtype,
+ )
+ return self._forward(input_, kwargs.get(LanguageModelKwargs.position_ids), kwargs.get(LanguageModelKwargs.tokens))
\ No newline at end of file
diff --git a/fast_llm/models/auto.py b/fast_llm/models/auto.py
index 6437e3b32..d6232ea6d 100644
--- a/fast_llm/models/auto.py
+++ b/fast_llm/models/auto.py
@@ -1,5 +1,6 @@
from fast_llm.models.custom.config import CustomModelConfig, CustomTrainerConfig
from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig
+from fast_llm.models.stardoc.config import StarDocModelConfig, StarDocTrainerConfig
from fast_llm.utils import Registry
model_registry = Registry(
@@ -8,6 +9,7 @@
model.model_name: model
for model in [
GPTModelConfig,
+ StarDocModelConfig,
CustomModelConfig,
]
},
@@ -19,7 +21,8 @@
trainer.get_field("model").type.model_name: trainer
for trainer in [
GPTTrainerConfig,
+ StarDocTrainerConfig,
CustomTrainerConfig,
]
},
-)
+)
\ No newline at end of file
diff --git a/fast_llm/models/stardoc/__init__.py b/fast_llm/models/stardoc/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/fast_llm/models/stardoc/config.py b/fast_llm/models/stardoc/config.py
new file mode 100644
index 000000000..390069eac
--- /dev/null
+++ b/fast_llm/models/stardoc/config.py
@@ -0,0 +1,92 @@
+import typing
+
+from fast_llm.config import Field, FieldHint, config_class
+from fast_llm.data.config import DataConfig
+from fast_llm.engine.training.config import TrainerConfig
+
+from fast_llm.models.gpt.config import (
+ GPTArchitectureConfig,
+ GPTBaseModelConfig,
+ GPTTrainerConfig,
+)
+
+from fast_llm.layers.multimodal_model.config import MultimodalModelArchitectureConfig, MultimodalModelBaseConfig
+from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig
+from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace
+
+if typing.TYPE_CHECKING:
+ from fast_llm.engine.multi_stage.conversion import ModelConverter
+
+
+@config_class()
+class StarDocDataConfig(DataConfig):
+ # TODO: If needed, inherit from AbstractDataConfig instead and re-implement everything.
+ pass
+
+
+@config_class()
+class StarDocArchitectureConfig(GPTArchitectureConfig):
+ multimodal_model: MultimodalModelArchitectureConfig = Field(
+ default_factory=MultimodalModelArchitectureConfig,
+ desc="Configuration for the multimodal components (image encoder and adapter).",
+ hint=FieldHint.core,
+ )
+
+ def setup_tensor_space(self, tensor_space: TensorSpace):
+ super().setup_tensor_space(tensor_space)
+ self.multimodal_model.setup_tensor_space(tensor_space)
+
+ @classmethod
+ def get_converter_class(cls, model_type: str | None = None) -> type["ModelConverter"]:
+ from fast_llm.models.stardoc.conversion import AutoStarDocConverter
+
+ return AutoStarDocConverter if model_type is None else AutoStarDocConverter.converter_map[model_type]
+
+@config_class()
+class StarDocBaseModelConfig(GPTBaseModelConfig, StarDocArchitectureConfig):
+ architecture_cls = StarDocArchitectureConfig
+
+ multimodal_model: MultimodalModelBaseConfig = Field(
+ default_factory=MultimodalModelBaseConfig,
+ desc="Configuration for the multimodal components (image encoder and adapter).",
+ hint=FieldHint.core,
+ )
+
+
+
+@config_class()
+class StarDocModelConfig(FastLLMModelConfig):
+ _abstract = False
+ base_model: StarDocBaseModelConfig = Field(default_factory=StarDocBaseModelConfig)
+
+ @classmethod
+ def get_model_class(cls):
+ from fast_llm.models.stardoc.model import StarDocModel
+
+ return StarDocModel
+
+
+@config_class()
+class PretrainedStarDocModelConfig(PretrainedFastLLMModelConfig):
+ _abstract = False
+ model: StarDocModelConfig = Field(default_factory=StarDocModelConfig)
+
+
+@config_class()
+class StarDocTrainerConfig(PretrainedStarDocModelConfig, GPTTrainerConfig):
+ @classmethod
+ def get_trainer_class(cls):
+ from fast_llm.models.stardoc.trainer import StarDocTrainer
+
+ return StarDocTrainer
+
+
+class HuggingfaceModelType:
+ """
+ An enum for the huggingface models with conversion support.
+ """
+
+ starcoder2 = "starcoder2"
+ llama = "llama"
+ mistral = "mistral"
+ mixtral = "mixtral"
\ No newline at end of file
diff --git a/fast_llm/models/stardoc/conversion.py b/fast_llm/models/stardoc/conversion.py
new file mode 100644
index 000000000..e5186a34e
--- /dev/null
+++ b/fast_llm/models/stardoc/conversion.py
@@ -0,0 +1,644 @@
+import abc
+import math
+import typing
+
+import torch
+
+from fast_llm.engine.multi_stage.conversion import (
+ AutoModelConverter,
+ ConstantExportParamConverter,
+ ConstantImportParamConverter,
+ HuggingfaceModelConverter,
+ IgnoreImportParamConverter,
+ IgnoreWeightConverter,
+ MappedConfigParamConverter,
+ ParamConverter,
+ SplitWeightConverter,
+ WeightConverter,
+)
+from fast_llm.functional.config import ActivationType
+from fast_llm.functional.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex
+from fast_llm.layers.common.config import NormalizationType
+from fast_llm.layers.transformer.config import RoutingType
+from fast_llm.models.stardoc.config import StarDocArchitectureConfig, StarDocBaseModelConfig, HuggingfaceModelType
+from fast_llm.tensor import SafeTensorSlice
+
+if typing.TYPE_CHECKING:
+ pass
+
+
+class QueryWeightConverter(WeightConverter):
+ # Hf uses the real format for rotary embeddings.
+ _config: StarDocArchitectureConfig
+
+ def export_weight(
+ self, weight: tuple[torch.Tensor | SafeTensorSlice, ...]
+ ) -> tuple[torch.Tensor | SafeTensorSlice, ...]:
+ (query,) = weight
+ if self._config.transformer.complex_rotary_embeddings:
+ query = convert_rotary_complex_to_real(query[:], self._config.transformer.kv_channels, 0)
+ return (query,)
+
+ def import_weight(
+ self, weight: tuple[torch.Tensor | SafeTensorSlice, ...]
+ ) -> tuple[torch.Tensor | SafeTensorSlice, ...]:
+ (query,) = weight
+ if self._config.transformer.complex_rotary_embeddings:
+ query = convert_rotary_real_to_complex(query[:], self._config.transformer.kv_channels, 0)
+ return (query,)
+
+
+class KeyValueWeightConverter(WeightConverter):
+ # Hf uses the real format for rotary embeddings, and keeps the key and value separate.
+ _config: StarDocArchitectureConfig
+
+ def export_weight(
+ self, weight: tuple[torch.Tensor | SafeTensorSlice, ...]
+ ) -> tuple[torch.Tensor | SafeTensorSlice, ...]:
+ (key_value,) = weight
+ key, value = key_value[:].chunk(2)
+ if self._config.transformer.complex_rotary_embeddings:
+ key = convert_rotary_complex_to_real(key, self._config.transformer.kv_channels, 0)
+ return key, value
+
+ def import_weight(
+ self, weight: tuple[torch.Tensor | SafeTensorSlice, ...]
+ ) -> tuple[torch.Tensor | SafeTensorSlice, ...]:
+ key, value = weight
+ if self._config.transformer.complex_rotary_embeddings:
+ key = convert_rotary_real_to_complex(key[:], self._config.transformer.kv_channels, 0)
+ key_value = torch.cat([key[:], value[:]])
+ return (key_value,)
+
+
+class MLPLayer2Converter(WeightConverter):
+ # Similar to SplitWeightConverter, but handles the optional MLP transpose.
+ # Still ok for non-gated (trivial split) and biases (trivial 1d transpose)
+ _config: StarDocArchitectureConfig
+
+ def export_weight(
+ self, weight: tuple[torch.Tensor | SafeTensorSlice, ...]
+ ) -> tuple[torch.Tensor | SafeTensorSlice, ...]:
+ (merged_weight,) = weight
+ return tuple(t.contiguous() for t in merged_weight[:].t().chunk(len(self.export_name), dim=-1))
+
+ def import_weight(
+ self, weight: tuple[torch.Tensor | SafeTensorSlice, ...]
+ ) -> tuple[torch.Tensor | SafeTensorSlice, ...]:
+ merged_weight = torch.cat([weight_[:] for weight_ in weight], dim=-1)
+ return (merged_weight.t().contiguous(),)
+
+
+class CommonHuggingfaceConverter(HuggingfaceModelConverter):
+ config: StarDocArchitectureConfig
+ _base_model_cls = StarDocBaseModelConfig
+ """
+ Common converter for llama-based huggingface models (llama, starcoder2, mistral, mixtral)
+ """
+
+ @abc.abstractmethod
+ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str):
+ pass
+
+ @classmethod
+ def _create_config_converters(cls) -> list[ParamConverter]:
+ return super()._create_config_converters() + [
+ ConstantImportParamConverter(("multimodal_model", "image_encoder_hidden_size",), None, 1024),
+ ConstantImportParamConverter(("multimodal_model", "num_image_tokens",), None, 256),
+ ConstantImportParamConverter(("multimodal_model", "max_num_images",), None, 10),
+ ConstantImportParamConverter(("use_position_embeddings",), None, False),
+ ConstantImportParamConverter(("transformer", "use_rotary_embeddings"), None, True),
+ MappedConfigParamConverter(
+ ("transformer", "rotary_embedding_scale"), "rope_theta", lambda x: -math.log(x), lambda x: math.exp(-x)
+ ),
+ MappedConfigParamConverter(
+ ("transformer", "activation_type"),
+ "hidden_act",
+ ActivationType.from_hf_name,
+ lambda activation_type: activation_type.hf_name,
+ ),
+ ParamConverter(("transformer", "num_layers"), "num_hidden_layers"),
+ ParamConverter(("transformer", "hidden_size"), "hidden_size"),
+ ParamConverter(("transformer", "num_attention_heads"), "num_attention_heads"),
+ ParamConverter(("transformer", "head_groups"), "num_key_value_heads"),
+ ParamConverter(("transformer", "ffn_hidden_size"), "intermediate_size"),
+ ParamConverter(("vocab_size",), "vocab_size"),
+ ParamConverter(("tie_word_embeddings",), "tie_word_embeddings"),
+ ]
+
+ def _create_weight_converters(self) -> list[WeightConverter]:
+ converters = []
+ num_layers = self.config.transformer.num_layers
+ norm_bias: bool = self.config.transformer.normalization.type == NormalizationType.layer_norm
+ linear_bias: bool = self.config.transformer.add_linear_biases
+
+ # Vision encoder
+ converters.append(WeightConverter("layers.0.visual_encoder.class_embedding", "visual_encoder.class_embedding"))
+ converters.append(WeightConverter("layers.0.visual_encoder.positional_embedding", "visual_encoder.positional_embedding"))
+ converters.append(WeightConverter("layers.0.visual_encoder.proj", "visual_encoder.proj"))
+ converters.append(WeightConverter("layers.0.visual_encoder.conv1.weight", "visual_encoder.conv1.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.ln_pre.weight", "visual_encoder.ln_pre.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.ln_pre.bias", "visual_encoder.ln_pre.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.0.ln_1.weight", "visual_encoder.transformer.resblocks.0.ln_1.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.0.ln_1.bias", "visual_encoder.transformer.resblocks.0.ln_1.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.0.attn.in_proj_weight", "visual_encoder.transformer.resblocks.0.attn.in_proj_weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.0.attn.in_proj_bias", "visual_encoder.transformer.resblocks.0.attn.in_proj_bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.0.attn.out_proj.weight", "visual_encoder.transformer.resblocks.0.attn.out_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.0.attn.out_proj.bias", "visual_encoder.transformer.resblocks.0.attn.out_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.0.ln_2.weight", "visual_encoder.transformer.resblocks.0.ln_2.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.0.ln_2.bias", "visual_encoder.transformer.resblocks.0.ln_2.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.0.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.0.mlp.c_fc.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.0.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.0.mlp.c_fc.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.0.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.0.mlp.c_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.0.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.0.mlp.c_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.1.ln_1.weight", "visual_encoder.transformer.resblocks.1.ln_1.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.1.ln_1.bias", "visual_encoder.transformer.resblocks.1.ln_1.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.1.attn.in_proj_weight", "visual_encoder.transformer.resblocks.1.attn.in_proj_weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.1.attn.in_proj_bias", "visual_encoder.transformer.resblocks.1.attn.in_proj_bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.1.attn.out_proj.weight", "visual_encoder.transformer.resblocks.1.attn.out_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.1.attn.out_proj.bias", "visual_encoder.transformer.resblocks.1.attn.out_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.1.ln_2.weight", "visual_encoder.transformer.resblocks.1.ln_2.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.1.ln_2.bias", "visual_encoder.transformer.resblocks.1.ln_2.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.1.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.1.mlp.c_fc.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.1.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.1.mlp.c_fc.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.1.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.1.mlp.c_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.1.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.1.mlp.c_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.2.ln_1.weight", "visual_encoder.transformer.resblocks.2.ln_1.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.2.ln_1.bias", "visual_encoder.transformer.resblocks.2.ln_1.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.2.attn.in_proj_weight", "visual_encoder.transformer.resblocks.2.attn.in_proj_weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.2.attn.in_proj_bias", "visual_encoder.transformer.resblocks.2.attn.in_proj_bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.2.attn.out_proj.weight", "visual_encoder.transformer.resblocks.2.attn.out_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.2.attn.out_proj.bias", "visual_encoder.transformer.resblocks.2.attn.out_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.2.ln_2.weight", "visual_encoder.transformer.resblocks.2.ln_2.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.2.ln_2.bias", "visual_encoder.transformer.resblocks.2.ln_2.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.2.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.2.mlp.c_fc.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.2.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.2.mlp.c_fc.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.2.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.2.mlp.c_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.2.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.2.mlp.c_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.3.ln_1.weight", "visual_encoder.transformer.resblocks.3.ln_1.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.3.ln_1.bias", "visual_encoder.transformer.resblocks.3.ln_1.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.3.attn.in_proj_weight", "visual_encoder.transformer.resblocks.3.attn.in_proj_weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.3.attn.in_proj_bias", "visual_encoder.transformer.resblocks.3.attn.in_proj_bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.3.attn.out_proj.weight", "visual_encoder.transformer.resblocks.3.attn.out_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.3.attn.out_proj.bias", "visual_encoder.transformer.resblocks.3.attn.out_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.3.ln_2.weight", "visual_encoder.transformer.resblocks.3.ln_2.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.3.ln_2.bias", "visual_encoder.transformer.resblocks.3.ln_2.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.3.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.3.mlp.c_fc.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.3.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.3.mlp.c_fc.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.3.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.3.mlp.c_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.3.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.3.mlp.c_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.4.ln_1.weight", "visual_encoder.transformer.resblocks.4.ln_1.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.4.ln_1.bias", "visual_encoder.transformer.resblocks.4.ln_1.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.4.attn.in_proj_weight", "visual_encoder.transformer.resblocks.4.attn.in_proj_weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.4.attn.in_proj_bias", "visual_encoder.transformer.resblocks.4.attn.in_proj_bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.4.attn.out_proj.weight", "visual_encoder.transformer.resblocks.4.attn.out_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.4.attn.out_proj.bias", "visual_encoder.transformer.resblocks.4.attn.out_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.4.ln_2.weight", "visual_encoder.transformer.resblocks.4.ln_2.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.4.ln_2.bias", "visual_encoder.transformer.resblocks.4.ln_2.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.4.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.4.mlp.c_fc.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.4.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.4.mlp.c_fc.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.4.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.4.mlp.c_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.4.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.4.mlp.c_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.5.ln_1.weight", "visual_encoder.transformer.resblocks.5.ln_1.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.5.ln_1.bias", "visual_encoder.transformer.resblocks.5.ln_1.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.5.attn.in_proj_weight", "visual_encoder.transformer.resblocks.5.attn.in_proj_weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.5.attn.in_proj_bias", "visual_encoder.transformer.resblocks.5.attn.in_proj_bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.5.attn.out_proj.weight", "visual_encoder.transformer.resblocks.5.attn.out_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.5.attn.out_proj.bias", "visual_encoder.transformer.resblocks.5.attn.out_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.5.ln_2.weight", "visual_encoder.transformer.resblocks.5.ln_2.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.5.ln_2.bias", "visual_encoder.transformer.resblocks.5.ln_2.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.5.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.5.mlp.c_fc.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.5.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.5.mlp.c_fc.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.5.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.5.mlp.c_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.5.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.5.mlp.c_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.6.ln_1.weight", "visual_encoder.transformer.resblocks.6.ln_1.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.6.ln_1.bias", "visual_encoder.transformer.resblocks.6.ln_1.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.6.attn.in_proj_weight", "visual_encoder.transformer.resblocks.6.attn.in_proj_weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.6.attn.in_proj_bias", "visual_encoder.transformer.resblocks.6.attn.in_proj_bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.6.attn.out_proj.weight", "visual_encoder.transformer.resblocks.6.attn.out_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.6.attn.out_proj.bias", "visual_encoder.transformer.resblocks.6.attn.out_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.6.ln_2.weight", "visual_encoder.transformer.resblocks.6.ln_2.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.6.ln_2.bias", "visual_encoder.transformer.resblocks.6.ln_2.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.6.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.6.mlp.c_fc.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.6.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.6.mlp.c_fc.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.6.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.6.mlp.c_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.6.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.6.mlp.c_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.7.ln_1.weight", "visual_encoder.transformer.resblocks.7.ln_1.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.7.ln_1.bias", "visual_encoder.transformer.resblocks.7.ln_1.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.7.attn.in_proj_weight", "visual_encoder.transformer.resblocks.7.attn.in_proj_weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.7.attn.in_proj_bias", "visual_encoder.transformer.resblocks.7.attn.in_proj_bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.7.attn.out_proj.weight", "visual_encoder.transformer.resblocks.7.attn.out_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.7.attn.out_proj.bias", "visual_encoder.transformer.resblocks.7.attn.out_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.7.ln_2.weight", "visual_encoder.transformer.resblocks.7.ln_2.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.7.ln_2.bias", "visual_encoder.transformer.resblocks.7.ln_2.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.7.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.7.mlp.c_fc.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.7.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.7.mlp.c_fc.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.7.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.7.mlp.c_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.7.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.7.mlp.c_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.8.ln_1.weight", "visual_encoder.transformer.resblocks.8.ln_1.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.8.ln_1.bias", "visual_encoder.transformer.resblocks.8.ln_1.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.8.attn.in_proj_weight", "visual_encoder.transformer.resblocks.8.attn.in_proj_weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.8.attn.in_proj_bias", "visual_encoder.transformer.resblocks.8.attn.in_proj_bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.8.attn.out_proj.weight", "visual_encoder.transformer.resblocks.8.attn.out_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.8.attn.out_proj.bias", "visual_encoder.transformer.resblocks.8.attn.out_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.8.ln_2.weight", "visual_encoder.transformer.resblocks.8.ln_2.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.8.ln_2.bias", "visual_encoder.transformer.resblocks.8.ln_2.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.8.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.8.mlp.c_fc.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.8.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.8.mlp.c_fc.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.8.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.8.mlp.c_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.8.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.8.mlp.c_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.9.ln_1.weight", "visual_encoder.transformer.resblocks.9.ln_1.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.9.ln_1.bias", "visual_encoder.transformer.resblocks.9.ln_1.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.9.attn.in_proj_weight", "visual_encoder.transformer.resblocks.9.attn.in_proj_weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.9.attn.in_proj_bias", "visual_encoder.transformer.resblocks.9.attn.in_proj_bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.9.attn.out_proj.weight", "visual_encoder.transformer.resblocks.9.attn.out_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.9.attn.out_proj.bias", "visual_encoder.transformer.resblocks.9.attn.out_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.9.ln_2.weight", "visual_encoder.transformer.resblocks.9.ln_2.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.9.ln_2.bias", "visual_encoder.transformer.resblocks.9.ln_2.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.9.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.9.mlp.c_fc.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.9.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.9.mlp.c_fc.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.9.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.9.mlp.c_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.9.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.9.mlp.c_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.10.ln_1.weight", "visual_encoder.transformer.resblocks.10.ln_1.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.10.ln_1.bias", "visual_encoder.transformer.resblocks.10.ln_1.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.10.attn.in_proj_weight", "visual_encoder.transformer.resblocks.10.attn.in_proj_weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.10.attn.in_proj_bias", "visual_encoder.transformer.resblocks.10.attn.in_proj_bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.10.attn.out_proj.weight", "visual_encoder.transformer.resblocks.10.attn.out_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.10.attn.out_proj.bias", "visual_encoder.transformer.resblocks.10.attn.out_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.10.ln_2.weight", "visual_encoder.transformer.resblocks.10.ln_2.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.10.ln_2.bias", "visual_encoder.transformer.resblocks.10.ln_2.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.10.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.10.mlp.c_fc.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.10.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.10.mlp.c_fc.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.10.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.10.mlp.c_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.10.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.10.mlp.c_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.11.ln_1.weight", "visual_encoder.transformer.resblocks.11.ln_1.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.11.ln_1.bias", "visual_encoder.transformer.resblocks.11.ln_1.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.11.attn.in_proj_weight", "visual_encoder.transformer.resblocks.11.attn.in_proj_weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.11.attn.in_proj_bias", "visual_encoder.transformer.resblocks.11.attn.in_proj_bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.11.attn.out_proj.weight", "visual_encoder.transformer.resblocks.11.attn.out_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.11.attn.out_proj.bias", "visual_encoder.transformer.resblocks.11.attn.out_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.11.ln_2.weight", "visual_encoder.transformer.resblocks.11.ln_2.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.11.ln_2.bias", "visual_encoder.transformer.resblocks.11.ln_2.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.11.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.11.mlp.c_fc.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.11.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.11.mlp.c_fc.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.11.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.11.mlp.c_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.11.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.11.mlp.c_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.12.ln_1.weight", "visual_encoder.transformer.resblocks.12.ln_1.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.12.ln_1.bias", "visual_encoder.transformer.resblocks.12.ln_1.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.12.attn.in_proj_weight", "visual_encoder.transformer.resblocks.12.attn.in_proj_weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.12.attn.in_proj_bias", "visual_encoder.transformer.resblocks.12.attn.in_proj_bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.12.attn.out_proj.weight", "visual_encoder.transformer.resblocks.12.attn.out_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.12.attn.out_proj.bias", "visual_encoder.transformer.resblocks.12.attn.out_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.12.ln_2.weight", "visual_encoder.transformer.resblocks.12.ln_2.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.12.ln_2.bias", "visual_encoder.transformer.resblocks.12.ln_2.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.12.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.12.mlp.c_fc.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.12.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.12.mlp.c_fc.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.12.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.12.mlp.c_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.12.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.12.mlp.c_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.13.ln_1.weight", "visual_encoder.transformer.resblocks.13.ln_1.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.13.ln_1.bias", "visual_encoder.transformer.resblocks.13.ln_1.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.13.attn.in_proj_weight", "visual_encoder.transformer.resblocks.13.attn.in_proj_weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.13.attn.in_proj_bias", "visual_encoder.transformer.resblocks.13.attn.in_proj_bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.13.attn.out_proj.weight", "visual_encoder.transformer.resblocks.13.attn.out_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.13.attn.out_proj.bias", "visual_encoder.transformer.resblocks.13.attn.out_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.13.ln_2.weight", "visual_encoder.transformer.resblocks.13.ln_2.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.13.ln_2.bias", "visual_encoder.transformer.resblocks.13.ln_2.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.13.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.13.mlp.c_fc.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.13.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.13.mlp.c_fc.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.13.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.13.mlp.c_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.13.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.13.mlp.c_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.14.ln_1.weight", "visual_encoder.transformer.resblocks.14.ln_1.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.14.ln_1.bias", "visual_encoder.transformer.resblocks.14.ln_1.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.14.attn.in_proj_weight", "visual_encoder.transformer.resblocks.14.attn.in_proj_weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.14.attn.in_proj_bias", "visual_encoder.transformer.resblocks.14.attn.in_proj_bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.14.attn.out_proj.weight", "visual_encoder.transformer.resblocks.14.attn.out_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.14.attn.out_proj.bias", "visual_encoder.transformer.resblocks.14.attn.out_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.14.ln_2.weight", "visual_encoder.transformer.resblocks.14.ln_2.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.14.ln_2.bias", "visual_encoder.transformer.resblocks.14.ln_2.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.14.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.14.mlp.c_fc.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.14.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.14.mlp.c_fc.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.14.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.14.mlp.c_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.14.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.14.mlp.c_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.15.ln_1.weight", "visual_encoder.transformer.resblocks.15.ln_1.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.15.ln_1.bias", "visual_encoder.transformer.resblocks.15.ln_1.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.15.attn.in_proj_weight", "visual_encoder.transformer.resblocks.15.attn.in_proj_weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.15.attn.in_proj_bias", "visual_encoder.transformer.resblocks.15.attn.in_proj_bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.15.attn.out_proj.weight", "visual_encoder.transformer.resblocks.15.attn.out_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.15.attn.out_proj.bias", "visual_encoder.transformer.resblocks.15.attn.out_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.15.ln_2.weight", "visual_encoder.transformer.resblocks.15.ln_2.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.15.ln_2.bias", "visual_encoder.transformer.resblocks.15.ln_2.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.15.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.15.mlp.c_fc.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.15.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.15.mlp.c_fc.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.15.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.15.mlp.c_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.15.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.15.mlp.c_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.16.ln_1.weight", "visual_encoder.transformer.resblocks.16.ln_1.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.16.ln_1.bias", "visual_encoder.transformer.resblocks.16.ln_1.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.16.attn.in_proj_weight", "visual_encoder.transformer.resblocks.16.attn.in_proj_weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.16.attn.in_proj_bias", "visual_encoder.transformer.resblocks.16.attn.in_proj_bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.16.attn.out_proj.weight", "visual_encoder.transformer.resblocks.16.attn.out_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.16.attn.out_proj.bias", "visual_encoder.transformer.resblocks.16.attn.out_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.16.ln_2.weight", "visual_encoder.transformer.resblocks.16.ln_2.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.16.ln_2.bias", "visual_encoder.transformer.resblocks.16.ln_2.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.16.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.16.mlp.c_fc.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.16.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.16.mlp.c_fc.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.16.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.16.mlp.c_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.16.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.16.mlp.c_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.17.ln_1.weight", "visual_encoder.transformer.resblocks.17.ln_1.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.17.ln_1.bias", "visual_encoder.transformer.resblocks.17.ln_1.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.17.attn.in_proj_weight", "visual_encoder.transformer.resblocks.17.attn.in_proj_weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.17.attn.in_proj_bias", "visual_encoder.transformer.resblocks.17.attn.in_proj_bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.17.attn.out_proj.weight", "visual_encoder.transformer.resblocks.17.attn.out_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.17.attn.out_proj.bias", "visual_encoder.transformer.resblocks.17.attn.out_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.17.ln_2.weight", "visual_encoder.transformer.resblocks.17.ln_2.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.17.ln_2.bias", "visual_encoder.transformer.resblocks.17.ln_2.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.17.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.17.mlp.c_fc.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.17.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.17.mlp.c_fc.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.17.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.17.mlp.c_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.17.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.17.mlp.c_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.18.ln_1.weight", "visual_encoder.transformer.resblocks.18.ln_1.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.18.ln_1.bias", "visual_encoder.transformer.resblocks.18.ln_1.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.18.attn.in_proj_weight", "visual_encoder.transformer.resblocks.18.attn.in_proj_weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.18.attn.in_proj_bias", "visual_encoder.transformer.resblocks.18.attn.in_proj_bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.18.attn.out_proj.weight", "visual_encoder.transformer.resblocks.18.attn.out_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.18.attn.out_proj.bias", "visual_encoder.transformer.resblocks.18.attn.out_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.18.ln_2.weight", "visual_encoder.transformer.resblocks.18.ln_2.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.18.ln_2.bias", "visual_encoder.transformer.resblocks.18.ln_2.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.18.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.18.mlp.c_fc.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.18.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.18.mlp.c_fc.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.18.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.18.mlp.c_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.18.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.18.mlp.c_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.19.ln_1.weight", "visual_encoder.transformer.resblocks.19.ln_1.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.19.ln_1.bias", "visual_encoder.transformer.resblocks.19.ln_1.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.19.attn.in_proj_weight", "visual_encoder.transformer.resblocks.19.attn.in_proj_weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.19.attn.in_proj_bias", "visual_encoder.transformer.resblocks.19.attn.in_proj_bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.19.attn.out_proj.weight", "visual_encoder.transformer.resblocks.19.attn.out_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.19.attn.out_proj.bias", "visual_encoder.transformer.resblocks.19.attn.out_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.19.ln_2.weight", "visual_encoder.transformer.resblocks.19.ln_2.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.19.ln_2.bias", "visual_encoder.transformer.resblocks.19.ln_2.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.19.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.19.mlp.c_fc.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.19.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.19.mlp.c_fc.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.19.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.19.mlp.c_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.19.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.19.mlp.c_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.20.ln_1.weight", "visual_encoder.transformer.resblocks.20.ln_1.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.20.ln_1.bias", "visual_encoder.transformer.resblocks.20.ln_1.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.20.attn.in_proj_weight", "visual_encoder.transformer.resblocks.20.attn.in_proj_weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.20.attn.in_proj_bias", "visual_encoder.transformer.resblocks.20.attn.in_proj_bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.20.attn.out_proj.weight", "visual_encoder.transformer.resblocks.20.attn.out_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.20.attn.out_proj.bias", "visual_encoder.transformer.resblocks.20.attn.out_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.20.ln_2.weight", "visual_encoder.transformer.resblocks.20.ln_2.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.20.ln_2.bias", "visual_encoder.transformer.resblocks.20.ln_2.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.20.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.20.mlp.c_fc.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.20.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.20.mlp.c_fc.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.20.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.20.mlp.c_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.20.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.20.mlp.c_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.21.ln_1.weight", "visual_encoder.transformer.resblocks.21.ln_1.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.21.ln_1.bias", "visual_encoder.transformer.resblocks.21.ln_1.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.21.attn.in_proj_weight", "visual_encoder.transformer.resblocks.21.attn.in_proj_weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.21.attn.in_proj_bias", "visual_encoder.transformer.resblocks.21.attn.in_proj_bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.21.attn.out_proj.weight", "visual_encoder.transformer.resblocks.21.attn.out_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.21.attn.out_proj.bias", "visual_encoder.transformer.resblocks.21.attn.out_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.21.ln_2.weight", "visual_encoder.transformer.resblocks.21.ln_2.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.21.ln_2.bias", "visual_encoder.transformer.resblocks.21.ln_2.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.21.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.21.mlp.c_fc.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.21.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.21.mlp.c_fc.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.21.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.21.mlp.c_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.21.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.21.mlp.c_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.22.ln_1.weight", "visual_encoder.transformer.resblocks.22.ln_1.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.22.ln_1.bias", "visual_encoder.transformer.resblocks.22.ln_1.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.22.attn.in_proj_weight", "visual_encoder.transformer.resblocks.22.attn.in_proj_weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.22.attn.in_proj_bias", "visual_encoder.transformer.resblocks.22.attn.in_proj_bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.22.attn.out_proj.weight", "visual_encoder.transformer.resblocks.22.attn.out_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.22.attn.out_proj.bias", "visual_encoder.transformer.resblocks.22.attn.out_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.22.ln_2.weight", "visual_encoder.transformer.resblocks.22.ln_2.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.22.ln_2.bias", "visual_encoder.transformer.resblocks.22.ln_2.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.22.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.22.mlp.c_fc.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.22.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.22.mlp.c_fc.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.22.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.22.mlp.c_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.22.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.22.mlp.c_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.23.ln_1.weight", "visual_encoder.transformer.resblocks.23.ln_1.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.23.ln_1.bias", "visual_encoder.transformer.resblocks.23.ln_1.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.23.attn.in_proj_weight", "visual_encoder.transformer.resblocks.23.attn.in_proj_weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.23.attn.in_proj_bias", "visual_encoder.transformer.resblocks.23.attn.in_proj_bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.23.attn.out_proj.weight", "visual_encoder.transformer.resblocks.23.attn.out_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.23.attn.out_proj.bias", "visual_encoder.transformer.resblocks.23.attn.out_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.23.ln_2.weight", "visual_encoder.transformer.resblocks.23.ln_2.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.23.ln_2.bias", "visual_encoder.transformer.resblocks.23.ln_2.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.23.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.23.mlp.c_fc.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.23.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.23.mlp.c_fc.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.23.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.23.mlp.c_proj.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.23.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.23.mlp.c_proj.bias"))
+ converters.append(WeightConverter("layers.0.visual_encoder.ln_post.weight", "visual_encoder.ln_post.weight"))
+ converters.append(WeightConverter("layers.0.visual_encoder.ln_post.bias", "visual_encoder.ln_post.bias"))
+ converters.append(WeightConverter("layers.0.ln_vision.weight", "ln_vision.weight"))
+ converters.append(WeightConverter("layers.0.ln_vision.bias", "ln_vision.bias"))
+
+ # Adapter
+ converters.append(WeightConverter("layers.1.adapter_fc.weight", "c_fc.weight"))
+ converters.append(WeightConverter("layers.1.adapter_fc.bias", "c_fc.bias"))
+
+ # Embedding and output
+ if self.config.tie_word_embeddings:
+ converters.append(WeightConverter("layers.2.word_embeddings_weight", "model.embed_tokens.weight"))
+ converters.append(IgnoreWeightConverter((), "lm_head.weight"))
+ else:
+ converters.append(WeightConverter("layers.2.word_embeddings_weight", "model.embed_tokens.weight"))
+ converters.append(WeightConverter(f"layers.{num_layers + 3}.output_weights", "lm_head.weight"))
+
+ # Final norm
+ converters += self._get_weight_and_bias_converters(
+ f"layers.{num_layers + 3}.final_norm", "model.norm", norm_bias
+ )
+
+ for i in range(num_layers):
+ # Self-attn
+ converters += self._get_weight_and_bias_converters(
+ f"layers.{i+3}.self_attn.query",
+ f"model.layers.{i}.self_attn.q_proj",
+ linear_bias,
+ QueryWeightConverter,
+ )
+ converters += self._get_weight_and_bias_converters(
+ f"layers.{i+3}.self_attn.key_value",
+ (f"model.layers.{i}.self_attn.k_proj", f"model.layers.{i}.self_attn.v_proj"),
+ linear_bias,
+ KeyValueWeightConverter,
+ )
+ converters += self._get_weight_and_bias_converters(
+ f"layers.{i+3}.self_attn.dense", f"model.layers.{i}.self_attn.o_proj", linear_bias
+ )
+
+ # Norm
+ converters += self._get_weight_and_bias_converters(
+ f"layers.{i+3}.norm_1", f"model.layers.{i}.input_layernorm", norm_bias
+ )
+ converters += self._get_weight_and_bias_converters(
+ f"layers.{i+3}.norm_2", f"model.layers.{i}.post_attention_layernorm", norm_bias
+ )
+
+ # MLP
+ converters += self._get_mlp_converters(f"layers.{i+3}", f"model.layers.{i}")
+
+ return converters
+
+ def _get_weight_and_bias_converters(
+ self,
+ fast_llm_prefix: str | tuple[str, ...],
+ hf_prefix: str | tuple[str, ...],
+ use_bias: bool,
+ cls=WeightConverter,
+ ):
+ if isinstance(fast_llm_prefix, str):
+ fast_llm_prefix = (fast_llm_prefix,)
+ if isinstance(hf_prefix, str):
+ hf_prefix = (hf_prefix,)
+ converters = [
+ cls(
+ tuple(f"{prefix}.weight" for prefix in fast_llm_prefix),
+ tuple(f"{prefix}.weight" for prefix in hf_prefix),
+ self.config,
+ )
+ ]
+ if use_bias:
+ converters.append(
+ cls(
+ tuple(f"{prefix}.bias" for prefix in fast_llm_prefix),
+ tuple(f"{prefix}.bias" for prefix in hf_prefix),
+ self.config,
+ )
+ )
+ return converters
+
+
+class Starcoder2HuggingfaceConverter(CommonHuggingfaceConverter):
+ model_type = HuggingfaceModelType.starcoder2
+
+ @classmethod
+ def _create_config_converters(cls) -> list[ParamConverter]:
+ return super()._create_config_converters() + [
+ ConstantExportParamConverter(None, "architectures", ["Starcoder2ForCausalLM"]),
+ ConstantImportParamConverter(
+ ("transformer", "normalization", "type"), None, NormalizationType.layer_norm
+ ),
+ ParamConverter(("transformer", "normalization", "epsilon"), "norm_epsilon"),
+ ConstantImportParamConverter(("transformer", "gated"), None, False),
+ ConstantImportParamConverter(("transformer", "add_linear_biases"), None, True),
+ ]
+
+ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str):
+ linear_bias: bool = self.config.transformer.add_linear_biases
+ return [
+ *self._get_weight_and_bias_converters(
+ f"{fast_llm_prefix}.mlp.layer_1", f"{hf_prefix}.mlp.c_fc", linear_bias
+ ),
+ *self._get_weight_and_bias_converters(
+ f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.c_proj", linear_bias, MLPLayer2Converter
+ ),
+ ]
+
+
+class CommonLlamaHuggingfaceConverter(CommonHuggingfaceConverter, abc.ABC):
+ @classmethod
+ def _create_config_converters(cls) -> list[ParamConverter]:
+ return super()._create_config_converters() + [
+ ConstantImportParamConverter(
+ ("transformer", "normalization", "type"), None, NormalizationType.rms_norm
+ ),
+ ParamConverter(("transformer", "normalization", "epsilon"), "rms_norm_eps"),
+ ConstantImportParamConverter(("transformer", "gated"), None, True),
+ ConstantImportParamConverter(("transformer", "add_linear_biases"), None, False),
+ ]
+
+
+class LlamaHuggingfaceConverter(CommonLlamaHuggingfaceConverter):
+ model_type = HuggingfaceModelType.llama
+
+ @classmethod
+ def _create_config_converters(cls) -> list[ParamConverter]:
+ return super()._create_config_converters() + [
+ ConstantExportParamConverter(None, "architectures", ["LlamaForCausalLM"]),
+ # TODO: Llama supports biases
+ ConstantExportParamConverter(None, "attention_bias", False),
+ ConstantExportParamConverter(None, "mlp_bias", False),
+ ConstantExportParamConverter(None, "rope_scaling", False),
+ ]
+
+ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str):
+ linear_bias: bool = self.config.transformer.add_linear_biases
+ return [
+ *self._get_weight_and_bias_converters(
+ f"{fast_llm_prefix}.mlp.layer_1",
+ (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"),
+ linear_bias,
+ SplitWeightConverter,
+ ),
+ *self._get_weight_and_bias_converters(
+ f"{fast_llm_prefix}.mlp.layer_2",
+ f"{hf_prefix}.mlp.down_proj",
+ linear_bias,
+ MLPLayer2Converter,
+ ),
+ ]
+
+
+class MistralHuggingfaceConverter(CommonLlamaHuggingfaceConverter):
+ model_type = HuggingfaceModelType.mistral
+
+ @classmethod
+ def _create_config_converters(cls) -> list[ParamConverter]:
+ return super()._create_config_converters() + [
+ ConstantExportParamConverter(None, "architectures", ["MistralForCausalLM"]),
+ IgnoreImportParamConverter(None, "sliding_window", None),
+ ]
+
+ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str):
+ return [
+ SplitWeightConverter(
+ f"{fast_llm_prefix}.mlp.layer_1.weight",
+ (f"{hf_prefix}.mlp.gate_proj.weight", f"{hf_prefix}.mlp.up_proj.weight"),
+ ),
+ MLPLayer2Converter(
+ f"{fast_llm_prefix}.mlp.layer_2.weight", f"{hf_prefix}.mlp.down_proj.weight", self.config
+ ),
+ ]
+
+
+class MixtralHuggingfaceConverter(CommonLlamaHuggingfaceConverter):
+ model_type = HuggingfaceModelType.mixtral
+
+ @classmethod
+ def _create_config_converters(cls) -> list[ParamConverter]:
+ return super()._create_config_converters() + [
+ ConstantExportParamConverter(None, "architectures", ["MixtralForCausalLM"]),
+ ConstantImportParamConverter(("transformer", "expert_routing_type"), None, RoutingType.topk),
+ ParamConverter(("transformer", "num_experts"), "num_local_experts"),
+ ParamConverter(("transformer", "num_experts_per_token"), "num_experts_per_tok"),
+ IgnoreImportParamConverter(None, "sliding_window", None),
+ ]
+
+ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str):
+ num_experts = self.config.transformer.num_experts
+ return [
+ WeightConverter(f"{fast_llm_prefix}.mlp.router.weight", f"{hf_prefix}.block_sparse_moe.gate.weight"),
+ SplitWeightConverter(
+ f"{fast_llm_prefix}.mlp.layer_1.weight",
+ tuple(
+ f"{hf_prefix}.block_sparse_moe.experts.{i}.{w}.weight"
+ for i in range(num_experts)
+ for w in ("w1", "w3")
+ ),
+ ),
+ MLPLayer2Converter(
+ f"{fast_llm_prefix}.mlp.layer_2.weight",
+ tuple(f"{hf_prefix}.block_sparse_moe.experts.{i}.w2.weight" for i in range(num_experts)),
+ self.config,
+ ),
+ ]
+
+
+class AutoStarDocConverter(AutoModelConverter, HuggingfaceModelConverter, abc.ABC):
+ converter_map = {
+ HuggingfaceModelType.starcoder2: Starcoder2HuggingfaceConverter,
+ HuggingfaceModelType.llama: LlamaHuggingfaceConverter,
+ HuggingfaceModelType.mistral: MistralHuggingfaceConverter,
+ HuggingfaceModelType.mixtral: MixtralHuggingfaceConverter,
+ }
\ No newline at end of file
diff --git a/fast_llm/models/stardoc/data.py b/fast_llm/models/stardoc/data.py
new file mode 100644
index 000000000..3d39b0da2
--- /dev/null
+++ b/fast_llm/models/stardoc/data.py
@@ -0,0 +1,177 @@
+import json
+import logging
+import math
+import pathlib
+import typing
+import warnings
+import numpy as np
+
+from fast_llm.models.stardoc.config import StarDocDataConfig
+from fast_llm.models.stardoc.stardoc_dataset import StarDocDataset
+from fast_llm.engine.distributed.config import DistributedConfig, PhaseType
+from fast_llm.engine.distributed.distributed import Distributed
+from fast_llm.engine.config_utils.run import get_run, log_main_rank
+from fast_llm.data.data import Data
+from fast_llm.data.tokenizer import Tokenizer
+from fast_llm.data.dataset import BlendedDataset, SampledDataset, Sampler
+from fast_llm.engine.schedule.config import BatchConfig
+from fast_llm.utils import Assert
+
+logger = logging.getLogger(__name__)
+
+
+def normalize_probs(p: list[float]) -> list[float]:
+ p = np.array(p)
+ Assert.custom(lambda x: np.all(x >= 0), p)
+ p_sum = p.sum()
+ Assert.gt(p_sum, 0)
+ return (p / p_sum).tolist()
+
+
+class StarDocData(Data):
+ """
+ A class for all dataset needs for StarDoc.
+ """
+ _sampled_datasets: dict[PhaseType, dict[str, SampledDataset]]
+ _blended_datasets: dict[PhaseType, SampledDataset]
+ _tokenizer: Tokenizer | None
+ _distributed: Distributed
+ _cache_dir: pathlib.Path | None
+ _samples_per_phase: dict[PhaseType, int]
+ _phases: typing.ClassVar[tuple[PhaseType, ...]] = (PhaseType.training, PhaseType.validation, PhaseType.test)
+
+ def __init__(
+ self,
+ config: StarDocDataConfig,
+ distributed_config: DistributedConfig,
+ vocab_size: int,
+ max_sequence_length: int,
+ ):
+ """
+ Create the data and gather some basic information on the dataset(s).
+ Should be `setup` before use.
+ """
+ self._config = config.validate()
+ self._distributed_config = distributed_config.validate()
+ self._vocab_size = vocab_size
+ self._max_sequence_length = max_sequence_length
+ Assert.eq(len(self._config.split), len(self._phases))
+ self._phase_split = {
+ phase: ratio for phase, ratio in zip(self._phases, normalize_probs(self._config.split)) if ratio > 0
+ }
+ data_base_path = None
+ Assert.eq(len(self._config.path), 1)
+ data_path = pathlib.Path(self._config.path[0])
+ dataset_defs = json.load(data_path.open("r"))
+ data_base_path = data_path.parent
+ dataset_prefixes = [dataset_def["prefix"] for dataset_def in dataset_defs["datasets"]]
+ dataset_weights = normalize_probs([dataset_def["weight"] for dataset_def in dataset_defs["datasets"]])
+ self._build_and_sample_dataset = self._build_and_sample_stardoc_dataset
+
+ dataset_names = [
+ f"dataset_{i}_{'dummy' if prefix is None else prefix.replace('/','__')}"
+ for i, prefix in enumerate(dataset_prefixes)
+ ]
+ self._num_datasets = len(dataset_names)
+ self._dataset_prefixes = {
+ name: (
+ None
+ if prefix is None
+ else (
+ pathlib.Path(prefix).resolve()
+ if data_base_path is None
+ else (pathlib.Path(data_base_path) / prefix).resolve()
+ )
+ )
+ for name, prefix in zip(dataset_names, dataset_prefixes)
+ }
+ self._dataset_weights = {name: weight for name, weight in zip(dataset_names, dataset_weights)}
+
+ def setup(self, distributed: Distributed, samples_per_phase: dict[PhaseType, int]):
+ """
+ Load the datasets. This may take a while and a significant amount of cpu memory.
+ """
+ run = get_run()
+ Assert.leq(set(samples_per_phase), set(self._phase_split))
+ log_main_rank(f"Preparing {self._num_datasets} datasets. This may take several minutes.")
+ self._tokenizer = Tokenizer(self._config.tokenizer, max_sequence_length=self._max_sequence_length)
+ self._distributed = distributed
+ self._cache_dir = run.dataset_cache_dir
+ self._samples_per_phase = samples_per_phase
+ if self._cache_dir is None:
+ warnings.warn(f"Using the dataset directory for the index cache.")
+
+ # Build and split datasets.
+ self._sampled_datasets = {phase: {} for phase in self._samples_per_phase}
+ for i, (name, weight) in enumerate(self._dataset_weights.items()):
+ if i % 100 == 0 and i > 0:
+ log_main_rank(f"Prepared {i} of {self._num_datasets} datasets.")
+ dataset_samples_per_phase = {}
+ for phase, samples_per_phase in self._samples_per_phase.items():
+ expected_samples = self._dataset_weights[name] * samples_per_phase
+ # Add 5 times the standard deviation (of a binomial distribution)
+ # so the probability of sampling more than this amount during blending is negligible.
+ dataset_samples_per_phase[phase] = math.ceil(
+ expected_samples
+ + 5 * math.sqrt(expected_samples * self._dataset_weights[name] * (1 - self._dataset_weights[name]))
+ )
+ sampled_datasets = self._build_and_sample_dataset(name, dataset_samples_per_phase)
+ for phase, dataset in sampled_datasets.items():
+ self._sampled_datasets[phase][name] = dataset
+
+ self._blended_datasets = {
+ phase: (
+ list(datasets.values())[0]
+ if len(datasets) == 1
+ else BlendedDataset(
+ list(datasets.values()),
+ weights=[self._dataset_weights[name] for name in datasets],
+ name=phase.value,
+ num_samples=self._samples_per_phase[phase],
+ cache_dir=self._cache_dir,
+ group=self._distributed.world_group,
+ verbose=run.is_main_rank,
+ data_sample_warn_time_ms=self._config.data_sample_warn_time_ms,
+ )
+ )
+ for phase, datasets in self._sampled_datasets.items()
+ }
+
+ def get_iterator(
+ self,
+ batch_config: BatchConfig,
+ phase: PhaseType,
+ *,
+ consumed_samples: int,
+ num_workers: int,
+ prefetch_factor: int | None = None,
+ ):
+ # TODO: Adjust or reimplement.
+ return super().get_iterator(
+ batch_config,
+ phase,
+ consumed_samples=consumed_samples,
+ num_workers=num_workers,
+ prefetch_factor=prefetch_factor,
+ )
+
+ def _build_and_sample_stardoc_dataset(self, name: str, dataset_samples_per_phase: dict[PhaseType, int]):
+ sampled_datasets = {}
+ for phase, num_samples in dataset_samples_per_phase.items():
+ if num_samples == 0:
+ continue
+
+ # TODO: Get image handling parameters from config
+ sampled_datasets[phase] = StarDocDataset(
+ self._dataset_prefixes[name],
+ im_size=224,
+ num_samples=num_samples,
+ num_im_tokens=256,
+ transforms=False,
+ multi_imgs=True,
+ split=phase,
+ tokenizer=self._tokenizer,
+ config=self._config,
+ )
+
+ return sampled_datasets
\ No newline at end of file
diff --git a/fast_llm/models/stardoc/head.py b/fast_llm/models/stardoc/head.py
new file mode 100644
index 000000000..786e36929
--- /dev/null
+++ b/fast_llm/models/stardoc/head.py
@@ -0,0 +1,6 @@
+from fast_llm.layers.language_model.head import LanguageModelHead
+
+
+class CustomHead(LanguageModelHead):
+ # TODO: Implement custom parts
+ pass
diff --git a/fast_llm/models/stardoc/huggingface.py b/fast_llm/models/stardoc/huggingface.py
new file mode 100644
index 000000000..7db4e73f8
--- /dev/null
+++ b/fast_llm/models/stardoc/huggingface.py
@@ -0,0 +1,18 @@
+from fast_llm.models.custom.config import CustomModelConfig
+from fast_llm.models.custom.model import CustomModel
+from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelConfig, HuggingfaceGPTModelForCausalLM
+
+
+class HuggingfaceCustomModelConfig(HuggingfaceGPTModelConfig):
+ model_type = "fast_llm_gpt_custom"
+ model_config_class = CustomModelConfig
+ fast_llm_config: CustomModelConfig
+
+
+class HuggingfaceCustomModelForCausalLM(HuggingfaceGPTModelForCausalLM):
+ # TODO: Implement changes in huggingface interface, if any.
+ # Ex.: Return predictions instead of logits.
+ config_class = HuggingfaceCustomModelConfig
+ config: HuggingfaceCustomModelConfig
+ model_class = CustomModel
+ _fast_llm_model: CustomModel
diff --git a/fast_llm/models/stardoc/model.py b/fast_llm/models/stardoc/model.py
new file mode 100644
index 000000000..f5f264481
--- /dev/null
+++ b/fast_llm/models/stardoc/model.py
@@ -0,0 +1,345 @@
+import logging
+
+import torch
+
+from fast_llm.engine.base_model.base_model import BaseModel, LossDef
+from fast_llm.engine.base_model.config import BaseModelConfig
+from fast_llm.engine.config_utils.tensor_space import TensorDim
+from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType
+from fast_llm.engine.distributed.distributed import Distributed
+from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
+from fast_llm.engine.schedule.config import BatchConfig
+from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames
+from fast_llm.layers.multimodal_model.config import MultimodalModelDimNames, MultimodalModelKwargs
+from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding
+from fast_llm.layers.language_model.head import LanguageModelHead
+from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor
+from fast_llm.layers.multimodal_model.multimodal_language_embedding import MultiModalLanguageModelEmbedding
+from fast_llm.layers.multimodal_model.image_encoder import ImageEncoder
+from fast_llm.layers.multimodal_model.adapter import Adapter
+
+from fast_llm.layers.transformer.config import (
+ RoutingType,
+ TransformerDimNames,
+ TransformerKwargs,
+ TransformerLossNames,
+)
+from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, RotaryEmbeddingPreprocessor
+from fast_llm.layers.transformer.transformer import TransformerLayer
+from fast_llm.models.stardoc.config import StarDocBaseModelConfig, StarDocModelConfig
+from fast_llm.tensor import ParameterMeta, TensorMeta
+from fast_llm.utils import Assert, div
+
+logger = logging.getLogger(__name__)
+
+
+class StarDocBaseModel(BaseModel):
+ """
+ A transformer-based language model generalizing the StarDoc model architecture.
+ """
+
+ _is_setup: bool = False
+ _config: StarDocBaseModelConfig
+ _rotary_embedding_frequencies: torch.Tensor
+ _position_ids: torch.Tensor
+ _mask: torch.Tensor
+ _mask_value: torch.Tensor
+ _tensor_cache_max_sequence_length: int = -1
+ config_cls = StarDocBaseModelConfig
+
+ def __init__(
+ self,
+ config: BaseModelConfig,
+ distributed_config: DistributedConfig,
+ ):
+ super().__init__(config, distributed_config)
+ self._use_flash_attention = self._config.transformer.do_use_flash_attention(distributed_config)
+ if self._config.use_absolute_position_embeddings:
+ self._position_embedding_preprocessor = PositionEmbeddingPreprocessor(self._config, self._tensor_space)
+ if self._config.transformer.use_rotary_position_embeddings:
+ self._rotary_embedding_preprocessor = RotaryEmbeddingPreprocessor(
+ self._config.transformer, self._tensor_space
+ )
+ if not self._use_flash_attention:
+ self._backup_attention_preprocessor = BackupAttentionPreprocessor(
+ self._config.transformer, self._tensor_space
+ )
+
+ def get_layers(self):
+ return [
+ ImageEncoder(self._config, self._tensor_space),
+ Adapter(self._config, self._tensor_space),
+ MultiModalLanguageModelEmbedding(self._config, self._tensor_space),
+ *[
+ TransformerLayer(
+ self._config.transformer,
+ self._tensor_space,
+ layer_index=i + 1,
+ )
+ for i in range(self._config.transformer.num_layers)
+ ],
+ LanguageModelHead(self._config, self._tensor_space),
+ ]
+
+ def setup(self, distributed: Distributed):
+ assert not self._is_setup
+ assert distributed.config is self._tensor_space.distributed_config
+ self._tensor_space.setup(distributed)
+ self._is_setup = True
+
+ def preprocess_meta(self, input_: BatchConfig | torch.Tensor, phase: PhaseType) -> list[tuple[TensorMeta, dict]]:
+ # TODO: How much of this is generalizable?
+ # TODO: Use parallel/sequential dims, distinguish micro and full batch/sequence
+
+ if isinstance(input_, BatchConfig):
+ micro_batch_size = input_.micro_batch_size
+ sequence_length = input_.sequence_length
+ micro_sequence_length = input_.micro_sequence_length
+ else:
+ micro_batch_size, sequence_length = input_.shape
+ if phase != PhaseType.inference:
+ sequence_length -= 1
+ micro_sequence_length = sequence_length
+
+ print(f'Sequence length for meta {sequence_length}')
+
+ batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data)
+ batch_dim = TensorDim(TransformerDimNames.batch, micro_batch_size * batch_data.size, batch_data)
+
+ if isinstance(input_, BatchConfig):
+ micro_sequence_length = input_.micro_sequence_length
+
+ if micro_sequence_length is None:
+ micro_sequence_length = sequence_length
+ else:
+ Assert.multiple(sequence_length, micro_sequence_length)
+
+ local_micro_sequence_length = div(
+ micro_sequence_length, self._tensor_space.distributed_config.sequence_data_parallel
+ )
+
+ need_sequence_first = (
+ self._tensor_space.distributed_config.sequence_tensor_parallel
+ or sequence_length > local_micro_sequence_length
+ )
+ if self._config.sequence_first is None:
+ sequence_first = need_sequence_first
+ else:
+ sequence_first = self._config.sequence_first
+ assert not (need_sequence_first and not sequence_first)
+
+ sequence_q_dim = TensorDim(TransformerDimNames.sequence_q, local_micro_sequence_length)
+
+ # TODO: Calculate hidden dims elsewhere?
+ hidden_sequence_q_dim = (
+ TensorDim(
+ TransformerDimNames.sequence_q_tp,
+ micro_sequence_length,
+ self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor),
+ )
+ if self._tensor_space.distributed_config.sequence_tensor_parallel
+ else sequence_q_dim
+ )
+ hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden)
+ hidden_dims = (
+ (hidden_sequence_q_dim, batch_dim, hidden_dim)
+ if sequence_first
+ else (batch_dim, hidden_sequence_q_dim, hidden_dim)
+ )
+
+ max_num_images = self._tensor_space.get_tensor_dim(MultimodalModelDimNames.max_num_images)
+ image_pixel_count = self._tensor_space.get_tensor_dim(MultimodalModelDimNames.image_pixel_count)
+ num_image_tokens = self._tensor_space.get_tensor_dim(MultimodalModelDimNames.num_image_tokens)
+ image_encoder_hidden_size = self._tensor_space.get_tensor_dim(MultimodalModelDimNames.image_encoder_hidden_size)
+
+ image_encoder_hidden_dims = (
+ (batch_dim, max_num_images, num_image_tokens, image_encoder_hidden_size)
+ )
+ adapter_hidden_dims = (
+ (batch_dim, max_num_images, num_image_tokens, hidden_dim)
+ )
+
+ common_kwargs = {
+ LanguageModelKwargs.phase: phase,
+ TransformerKwargs.sequence_first: sequence_first,
+ TransformerKwargs.hidden_dims: hidden_dims,
+ TransformerKwargs.sequence_length: sequence_length,
+ TransformerKwargs.sequence_q_dim: sequence_q_dim,
+ MultimodalModelKwargs.image_encoder_hidden_dims: image_encoder_hidden_dims,
+ MultimodalModelKwargs.adapter_hidden_dims: adapter_hidden_dims,
+ }
+
+ # For stardoc, since image tokens and text tokens need to be merged, sequence parallel is complicated
+ Assert.eq(micro_sequence_length, sequence_length)
+ Assert.eq(local_micro_sequence_length, sequence_length)
+
+ preprocessed_meta = []
+ for sequence_k_past in range(
+ local_micro_sequence_length * self._tensor_space.distributed_config.sequence_data_rank,
+ sequence_length,
+ micro_sequence_length,
+ ):
+ sequence_k = sequence_k_past + local_micro_sequence_length
+ sequence_k_dim = TensorDim(TransformerDimNames.sequence_k, sequence_k)
+
+ tokens = TensorMeta.from_dims(
+ hidden_dims[:2], tensor_name=f"tokens_{sequence_k_past}_to_{sequence_k-1}", dtype=torch.int64
+ )
+
+ image_data = TensorMeta.from_dims(
+ (
+ batch_dim,
+ max_num_images,
+ image_pixel_count,
+ ),
+ tensor_name="image_data",
+ dtype=torch.float32,
+ )
+
+ kwargs = {
+ **common_kwargs,
+ LanguageModelKwargs.tokens: tokens,
+ TransformerKwargs.sequence_k_dim: sequence_k_dim,
+ }
+ if phase != PhaseType.inference:
+ kwargs[LanguageModelKwargs.labels] = TensorMeta.from_dims(
+ hidden_dims[:2], tensor_name="labels", dtype=torch.int64
+ )
+ if self._config.use_absolute_position_embeddings:
+ self._position_embedding_preprocessor.preprocess_meta(kwargs)
+ if self._config.transformer.use_rotary_position_embeddings:
+ self._rotary_embedding_preprocessor.preprocess_meta(kwargs)
+ if not self._use_flash_attention:
+ self._backup_attention_preprocessor.preprocess_meta(kwargs)
+ preprocessed_meta.append((image_data, kwargs))
+
+ return preprocessed_meta
+
+ def preprocess(
+ self,
+ batch: dict,
+ preprocessed_meta: list[tuple[TensorMeta, dict]] | None = None,
+ *,
+ phase: PhaseType,
+ iteration: int,
+ metrics: dict | None = None,
+ ) -> list[tuple[torch.Tensor, dict]]:
+ # TODO: How much of this is generalizable?
+ assert self._is_setup
+
+ if preprocessed_meta is None:
+ preprocessed_meta = self.preprocess_meta(batch, phase)
+
+ _, common_kwargs = preprocessed_meta[0]
+ sequence_q = common_kwargs[TransformerKwargs.sequence_q_dim].size
+ sequence_first = common_kwargs[TransformerKwargs.sequence_first]
+ sequence_length = common_kwargs[TransformerKwargs.sequence_length]
+
+ tokens = batch["input_ids"]
+ labels = batch["labels"]
+ image_data = batch["images"]
+
+ # Move input_ids, labels and images to device
+ tokens = tokens.to(
+ device=self._tensor_space.distributed.device,
+ dtype=torch.int64,
+ non_blocking=True,
+ ).contiguous()
+ labels = labels.to(
+ device=self._tensor_space.distributed.device,
+ dtype=torch.int64,
+ non_blocking=True,
+ ).contiguous()
+ image_data = image_data.to(
+ device=self._tensor_space.distributed.device,
+ dtype=torch.float32,
+ non_blocking=True,
+ ).contiguous()
+
+ if self._config.use_absolute_position_embeddings:
+ self._position_embedding_preprocessor.create_tensors(sequence_length)
+ if self._config.transformer.use_rotary_position_embeddings:
+ self._rotary_embedding_preprocessor.create_tensors(sequence_length)
+ if not self._use_flash_attention:
+ self._backup_attention_preprocessor.create_tensors(sequence_length)
+
+ # TODO: Pasts and presents for inference?
+ preprocessed = []
+ presents = None
+ for tokens_meta, kwargs_meta in preprocessed_meta:
+ sequence_k = kwargs_meta[TransformerKwargs.sequence_k_dim].size
+ tokens = tokens[:, sequence_k - sequence_q : sequence_k].contiguous()
+
+ pasts = presents
+ presents = None if sequence_k == sequence_length else []
+ kwargs = {
+ **kwargs_meta,
+ LanguageModelKwargs.tokens: tokens,
+ TransformerKwargs.past_key_values: pasts,
+ TransformerKwargs.presents: presents,
+ }
+ if phase != PhaseType.inference:
+ labels = labels[:, sequence_k - sequence_q + 1 : sequence_k + 1].contiguous()
+ kwargs[LanguageModelKwargs.labels] = labels
+
+ if self._config.use_absolute_position_embeddings:
+ self._position_embedding_preprocessor.preprocess(kwargs)
+ if self._config.transformer.use_rotary_position_embeddings:
+ self._rotary_embedding_preprocessor.preprocess(kwargs)
+ if not self._use_flash_attention:
+ self._backup_attention_preprocessor.preprocess(kwargs)
+ preprocessed.append((image_data, kwargs))
+
+ return preprocessed
+
+ @property
+ def embedding(self) -> LanguageModelEmbedding:
+ return self.layers[0]
+
+ @property
+ def transformer_layers(self) -> list[TransformerLayer]:
+ return self.layers[1:-1]
+
+ @property
+ def model_head(self) -> LanguageModelHead:
+ return self.layers[-1]
+
+ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]:
+ return (
+ {WORD_EMBEDDINGS_WEIGHT: (self.embedding.word_embeddings_weight, (0, len(self) - 1))}
+ if self._config.tie_word_embeddings
+ else {}
+ )
+
+ @property
+ def loss_defs(self) -> list[LossDef]:
+ loss_defs = [
+ LossDef(name=LanguageModelLossNames.language_model_loss, formatted_name="language model loss", count=1)
+ ]
+ if (
+ self._config.transformer.num_experts > 1
+ and self._config.transformer.expert_routing_type == RoutingType.topk
+ ):
+ loss_defs.append(
+ LossDef(
+ name=TransformerLossNames.load_balancing_loss,
+ formatted_name="load balancing loss",
+ count=self._config.transformer.num_layers,
+ )
+ )
+ if self._config.transformer.expert_z_loss_coefficient:
+ loss_defs.append(
+ LossDef(
+ name=TransformerLossNames.router_z_loss,
+ formatted_name="router z loss",
+ count=self._config.transformer.num_layers,
+ )
+ )
+ if self._config.logit_z_loss:
+ LossDef(name=LanguageModelLossNames.z_loss, formatted_name="logit z loss", count=1)
+ return loss_defs
+
+
+class StarDocModel(FastLLMModel):
+ config_class = StarDocModelConfig
+ base_model_class = StarDocBaseModel
\ No newline at end of file
diff --git a/fast_llm/models/stardoc/stardoc_dataset.py b/fast_llm/models/stardoc/stardoc_dataset.py
new file mode 100644
index 000000000..1fcee9367
--- /dev/null
+++ b/fast_llm/models/stardoc/stardoc_dataset.py
@@ -0,0 +1,141 @@
+import os
+import io
+import logging
+import time
+
+import torch
+import torch.nn as nn
+from PIL import Image
+from torch.utils.data import Dataset
+from datasets import load_dataset, load_from_disk
+
+from fast_llm.data.config import DataConfig
+from fast_llm.data.tokenizer import Tokenizer
+from fast_llm.data.stardoc_data_utils.docowl_processor import DocProcessor
+from fast_llm.data.stardoc_data_utils.utils import (
+ convert_queries_and_annotations_to_messages,
+ image_loading_function,
+)
+from fast_llm.data.stardoc_data_utils.docowl_stardoc_processor import docowl_text_preprocess_v1
+from fast_llm.data.stardoc_data_utils.constants import IGNORE_INDEX
+from fast_llm.engine.distributed.config import PhaseType
+
+
+logger = logging.getLogger(__name__)
+
+
+class StarDocDataset(Dataset):
+ def __init__(
+ self,
+ dataset_path: str | None = None,
+ im_size: int = 224,
+ num_samples: int = -1,
+ num_im_tokens: int = 256,
+ transforms: bool = False,
+ multi_imgs: bool = True,
+ split: str = "train",
+ tokenizer: Tokenizer | None = None,
+ config: DataConfig | None = None,
+ ):
+ self.im_size = im_size
+ self.transforms = transforms
+ self.num_samples = num_samples
+ self.num_im_tokens = num_im_tokens
+ self.multi_imgs = multi_imgs
+ self.tokenizer = tokenizer
+ phase_map = {
+ PhaseType.training: "train",
+ PhaseType.validation: "val",
+ PhaseType.test: "test",
+ }
+ self.split=phase_map[split]
+
+ # Use DocOwl processor
+ self.processor = DocProcessor(image_size=self.im_size, anchors='grid_9', add_global_img=True, add_textual_crop_indicator=True, media_token=self.tokenizer._config.special_tokens.image_placeholder_token)
+
+ assert dataset_path is not None
+
+ # TODO: config validation issue
+ multimodal_load_local = True
+ if multimodal_load_local:
+ # Load from a locally cached copy of the dataset
+ self.data_dict = load_from_disk(dataset_path)
+ self.data = self.data_dict[self.split]
+ else:
+ # Load the required spit from HF
+ # TODO: configurable cache_dir
+ self.data = load_dataset(dataset_path, split=self.split, cache_dir="/mnt/core_llm/cache/", num_proc=os.cpu_count()-1)
+
+ if self.num_samples != -1:
+ self.data = self.data.select(range(self.num_samples))
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ sample = self.data[idx]
+ images = sample.get('image', [])
+
+ if self.multi_imgs and not isinstance(images, list):
+ images = [images]
+
+ images = image_loading_function(self.data[idx]["image"])
+
+ sample_id = sample["sample_id"]
+ dataset_name = sample["dataset_name"]
+ queries = sample["queries"]
+ annotations = sample["annotations"]
+ task_name = sample["task_name"]
+
+ if images[0].size[0] < 10 or images[0].size[1] < 10:
+ logger.error("Dummy images with small resolution of < 10x10 seen. Handling these is not implemented")
+
+ sample_tokenized_buffer = []
+ labels = []
+
+ # Add BOS token at the beginning of the sample
+ sample_tokenized_buffer.append(self.tokenizer.bos_token_id)
+
+ # tokenized IDs for "USER:" and "ASSISTANT:"
+ user_ids = self.tokenizer.tokenize("USER: ")
+ assistant_ids = self.tokenizer.tokenize(" ASSISTANT: ")
+ sample_tokenized_buffer.extend(user_ids)
+
+ # Add dummy tokens for all image tokens
+ if len(images) > 0:
+ # Get all the crops and process them
+ all_images, _, processed_query = self.processor(images=images, query=self.tokenizer._config.special_tokens.image_placeholder_token)
+ crop_splits = processed_query.split(self.tokenizer._config.special_tokens.image_placeholder_token)[:-1]
+ assert len(crop_splits) == len(all_images)
+ for crop_split_part in crop_splits:
+ sample_tokenized_buffer.extend(self.tokenizer.tokenize(crop_split_part.strip()))
+ sample_tokenized_buffer.extend([self.tokenizer.image_placeholder_token_id] * self.num_im_tokens)
+
+ # Don't learn on any image tokens
+ [labels.append(IGNORE_INDEX) for x in range(len(sample_tokenized_buffer))]
+
+ assert(len(queries) == len(annotations))
+ for i, (q, a) in enumerate(zip(queries, annotations)):
+ if i>0:
+ sample_tokenized_buffer.extend(user_ids)
+ sample_tokenized_buffer.extend(self.tokenizer.tokenize(q))
+ sample_tokenized_buffer.extend(assistant_ids)
+ sample_tokenized_buffer.extend(self.tokenizer.tokenize(a))
+
+ # Add EOS token at the end of the sample
+ sample_tokenized_buffer.append(self.tokenizer.eos_token_id)
+ labels.extend(sample_tokenized_buffer[len(labels):len(sample_tokenized_buffer)])
+ assert len(sample_tokenized_buffer) == len(labels)
+
+ # Right pad to max. sequence length
+ n_pad_tokens = self.tokenizer.max_sequence_length - len(sample_tokenized_buffer)
+ sample_tokenized_buffer = sample_tokenized_buffer + n_pad_tokens*[self.tokenizer.pad_token_id]
+
+ # Add an extra pad token to the labels at the end to support shifting left
+ labels = labels + (n_pad_tokens + 1) *[IGNORE_INDEX]
+
+ return {
+ 'input_ids': torch.tensor(sample_tokenized_buffer),
+ 'labels': torch.tensor(labels),
+ 'images': all_images,
+ }
\ No newline at end of file
diff --git a/fast_llm/models/stardoc/trainer.py b/fast_llm/models/stardoc/trainer.py
new file mode 100644
index 000000000..9f77eb790
--- /dev/null
+++ b/fast_llm/models/stardoc/trainer.py
@@ -0,0 +1,19 @@
+from fast_llm.models.stardoc.config import StarDocTrainerConfig
+from fast_llm.models.stardoc.data import StarDocData
+from fast_llm.models.stardoc.model import StarDocModel
+from fast_llm.models.gpt.trainer import GPTTrainer
+
+
+class StarDocTrainer(GPTTrainer):
+ _abstract = False
+ _config: StarDocTrainerConfig
+ config_class = StarDocTrainerConfig
+ model_class = StarDocModel
+
+ def _get_data(self):
+ return StarDocData(
+ config=self._config.data,
+ distributed_config=self._config.distributed,
+ vocab_size=self._config.base_model.vocab_size,
+ max_sequence_length=self._config.batch.sequence_length,
+ )