sm2.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674
  1. /*
  2. Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package sm2
  14. // reference to ecdsa
  15. import (
  16. "bytes"
  17. "crypto"
  18. "crypto/elliptic"
  19. "crypto/rand"
  20. "encoding/asn1"
  21. "encoding/binary"
  22. "errors"
  23. "io"
  24. "math/big"
  25. "github.com/tjfoc/gmsm/sm3"
  26. )
  27. var (
  28. default_uid = []byte{0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38}
  29. C1C3C2=0
  30. C1C2C3=1
  31. )
  32. type PublicKey struct {
  33. elliptic.Curve
  34. X, Y *big.Int
  35. }
  36. type PrivateKey struct {
  37. PublicKey
  38. D *big.Int
  39. }
  40. type sm2Signature struct {
  41. R, S *big.Int
  42. }
  43. type sm2Cipher struct {
  44. XCoordinate *big.Int
  45. YCoordinate *big.Int
  46. HASH []byte
  47. CipherText []byte
  48. }
  49. // The SM2's private key contains the public key
  50. func (priv *PrivateKey) Public() crypto.PublicKey {
  51. return &priv.PublicKey
  52. }
  53. var errZeroParam = errors.New("zero parameter")
  54. var one = new(big.Int).SetInt64(1)
  55. var two = new(big.Int).SetInt64(2)
  56. // sign format = 30 + len(z) + 02 + len(r) + r + 02 + len(s) + s, z being what follows its size, ie 02+len(r)+r+02+len(s)+s
  57. func (priv *PrivateKey) Sign(random io.Reader, msg []byte, signer crypto.SignerOpts) ([]byte, error) {
  58. r, s, err := Sm2Sign(priv, msg, nil, random)
  59. if err != nil {
  60. return nil, err
  61. }
  62. return asn1.Marshal(sm2Signature{r, s})
  63. }
  64. func (pub *PublicKey) Verify(msg []byte, sign []byte) bool {
  65. var sm2Sign sm2Signature
  66. _, err := asn1.Unmarshal(sign, &sm2Sign)
  67. if err != nil {
  68. return false
  69. }
  70. return Sm2Verify(pub, msg, default_uid, sm2Sign.R, sm2Sign.S)
  71. }
  72. func (pub *PublicKey) Sm3Digest(msg, uid []byte) ([]byte, error) {
  73. if len(uid) == 0 {
  74. uid = default_uid
  75. }
  76. za, err := ZA(pub, uid)
  77. if err != nil {
  78. return nil, err
  79. }
  80. e, err := msgHash(za, msg)
  81. if err != nil {
  82. return nil, err
  83. }
  84. return e.Bytes(), nil
  85. }
  86. //****************************Encryption algorithm****************************//
  87. func (pub *PublicKey) EncryptAsn1(data []byte, random io.Reader) ([]byte, error) {
  88. return EncryptAsn1(pub, data, random)
  89. }
  90. func (priv *PrivateKey) DecryptAsn1(data []byte) ([]byte, error) {
  91. return DecryptAsn1(priv, data)
  92. }
  93. //**************************Key agreement algorithm**************************//
  94. // KeyExchangeB 协商第二部,用户B调用, 返回共享密钥k
  95. func KeyExchangeB(klen int, ida, idb []byte, priB *PrivateKey, pubA *PublicKey, rpri *PrivateKey, rpubA *PublicKey) (k, s1, s2 []byte, err error) {
  96. return keyExchange(klen, ida, idb, priB, pubA, rpri, rpubA, false)
  97. }
  98. // KeyExchangeA 协商第二部,用户A调用,返回共享密钥k
  99. func KeyExchangeA(klen int, ida, idb []byte, priA *PrivateKey, pubB *PublicKey, rpri *PrivateKey, rpubB *PublicKey) (k, s1, s2 []byte, err error) {
  100. return keyExchange(klen, ida, idb, priA, pubB, rpri, rpubB, true)
  101. }
  102. //****************************************************************************//
  103. func Sm2Sign(priv *PrivateKey, msg, uid []byte, random io.Reader) (r, s *big.Int, err error) {
  104. digest, err := priv.PublicKey.Sm3Digest(msg, uid)
  105. if err != nil {
  106. return nil, nil, err
  107. }
  108. e := new(big.Int).SetBytes(digest)
  109. c := priv.PublicKey.Curve
  110. N := c.Params().N
  111. if N.Sign() == 0 {
  112. return nil, nil, errZeroParam
  113. }
  114. var k *big.Int
  115. for { // 调整算法细节以实现SM2
  116. for {
  117. k, err = randFieldElement(c, random)
  118. if err != nil {
  119. r = nil
  120. return
  121. }
  122. r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
  123. r.Add(r, e)
  124. r.Mod(r, N)
  125. if r.Sign() != 0 {
  126. if t := new(big.Int).Add(r, k); t.Cmp(N) != 0 {
  127. break
  128. }
  129. }
  130. }
  131. rD := new(big.Int).Mul(priv.D, r)
  132. s = new(big.Int).Sub(k, rD)
  133. d1 := new(big.Int).Add(priv.D, one)
  134. d1Inv := new(big.Int).ModInverse(d1, N)
  135. s.Mul(s, d1Inv)
  136. s.Mod(s, N)
  137. if s.Sign() != 0 {
  138. break
  139. }
  140. }
  141. return
  142. }
  143. func Sm2Verify(pub *PublicKey, msg, uid []byte, r, s *big.Int) bool {
  144. c := pub.Curve
  145. N := c.Params().N
  146. one := new(big.Int).SetInt64(1)
  147. if r.Cmp(one) < 0 || s.Cmp(one) < 0 {
  148. return false
  149. }
  150. if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
  151. return false
  152. }
  153. if len(uid) == 0 {
  154. uid = default_uid
  155. }
  156. za, err := ZA(pub, uid)
  157. if err != nil {
  158. return false
  159. }
  160. e, err := msgHash(za, msg)
  161. if err != nil {
  162. return false
  163. }
  164. t := new(big.Int).Add(r, s)
  165. t.Mod(t, N)
  166. if t.Sign() == 0 {
  167. return false
  168. }
  169. var x *big.Int
  170. x1, y1 := c.ScalarBaseMult(s.Bytes())
  171. x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
  172. x, _ = c.Add(x1, y1, x2, y2)
  173. x.Add(x, e)
  174. x.Mod(x, N)
  175. return x.Cmp(r) == 0
  176. }
  177. /*
  178. za, err := ZA(pub, uid)
  179. if err != nil {
  180. return
  181. }
  182. e, err := msgHash(za, msg)
  183. hash=e.getBytes()
  184. */
  185. func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool {
  186. c := pub.Curve
  187. N := c.Params().N
  188. if r.Sign() <= 0 || s.Sign() <= 0 {
  189. return false
  190. }
  191. if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
  192. return false
  193. }
  194. // 调整算法细节以实现SM2
  195. t := new(big.Int).Add(r, s)
  196. t.Mod(t, N)
  197. if t.Sign() == 0 {
  198. return false
  199. }
  200. var x *big.Int
  201. x1, y1 := c.ScalarBaseMult(s.Bytes())
  202. x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
  203. x, _ = c.Add(x1, y1, x2, y2)
  204. e := new(big.Int).SetBytes(hash)
  205. x.Add(x, e)
  206. x.Mod(x, N)
  207. return x.Cmp(r) == 0
  208. }
  209. /*
  210. * sm2密文结构如下:
  211. * x
  212. * y
  213. * hash
  214. * CipherText
  215. */
  216. func Encrypt(pub *PublicKey, data []byte, random io.Reader,mode int) ([]byte, error) {
  217. length := len(data)
  218. for {
  219. c := []byte{}
  220. curve := pub.Curve
  221. k, err := randFieldElement(curve, random)
  222. if err != nil {
  223. return nil, err
  224. }
  225. x1, y1 := curve.ScalarBaseMult(k.Bytes())
  226. x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
  227. x1Buf := x1.Bytes()
  228. y1Buf := y1.Bytes()
  229. x2Buf := x2.Bytes()
  230. y2Buf := y2.Bytes()
  231. if n := len(x1Buf); n < 32 {
  232. x1Buf = append(zeroByteSlice()[:32-n], x1Buf...)
  233. }
  234. if n := len(y1Buf); n < 32 {
  235. y1Buf = append(zeroByteSlice()[:32-n], y1Buf...)
  236. }
  237. if n := len(x2Buf); n < 32 {
  238. x2Buf = append(zeroByteSlice()[:32-n], x2Buf...)
  239. }
  240. if n := len(y2Buf); n < 32 {
  241. y2Buf = append(zeroByteSlice()[:32-n], y2Buf...)
  242. }
  243. c = append(c, x1Buf...) // x分量
  244. c = append(c, y1Buf...) // y分量
  245. tm := []byte{}
  246. tm = append(tm, x2Buf...)
  247. tm = append(tm, data...)
  248. tm = append(tm, y2Buf...)
  249. h := sm3.Sm3Sum(tm)
  250. c = append(c, h...)
  251. ct, ok := kdf(length, x2Buf, y2Buf) // 密文
  252. if !ok {
  253. continue
  254. }
  255. c = append(c, ct...)
  256. for i := 0; i < length; i++ {
  257. c[96+i] ^= data[i]
  258. }
  259. switch mode{
  260. case C1C3C2:
  261. return append([]byte{0x04}, c...), nil
  262. case C1C2C3:
  263. c1 := make([]byte, 64)
  264. c2 := make([]byte, len(c) - 96)
  265. c3 := make([]byte, 32)
  266. copy(c1, c[:64])//x1,y1
  267. copy(c3, c[64:96])//hash
  268. copy(c2, c[96:])//密文
  269. ciphertext := []byte{}
  270. ciphertext = append(ciphertext, c1...)
  271. ciphertext = append(ciphertext, c2...)
  272. ciphertext = append(ciphertext, c3...)
  273. return append([]byte{0x04}, ciphertext...), nil
  274. default:
  275. return append([]byte{0x04}, c...), nil
  276. }
  277. }
  278. }
  279. func Decrypt(priv *PrivateKey, data []byte,mode int) ([]byte, error) {
  280. switch mode {
  281. case C1C3C2:
  282. data = data[1:]
  283. case C1C2C3:
  284. data = data[1:]
  285. c1 := make([]byte, 64)
  286. c2 := make([]byte, len(data) - 96)
  287. c3 := make([]byte, 32)
  288. copy(c1, data[:64])//x1,y1
  289. copy(c2, data[64:len(data) - 32])//密文
  290. copy(c3, data[len(data) - 32:])//hash
  291. c := []byte{}
  292. c = append(c, c1...)
  293. c = append(c, c3...)
  294. c = append(c, c2...)
  295. data = c
  296. default:
  297. data = data[1:]
  298. }
  299. length := len(data) - 96
  300. curve := priv.Curve
  301. x := new(big.Int).SetBytes(data[:32])
  302. y := new(big.Int).SetBytes(data[32:64])
  303. x2, y2 := curve.ScalarMult(x, y, priv.D.Bytes())
  304. x2Buf := x2.Bytes()
  305. y2Buf := y2.Bytes()
  306. if n := len(x2Buf); n < 32 {
  307. x2Buf = append(zeroByteSlice()[:32-n], x2Buf...)
  308. }
  309. if n := len(y2Buf); n < 32 {
  310. y2Buf = append(zeroByteSlice()[:32-n], y2Buf...)
  311. }
  312. c, ok := kdf(length, x2Buf, y2Buf)
  313. if !ok {
  314. return nil, errors.New("Decrypt: failed to decrypt")
  315. }
  316. for i := 0; i < length; i++ {
  317. c[i] ^= data[i+96]
  318. }
  319. tm := []byte{}
  320. tm = append(tm, x2Buf...)
  321. tm = append(tm, c...)
  322. tm = append(tm, y2Buf...)
  323. h := sm3.Sm3Sum(tm)
  324. if bytes.Compare(h, data[64:96]) != 0 {
  325. return c, errors.New("Decrypt: failed to decrypt")
  326. }
  327. return c, nil
  328. }
  329. // keyExchange 为SM2密钥交换算法的第二部和第三步复用部分,协商的双方均调用此函数计算共同的字节串
  330. // klen: 密钥长度
  331. // ida, idb: 协商双方的标识,ida为密钥协商算法发起方标识,idb为响应方标识
  332. // pri: 函数调用者的密钥
  333. // pub: 对方的公钥
  334. // rpri: 函数调用者生成的临时SM2密钥
  335. // rpub: 对方发来的临时SM2公钥
  336. // thisIsA: 如果是A调用,文档中的协商第三步,设置为true,否则设置为false
  337. // 返回 k 为klen长度的字节串
  338. func keyExchange(klen int, ida, idb []byte, pri *PrivateKey, pub *PublicKey, rpri *PrivateKey, rpub *PublicKey, thisISA bool) (k, s1, s2 []byte, err error) {
  339. curve := P256Sm2()
  340. N := curve.Params().N
  341. x2hat := keXHat(rpri.PublicKey.X)
  342. x2rb := new(big.Int).Mul(x2hat, rpri.D)
  343. tbt := new(big.Int).Add(pri.D, x2rb)
  344. tb := new(big.Int).Mod(tbt, N)
  345. if !curve.IsOnCurve(rpub.X, rpub.Y) {
  346. err = errors.New("Ra not on curve")
  347. return
  348. }
  349. x1hat := keXHat(rpub.X)
  350. ramx1, ramy1 := curve.ScalarMult(rpub.X, rpub.Y, x1hat.Bytes())
  351. vxt, vyt := curve.Add(pub.X, pub.Y, ramx1, ramy1)
  352. vx, vy := curve.ScalarMult(vxt, vyt, tb.Bytes())
  353. pza := pub
  354. if thisISA {
  355. pza = &pri.PublicKey
  356. }
  357. za, err := ZA(pza, ida)
  358. if err != nil {
  359. return
  360. }
  361. zero := new(big.Int)
  362. if vx.Cmp(zero) == 0 || vy.Cmp(zero) == 0 {
  363. err = errors.New("V is infinite")
  364. }
  365. pzb := pub
  366. if !thisISA {
  367. pzb = &pri.PublicKey
  368. }
  369. zb, err := ZA(pzb, idb)
  370. k, ok := kdf(klen, vx.Bytes(), vy.Bytes(), za, zb)
  371. if !ok {
  372. err = errors.New("kdf: zero key")
  373. return
  374. }
  375. h1 := BytesCombine(vx.Bytes(), za, zb, rpub.X.Bytes(), rpub.Y.Bytes(), rpri.X.Bytes(), rpri.Y.Bytes())
  376. if !thisISA {
  377. h1 = BytesCombine(vx.Bytes(), za, zb, rpri.X.Bytes(), rpri.Y.Bytes(), rpub.X.Bytes(), rpub.Y.Bytes())
  378. }
  379. hash := sm3.Sm3Sum(h1)
  380. h2 := BytesCombine([]byte{0x02}, vy.Bytes(), hash)
  381. S1 := sm3.Sm3Sum(h2)
  382. h3 := BytesCombine([]byte{0x03}, vy.Bytes(), hash)
  383. S2 := sm3.Sm3Sum(h3)
  384. return k, S1, S2, nil
  385. }
  386. func msgHash(za, msg []byte) (*big.Int, error) {
  387. e := sm3.New()
  388. e.Write(za)
  389. e.Write(msg)
  390. return new(big.Int).SetBytes(e.Sum(nil)[:32]), nil
  391. }
  392. // ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA)
  393. func ZA(pub *PublicKey, uid []byte) ([]byte, error) {
  394. za := sm3.New()
  395. uidLen := len(uid)
  396. if uidLen >= 8192 {
  397. return []byte{}, errors.New("SM2: uid too large")
  398. }
  399. Entla := uint16(8 * uidLen)
  400. za.Write([]byte{byte((Entla >> 8) & 0xFF)})
  401. za.Write([]byte{byte(Entla & 0xFF)})
  402. if uidLen > 0 {
  403. za.Write(uid)
  404. }
  405. za.Write(sm2P256ToBig(&sm2P256.a).Bytes())
  406. za.Write(sm2P256.B.Bytes())
  407. za.Write(sm2P256.Gx.Bytes())
  408. za.Write(sm2P256.Gy.Bytes())
  409. xBuf := pub.X.Bytes()
  410. yBuf := pub.Y.Bytes()
  411. if n := len(xBuf); n < 32 {
  412. xBuf = append(zeroByteSlice()[:32-n], xBuf...)
  413. }
  414. if n := len(yBuf); n < 32 {
  415. yBuf = append(zeroByteSlice()[:32-n], yBuf...)
  416. }
  417. za.Write(xBuf)
  418. za.Write(yBuf)
  419. return za.Sum(nil)[:32], nil
  420. }
  421. // 32byte
  422. func zeroByteSlice() []byte {
  423. return []byte{
  424. 0, 0, 0, 0,
  425. 0, 0, 0, 0,
  426. 0, 0, 0, 0,
  427. 0, 0, 0, 0,
  428. 0, 0, 0, 0,
  429. 0, 0, 0, 0,
  430. 0, 0, 0, 0,
  431. 0, 0, 0, 0,
  432. }
  433. }
  434. /*
  435. sm2加密,返回asn.1编码格式的密文内容
  436. */
  437. func EncryptAsn1(pub *PublicKey, data []byte, rand io.Reader) ([]byte, error) {
  438. cipher, err := Encrypt(pub, data, rand,C1C3C2)
  439. if err != nil {
  440. return nil, err
  441. }
  442. return CipherMarshal(cipher)
  443. }
  444. /*
  445. sm2解密,解析asn.1编码格式的密文内容
  446. */
  447. func DecryptAsn1(pub *PrivateKey, data []byte) ([]byte, error) {
  448. cipher, err := CipherUnmarshal(data)
  449. if err != nil {
  450. return nil, err
  451. }
  452. return Decrypt(pub, cipher,C1C3C2)
  453. }
  454. /*
  455. *sm2密文转asn.1编码格式
  456. *sm2密文结构如下:
  457. * x
  458. * y
  459. * hash
  460. * CipherText
  461. */
  462. func CipherMarshal(data []byte) ([]byte, error) {
  463. data = data[1:]
  464. x := new(big.Int).SetBytes(data[:32])
  465. y := new(big.Int).SetBytes(data[32:64])
  466. hash := data[64:96]
  467. cipherText := data[96:]
  468. return asn1.Marshal(sm2Cipher{x, y, hash, cipherText})
  469. }
  470. /*
  471. sm2密文asn.1编码格式转C1|C3|C2拼接格式
  472. */
  473. func CipherUnmarshal(data []byte) ([]byte, error) {
  474. var cipher sm2Cipher
  475. _, err := asn1.Unmarshal(data, &cipher)
  476. if err != nil {
  477. return nil, err
  478. }
  479. x := cipher.XCoordinate.Bytes()
  480. y := cipher.YCoordinate.Bytes()
  481. hash := cipher.HASH
  482. if err != nil {
  483. return nil, err
  484. }
  485. cipherText := cipher.CipherText
  486. if err != nil {
  487. return nil, err
  488. }
  489. if n := len(x); n < 32 {
  490. x = append(zeroByteSlice()[:32-n], x...)
  491. }
  492. if n := len(y); n < 32 {
  493. y = append(zeroByteSlice()[:32-n], y...)
  494. }
  495. c := []byte{}
  496. c = append(c, x...) // x分量
  497. c = append(c, y...) // y分
  498. c = append(c, hash...) // x分量
  499. c = append(c, cipherText...) // y分
  500. return append([]byte{0x04}, c...), nil
  501. }
  502. // keXHat 计算 x = 2^w + (x & (2^w-1))
  503. // 密钥协商算法辅助函数
  504. func keXHat(x *big.Int) (xul *big.Int) {
  505. buf := x.Bytes()
  506. for i := 0; i < len(buf)-16; i++ {
  507. buf[i] = 0
  508. }
  509. if len(buf) >= 16 {
  510. c := buf[len(buf)-16]
  511. buf[len(buf)-16] = c & 0x7f
  512. }
  513. r := new(big.Int).SetBytes(buf)
  514. _2w := new(big.Int).SetBytes([]byte{
  515. 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
  516. 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
  517. return r.Add(r, _2w)
  518. }
  519. func BytesCombine(pBytes ...[]byte) []byte {
  520. len := len(pBytes)
  521. s := make([][]byte, len)
  522. for index := 0; index < len; index++ {
  523. s[index] = pBytes[index]
  524. }
  525. sep := []byte("")
  526. return bytes.Join(s, sep)
  527. }
  528. func intToBytes(x int) []byte {
  529. var buf = make([]byte, 4)
  530. binary.BigEndian.PutUint32(buf, uint32(x))
  531. return buf
  532. }
  533. func kdf(length int, x ...[]byte) ([]byte, bool) {
  534. var c []byte
  535. ct := 1
  536. h := sm3.New()
  537. for i, j := 0, (length+31)/32; i < j; i++ {
  538. h.Reset()
  539. for _, xx := range x {
  540. h.Write(xx)
  541. }
  542. h.Write(intToBytes(ct))
  543. hash := h.Sum(nil)
  544. if i+1 == j && length%32 != 0 {
  545. c = append(c, hash[:length%32]...)
  546. } else {
  547. c = append(c, hash...)
  548. }
  549. ct++
  550. }
  551. for i := 0; i < length; i++ {
  552. if c[i] != 0 {
  553. return c, true
  554. }
  555. }
  556. return c, false
  557. }
  558. func randFieldElement(c elliptic.Curve, random io.Reader) (k *big.Int, err error) {
  559. if random == nil {
  560. random = rand.Reader //If there is no external trusted random source,please use rand.Reader to instead of it.
  561. }
  562. params := c.Params()
  563. b := make([]byte, params.BitSize/8+8)
  564. _, err = io.ReadFull(random, b)
  565. if err != nil {
  566. return
  567. }
  568. k = new(big.Int).SetBytes(b)
  569. n := new(big.Int).Sub(params.N, one)
  570. k.Mod(k, n)
  571. k.Add(k, one)
  572. return
  573. }
  574. func GenerateKey(random io.Reader) (*PrivateKey, error) {
  575. c := P256Sm2()
  576. if random == nil {
  577. random = rand.Reader //If there is no external trusted random source,please use rand.Reader to instead of it.
  578. }
  579. params := c.Params()
  580. b := make([]byte, params.BitSize/8+8)
  581. _, err := io.ReadFull(random, b)
  582. if err != nil {
  583. return nil, err
  584. }
  585. k := new(big.Int).SetBytes(b)
  586. n := new(big.Int).Sub(params.N, two)
  587. k.Mod(k, n)
  588. k.Add(k, one)
  589. priv := new(PrivateKey)
  590. priv.PublicKey.Curve = c
  591. priv.D = k
  592. priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
  593. return priv, nil
  594. }
  595. type zr struct {
  596. io.Reader
  597. }
  598. func (z *zr) Read(dst []byte) (n int, err error) {
  599. for i := range dst {
  600. dst[i] = 0
  601. }
  602. return len(dst), nil
  603. }
  604. var zeroReader = &zr{}
  605. func getLastBit(a *big.Int) uint {
  606. return a.Bit(0)
  607. }
  608. // crypto.Decrypter
  609. func (priv *PrivateKey) Decrypt(_ io.Reader, msg []byte, _ crypto.DecrypterOpts) (plaintext []byte, err error) {
  610. return Decrypt(priv, msg,C1C3C2)
  611. }