|
| 1 | +from pathlib import Path |
| 2 | +from types import SimpleNamespace |
| 3 | + |
| 4 | +from pymongo import MongoClient |
| 5 | +from sklearn.pipeline import Pipeline |
| 6 | + |
| 7 | +from origami.inference import Predictor |
| 8 | +from origami.model import ORIGAMI |
| 9 | +from origami.model.vpda import ObjectVPDA |
| 10 | +from origami.preprocessing import ( |
| 11 | + DFDataset, |
| 12 | + DocPermuterPipe, |
| 13 | + DocTokenizerPipe, |
| 14 | + PadTruncTokensPipe, |
| 15 | + TargetFieldPipe, |
| 16 | + TokenEncoderPipe, |
| 17 | + UpscalerPipe, |
| 18 | + load_df_from_mongodb, |
| 19 | +) |
| 20 | +from origami.utils.common import set_seed |
| 21 | +from origami.utils.config import GuardrailsMethod, ModelConfig, PositionEncodingMethod, TrainConfig |
| 22 | +from origami.utils.guild import load_secrets, print_guild_scalars |
| 23 | + |
| 24 | +# populated by guild |
| 25 | +flags = SimpleNamespace() |
| 26 | +secrets = load_secrets() |
| 27 | + |
| 28 | +# for reproducibility |
| 29 | +set_seed(1234) |
| 30 | + |
| 31 | +TARGET_FIELD = "problem" |
| 32 | +UPSCALE = 2 |
| 33 | + |
| 34 | +client = MongoClient(secrets["MONGO_URI"]) |
| 35 | +collection = client["codenet_java"].train |
| 36 | + |
| 37 | +target_problems = collection.distinct(TARGET_FIELD) |
| 38 | +num_problems = len(target_problems) |
| 39 | + |
| 40 | +target_problems = target_problems[: flags.n_problems] |
| 41 | +print(f"training on {flags.n_problems} problems (out of {num_problems})") |
| 42 | + |
| 43 | +# load data into dataframe for train/test |
| 44 | + |
| 45 | +train_docs_df = load_df_from_mongodb( |
| 46 | + "mongodb://localhost:27017", |
| 47 | + "codenet_java", |
| 48 | + "train", |
| 49 | + filter={"problem": {"$in": target_problems}}, |
| 50 | + projection={"_id": 0, "filePath": 0}, |
| 51 | +) |
| 52 | + |
| 53 | +test_docs_df = load_df_from_mongodb( |
| 54 | + "mongodb://localhost:27017", |
| 55 | + "codenet_java", |
| 56 | + "test", |
| 57 | + filter={"problem": {"$in": target_problems}}, |
| 58 | + projection={"_id": 0, "filePath": 0}, |
| 59 | +) |
| 60 | + |
| 61 | +num_train_inst = len(train_docs_df) |
| 62 | +num_test_inst = len(test_docs_df) |
| 63 | + |
| 64 | +# create train and test pipelines |
| 65 | +pipes = { |
| 66 | + # --- train only --- |
| 67 | + "upscaler": UpscalerPipe(n=UPSCALE), |
| 68 | + "permuter": DocPermuterPipe(shuffle_arrays=True), |
| 69 | + # --- test only --- |
| 70 | + "target": TargetFieldPipe(TARGET_FIELD), |
| 71 | + # --- train and test --- |
| 72 | + "tokenizer": DocTokenizerPipe(path_in_field_tokens=False), |
| 73 | + "padding": PadTruncTokensPipe(length=flags.max_length), |
| 74 | + "encoder": TokenEncoderPipe(max_tokens=flags.max_tokens), |
| 75 | +} |
| 76 | + |
| 77 | +train_pipeline = Pipeline( |
| 78 | + [(name, pipes[name]) for name in ("target", "upscaler", "permuter", "tokenizer", "padding", "encoder")], |
| 79 | + verbose=True, |
| 80 | +) |
| 81 | +test_pipeline = Pipeline([(name, pipes[name]) for name in ("target", "tokenizer", "padding", "encoder")], verbose=True) |
| 82 | + |
| 83 | +# process train, eval and test data (first fit both, then transform) |
| 84 | +train_pipeline.fit(train_docs_df) |
| 85 | +test_pipeline.fit(test_docs_df) |
| 86 | + |
| 87 | +train_df = train_pipeline.transform(train_docs_df) |
| 88 | +test_df = test_pipeline.transform(test_docs_df) |
| 89 | + |
| 90 | +# drop ordered_docs columns to save space |
| 91 | +train_df.drop(columns=["docs"], inplace=True) |
| 92 | +test_df.drop(columns=["docs"], inplace=True) |
| 93 | + |
| 94 | +# drop all rows where the tokens array doesn't end in 0 (longer than max_length) |
| 95 | +train_df = train_df[train_df["tokens"].apply(lambda x: x[-1] == 0)] |
| 96 | +test_df = test_df[test_df["tokens"].apply(lambda x: x[-1] == 0)] |
| 97 | + |
| 98 | +# get stateful objects |
| 99 | +encoder = pipes["encoder"].encoder |
| 100 | +block_size = pipes["padding"].length |
| 101 | + |
| 102 | +# print data stats |
| 103 | +print( |
| 104 | + f"dropped {(1 - (len(train_df) / (UPSCALE * num_train_inst))) * 100:.2f}% training instances, and " |
| 105 | + f"{(1 - (len(test_df) / num_test_inst)) * 100:.2f}% test instances." |
| 106 | +) |
| 107 | +print(f"vocab size {encoder.vocab_size}") |
| 108 | +print(f"block size {block_size}") |
| 109 | + |
| 110 | +# confirm that all targets are in the vocabulary |
| 111 | +for target in train_df["target"].unique(): |
| 112 | + enc = encoder.encode(target) |
| 113 | + assert target == encoder.decode(enc), f"token not {target} represented in vocab." |
| 114 | + |
| 115 | +for target in test_df["target"].unique(): |
| 116 | + enc = encoder.encode(target) |
| 117 | + assert target == encoder.decode(enc), f"token not {target} represented in vocab." |
| 118 | + |
| 119 | +# create datasets, VPDA and model |
| 120 | + |
| 121 | +# model and train configs |
| 122 | +model_config = ModelConfig.from_preset("small") |
| 123 | +model_config.position_encoding = PositionEncodingMethod.KEY_VALUE |
| 124 | +model_config.vocab_size = encoder.vocab_size |
| 125 | +model_config.block_size = block_size |
| 126 | +model_config.n_embd = flags.n_embd |
| 127 | +model_config.mask_field_token_losses = False |
| 128 | +model_config.tie_weights = False |
| 129 | +model_config.guardrails = GuardrailsMethod.STRUCTURE_ONLY |
| 130 | +model_config.fuse_pos_with_mlp = True |
| 131 | + |
| 132 | +train_config = TrainConfig() |
| 133 | +train_config.learning_rate = flags.learning_rate |
| 134 | +train_config.batch_size = flags.batch_size |
| 135 | +train_config.n_warmup_batches = 100 |
| 136 | +train_config.eval_every = flags.eval_every |
| 137 | + |
| 138 | +# datasets |
| 139 | +train_dataset = DFDataset(train_df) |
| 140 | +test_dataset = DFDataset(test_df) |
| 141 | + |
| 142 | +vpda = ObjectVPDA(encoder) |
| 143 | +model = ORIGAMI(model_config, train_config, vpda=vpda) |
| 144 | + |
| 145 | +# load model checkpoint if it exists |
| 146 | +checkpoint_file = Path("./gpt-codenet-snapshot.pt") |
| 147 | +if checkpoint_file.is_file(): |
| 148 | + model.load("gpt-codenet-snapshot.pt") |
| 149 | + print(f"loading existing checkpoint at batch_num {model.batch_num}...") |
| 150 | + |
| 151 | + |
| 152 | +# create a predictor |
| 153 | +predictor = Predictor(model, encoder, TARGET_FIELD) |
| 154 | + |
| 155 | + |
| 156 | +def progress_callback(model): |
| 157 | + print_guild_scalars( |
| 158 | + step=f"{int(model.batch_num)}", |
| 159 | + epoch=model.epoch_num, |
| 160 | + batch_num=model.batch_num, |
| 161 | + batch_dt=f"{model.batch_dt * 1000:.2f}", |
| 162 | + batch_loss=f"{model.loss:.4f}", |
| 163 | + lr=f"{model.learning_rate:.2e}", |
| 164 | + ) |
| 165 | + if model.batch_num % train_config.eval_every == 0: |
| 166 | + try: |
| 167 | + # train_acc = predictor.accuracy(train_dataset.sample(n=100)) |
| 168 | + test_acc = predictor.accuracy(test_dataset.sample(n=100), show_progress=True) |
| 169 | + print_guild_scalars( |
| 170 | + step=f"{int(model.batch_num)}", |
| 171 | + # train_acc=f"{train_acc:.4f}", |
| 172 | + test_acc=f"{test_acc:.4f}", |
| 173 | + ) |
| 174 | + # print(f"Train accuracy @ 100: {train_acc:.4f}, Test accuracy @ 100: {test_acc:.4f}") |
| 175 | + except AssertionError as e: |
| 176 | + print(e) |
| 177 | + print("continuing...") |
| 178 | + |
| 179 | + model.save("gpt-codenet-snapshot.pt") |
| 180 | + print("model saved to gpt-codenet-snapshot.pt") |
| 181 | + |
| 182 | + |
| 183 | +model.set_callback("on_batch_end", progress_callback) |
| 184 | + |
| 185 | +try: |
| 186 | + model.train_model(train_dataset, batches=flags.n_batches) |
| 187 | +except KeyboardInterrupt: |
| 188 | + pass |
| 189 | + |
| 190 | +# final save |
| 191 | +model.save("gpt-codenet-snapshot.pt") |
| 192 | +print("model saved to gpt-codenet-snapshot.pt") |
| 193 | + |
| 194 | +test_acc = predictor.accuracy(test_dataset, show_progress=True) |
| 195 | +print_guild_scalars( |
| 196 | + step=f"{int(model.batch_num / train_config.eval_every)}", |
| 197 | + test_acc=f"{test_acc:.4f}", |
| 198 | +) |
| 199 | + |
| 200 | +dropped_ratio = 1 - (len(test_df) / num_test_inst) |
| 201 | +print(f"Final test accuracy when taking into account the dropped instances: {(1 - dropped_ratio) * test_acc:.4f}%") |
0 commit comments