-
Notifications
You must be signed in to change notification settings - Fork 300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Basic FP8 Training Support #992
base: main
Are you sure you want to change the base?
Conversation
0ba5ab6
to
d0e4916
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks.
update_cfg: OverrideInplaceUpdateTransformation.Config = ( | ||
OverrideInplaceUpdateTransformation.default_config() | ||
) | ||
update_cfg.rules = [ | ||
f".*/{x}" | ||
for x in [ | ||
"input_scale", | ||
"kernel_scale", | ||
"output_grad_scale", | ||
"input_amax_history", | ||
"kernel_amax_history", | ||
"output_grad_amax_history", | ||
] | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use update_rules? See
axlearn/axlearn/common/learner_test.py
Line 343 in daec8c5
("state_updates_v", [(".*v", UpdateType.STATE_UPDATES)]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. There is a tracer problem when plumbing the FP8 stats through as state updates, because they are produced in the backward pass of a custom vjp. To circumvent this problem, we produce them as gradient and treat them as state updates.
See the internal PR 806 and 816 for context and discussions.
axlearn/common/utils.py
Outdated
secondary: Nested[Any], | ||
override_fn: Callable[[Any, Any], Any] = tree_merge_default_override_fn, | ||
) -> Nested[Any]: | ||
"""Merge `secondary` into `primary`. The result contains shallow copies of subtrees from both. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider making them deep copies to be consistent with the jax tree utils. Deep copies are also less error prone by decoupling side effects.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
axlearn/common/utils.py
Outdated
if k in primary: | ||
out_tree[k] = tree_merge(primary[k], secondary=secondary[k], override_fn=override_fn) | ||
else: | ||
out_tree[k] = secondary[k] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out_tree[k] = secondary[k] | |
out_tree[k] = copy.deepcopy(secondary[k]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
axlearn/common/utils.py
Outdated
if isinstance(primary, dict) ^ isinstance(secondary, dict): | ||
raise ValueError(f"Trying to merge incompatible subtrees: {primary=}, {secondary=}") | ||
if not (isinstance(primary, dict) or isinstance(secondary, dict)): | ||
return override_fn(primary, secondary) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return override_fn(primary, secondary) | |
return copy.deepcopy(override_fn(primary, secondary)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
axlearn/common/utils.py
Outdated
primary: Nested[Any], | ||
*, | ||
secondary: Nested[Any], | ||
override_fn: Callable[[Any, Any], Any] = tree_merge_default_override_fn, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
override_fn: Callable[[Any, Any], Any] = tree_merge_default_override_fn, | |
leaf_merge_fn: Callable[[Any, Any], Any] = tree_merge_default_override_fn, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
axlearn/common/utils.py
Outdated
primary: Nested[Any], | ||
*, | ||
secondary: Nested[Any], | ||
override_fn: Callable[[Any, Any], Any] = tree_merge_default_override_fn, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about using a simpler function (always choose primary) as the default merge function?
override_fn: Callable[[Any, Any], Any] = tree_merge_default_override_fn, | |
override_fn: Callable[[Any, Any], Any] = lambda x, y: x, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the current default makes more sense. Always choose primary will discard subtree of secondary even when the primary tree has an empty leaf or None at the corresponding position.
@ruomingp Could you please take a look again? |
This PR adds FP8 training support using either in-batch scaling or delayed scaling. There's a one scaling factor per tensor.