Skip to content

Commit

Permalink
Convert numpy array to jax in the training loop
Browse files Browse the repository at this point in the history
  • Loading branch information
neel04 committed Apr 14, 2024
1 parent a572c47 commit 9f28c97
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions ReAct/data/minipile.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,15 @@ def create_dataloader(self, slice: str = '100%'):

print(f'Loaded {self.split} dataset from HuggingFace Hub')

dataset.set_format(type='jax')
dataset.set_format(type='numpy')

return dataset

except (FileNotFoundError, ValueError):
if os.path.exists(data_path):
print(f'Loading dataset from {data_path}...')
dataset = self.load_data(data_path)
return self.numpify(dataset)
return dataset
else:
print(f'Building dataset from scratch... [split: {self.split}] | [bsz: {self.bsz}]')

Expand All @@ -128,7 +128,7 @@ def create_dataloader(self, slice: str = '100%'):
dataset = dataset.map(self.shift_tokens, batched=True, batch_size=self.bsz,
keep_in_memory=True, drop_last_batch=True, num_proc=None)

dataset.set_format(type='jax')
dataset.set_format(type='numpy')

self.upload_dataset(dataset,
hub_path=f'Neel-Gupta/minipile-processed_{self.bsz}') # upload the processed dataset to the Hub
Expand Down
4 changes: 2 additions & 2 deletions ReAct/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def evaluate_acc(self, model: eqx.Module, loader: DataLoader, eval_iters: int, k
metric = []

for step, batch in tqdm(enumerate(loader), total=len(loader), desc='Validating'):
seq, label, pad_mask = batch['text']
seq, label, pad_mask = jnp.asarray(batch['text'])
acc, loss, ppl = self.compute_metrics(model, seq, label, pad_mask, eval_iters, self.num_classes, keys)
metric.extend([acc, loss, ppl])

Expand Down Expand Up @@ -315,7 +315,7 @@ def train(self):
for step, batch in tqdm(enumerate(self.trainloader), total=len(self.trainloader), desc=f'Epoch {epoch}'):
step += step_done # for multiple epochs

seq, label, pad_mask = batch['text']
seq, label, pad_mask = jnp.asarray(batch['text'])

loss, model, opt_state = make_step(model, opt_state, filter_spec, seq, label, pad_mask,
rndm_n, rndm_k, optim, self.num_classes, keys)
Expand Down

0 comments on commit 9f28c97

Please sign in to comment.