|
39 | 39 | ######################################################################
|
40 | 40 | # By padding every underlying tensor to the same shape,
|
41 | 41 | # a nested tensor can be converted to a regular tensor.
|
42 |
| -pt = nt.to_padded_tensor(0.0) |
| 42 | +pt = torch.nested.to_padded_tensor(nt, padding=0.0) |
43 | 43 | print(pt)
|
44 | 44 |
|
45 | 45 | ######################################################################
|
@@ -400,9 +400,9 @@ def zipf_sentence_lengths(alpha: float, batch_size: int) -> np.ndarray:
|
400 | 400 | value = torch.nested_tensor(values )
|
401 | 401 |
|
402 | 402 | # pad input
|
403 |
| -padded_query = query.to_padded_tensor(0.0, (N, L_t, E_q)) |
404 |
| -padded_key = key .to_padded_tensor(0.0, (N, L_s, E_k)) |
405 |
| -padded_value = value.to_padded_tensor(0.0, (N, L_s, E_v)) |
| 403 | +padded_query = torch.nested.to_padded_tensor(query, 0.0, (N, L_t, E_q)) |
| 404 | +padded_key = torch.nested.to_padded_tensor(key, 0.0, (N, L_s, E_k)) |
| 405 | +padded_value = torch.nested.to_padded_tensor(value, 0.0, (N, L_s, E_v)) |
406 | 406 |
|
407 | 407 | # create attention masks
|
408 | 408 | attn_mask_q = torch.zeros((N, L_t), dtype=torch.bool)
|
@@ -436,7 +436,7 @@ def zipf_sentence_lengths(alpha: float, batch_size: int) -> np.ndarray:
|
436 | 436 | dropout_p=dropout_p)
|
437 | 437 | t2 = timeit.default_timer()
|
438 | 438 |
|
439 |
| -print("nested and padded calculations differ by", (out_nested.to_padded_tensor(0.0, (N, L_t, E_out)) - out_padded).abs().max().item()) |
| 439 | +print("nested and padded calculations differ by", (torch.nested.to_padded_tensor(out_nested, 0.0, (N, L_t, E_out)) - out_padded).abs().max().item()) |
440 | 440 | print("nested tensor multi-head attention takes", t1 - t0, "seconds")
|
441 | 441 | print("padded tensor multi-head attention takes", t2 - t1, "seconds")
|
442 | 442 |
|
@@ -486,7 +486,7 @@ def zipf_sentence_lengths(alpha: float, batch_size: int) -> np.ndarray:
|
486 | 486 | dropout_p=dropout_p)
|
487 | 487 | t3 = timeit.default_timer()
|
488 | 488 |
|
489 |
| -print("nested general and library calculations differ by", (out_nested.to_padded_tensor(0.0) - out_lib.to_padded_tensor(0.0)).abs().max().item()) |
| 489 | +print("nested general and library calculations differ by", (torch.nested.to_padded_tensor(out_nested, 0.0) - torch.nested.to_padded_tensor(out_lib, 0.0)).abs().max().item()) |
490 | 490 | print("nested library multi-head attention takes", t1 - t0, "seconds")
|
491 | 491 | print("nested general multi-head attention takes", t2 - t1, "seconds")
|
492 | 492 | print("padded tensor multi-head attention takes", t3 - t2, "seconds")
|
0 commit comments