Skip to content

Implementation of generate #222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft

Implementation of generate #222

wants to merge 18 commits into from

Conversation

bigximik
Copy link
Contributor

@bigximik bigximik commented Apr 3, 2025

✨ Description

part of #217

Closes #

πŸ” Type of change

Select all that apply:

  • πŸ› Bug fix (non-breaking change that addresses a specific issue)
  • πŸš€ New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • πŸ“ˆ Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • πŸ› οΈ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • πŸ“¦ Dependency bump (updates dependencies, including Dockerfile or package changes)
  • πŸ“ Documentation change (updates documentation, including new content or typo fixes)
  • πŸ”§ Infrastructure/Build change (affects build process, CI/CD, or dependencies)

πŸ“ Changes

List the key changes introduced in this PR:

  1. Change A
  2. Change B

βœ… Checklist

Make sure the following tasks are completed before submitting the PR:

General

  • πŸ“œ I have read and followed the contributing guidelines.
  • 🏷️ I am using a clear and descriptive PR title that summarizes the key change or feature introduced.
  • πŸŽ‰ The functionality is complete, and I have tested the changes.
  • πŸ“ I have updated the documentation if needed.
  • ⚠️ The change does not introduce any new issues (e.g., runtime warnings, type checker errors, linting problems, unhandled edge cases).
  • 🧩 I have commented my code, especially in hard-to-understand areas.

Dependencies and Configuration

  • πŸ‹ I have updated the Docker configuration or dependencies, if applicable.
  • πŸ”„ I have ensured compatibility with the existing setup after dependency changes.

Testing

  • πŸ§ͺ I have added or updated tests to cover my changes.
  • βœ”οΈ New and existing tests pass locally with my changes.
  • 🚦 I have tested these changes on GPUs and verified training stability.
  • πŸ‹οΈ I have tested the changes on realistic training workloads, if applicable.

Performance Impact

  • πŸ“Š I have run benchmarks where applicable to evaluate the performance impact.
  • βœ… The benchmarks show no performance regression.
  • πŸš€ The benchmarks indicate a potential performance improvement.
  • ⚠️ The benchmarks indicate a potential performance degradation.
  • πŸ“ˆ I have provided benchmark results and detailed any performance impact below, if applicable.

πŸ“Š Performance Impact Details

If there is any impact on performance, describe it and provide benchmark results, if applicable:


πŸ—’οΈ Additional Notes

Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.

@bigximik
Copy link
Contributor Author

bigximik commented Apr 3, 2025

I have created a debugging sandbox with manual tests for now. The results are as follows:

Ignoring attention_mask and position_ids:

Batch Size No Flash Attention (Float32) No Flash Attention (BF16) Flash Attention (BF16)
1 Same output (same model via HF and Fast-LLM) Same output Different output
2 Different output Different output Different output

Converting attention_mask (from HF forward) to sequence_lengths:

Batch Size No Flash Attention (Float32) No Flash Attention (BF16) Flash Attention (BF16)
1 FastLLM empty output FastLLM empty output Different output
2 FastLLM empty output FastLLM empty output Different output

It seems sequence_lengths is not supported for fused attention and does not improve Flash Attention. Could this be correct?

If attention_mask is a left-padded mask like this:
[[0, 0, 0, 1, 1, 1, 1], ....]
I convert it to sequence_lengths = [[3, 4], ....].

# First non zero indexes or zero index if the row is all zeros (invalid row)
first_non_zero_indexes = attention_mask.argmax(dim=1)

# Check if the sequence is left-padded and if the remaining ones are continuous 1-ns
assert (attention_mask.sum(axis=1) == (attention_mask.shape[1] - first_non_zero_indexes)).all()

sequence_lenghts = [
    torch.tensor(
        [attention_mask.shape[1]] if el == 0 else [el, attention_mask.shape[1] - el], dtype=torch.int64
    )
    for el in first_non_zero_indexes.tolist()
]

