-
Notifications
You must be signed in to change notification settings - Fork 28
Implementation of generate #222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
I have created a debugging sandbox with manual tests for now. The results are as follows: Ignoring
|
Batch Size | No Flash Attention (Float32) | No Flash Attention (BF16) | Flash Attention (BF16) |
---|---|---|---|
1 | Same output (same model via HF and Fast-LLM) | Same output | Different output |
2 | Different output | Different output | Different output |
Converting attention_mask
(from HF forward
) to sequence_lengths
:
Batch Size | No Flash Attention (Float32) | No Flash Attention (BF16) | Flash Attention (BF16) |
---|---|---|---|
1 | FastLLM empty output | FastLLM empty output | Different output |
2 | FastLLM empty output | FastLLM empty output | Different output |
It seems sequence_lengths
is not supported for fused attention and does not improve Flash Attention. Could this be correct?
If attention_mask
is a left-padded mask like this:
[[0, 0, 0, 1, 1, 1, 1], ....]
I convert it to sequence_lengths = [[3, 4], ....]
.
# First non zero indexes or zero index if the row is all zeros (invalid row)
first_non_zero_indexes = attention_mask.argmax(dim=1)
# Check if the sequence is left-padded and if the remaining ones are continuous 1-ns
assert (attention_mask.sum(axis=1) == (attention_mask.shape[1] - first_non_zero_indexes)).all()
sequence_lenghts = [
torch.tensor(
[attention_mask.shape[1]] if el == 0 else [el, attention_mask.shape[1] - el], dtype=torch.int64
)
for el in first_non_zero_indexes.tolist()
]
@sohamparikh @jlamypoirier Hi, I am trying to use the cross-document attention prevention that @tscholak pointed me to (https://github.com/ServiceNow/Fast-LLM/pull/177/files) to mimic left padding for documents in a batch during generation. It appears to be doing the correct thing, such as building the internal mask and position IDs, but it is not working. Could you please comment on what might be wrong? Thanks! |
β¦model and saving
completed_steps: int, | ||
consumed_samples: int, | ||
consumed_tokens: int, | ||
) -> tuple[dict[str, any], str | None]: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use dataclass
) | ||
end_time = time.perf_counter() | ||
time_per_iteration = (end_time - begin_time) / num_iters | ||
model_tflops, hardware_tflops = self._get_tflops_func(phase, time_per_iteration) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move downstream
@classmethod | ||
def build( | ||
cls, | ||
name: str, | ||
eval_config: EvaluationLossConfig, | ||
trainer_config: TrainerConfig, | ||
get_tflops_func: callable, | ||
) -> "Evaluation": | ||
return cls( | ||
name=name, | ||
eval_config=eval_config, | ||
trainer_config=trainer_config, | ||
get_tflops_func=get_tflops_func, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make dataclass fields
self._trainer_config = trainer_config | ||
self._get_tflops_func = get_tflops_func | ||
|
||
self._loss_defs = self._multi_stage.base_model.loss_defs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use __post_init__
?
assert not args.wandb_args # default empty string | ||
assert not args.wandb_config_args # default empty string | ||
assert args.model == "hf" # default value of 'hf' | ||
assert not args.model_args # default empty string | ||
assert args.batch_size == 1 # default value of 1 | ||
assert args.max_batch_size is None | ||
assert args.device is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make sure these are raised during config class validation
continue | ||
|
||
|
||
def setup_parser() -> argparse.ArgumentParser: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
os.environ["TOKENIZERS_PARALLELISM"] = "false" | ||
|
||
# update the evaluation tracker args with the output path and the HF token | ||
if args.output_path: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please clean this up vvv.
we are not pushing anything to the hf hub during eval.
the remainder should be controlled by fast-llm
# utils.setup_logging(args.verbosity) | ||
# eval_logger = logging.getLogger(__name__) | ||
|
||
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doesn't apply
Can we please break down this PR? Otherwise it will make reviewing too difficult. Let's keep this one about the minimalistic |
Sure, eventually we can do that. @bigximik is currently iterating towards an end-to-end solution for running benchmarks, and he's solving issues as they arise. It makes sense for him to operate that way for the time being, but when the time comes to review the changes, we should separate the concerns. |
@jlamypoirier, btw, we need your guidance in determining the best way to distribute generation across ranks.
where |
β¨ Description
part of #217
Closes #
π Type of change
Select all that apply:
π Changes
List the key changes introduced in this PR:
β Checklist
Make sure the following tasks are completed before submitting the PR:
General
Dependencies and Configuration
Testing
Performance Impact
π Performance Impact Details
If there is any impact on performance, describe it and provide benchmark results, if applicable:
ποΈ Additional Notes
Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.