Skip to content

Bug: Jax 0.4.13 support #279

@arnon-1

Description

@arnon-1

Hello,

While trying to use kfac_jax in jax 0.4.13 (which is supported if I am not mistaken), I had to fix some errors.
I installed commit a4531e9 which is fairly recent.

  File "/opt/miniconda3/envs/jax/lib/python3.10/site-packages/kfac_jax/_src/utils/types.py", line 27, in <module>
    DType = jax.typing.DTypeLike
AttributeError: module 'jax.typing' has no attribute 'DTypeLike'

  File "/opt/miniconda3/envs/jax/lib/python3.10/site-packages/kfac_jax/_src/curvature_blocks/curvature_block.py", line 123, in parameters_shapes
    return tuple(jax.tree.map(
  File "/opt/miniconda3/envs/jax/lib/python3.10/site-packages/jax/_src/deprecations.py", line 53, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax' has no attribute 'tree'

  File "/opt/miniconda3/envs/jax/lib/python3.10/site-packages/kfac_jax/_src/tracer.py", line 790, in forward
    write(eqn.outvars, tgm.eval_jaxpr_eqn(eqn, read(eqn.invars)))
  File "/opt/miniconda3/envs/jax/lib/python3.10/site-packages/kfac_jax/_src/tag_graph_matcher.py", line 68, in eval_jaxpr_eqn
    user_context = jax_extend.source_info_util.user_context
AttributeError: module 'jax.extend' has no attribute 'source_info_util'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions