Skip to content

Load target address earlier for tail call interpreter #129976

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
aconz2 opened this issue Feb 10, 2025 · 7 comments
Open

Load target address earlier for tail call interpreter #129976

aconz2 opened this issue Feb 10, 2025 · 7 comments
Labels
interpreter-core (Objects, Python, Grammar, and Parser dirs) performance Performance or resource usage

Comments

@aconz2
Copy link

aconz2 commented Feb 10, 2025

My working branch is here main...aconz2:cpython:aconz2/early-tail-call-load

I saw the recent merge of the tail call interpreter (#128718), very nice! I have played with this style of interpreter before and one thing that comes up is when to calculate the target address. As it is, the current interpreter does it in DISPATCH() by doing

DEF_TARGET(foo) {
    // ...
    TAIL return INSTRUCTION_TABLE[opcode](ARGS);
}

this results in assembly like:

0000000000289580 <_TAIL_CALL_GET_LEN>:
  289580: 50                           	push	rax
  289581: 89 fb                        	mov	ebx, edi
  289583: 4d 89 7c 24 38               	mov	qword ptr [r12 + 0x38], r15
  289588: 49 83 c7 02                  	add	r15, 0x2
  28958c: 49 8b 7d f8                  	mov	rdi, qword ptr [r13 - 0x8]
  289590: 4d 89 6c 24 40               	mov	qword ptr [r12 + 0x40], r13
  289595: e8 f6 fc ea ff               	call	0x139290 <PyObject_Size>
  28959a: 4d 8b 6c 24 40               	mov	r13, qword ptr [r12 + 0x40]
  28959f: 49 c7 44 24 40 00 00 00 00   	mov	qword ptr [r12 + 0x40], 0x0
  2895a8: 48 85 c0                     	test	rax, rax
  2895ab: 78 2b                        	js	0x2895d8 <_TAIL_CALL_GET_LEN+0x58>
  2895ad: 48 89 c7                     	mov	rdi, rax
  2895b0: e8 eb 54 ec ff               	call	0x14eaa0 <PyLong_FromSsize_t>
  2895b5: 48 85 c0                     	test	rax, rax
  2895b8: 74 1e                        	je	0x2895d8 <_TAIL_CALL_GET_LEN+0x58>
  2895ba: 49 89 45 00                  	mov	qword ptr [r13], rax
  2895be: 49 83 c5 08                  	add	r13, 0x8
  2895c2: 41 0f b7 3f                  	movzx	edi, word ptr [r15]  #<-- Load next_instr
  2895c6: 40 0f b6 c7                  	movzx	eax, dil             #<-- grab opcode
  2895ca: c1 ef 08                     	shr	edi, 0x8
  2895cd: 48 8d 0d 7c 50 1f 00         	lea	rcx, [rip + 0x1f507c]   # 0x47e650 <INSTRUCTION_TABLE>
  2895d4: 5a                           	pop	rdx
  2895d5: ff 24 c1                     	jmp	qword ptr [rcx + 8*rax] #<-- jmp with addr calculation
  2895d8: 89 df                        	mov	edi, ebx
  2895da: 58                           	pop	rax
  2895db: e9 30 dc ff ff               	jmp	0x287210 <_TAIL_CALL_error>

where we jmp to a computed adress which is dependent on the lea and a memory load a few instructions prior.

Another method looks like

DEF_TARGET(foo) {
  // ...
  tail_funcptr next_f = INSTRUCTION_TABLE[next_opcode];
  // ...
  TAIL return next_f(ARGS);
}

where we try to get the compiler to compute the target earlier and then have a jmp reg. We have to pay special attention to places where next_instr is modified and reload the pointer (though hopefully the optimizer will just wait to do the calculation until the latest place).

In this early branch, I was able to get this working enough to see what asm it would generate. For _TAIL_CALL_GET_LEN, the sequence now looks like

00000000002896b0 <_TAIL_CALL_GET_LEN>:
  2896b0: 55                           	push	rbp
  2896b1: 89 fb                        	mov	ebx, edi
  2896b3: 4d 89 7c 24 38               	mov	qword ptr [r12 + 0x38], r15
  2896b8: 41 0f b6 47 02               	movzx	eax, byte ptr [r15 + 0x2]  #<-- Load next instr opcode
  2896bd: 49 83 c7 02                  	add	r15, 0x2
  2896c1: 48 8d 0d 88 5f 1f 00         	lea	rcx, [rip + 0x1f5f88]   # 0x47f650 <INSTRUCTION_TABLE>
  2896c8: 48 8b 2c c1                  	mov	rbp, qword ptr [rcx + 8*rax]  #<-- load next target addr
  2896cc: 49 8b 7d f8                  	mov	rdi, qword ptr [r13 - 0x8]
  2896d0: 4d 89 6c 24 40               	mov	qword ptr [r12 + 0x40], r13
  2896d5: e8 b6 fb ea ff               	call	0x139290 <PyObject_Size>
  2896da: 4d 8b 6c 24 40               	mov	r13, qword ptr [r12 + 0x40]
  2896df: 49 c7 44 24 40 00 00 00 00   	mov	qword ptr [r12 + 0x40], 0x0
  2896e8: 48 85 c0                     	test	rax, rax
  2896eb: 78 20                        	js	0x28970d <_TAIL_CALL_GET_LEN+0x5d>
  2896ed: 48 89 c7                     	mov	rdi, rax
  2896f0: e8 ab 53 ec ff               	call	0x14eaa0 <PyLong_FromSsize_t>
  2896f5: 48 85 c0                     	test	rax, rax
  2896f8: 74 13                        	je	0x28970d <_TAIL_CALL_GET_LEN+0x5d>
  2896fa: 49 89 45 00                  	mov	qword ptr [r13], rax
  2896fe: 49 83 c5 08                  	add	r13, 0x8
  289702: 41 0f b6 7f 01               	movzx	edi, byte ptr [r15 + 0x1]
  289707: 48 89 e8                     	mov	rax, rbp                  #<-- register rename
  28970a: 5d                           	pop	rbp
  28970b: ff e0                        	jmp	rax                       #<-- jmp to target addr
  28970d: 89 df                        	mov	edi, ebx
  28970f: 5d                           	pop	rbp
  289710: e9 fb da ff ff               	jmp	0x287210 <_TAIL_CALL_error>
  289715: 66 66 2e 0f 1f 84 00 00 00 00 00     	nop	word ptr cs:[rax + rax]
  2896c1: 48 8d 0d 88 5f 1f 00         	lea	rcx, [rip + 0x1f5f88]   # 0x47f650 <INSTRUCTION_TABLE>
  2896c8: 48 8b 2c c1                  	mov	rbp, qword ptr [rcx + 8*rax]

Specifically in this case, both PyObject_Size and PyLong_FromSsize_t don't touch rbp so there isn't any additional register pressure. But I haven't looked extensively so may not be universally true.

My theory is that this could be better for the CPU because in this example once it gets back from PyLong_FromSsize_t, the jump target is already in a register and could maybe prefetch better.

Have not benchmarked anything yet.

Looking at another example _TAIL_CALL_BINARY_OP_SUBSCR_GETITEM, this does a LOAD_IP() towards the end so we have to reload our target address. It does seem like the optimizer is smart enough to avoid double loading, but this just ends up with an almost identical ending:

# main
  28eba5: 41 0f b7 3f                   movzx   edi, word ptr [r15]
  28eba9: 40 0f b6 cf                   movzx   ecx, dil
  28ebad: c1 ef 08                      shr     edi, 0x8
  28ebb0: 48 8d 15 99 fa 1e 00          lea     rdx, [rip + 0x1efa99]   # 0x47e650 <INSTRUCTION_TABLE>
  28ebb7: 49 89 c4                      mov     r12, rax
  28ebba: ff 24 ca                      jmp     qword ptr [rdx + 8*rcx]

# this branch
  28f185: 41 0f b6 0f                   movzx   ecx, byte ptr [r15]
  28f189: 48 8d 15 c0 04 1f 00          lea     rdx, [rip + 0x1f04c0]   # 0x47f650 <INSTRUCTION_TABLE>
  28f190: 41 0f b6 7f 01                movzx   edi, byte ptr [r15 + 0x1]
  28f195: 49 89 c4                      mov     r12, rax
  28f198: ff 24 ca                      jmp     qword ptr [rdx + 8*rcx]

I did this a bit half-hazardly through a combination of modifying macros and manual changes to anything that assigns to next_instr and a few special cases like exit_unwind that didn't fit. Could clean up with some direction.

One super naive metric is

# should be a tab after jmp
llvm-objdump --x86-asm-syntax=intel -D python | grep 'jmp   r' | wc -l

which is 916 for this modification and 731 originally, so 185 more places where we jmp to a register instead of a computed address.

@picnixz picnixz added the interpreter-core (Objects, Python, Grammar, and Parser dirs) label Feb 10, 2025
@tomasr8 tomasr8 added the performance Performance or resource usage label Feb 10, 2025
@Fidget-Spinner
Copy link
Member

Thanks for your work on this! Let me know if it passes the PGO test suite (./python -m test --pgo), after that we can bench it!

@aconz2
Copy link
Author

aconz2 commented Feb 12, 2025

Okay it is passing ./python -m test and ./python -m test --pgo which took a bit of trial and error to figure out which ops can call into things that modify bytecode where we have to reload next_op_f later than the entrypoint. These are unceremoniously scattered right now in bytecodes.c sometimes as early as I think is possible and some I just stuck at the end because I couldn't tell what was going on.

An example of one of these is STORE_ATTR, where in for example ./python -m test test_descr the next op is LOAD_GLOBAL initially but after calling PyObject_SetAttr, it reaches specialize and turns the op into LOAD_GLOBAL_BUILTIN. Or at least I think that is what is going on. Some others that I think can do this are PyObject_SetItem, PyObject_GetItem, PyStackRef_CLOSE, and PyStackRef_XCLOSE.

I'm not sure if passing all the tests is conclusive that this is correct; is there a more principled way of knowing who/what could change the value of next_instr between the start of the op and the end?

After fixing bugs the jmp to reg count went up to 933 from 917.

One thing I noticed is that sometimes the target address is loaded early and then stashed on the interpreter stack, like:

00000000001a8920 <_TAIL_CALL_BINARY_OP>:
  1a8920: 55                            push    rbp
  1a8921: 48 83 ec 10                   sub     rsp, 0x10
  1a8925: 41 89 f9                      mov     r9d, edi
  1a8928: 4d 89 7c 24 38                mov     qword ptr [r12 + 0x38], r15
  1a892d: 48 8d 15 1c cd 2d 00          lea     rdx, [rip + 0x2dcd1c]   # 0x485650 <INSTRUCTION_TABLE>
  1a8934: 49 8b 6d f0                   mov     rbp, qword ptr [r13 - 0x10]
  1a8938: 49 8b 5d f8                   mov     rbx, qword ptr [r13 - 0x8]
  1a893c: 41 0f b7 47 02                movzx   eax, word ptr [r15 + 0x2]
  1a8941: 66 83 f8 0e                   cmp     ax, 0xe
  1a8945: 0f 86 97 00 00 00             jbe     0x1a89e2 <_TAIL_CALL_BINARY_OP+0xc2>
  1a894b: 4c 89 34 24                   mov     qword ptr [rsp], r14
  1a894f: 41 0f b6 4f 0c                movzx   ecx, byte ptr [r15 + 0xc]
  1a8954: 48 8b 0c ca                   mov     rcx, qword ptr [rdx + 8*rcx]  #<-- compute target address
  1a8958: 48 89 4c 24 08                mov     qword ptr [rsp + 0x8], rcx    #<-- store to stack
  ...
  ...
  1a89d6: 48 8b 44 24 08                mov     rax, qword ptr [rsp + 0x8]    #<-- load from stack
  1a89db: 48 83 c4 10                   add     rsp, 0x10
  1a89df: 5d                            pop     rbp
  1a89e0: ff e0                         jmp     rax                           #<-- jmp
  1a89e2: 4d 89 6c 24 40                mov     qword ptr [r12 + 0x40], r13
  1a89e7: 4d 8d 44 24 50                lea     r8, [r12 + 0x50]
  1a89ec: 48 89 ef                      mov     rdi, rbp
  1a89ef: 48 89 de                      mov     rsi, rbx
  1a89f2: 48 89 d5                      mov     rbp, rdx                      #<-- rdx still has INSTRUCTION_TABLE base, "save" to rbp
  1a89f5: 4c 89 fa                      mov     rdx, r15
  1a89f8: 44 89 c9                      mov     ecx, r9d
  1a89fb: 44 89 cb                      mov     ebx, r9d
  1a89fe: e8 5d ea 11 00                call    0x2c7460 <_Py_Specialize_BinaryOp>
  1a8a03: 4d 8b 6c 24 40                mov     r13, qword ptr [r12 + 0x40]
  1a8a08: 49 c7 44 24 40 00 00 00 00    mov     qword ptr [r12 + 0x40], 0x0
  1a8a11: 41 0f b6 07                   movzx   eax, byte ptr [r15]
  1a8a15: 89 df                         mov     edi, ebx
  1a8a17: 48 89 e9                      mov     rcx, rbp                    #<-- rename to get INSTRUCTION_TABLE
  1a8a1a: 48 83 c4 10                   add     rsp, 0x10
  1a8a1e: 5d                            pop     rbp
  1a8a1f: ff 24 c1                      jmp     qword ptr [rcx + 8*rax]     #<-- jump to computed address
  ...

and then there is another variant seen above where it has loaded the INSTRUCTION_TABLE address early and then uses it later with a computed address. Not sure if these are still a win.

And this is all without the jit in mind right now.

@Fidget-Spinner
Copy link
Member

I've requested a benchmarking run for this. Will report back soon.

@Fidget-Spinner
Copy link
Member

Ok we have received the benchmark results (thanks to Mike):

https://github.com/faster-cpython/benchmarking-public/tree/main/results/bm-20250212-3.14.0a4+-e9c43a0-CLANG

Roughly 1% faster on AMD64 Xeon W-2255

No speedup on macOS M1.

This makes sense. The Xeon W-2255 has presumably weaker out of order execution capabilities than the M1. So there's benefit in doing this for older processors, or those with weaker OOO execution.

@Fidget-Spinner
Copy link
Member

Note that a 1% win is pretty significant. So great job! If we were to upstream this though, it would need to be less likely to trip up the cpython contributors, so let me think how to approach this in a less error-prone way using the code generator.

@aconz2
Copy link
Author

aconz2 commented Feb 27, 2025

One thought I've had to eliminate the special casing is something like:

DEF_TARGET() {
    u8 next_opcode = load_next_opcode();
    tail_funcptr next_f = INSTRUCTION_TABLE[next_opcode];
    // ...
    u8 next_opcode2 = load_next_opcode();
    if (next_opcode == next_opcode2) TAIL return next_f(ARGS);
    next_f = INSTRUCTION_TABLE[next_opcode];
    TAIL return next_f(ARGS);
}

But not clear without testing whether this would still be a win. I can try this out and post a branch when I get some time

@aconz2
Copy link
Author

aconz2 commented Mar 1, 2025

I tried it out and is passing -m test pgo. Branch is main...aconz2:cpython:aconz2/early-tail-call-load-2

The diff is much smaller, only exception_unwind and start_frame needed changes beyond the generator.

Here is one op for comparison

0000000000296870 <_TAIL_CALL_GET_LEN>:
  296870: 55                            push    rbp
  296871: 4d 89 7c 24 38                mov     qword ptr [r12 + 0x38], r15
  296876: 41 0f b6 5f 02                movzx   ebx, byte ptr [r15 + 0x2]
  29687b: 49 83 c7 02                   add     r15, 0x2
  29687f: 48 8d 05 ca fd 1e 00          lea     rax, [rip + 0x1efdca]   # 0x486650 <INSTRUCTION_TABLE>
  296886: 48 8b 2c d8                   mov     rbp, qword ptr [rax + 8*rbx]
  29688a: 49 8b 7d f8                   mov     rdi, qword ptr [r13 - 0x8]
  29688e: 4d 89 6c 24 40                mov     qword ptr [r12 + 0x40], r13
  296893: e8 98 27 ea ff                call    0x139030 <PyObject_Size>
  296898: 4d 8b 6c 24 40                mov     r13, qword ptr [r12 + 0x40]
  29689d: 49 c7 44 24 40 00 00 00 00    mov     qword ptr [r12 + 0x40], 0x0
  2968a6: 48 85 c0                      test    rax, rax
  2968a9: 78 28                         js      0x2968d3 <_TAIL_CALL_GET_LEN+0x63>
  2968ab: 48 89 c7                      mov     rdi, rax
  2968ae: e8 cd 7e eb ff                call    0x14e780 <PyLong_FromSsize_t>
  2968b3: 48 85 c0                      test    rax, rax
  2968b6: 74 1b                         je      0x2968d3 <_TAIL_CALL_GET_LEN+0x63>
  2968b8: 49 89 45 00                   mov     qword ptr [r13], rax
  2968bc: 49 83 c5 08                   add     r13, 0x8
  2968c0: 41 0f b7 07                   movzx   eax, word ptr [r15]
  2968c4: 89 c7                         mov     edi, eax
  2968c6: c1 ef 08                      shr     edi, 0x8
  2968c9: 38 c3                         cmp     bl, al
  2968cb: 75 0c                         jne     0x2968d9 <_TAIL_CALL_GET_LEN+0x69>
  2968cd: 48 89 e8                      mov     rax, rbp
  2968d0: 5d                            pop     rbp
  2968d1: ff e0                         jmp     rax
  2968d3: 5d                            pop     rbp
  2968d4: e9 c7 ce f0 ff                jmp     0x1a37a0 <_TAIL_CALL_error>
  2968d9: 0f b6 c0                      movzx   eax, al
  2968dc: 48 8d 0d 6d fd 1e 00          lea     rcx, [rip + 0x1efd6d]   # 0x486650 <INSTRUCTION_TABLE>
  2968e3: 5d                            pop     rbp
  2968e4: ff 24 c1                      jmp     qword ptr [rcx + 8*rax]
  2968e7: 66 0f 1f 84 00 00 00 00 00    nop     word ptr [rax + rax]

and bigger ops are stashing the original next opcode on the stack for later comparison.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
interpreter-core (Objects, Python, Grammar, and Parser dirs) performance Performance or resource usage
Projects
None yet
Development

No branches or pull requests

4 participants