Skip to content

Conversation

@neevparikh
Copy link

Adds AdamW 8Bit from TorchAO.

@neevparikh
Copy link
Author

I think this has some issues with SFT checkpointing unfortunately but I lack bandwidth to dig into why. Leaving this draft pr up in case I get more time or someone else wants to take a stab.

  File "/home/neev/host-dir/prime-rl/src/prime_rl/utils/tensor_hashing.py", line 49, in get_optimizer_signature
    state_dict_sig = unwrap_tensor(optimizer.state_dict())
                     │             │         └ <function Optimizer.state_dict at 0x796c23f2bd80>
                     │             └ AdamW8bit (
                     │               Parameter Group 0
                     │                   amsgrad: False
                     │                   betas: (0.9, 0.999)
                     │                   eps: 1e-08
                     │                   initial_lr: 1.9999999494757503e-0...
                     └ <function get_optimizer_signature.<locals>.unwrap_tensor at 0x7968741f28e0>

  File "/home/neev/host-dir/prime-rl/src/prime_rl/utils/tensor_hashing.py", line 42, in unwrap_tensor
    new_dict[key] = unwrap_tensor(value)
    │        │      │             └ {0: {'step': tensor(10.), 'exp_avg': DTensor(local_tensor=OptimState8bit(signed=True, block_size=256, shape=(151936, 1024), d...
    │        │      └ <function get_optimizer_signature.<locals>.unwrap_tensor at 0x7968741f28e0>
    │        └ 'state'
    └ {}

  File "/home/neev/host-dir/prime-rl/src/prime_rl/utils/tensor_hashing.py", line 42, in unwrap_tensor
    new_dict[key] = unwrap_tensor(value)
    │        │      │             └ {'step': tensor(10.), 'exp_avg': DTensor(local_tensor=OptimState8bit(signed=True, block_size=256, shape=(151936, 1024), devic...
    │        │      └ <function get_optimizer_signature.<locals>.unwrap_tensor at 0x7968741f28e0>
    │        └ 0
    └ {}

  File "/home/neev/host-dir/prime-rl/src/prime_rl/utils/tensor_hashing.py", line 44, in unwrap_tensor
    new_dict[key] = get_tensor_signature(value)
    │        │      │                    └ DTensor(local_tensor=OptimState8bit(signed=True, block_size=256, shape=(151936, 1024), device=cuda:0, requires_grad=False), d...
    │        │      └ <function get_tensor_signature at 0x796b5b7723e0>
    │        └ 'exp_avg'
    └ {'step': 'torch.float32torch.Size([])()<0eed1ad063119fbe58c00b34dfa2959a>'}

  File "/home/neev/host-dir/prime-rl/src/prime_rl/utils/tensor_hashing.py", line 24, in get_tensor_signature
    b = a.as_strided(size=(TENSOR_SIG_SAMPLE_SIZE,), stride=(step_size,))
        │ │                │                                 └ 155582
        │ │                └ 1000
        │ └ <method 'as_strided' of 'torch._C.TensorBase' objects>
        └ OptimState8bit(signed=True, block_size=256, shape=(151936, 1024), device=cuda:0, requires_grad=False)

  File "/home/neev/host-dir/prime-rl/.venv/lib/python3.12/site-packages/torchao/utils.py", line 638, in _dispatch__torch_function__
    return func(*args, **kwargs)
           │     │       └ {'size': (1000,), 'stride': (155582,)}
           │     └ (OptimState8bit(signed=True, block_size=256, shape=(151936, 1024), device=cuda:0, requires_grad=False),)
           └ <method 'as_strided' of 'torch._C.TensorBase' objects>
  File "/home/neev/host-dir/prime-rl/.venv/lib/python3.12/site-packages/torchao/utils.py", line 658, in _dispatch__torch_dispatch__
    raise NotImplementedError(

NotImplementedError: OptimState8bit dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.as_strided', overload='default')>, types=(<class 'torchao.optim.subclass_8bit.OptimState8bit'>,), arg_types=(<class 'torchao.optim.subclass_8bit.OptimState8bit'>, <class 'list'>, <class 'list'>), kwarg_types={}

@neevparikh
Copy link
Author

neevparikh commented Oct 8, 2025

It does work though:
RL:
image
SFT:
image

@samsja
Copy link
Member

samsja commented Oct 13, 2025

Hey, this is a really usefull PR, thank you for the work.

Really weird that the dcp check-pointing is failing, I will look into it when I have more bandwidth

@neevparikh
Copy link
Author

Glad to hear it! I found it very useful to reduce memory requirements for this project: https://x.com/neev_parikh/status/1967767438243876924

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants