Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added NeoX/__init__.py
Empty file.
27 changes: 27 additions & 0 deletions NeoX/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from torch import nn
import loralib as lora

def convert_model_lora(model):
for child_name, child in model.named_children():
if isinstance(child, nn.Linear) and child_name == "query_key_value":
weight = child.weight
bias = child.bias
new = lora.MergedLinear(child.in_features, child.out_features, r = 4)
new.weight = weight
new.bias = bias
setattr(model, child_name, new)
# elif isinstance(child, nn.Conv2d):
# weight = child.weight
# bias = child.bias
# new = lora.Conv2d(child.in_channels, child.out_channels, child.kernel_size[0], r = 4)#kernel size would
# new.weight = weight
# new.bias = bias
# setattr(model, child_name, new)
# elif isinstance(child, nn.Embedding):
# weight = child.weight
# new = lora.Embedding(child.num_embeddings, child.embedding_dim, r = 4)
# new.weight = weight
# setattr(model, child_name, new)
else:
convert_model_lora(child)
return model
27 changes: 27 additions & 0 deletions NeoX/load_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

#dummy config
from transformers import GPTNeoXForCausalLM, AutoTokenizer
from convert import convert_model_lora
import torch
initial = True
from safetensors.torch import save_file, load_file

if initial:#Load up a GPT Neo-x model specified by the config, convert to the lora model desired.

model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-70m-deduped")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m-deduped")
model = convert_model_lora(model)

# torch.save(model.state_dict(), "./model.pt")
model.save_pretrained("./", safe_serialization = "True")

else:
#We want to load a model

model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-70m-deduped")#Is it possible to just load from config without this issue...
model = convert_model_lora(model)
#We could skip the above step if we coded something that has the new architecture - this seems bad though because we'd need to do per adapter method

loaded = load_file("./model.safetensors")
model.load_state_dict(loaded)