Skip to content

Commit

Permalink
Updated dockerfile
Browse files Browse the repository at this point in the history
  • Loading branch information
neel04 committed Oct 23, 2024
1 parent 96fbec7 commit 98d371b
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 159 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ RUN pip3 install numpy pandas scipy
RUN pip3 install -U numpy==1.26.4
RUN pip3 install -U -q jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
RUN pip3 install -q transformers datasets scalax tokenizers icecream wandb einops torch tqdm jaxtyping optax optuna equinox rich
RUN pip3 install -U tensorboard-plugin-profile optuna-integration plotly
RUN pip3 install -U tensorboard-plugin-profile optuna-integration plotly lm-eval
RUN pip3 install git+https://github.com/deepmind/jmp
RUN pip3 install git+https://github.com/Findus23/jax-array-info.git

Expand Down
205 changes: 47 additions & 158 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import equinox as eqx
import jax
import jax.numpy as jnp
from tqdm import tqdm
import lm_eval
from jaxtyping import Array, PRNGKeyArray
from lm_eval.api.model import TemplateLM
Expand Down Expand Up @@ -94,47 +95,61 @@ def tok_encode(self, string: str, **kwargs) -> list[int]:

return encoded.tolist()

def _calc_ll(self, arr: Array, target: Array, **kwargs) -> Array:
arr, target = jnp.asarray(arr), jnp.asarray(target)

seq = jnp.concat([arr, target])
seq = jnp.pad(
seq,
(0, self.args.seqlen - seq.shape[0]),
constant_values=self.eot_token_id,
)

def _calc_ll(self, seq: Array, lengths: tuple[int, int], target: Array) -> Array:
breakpoint()
arrlen, tgtlen = lengths
pad_mask = jnp.where(seq == self.eot_token_id, 0, 1)
key = jax.random.PRNGKey(0)

if self.args.baseline:
logits = self.model(seq, pad_mask, False, key)
else:
logits = self.model(
seq,
self.args.max_iters,
jnp.ones_like(seq),
False,
False,
key,
)[0]
def fwd(seq: Array, pad_mask: Array, key: PRNGKeyArray) -> Array:
if self.args.baseline:
logits = self.model(seq, pad_mask, False, key)
else:
logits = self.model(
seq,
self.args.max_iters,
jnp.ones_like(seq),
False,
False,
key,
)[0]

probs = jax.nn.log_softmax(logits, axis=-1)
return jax.nn.log_softmax(logits, axis=-1)

breakpoint()
# select logprobs for the target token
target_log_probs = jnp.take_along_axis(
probs, target[:, :, None], axis=-1
).squeeze(-1)
probs = fwd(seq, pad_mask, key)
target_log_probs = probs[
jnp.arange(arrlen, arrlen + tgtlen), target[tgtlen]
]

return target_log_probs.sum()

def _loglikelihood_tokens(self, requests: list, **kwargs) -> list[tuple[float, bool]]:
breakpoint()
for request in requests:
output = []
reqlen = len(requests)
requests = requests[:100]

for request in tqdm(requests):
context, target = request[-2], request[-1]
self._calc_ll(context, target, kwargs=kwargs)

...
arr, target = jnp.asarray(context), jnp.asarray(target)

seq = jnp.concat([arr, target])
seq = jnp.pad(
seq,
(0, self.args.seqlen - seq.shape[0]),
constant_values=self.eot_token_id,
)

target = jnp.pad(
target,
(0, self.args.seqlen - target.shape[0]),
constant_values=self.eot_token_id,
)

ll = self._calc_ll(seq, (len(arr), len(target)), target)
output.append((ll.item(), 1))

return output * (reqlen - 100)

def loglikelihood_rolling(self, requests, disable_tqdm: bool = False) -> list[float]:
return super().loglikelihood_rolling(requests, disable_tqdm)
Expand All @@ -144,132 +159,6 @@ def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:

return super().generate_until(requests, disable_tqdm)

'''
def _loglikelihood(self, requests: list[Instance]):
res = []
for request in requests:
context, continuation = request.args
inp = self.encode_fn(context + continuation)
context_enc = self.encode_fn(context)
cont_enc = inp[len(context_enc) :]
if len(inp) > self.args.seqlen:
inp = inp[-self.args.seqlen :]
context_enc = context_enc[-(self.args.seqlen - len(cont_enc)) :]
pad_amount = self.args.seqlen - len(inp)
inp = jnp.pad(
jnp.array(inp),
(0, pad_amount),
constant_values=self.args.pad_token,
)
key = jax.random.PRNGKey(0)
pad_mask = jnp.where(inp == self.args.pad_token, 0, 1)
if self.args.baseline:
logits = self.model(
inp[None, :], pad_mask[None, :], False, key
)[0, : len(context_enc) + len(cont_enc) - 1]
else:
logits = self.model(
inp[None, :],
self.args.max_iters,
pad_mask[None, :],
False,
False,
key,
)[0][0, : len(context_enc) + len(cont_enc) - 1]
log_probs = jax.nn.log_softmax(logits, axis=-1)
cont_log_probs = (
log_probs[len(context_enc) - 1 :]
.at[jnp.arange(len(cont_enc)), jnp.array(cont_enc)]
.get()
)
total_log_prob = cont_log_probs.sum()
res.append((total_log_prob.item(), True))
return res
def loglikelihood_rolling(self, requests):
res = []
for context, continuation in requests:
inp = self.encode_fn(context + continuation)
total_log_prob = 0.0
for i in range(len(context), len(inp)):
window = inp[max(0, i - self.args.seqlen + 1) : i + 1]
pad_amount = self.args.seqlen - len(window)
window = jnp.pad(
jnp.array(window),
(pad_amount, 0),
constant_values=self.args.pad_token,
)
key = jax.random.PRNGKey(0)
pad_mask = jnp.where(window == self.args.pad_token, 0, 1)
if self.args.baseline:
logits = self.model(
window[None, :], pad_mask[None, :], False, key
)[0, -1]
else:
logits = self.model(
window[None, :],
self.args.max_iters,
pad_mask[None, :],
False,
False,
key,
)[0][0, -1]
log_probs = jax.nn.log_softmax(logits, axis=-1)
total_log_prob += log_probs[inp[i]].item()
res.append((total_log_prob, True))
return res
def generate_until(self, requests):
res = []
for context, until in requests:
inp = self.encode_fn(context)
generated = inp.copy()
while not any(self.decode_fn(generated).endswith(u) for u in until):
if len(generated) >= self.args.seqlen:
break
window = generated[-self.args.seqlen :]
pad_amount = self.args.seqlen - len(window)
window = jnp.pad(
jnp.array(window),
(pad_amount, 0),
constant_values=self.args.pad_token,
)
key = jax.random.PRNGKey(0)
pad_mask = jnp.where(window == self.args.pad_token, 0, 1)
if self.args.baseline:
logits = self.model(
window[None, :], pad_mask[None, :], False, key
)[0, -1]
else:
logits = self.model(
window[None, :],
self.args.max_iters,
pad_mask[None, :],
False,
False,
key,
)[0][0, -1]
next_token = jax.random.categorical(key, logits).item()
generated.append(next_token)
res.append(self.decode_fn(generated[len(inp) :]))
return res
'''

lm_obj = MyLM(
model=model,
encode_fn=self.encode_input,
Expand All @@ -286,7 +175,7 @@ def generate_until(self, requests):
task_manager=task_manager,
)

return results
return results['results'] # type: ignore


if __name__ == "__main__":
Expand Down

0 comments on commit 98d371b

Please sign in to comment.