Skip to content
This repository was archived by the owner on Jan 22, 2024. It is now read-only.

Commit 5efeb51

Browse files
committed
Add instruction pointer entropy metric.
Work in-progress.
1 parent 71deb99 commit 5efeb51

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

core/lib/metrics.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,33 @@ def instruction_pointers_to_images(instruction_pointer, multidevice: bool):
161161
return jnp.array(instruction_pointer_image_list)
162162

163163

164+
def instruction_pointers_to_entropy(instruction_pointer, multidevice: bool):
165+
"""Converts the given batched instruction pointer to an entropy value.
166+
167+
The entropy value measures the sharpness of the instruction pointer, i.e. how
168+
hard vs soft it is.
169+
"""
170+
if multidevice:
171+
# instruction_pointer: device, batch_size / device, timesteps, num_nodes
172+
instruction_pointer = instruction_pointer[0]
173+
174+
# instruction_pointer: batch_size / device, timesteps, num_nodes
175+
instruction_pointer = jnp.transpose(instruction_pointer[:, :16, :],
176+
(1, 2, 0))
177+
# instruction_pointer: logging_slice_size, num_nodes, timesteps
178+
instruction_pointer_image_list = [
179+
instruction_pointer_to_image(ip)
180+
for ip in instruction_pointer
181+
]
182+
instruction_pointer_image_leading_dim_max = max(
183+
image.shape[0] for image in instruction_pointer_image_list)
184+
instruction_pointer_image_list = [
185+
pad(image, instruction_pointer_image_leading_dim_max)
186+
for image in instruction_pointer_image_list
187+
]
188+
return jnp.array(instruction_pointer_image_list)
189+
190+
164191
def pad(array, leading_dim_size: int):
165192
"""Pad the leading dimension of the given array."""
166193
leading_dim_difference = max(0, leading_dim_size - array.shape[0])

core/lib/trainer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,14 @@ def run_train(self, dataset_path=DEFAULT_DATASET_PATH, split='train', steps=None
378378
transform_fn=functools.partial(
379379
metrics.instruction_pointers_to_images,
380380
multidevice=config.multidevice))
381+
metrics.write_metric(
382+
EvaluationMetric.INSTRUCTION_POINTER_ENTROPY.value,
383+
aux,
384+
train_writer.scalar,
385+
step,
386+
transform_fn=functools.partial(
387+
metrics.instruction_pointers_to_entropy,
388+
multidevice=config.multidevice))
381389

382390
# Write validation metrics.
383391
valid_writer.scalar('loss', valid_loss, step)

0 commit comments

Comments
 (0)