Skip to content

Commit 1582043

Browse files
author
Jiawen
committed
feat:encryption adds support for SM2, SM3, and SM4 #131
1 parent e6fefa5 commit 1582043

File tree

7 files changed

+1542
-2
lines changed

7 files changed

+1542
-2
lines changed

cryptor/gm_example_test.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package cryptor_test
2+
3+
import (
4+
"encoding/hex"
5+
"fmt"
6+
7+
"github.com/duke-git/lancet/v2/cryptor"
8+
)
9+
10+
func ExampleSm3() {
11+
data := []byte("hello world")
12+
hash := cryptor.Sm3(data)
13+
14+
fmt.Println(hex.EncodeToString(hash))
15+
16+
// Output:
17+
// 44f0061e69fa6fdfc290c494654a05dc0c053da7e5c52b84ef93a9d67d3fff88
18+
}
19+
20+
func ExampleSm4EcbEncrypt() {
21+
key := []byte("1234567890abcdef") // 16 bytes key
22+
plaintext := []byte("hello world")
23+
24+
encrypted := cryptor.Sm4EcbEncrypt(plaintext, key)
25+
decrypted := cryptor.Sm4EcbDecrypt(encrypted, key)
26+
27+
fmt.Println(string(decrypted))
28+
29+
// Output:
30+
// hello world
31+
}
32+
33+
func ExampleSm4CbcEncrypt() {
34+
key := []byte("1234567890abcdef") // 16 bytes key
35+
plaintext := []byte("hello world")
36+
37+
encrypted := cryptor.Sm4CbcEncrypt(plaintext, key)
38+
decrypted := cryptor.Sm4CbcDecrypt(encrypted, key)
39+
40+
fmt.Println(string(decrypted))
41+
42+
// Output:
43+
// hello world
44+
}
45+
46+
func ExampleGenerateSm2Key() {
47+
// Generate SM2 key pair
48+
privateKey, err := cryptor.GenerateSm2Key()
49+
if err != nil {
50+
return
51+
}
52+
53+
plaintext := []byte("hello world")
54+
55+
// Encrypt with public key
56+
ciphertext, err := cryptor.Sm2Encrypt(&privateKey.PublicKey, plaintext)
57+
if err != nil {
58+
return
59+
}
60+
61+
// Decrypt with private key
62+
decrypted, err := cryptor.Sm2Decrypt(privateKey, ciphertext)
63+
if err != nil {
64+
return
65+
}
66+
67+
fmt.Println(string(decrypted))
68+
69+
// Output:
70+
// hello world
71+
}

