Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic FP8 Training Support #992

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

hanzhi713
Copy link
Member

This PR adds FP8 training support using either in-batch scaling or delayed scaling. There's a one scaling factor per tensor.

@hanzhi713 hanzhi713 requested review from ruomingp, markblee and a team as code owners February 13, 2025 23:42
Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

Comment on lines +317 to +330
update_cfg: OverrideInplaceUpdateTransformation.Config = (
OverrideInplaceUpdateTransformation.default_config()
)
update_cfg.rules = [
f".*/{x}"
for x in [
"input_scale",
"kernel_scale",
"output_grad_scale",
"input_amax_history",
"kernel_amax_history",
"output_grad_amax_history",
]
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use update_rules? See

("state_updates_v", [(".*v", UpdateType.STATE_UPDATES)]),
.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. There is a tracer problem when plumbing the FP8 stats through as state updates, because they are produced in the backward pass of a custom vjp. To circumvent this problem, we produce them as gradient and treat them as state updates.

See the internal PR 806 and 816 for context and discussions.

secondary: Nested[Any],
override_fn: Callable[[Any, Any], Any] = tree_merge_default_override_fn,
) -> Nested[Any]:
"""Merge `secondary` into `primary`. The result contains shallow copies of subtrees from both.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider making them deep copies to be consistent with the jax tree utils. Deep copies are also less error prone by decoupling side effects.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if k in primary:
out_tree[k] = tree_merge(primary[k], secondary=secondary[k], override_fn=override_fn)
else:
out_tree[k] = secondary[k]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
out_tree[k] = secondary[k]
out_tree[k] = copy.deepcopy(secondary[k])

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if isinstance(primary, dict) ^ isinstance(secondary, dict):
raise ValueError(f"Trying to merge incompatible subtrees: {primary=}, {secondary=}")
if not (isinstance(primary, dict) or isinstance(secondary, dict)):
return override_fn(primary, secondary)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return override_fn(primary, secondary)
return copy.deepcopy(override_fn(primary, secondary))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

primary: Nested[Any],
*,
secondary: Nested[Any],
override_fn: Callable[[Any, Any], Any] = tree_merge_default_override_fn,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
override_fn: Callable[[Any, Any], Any] = tree_merge_default_override_fn,
leaf_merge_fn: Callable[[Any, Any], Any] = tree_merge_default_override_fn,

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

primary: Nested[Any],
*,
secondary: Nested[Any],
override_fn: Callable[[Any, Any], Any] = tree_merge_default_override_fn,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using a simpler function (always choose primary) as the default merge function?

Suggested change
override_fn: Callable[[Any, Any], Any] = tree_merge_default_override_fn,
override_fn: Callable[[Any, Any], Any] = lambda x, y: x,

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the current default makes more sense. Always choose primary will discard subtree of secondary even when the primary tree has an empty leaf or None at the corresponding position.

@hanzhi713 hanzhi713 requested a review from ruomingp February 27, 2025 00:17
@hanzhi713
Copy link
Member Author

@ruomingp Could you please take a look again?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants