@@ -161,6 +161,33 @@ def instruction_pointers_to_images(instruction_pointer, multidevice: bool):
161161 return jnp .array (instruction_pointer_image_list )
162162
163163
164+ def instruction_pointers_to_entropy (instruction_pointer , multidevice : bool ):
165+ """Converts the given batched instruction pointer to an entropy value.
166+
167+ The entropy value measures the sharpness of the instruction pointer, i.e. how
168+ hard vs soft it is.
169+ """
170+ if multidevice :
171+ # instruction_pointer: device, batch_size / device, timesteps, num_nodes
172+ instruction_pointer = instruction_pointer [0 ]
173+
174+ # instruction_pointer: batch_size / device, timesteps, num_nodes
175+ instruction_pointer = jnp .transpose (instruction_pointer [:, :16 , :],
176+ (1 , 2 , 0 ))
177+ # instruction_pointer: logging_slice_size, num_nodes, timesteps
178+ instruction_pointer_image_list = [
179+ instruction_pointer_to_image (ip )
180+ for ip in instruction_pointer
181+ ]
182+ instruction_pointer_image_leading_dim_max = max (
183+ image .shape [0 ] for image in instruction_pointer_image_list )
184+ instruction_pointer_image_list = [
185+ pad (image , instruction_pointer_image_leading_dim_max )
186+ for image in instruction_pointer_image_list
187+ ]
188+ return jnp .array (instruction_pointer_image_list )
189+
190+
164191def pad (array , leading_dim_size : int ):
165192 """Pad the leading dimension of the given array."""
166193 leading_dim_difference = max (0 , leading_dim_size - array .shape [0 ])
0 commit comments