cryptor/gm_sm2.go

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
// Copyright 2021 dudaodong@gmail.com. All rights reserved.
2+
// Use of this source code is governed by MIT license
3+
4+
//nolint:staticcheck // crypto/elliptic methods are deprecated in Go 1.20+ but still functional
5+
package cryptor
6+
7+
import (
8+
"crypto/elliptic"
9+
"crypto/rand"
10+
"encoding/binary"
11+
"errors"
12+
"io"
13+
"math/big"
14+
)
15+
16+
// SM2 implements the Chinese SM2 elliptic curve public key algorithm.
17+
// SM2 is based on elliptic curve cryptography and provides encryption, decryption, signing and verification.
18+
//
19+
// Note: This implementation uses crypto/elliptic package methods (GenerateKey, ScalarBaseMult, ScalarMult, IsOnCurve)
20+
// which are marked as deprecated in Go 1.20+. These methods still work correctly and are widely used.
21+
// The //nolint:staticcheck directive suppresses deprecation warnings.
22+
// A future version may replace these with a custom elliptic curve implementation.
23+
24+
var (
25+
sm2P256 *sm2Curve
26+
sm2P256Params = &elliptic.CurveParams{Name: "sm2p256v1"}
27+
)
28+
29+
func init() {
30+
// SM2 curve parameters
31+
sm2P256Params.P, _ = new(big.Int).SetString("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF", 16)
32+
sm2P256Params.N, _ = new(big.Int).SetString("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123", 16)
33+
sm2P256Params.B, _ = new(big.Int).SetString("28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93", 16)
34+
sm2P256Params.Gx, _ = new(big.Int).SetString("32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7", 16)
35+
sm2P256Params.Gy, _ = new(big.Int).SetString("BC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0", 16)
36+
sm2P256Params.BitSize = 256
37+
38+
sm2P256 = &sm2Curve{sm2P256Params}
39+
}
40+
41+
type sm2Curve struct {
42+
*elliptic.CurveParams
43+
}
44+
45+
// Sm2PrivateKey represents an SM2 private key.
46+
type Sm2PrivateKey struct {
47+
D *big.Int
48+
PublicKey Sm2PublicKey
49+
}
50+
51+
// Sm2PublicKey represents an SM2 public key.
52+
type Sm2PublicKey struct {
53+
X, Y *big.Int
54+
}
55+
56+
// GenerateSm2Key generates a new SM2 private/public key pair.
57+
// Play: https://go.dev/play/p/bKYMqRLvIx3
58+
func GenerateSm2Key() (*Sm2PrivateKey, error) {
59+
priv, x, y, err := elliptic.GenerateKey(sm2P256, rand.Reader)
60+
if err != nil {
61+
return nil, err
62+
}
63+
64+
privateKey := &Sm2PrivateKey{
65+
D: new(big.Int).SetBytes(priv),
66+
PublicKey: Sm2PublicKey{
67+
X: x,
68+
Y: y,
69+
},
70+
}
71+
72+
return privateKey, nil
73+
}
74+
75+
// Sm2Encrypt encrypts plaintext using SM2 public key.
76+
// Returns ciphertext in the format: C1 || C3 || C2
77+
// C1 = kG (65 bytes in uncompressed format)
78+
// C3 = Hash(x2 || M || y2) (32 bytes for SM3)
79+
// C2 = M xor t (same length as plaintext)
80+
// Play: https://go.dev/play/p/bKYMqRLvIx3
81+
func Sm2Encrypt(pub *Sm2PublicKey, plaintext []byte) ([]byte, error) {
82+
if pub == nil || pub.X == nil || pub.Y == nil {
83+
return nil, errors.New("sm2: invalid public key")
84+
}
85+
86+
for {
87+
// Generate random k
88+
k, err := randFieldElement(sm2P256, rand.Reader)
89+
if err != nil {
90+
return nil, err
91+
}
92+
93+
// C1 = kG
94+
c1x, c1y := sm2P256.ScalarBaseMult(k.Bytes())
95+
96+
// kP = (x2, y2)
97+
x2, y2 := sm2P256.ScalarMult(pub.X, pub.Y, k.Bytes())
98+
99+
// Derive key using KDF
100+
kdfLen := len(plaintext)
101+
t := sm2KDF(append(toBytes(sm2P256, x2), toBytes(sm2P256, y2)...), kdfLen)
102+
103+
// Check if t is all zeros
104+
allZero := true
105+
for _, b := range t {
106+
if b != 0 {
107+
allZero = false
108+
break
109+
}
110+
}
111+
if allZero {
112+
continue
113+
}
114+
115+
// C2 = M xor t
116+
c2 := make([]byte, len(plaintext))
117+
for i := 0; i < len(plaintext); i++ {
118+
c2[i] = plaintext[i] ^ t[i]
119+
}
120+
121+
// C3 = Hash(x2 || M || y2)
122+
c3Input := append(toBytes(sm2P256, x2), plaintext...)
123+
c3Input = append(c3Input, toBytes(sm2P256, y2)...)
124+
c3 := Sm3(c3Input)
125+
126+
// Return C1 || C3 || C2
127+
c1 := sm2MarshalUncompressed(sm2P256, c1x, c1y)
128+
result := append(c1, c3...)
129+
result = append(result, c2...)
130+
131+
return result, nil
132+
}
133+
}
134+
135+
// Sm2Decrypt decrypts ciphertext using SM2 private key.
136+
// Expects ciphertext in the format: C1 || C3 || C2
137+
// Play: https://go.dev/play/p/bKYMqRLvIx3
138+
func Sm2Decrypt(priv *Sm2PrivateKey, ciphertext []byte) ([]byte, error) {
139+
if priv == nil || priv.D == nil {
140+
return nil, errors.New("sm2: invalid private key")
141+
}
142+
143+
// Parse C1 (65 bytes), C3 (32 bytes), C2 (remaining)
144+
if len(ciphertext) < 97 {
145+
return nil, errors.New("sm2: ciphertext too short")
146+
}
147+
148+
c1 := ciphertext[:65]
149+
c3 := ciphertext[65:97]
150+
c2 := ciphertext[97:]
151+
152+
// Parse C1
153+
c1x, c1y := sm2UnmarshalUncompressed(sm2P256, c1)
154+
if c1x == nil {
155+
return nil, errors.New("sm2: invalid C1 point")
156+
}
157+
158+
// Verify C1 is on curve
159+
if !sm2P256.IsOnCurve(c1x, c1y) {
160+
return nil, errors.New("sm2: C1 not on curve")
161+
}
162+
163+
// dC1 = (x2, y2)
164+
x2, y2 := sm2P256.ScalarMult(c1x, c1y, priv.D.Bytes())
165+
166+
// Derive key using KDF
167+
kdfLen := len(c2)
168+
t := sm2KDF(append(toBytes(sm2P256, x2), toBytes(sm2P256, y2)...), kdfLen)
169+
170+
// M = C2 xor t
171+
plaintext := make([]byte, len(c2))
172+
for i := 0; i < len(c2); i++ {
173+
plaintext[i] = c2[i] ^ t[i]
174+
}
175+
176+
// Verify C3 = Hash(x2 || M || y2)
177+
u := append(toBytes(sm2P256, x2), plaintext...)
178+
u = append(u, toBytes(sm2P256, y2)...)
179+
hash := Sm3(u)
180+
181+
for i := 0; i < len(c3); i++ {
182+
if c3[i] != hash[i] {
183+
return nil, errors.New("sm2: hash verification failed")
184+
}
185+
}
186+
187+
return plaintext, nil
188+
}
189+
190+
// SM2 KDF (Key Derivation Function)
191+
func sm2KDF(z []byte, klen int) []byte {
192+
limit := (klen + 31) / 32
193+
result := make([]byte, 0, limit*32)
194+
195+
for i := 1; i <= limit; i++ {
196+
counter := make([]byte, 4)
197+
binary.BigEndian.PutUint32(counter, uint32(i))
198+
hash := Sm3(append(z, counter...))
199+
result = append(result, hash...)
200+
}
201+
202+
return result[:klen]
203+
}
204+
205+
func toBytes(curve elliptic.Curve, value *big.Int) []byte {
206+
byteLen := (curve.Params().BitSize + 7) / 8
207+
buf := make([]byte, byteLen)
208+
b := value.Bytes()
209+
copy(buf[byteLen-len(b):], b)
210+
return buf
211+
}
212+
213+
func sm2MarshalUncompressed(curve *sm2Curve, x, y *big.Int) []byte {
214+
byteLen := (curve.BitSize + 7) / 8
215+
ret := make([]byte, 1+2*byteLen)
216+
ret[0] = 4 // uncompressed point
217+
218+
xBytes := x.Bytes()
219+
copy(ret[1+byteLen-len(xBytes):], xBytes)
220+
yBytes := y.Bytes()
221+
copy(ret[1+2*byteLen-len(yBytes):], yBytes)
222+
223+
return ret
224+
}
225+
226+
func sm2UnmarshalUncompressed(curve *sm2Curve, data []byte) (*big.Int, *big.Int) {
227+
byteLen := (curve.BitSize + 7) / 8
228+
if len(data) != 1+2*byteLen {
229+
return nil, nil
230+
}
231+
if data[0] != 4 {
232+
return nil, nil
233+
}
234+
235+
x := new(big.Int).SetBytes(data[1 : 1+byteLen])
236+
y := new(big.Int).SetBytes(data[1+byteLen:])
237+
238+
return x, y
239+
}
240+
241+
func randFieldElement(c elliptic.Curve, rand io.Reader) (*big.Int, error) {
242+
params := c.Params()
243+
b := make([]byte, params.BitSize/8+8)
244+
_, err := io.ReadFull(rand, b)
245+
if err != nil {
246+
return nil, err
247+
}
248+
249+
k := new(big.Int).SetBytes(b)
250+
n := new(big.Int).Sub(params.N, big.NewInt(1))
251+
k.Mod(k, n)
252+
k.Add(k, big.NewInt(1))
253+
254+
return k, nil
255+
}

0 commit comments

Comments
 (0)