Skip to content

Commit 322e5a2

Browse files
authored
Update base64 string handling and adjust LUT (#51)
* add internal string to bytes conversion * use a slice for lookup tables * use pointer for lut * add test to verify issue #50
1 parent 86603cd commit 322e5a2

10 files changed

+77
-52
lines changed

base64/base64_amd64.go

+10-10
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,19 @@ package base64
22

33
import (
44
"encoding/base64"
5-
"unsafe"
65

76
"github.com/segmentio/asm/cpu"
7+
"github.com/segmentio/asm/internal/unsafebytes"
88
)
99

1010
// An Encoding is a radix 64 encoding/decoding scheme, defined by a
1111
// 64-character alphabet.
1212
type Encoding struct {
13-
enc func(dst []byte, src []byte, lut [16]int8) (int, int)
14-
enclut [16]int8
13+
enc func(dst []byte, src []byte, lut *int8) (int, int)
14+
enclut [32]int8
1515

16-
dec func(dst []byte, src []byte, lut [32]int8) (int, int)
17-
declut [32]int8
16+
dec func(dst []byte, src []byte, lut *int8) (int, int)
17+
declut [48]int8
1818

1919
base *base64.Encoding
2020
}
@@ -42,7 +42,7 @@ func (e *Encoding) enableEncodeAVX2(encoder string) {
4242
// [52..61] [48..57] -4 [2..11] 0123456789
4343
// [62] [43] -19 12 +
4444
// [63] [47] -16 13 /
45-
tab := [16]int8{int8(encoder[0]), int8(encoder[letterRange]) - letterRange}
45+
tab := [32]int8{int8(encoder[0]), int8(encoder[letterRange]) - letterRange}
4646
for i, ch := range encoder[2*letterRange:] {
4747
tab[2+i] = int8(ch) - 2*letterRange - int8(i)
4848
}
@@ -67,7 +67,7 @@ func (e *Encoding) enableDecodeAVX2(encoder string) {
6767
// [48..57] [52..61] +4 3 0123456789
6868
// [65..90] [0..25] -65 4,5 ABCDEFGHIJKLMNOPQRSTUVWXYZ
6969
// [97..122] [26..51] -71 6,7 abcdefghijklmnopqrstuvwxyz
70-
tab := [32]int8{
70+
tab := [48]int8{
7171
0, 63 - c63, 62 - c62, 4, -65, -65, -71, -71,
7272
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
7373
0x15, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
@@ -104,7 +104,7 @@ func (enc Encoding) Strict() *Encoding {
104104
// This will write EncodedLen(len(src)) bytes to dst.
105105
func (enc *Encoding) Encode(dst, src []byte) {
106106
if len(src) >= minEncodeLen && enc.enc != nil {
107-
d, s := enc.enc(dst, src, enc.enclut)
107+
d, s := enc.enc(dst, src, &enc.enclut[0])
108108
dst = dst[d:]
109109
src = src[s:]
110110
}
@@ -131,7 +131,7 @@ func (enc *Encoding) EncodedLen(n int) int {
131131
func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
132132
var d, s int
133133
if len(src) >= minDecodeLen && enc.dec != nil {
134-
d, s = enc.dec(dst, src, enc.declut)
134+
d, s = enc.dec(dst, src, &enc.declut[0])
135135
dst = dst[d:]
136136
src = src[s:]
137137
}
@@ -143,7 +143,7 @@ func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
143143
// DecodeString decodes the base64 encoded string s, returns the decoded
144144
// value as bytes.
145145
func (enc *Encoding) DecodeString(s string) ([]byte, error) {
146-
src := *(*[]byte)(unsafe.Pointer(&s))
146+
src := unsafebytes.BytesOf(s)
147147
dst := make([]byte, enc.base.DecodedLen(len(s)))
148148
n, err := enc.Decode(dst, src)
149149
return dst[:n], err

base64/base64_amd64_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ func TestEncodeAVX2(t *testing.T) {
4646
}
4747
defer dst.Release()
4848

49-
_, ns := enc.candidate.enc(dst.ProtectTail(), buf, enc.candidate.enclut)
49+
_, ns := enc.candidate.enc(dst.ProtectTail(), buf, &enc.candidate.enclut[0])
5050

5151
if len(buf)-ns >= 32 {
5252
t.Errorf("encode remain should be less than 32, but is %d", len(buf)-ns)
@@ -86,7 +86,7 @@ func TestDecodeAVX2(t *testing.T) {
8686

8787
enc.candidate.Encode(src, buf)
8888

89-
_, ns := enc.candidate.dec(dst.ProtectTail(), src, enc.candidate.declut)
89+
_, ns := enc.candidate.dec(dst.ProtectTail(), src, &enc.candidate.declut[0])
9090

9191
if len(buf)-ns >= 45 {
9292
t.Errorf("decode remain should be less than 45, but is %d", len(buf)-ns)

base64/base64_test.go

+16
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,22 @@ func TestEncoding(t *testing.T) {
7979
if !bytes.Equal(decExpect, decActual) {
8080
t.Fatalf("failed decode:\n\texpect = %v\n\tactual = %v", decExpect, decActual)
8181
}
82+
83+
encString := enc.control.EncodeToString(src)
84+
decExpect, errControl = enc.control.DecodeString(encString)
85+
decActual, errCandidate = enc.candidate.DecodeString(encString)
86+
87+
if errControl != nil {
88+
t.Fatalf("control decode error: %v", errControl)
89+
}
90+
91+
if errCandidate != nil {
92+
t.Fatalf("candidate decode error: %v", errCandidate)
93+
}
94+
95+
if !bytes.Equal(decExpect, decActual) {
96+
t.Fatalf("failed decode:\n\texpect = %v\n\tactual = %v", decExpect, decActual)
97+
}
8298
}
8399
})
84100
}

base64/decode_amd64.go

+2-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

base64/decode_amd64.s

+20-18
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,21 @@ DATA b64_dec_shuf<>+16(SB)/8, $0x0c0d0e08090a0405
3232
DATA b64_dec_shuf<>+24(SB)/8, $0x0000000000000000
3333
GLOBL b64_dec_shuf<>(SB), RODATA|NOPTR, $32
3434

35-
// func decodeAVX2(dst []byte, src []byte, lut [32]int8) (int, int)
35+
// func decodeAVX2(dst []byte, src []byte, lut *int8) (int, int)
3636
// Requires: AVX, AVX2, SSE4.1
37-
TEXT ·decodeAVX2(SB), NOSPLIT, $0-96
37+
TEXT ·decodeAVX2(SB), NOSPLIT, $0-72
3838
MOVQ dst_base+0(FP), AX
3939
MOVQ src_base+24(FP), DX
40-
MOVQ src_len+32(FP), SI
40+
MOVQ lut+48(FP), SI
41+
MOVQ src_len+32(FP), DI
4142
MOVB $0x2f, CL
4243
PINSRB $0x00, CX, X8
4344
VPBROADCASTB X8, Y8
4445
XORQ CX, CX
4546
XORQ BX, BX
4647
VPXOR Y7, Y7, Y7
47-
VPERMQ $0x44, lut_0+48(FP), Y6
48-
VPERMQ $0x44, lut_0+64(FP), Y4
48+
VPERMQ $0x44, (SI), Y6
49+
VPERMQ $0x44, 16(SI), Y4
4950
VMOVDQA b64_dec_lut_hi<>+0(SB), Y5
5051

5152
loop:
@@ -71,20 +72,20 @@ loop:
7172
VMOVDQU Y1, (AX)(CX*1)
7273
ADDQ $0x18, CX
7374
ADDQ $0x20, BX
74-
SUBQ $0x20, SI
75-
CMPQ SI, $0x2d
75+
SUBQ $0x20, DI
76+
CMPQ DI, $0x2d
7677
JB done
7778
JMP loop
7879

7980
done:
80-
MOVQ CX, ret+80(FP)
81-
MOVQ BX, ret1+88(FP)
81+
MOVQ CX, ret+56(FP)
82+
MOVQ BX, ret1+64(FP)
8283
VZEROUPPER
8384
RET
8485

85-
// func decodeAVX2URI(dst []byte, src []byte, lut [32]int8) (int, int)
86+
// func decodeAVX2URI(dst []byte, src []byte, lut *int8) (int, int)
8687
// Requires: AVX, AVX2, SSE4.1
87-
TEXT ·decodeAVX2URI(SB), NOSPLIT, $0-96
88+
TEXT ·decodeAVX2URI(SB), NOSPLIT, $0-72
8889
MOVB $0x2f, AL
8990
PINSRB $0x00, AX, X0
9091
VPBROADCASTB X0, Y0
@@ -93,15 +94,16 @@ TEXT ·decodeAVX2URI(SB), NOSPLIT, $0-96
9394
VPBROADCASTB X1, Y1
9495
MOVQ dst_base+0(FP), AX
9596
MOVQ src_base+24(FP), DX
96-
MOVQ src_len+32(FP), SI
97+
MOVQ lut+48(FP), SI
98+
MOVQ src_len+32(FP), DI
9799
MOVB $0x2f, CL
98100
PINSRB $0x00, CX, X10
99101
VPBROADCASTB X10, Y10
100102
XORQ CX, CX
101103
XORQ BX, BX
102104
VPXOR Y9, Y9, Y9
103-
VPERMQ $0x44, lut_0+48(FP), Y8
104-
VPERMQ $0x44, lut_0+64(FP), Y6
105+
VPERMQ $0x44, (SI), Y8
106+
VPERMQ $0x44, 16(SI), Y6
105107
VMOVDQA b64_dec_lut_hi<>+0(SB), Y7
106108

107109
loop:
@@ -129,13 +131,13 @@ loop:
129131
VMOVDQU Y3, (AX)(CX*1)
130132
ADDQ $0x18, CX
131133
ADDQ $0x20, BX
132-
SUBQ $0x20, SI
133-
CMPQ SI, $0x2d
134+
SUBQ $0x20, DI
135+
CMPQ DI, $0x2d
134136
JB done
135137
JMP loop
136138

137139
done:
138-
MOVQ CX, ret+80(FP)
139-
MOVQ BX, ret1+88(FP)
140+
MOVQ CX, ret+56(FP)
141+
MOVQ BX, ret1+64(FP)
140142
VZEROUPPER
141143
RET

base64/encode_amd64.go

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

base64/encode_amd64.s

+9-8
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44

55
#include "textflag.h"
66

7-
// func encodeAVX2(dst []byte, src []byte, lut [16]int8) (int, int)
7+
// func encodeAVX2(dst []byte, src []byte, lut *int8) (int, int)
88
// Requires: AVX, AVX2, SSE4.1
9-
TEXT ·encodeAVX2(SB), NOSPLIT, $0-80
9+
TEXT ·encodeAVX2(SB), NOSPLIT, $0-72
1010
MOVQ dst_base+0(FP), AX
1111
MOVQ src_base+24(FP), DX
12-
MOVQ src_len+32(FP), SI
12+
MOVQ lut+48(FP), SI
13+
MOVQ src_len+32(FP), DI
1314
MOVB $0x33, CL
1415
PINSRB $0x00, CX, X4
1516
VPBROADCASTB X4, Y4
@@ -20,7 +21,7 @@ TEXT ·encodeAVX2(SB), NOSPLIT, $0-80
2021
XORQ BX, BX
2122

2223
// Load the 16-byte LUT into both lanes of the register
23-
VPERMQ $0x44, lut_0+48(FP), Y3
24+
VPERMQ $0x44, (SI), Y3
2425

2526
// Load the first block using a mask to avoid potential fault
2627
VMOVDQU b64_enc_load<>+0(SB), Y0
@@ -43,15 +44,15 @@ loop:
4344
VMOVDQU Y0, (AX)(CX*1)
4445
ADDQ $0x20, CX
4546
ADDQ $0x18, BX
46-
SUBQ $0x18, SI
47-
CMPQ SI, $0x20
47+
SUBQ $0x18, DI
48+
CMPQ DI, $0x20
4849
JB done
4950
VMOVDQU -4(DX)(BX*1), Y0
5051
JMP loop
5152

5253
done:
53-
MOVQ CX, ret+64(FP)
54-
MOVQ BX, ret1+72(FP)
54+
MOVQ CX, ret+56(FP)
55+
MOVQ BX, ret1+64(FP)
5556
VZEROUPPER
5657
RET
5758

build/base64/decode_asm.go

+5-8
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,12 @@ func init() {
4747
}
4848

4949
func main() {
50-
TEXT("decodeAVX2", NOSPLIT, "func(dst, src []byte, lut [32]int8) (int, int)")
50+
TEXT("decodeAVX2", NOSPLIT, "func(dst, src []byte, lut *int8) (int, int)")
5151
createDecode(Param("dst"), Param("src"), Param("lut"), func(m Mem, r VecVirtual) {
5252
VMOVDQU(m, r)
5353
})
5454

55-
TEXT("decodeAVX2URI", NOSPLIT, "func(dst, src []byte, lut [32]int8) (int, int)")
55+
TEXT("decodeAVX2URI", NOSPLIT, "func(dst, src []byte, lut *int8) (int, int)")
5656
slash := VecBroadcast(U8('/'), YMM())
5757
underscore := VecBroadcast(U8('_'), YMM())
5858
createDecode(Param("dst"), Param("src"), Param("lut"), func(m Mem, r VecVirtual) {
@@ -68,11 +68,8 @@ func main() {
6868
func createDecode(pdst, psrc, plut Component, load func(m Mem, r VecVirtual)) {
6969
dst := Mem{Base: Load(pdst.Base(), GP64()), Index: GP64(), Scale: 1}
7070
src := Mem{Base: Load(psrc.Base(), GP64()), Index: GP64(), Scale: 1}
71+
lut := Mem{Base: Load(plut, GP64())}
7172
rem := Load(psrc.Len(), GP64())
72-
lut, err := plut.Index(0).Resolve()
73-
if err != nil {
74-
panic(err)
75-
}
7673

7774
rsrc := YMM()
7875
rdst := YMM()
@@ -93,8 +90,8 @@ func createDecode(pdst, psrc, plut Component, load func(m Mem, r VecVirtual)) {
9390
XORQ(src.Index, src.Index)
9491
VPXOR(zero, zero, zero)
9592

96-
VPERMQ(Imm(1<<6|1<<2), lut.Addr, lutr)
97-
VPERMQ(Imm(1<<6|1<<2), lut.Addr.Offset(16), lutl)
93+
VPERMQ(Imm(1<<6|1<<2), lut, lutr)
94+
VPERMQ(Imm(1<<6|1<<2), lut.Offset(16), lutl)
9895
VMOVDQA(lutHi, luth)
9996

10097
Label("loop")

build/base64/encode_asm.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ func init() {
2121
}
2222

2323
func main() {
24-
TEXT("encodeAVX2", NOSPLIT, "func(dst, src []byte, lut [16]int8) (int, int)")
24+
TEXT("encodeAVX2", NOSPLIT, "func(dst, src []byte, lut *int8) (int, int)")
2525

2626
dst := Mem{Base: Load(Param("dst").Base(), GP64()), Index: GP64(), Scale: 1}
2727
src := Mem{Base: Load(Param("src").Base(), GP64()), Index: GP64(), Scale: 1}
28+
lut := Mem{Base: Load(Param("lut"), GP64())}
2829
rem := Load(Param("src").Len(), GP64())
29-
lut, _ := Param("lut").Index(0).Resolve()
3030

3131
rsrc := YMM()
3232
rdst := YMM()
@@ -47,7 +47,7 @@ func main() {
4747
XORQ(src.Index, src.Index)
4848

4949
Comment("Load the 16-byte LUT into both lanes of the register")
50-
VPERMQ(Imm(1<<6|1<<2), lut.Addr, xtab)
50+
VPERMQ(Imm(1<<6|1<<2), lut, xtab)
5151

5252
Comment("Load the first block using a mask to avoid potential fault")
5353
VMOVDQU(ConstLoadMask32("b64_enc_load",

internal/unsafebytes/unsafebytes.go

+9
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,12 @@ func Pointer(b []byte) *byte {
99
func String(b []byte) string {
1010
return *(*string)(unsafe.Pointer(&b))
1111
}
12+
13+
func BytesOf(s string) []byte {
14+
return *(*[]byte)(unsafe.Pointer(&sliceHeader{str: s, cap: len(s)}))
15+
}
16+
17+
type sliceHeader struct {
18+
str string
19+
cap int
20+
}

0 commit comments

Comments
 (0)