@bigximik
Copy link
Contributor Author

bigximik commented Apr 3, 2025

@sohamparikh @jlamypoirier Hi, I am trying to use the cross-document attention prevention that @tscholak pointed me to (https://github.com/ServiceNow/Fast-LLM/pull/177/files) to mimic left padding for documents in a batch during generation. It appears to be doing the correct thing, such as building the internal mask and position IDs, but it is not working. Could you please comment on what might be wrong? Thanks!

completed_steps: int,
consumed_samples: int,
consumed_tokens: int,
) -> tuple[dict[str, any], str | None]: ...
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use dataclass

)
end_time = time.perf_counter()
time_per_iteration = (end_time - begin_time) / num_iters
model_tflops, hardware_tflops = self._get_tflops_func(phase, time_per_iteration)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move downstream

Comment on lines +44 to +57
@classmethod
def build(
cls,
name: str,
eval_config: EvaluationLossConfig,
trainer_config: TrainerConfig,
get_tflops_func: callable,
) -> "Evaluation":
return cls(
name=name,
eval_config=eval_config,
trainer_config=trainer_config,
get_tflops_func=get_tflops_func,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make dataclass fields

self._trainer_config = trainer_config
self._get_tflops_func = get_tflops_func

self._loss_defs = self._multi_stage.base_model.loss_defs
Copy link
Collaborator

@tscholak tscholak Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use __post_init__?

Comment on lines +335 to +341
assert not args.wandb_args # default empty string
assert not args.wandb_config_args # default empty string
assert args.model == "hf" # default value of 'hf'
assert not args.model_args # default empty string
assert args.batch_size == 1 # default value of 1
assert args.max_batch_size is None
assert args.device is None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure these are raised during config class validation

continue


def setup_parser() -> argparse.ArgumentParser:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

os.environ["TOKENIZERS_PARALLELISM"] = "false"

# update the evaluation tracker args with the output path and the HF token
if args.output_path:
Copy link
Collaborator

@tscholak tscholak Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please clean this up vvv.
we are not pushing anything to the hf hub during eval.
the remainder should be controlled by fast-llm

# utils.setup_logging(args.verbosity)
# eval_logger = logging.getLogger(__name__)

os.environ["TOKENIZERS_PARALLELISM"] = "false"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't apply

@jlamypoirier
Copy link
Collaborator

Can we please break down this PR? Otherwise it will make reviewing too difficult. Let's keep this one about the minimalistic generate, and move the rest to the next PR

@tscholak
Copy link
Collaborator

Can we please break down this PR? Otherwise it will make reviewing too difficult. Let's keep this one about the minimalistic generate, and move the rest to the next PR

Sure, eventually we can do that. @bigximik is currently iterating towards an end-to-end solution for running benchmarks, and he's solving issues as they arise. It makes sense for him to operate that way for the time being, but when the time comes to review the changes, we should separate the concerns.

@tscholak
Copy link
Collaborator

@jlamypoirier, btw, we need your guidance in determining the best way to distribute generation across ranks.
Concretely, we are looking to implement this lm-eval-harness API:

    @abc.abstractmethod
    def generate_until(self, requests) -> List[str]:
        """Generate greedily until a stopping sequence

        :param requests: list[Instance]
            A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs).
            context: str
                Context string
            gen_kwargs: dict
                A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc.
        :return: list[str]
            A list of model generated continuations.
            continuation: str
                The generated continuation.
        """
        pass

where generate_until(requests: list[Instance], ...) is called from rank 0 and distribute the Instances across ranks calling the Fast-LLM model's generate(inputs: torch.Tensor, ...). An Instance is a prompt with fluff, https://github.com/EleutherAI/lm-evaluation-harness/blob/e4a7b69fe0fc6cb430e12cf15c4109bf28185124/lm_eval/api/instance.py#L11.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants