Skip to content

Commit ed64264

Browse files
authored
Update nestedtensor to_padded calls (#2036)
1 parent ce59112 commit ed64264

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

prototype_source/nestedtensor.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
######################################################################
4040
# By padding every underlying tensor to the same shape,
4141
# a nested tensor can be converted to a regular tensor.
42-
pt = nt.to_padded_tensor(0.0)
42+
pt = torch.nested.to_padded_tensor(nt, padding=0.0)
4343
print(pt)
4444

4545
######################################################################
@@ -400,9 +400,9 @@ def zipf_sentence_lengths(alpha: float, batch_size: int) -> np.ndarray:
400400
value = torch.nested_tensor(values )
401401

402402
# pad input
403-
padded_query = query.to_padded_tensor(0.0, (N, L_t, E_q))
404-
padded_key = key .to_padded_tensor(0.0, (N, L_s, E_k))
405-
padded_value = value.to_padded_tensor(0.0, (N, L_s, E_v))
403+
padded_query = torch.nested.to_padded_tensor(query, 0.0, (N, L_t, E_q))
404+
padded_key = torch.nested.to_padded_tensor(key, 0.0, (N, L_s, E_k))
405+
padded_value = torch.nested.to_padded_tensor(value, 0.0, (N, L_s, E_v))
406406

407407
# create attention masks
408408
attn_mask_q = torch.zeros((N, L_t), dtype=torch.bool)
@@ -436,7 +436,7 @@ def zipf_sentence_lengths(alpha: float, batch_size: int) -> np.ndarray:
436436
dropout_p=dropout_p)
437437
t2 = timeit.default_timer()
438438

439-
print("nested and padded calculations differ by", (out_nested.to_padded_tensor(0.0, (N, L_t, E_out)) - out_padded).abs().max().item())
439+
print("nested and padded calculations differ by", (torch.nested.to_padded_tensor(out_nested, 0.0, (N, L_t, E_out)) - out_padded).abs().max().item())
440440
print("nested tensor multi-head attention takes", t1 - t0, "seconds")
441441
print("padded tensor multi-head attention takes", t2 - t1, "seconds")
442442

@@ -486,7 +486,7 @@ def zipf_sentence_lengths(alpha: float, batch_size: int) -> np.ndarray:
486486
dropout_p=dropout_p)
487487
t3 = timeit.default_timer()
488488

489-
print("nested general and library calculations differ by", (out_nested.to_padded_tensor(0.0) - out_lib.to_padded_tensor(0.0)).abs().max().item())
489+
print("nested general and library calculations differ by", (torch.nested.to_padded_tensor(out_nested, 0.0) - torch.nested.to_padded_tensor(out_lib, 0.0)).abs().max().item())
490490
print("nested library multi-head attention takes", t1 - t0, "seconds")
491491
print("nested general multi-head attention takes", t2 - t1, "seconds")
492492
print("padded tensor multi-head attention takes", t3 - t2, "seconds")

0 commit comments

Comments
 (0)