@@ -139,7 +139,7 @@ def __init__(self, rngs: nnx.Rngs):
139139 4 ,
140140 kernel_init = nnx .with_metadata (
141141 nnx .initializers .lecun_normal (),
142- sharding_names = ('din' , 'dout' ),
142+ sharding_metadata = ('din' , 'dout' ),
143143 nickname = ('in' , 'out' ),
144144 on_add_axis = lambda _ , idx , name : kadds .append ((idx , name )),
145145 on_remove_axis = lambda _ , idx , name : kremoves .append ((idx , name )),
@@ -160,7 +160,7 @@ def __call__(self, x: jax.Array):
160160 x = self .linear (x )
161161 # test sharding layer axes is not present inside scan
162162 test .assertEqual (self .linear .kernel .shape , (4 , 4 ))
163- test .assertEqual (self .linear .kernel .sharding_names , ('din' , 'dout' ))
163+ test .assertEqual (self .linear .kernel .sharding_metadata , ('din' , 'dout' ))
164164 # at least a remove_axis was already called to remove the layer axis
165165 test .assertEqual (kremoves [- 1 ], (0 , 'layers' ))
166166 test .assertEqual (bremoves [- 1 ], (0 , 'layers' ))
@@ -175,7 +175,7 @@ def __call__(self, x: jax.Array):
175175 with jax .set_mesh (mesh ):
176176 m = MLP (rngs = nnx .Rngs (0 ))
177177 self .assertEqual (m .linear .kernel .shape , (5 , 4 , 4 ))
178- self .assertEqual (m .linear .kernel .sharding_names , ('layers' , 'din' , 'dout' ))
178+ self .assertEqual (m .linear .kernel .sharding_metadata , ('layers' , 'din' , 'dout' ))
179179 self .assertEqual (m .linear .kernel .nickname , ('nick' , 'in' , 'out' ))
180180 self .assertEqual (m .linear .bias .shape , (5 , 4 ))
181181 # One add_axis called to add the `nnx.vmap` dimension
@@ -201,7 +201,7 @@ def test_eager_sharding_context(self, use_eager_sharding):
201201 with jax .set_mesh (mesh ):
202202 w = nnx .Param (
203203 rngs .lecun_normal ()((4 , 8 )),
204- sharding_names = (None , 'model' ))
204+ sharding_metadata = (None , 'model' ))
205205 if use_eager_sharding :
206206 assert has_sharding_spec (w )
207207 else :
@@ -273,7 +273,7 @@ def test_explicit_sharding(self):
273273 )
274274 v = nnx .Variable (
275275 jnp .ones ((4 , 4 )),
276- sharding_names = ('row' , 'col' ),
276+ sharding_metadata = ('row' , 'col' ),
277277 mesh = mesh ,
278278 )
279279 self .assertEqual (v .sharding .mesh , mesh )
@@ -291,7 +291,7 @@ def test_explicit_sharding_disable_jit(self):
291291 with jax .disable_jit (True ):
292292 v = nnx .Variable (
293293 jnp .ones ((4 , 4 )),
294- sharding_names = ('row' , 'col' ),
294+ sharding_metadata = ('row' , 'col' ),
295295 mesh = mesh ,
296296 )
297297 self .assertEqual (v .sharding .mesh , mesh )
@@ -309,7 +309,7 @@ def test_explicit_sharding_mesh_context(self):
309309 with jax .set_mesh (mesh ):
310310 v = nnx .Variable (
311311 jnp .ones ((4 , 4 )),
312- sharding_names = ('row' , 'col' ),
312+ sharding_metadata = ('row' , 'col' ),
313313 )
314314 self .assertEqual (v .sharding .mesh , mesh )
315315 self .assertEqual (
0 commit comments