123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674 |
- /*
- Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
- package sm2
- // reference to ecdsa
- import (
- "bytes"
- "crypto"
- "crypto/elliptic"
- "crypto/rand"
- "encoding/asn1"
- "encoding/binary"
- "errors"
- "io"
- "math/big"
- "github.com/tjfoc/gmsm/sm3"
- )
- var (
- default_uid = []byte{0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38}
- C1C3C2=0
- C1C2C3=1
- )
- type PublicKey struct {
- elliptic.Curve
- X, Y *big.Int
- }
- type PrivateKey struct {
- PublicKey
- D *big.Int
- }
- type sm2Signature struct {
- R, S *big.Int
- }
- type sm2Cipher struct {
- XCoordinate *big.Int
- YCoordinate *big.Int
- HASH []byte
- CipherText []byte
- }
- // The SM2's private key contains the public key
- func (priv *PrivateKey) Public() crypto.PublicKey {
- return &priv.PublicKey
- }
- var errZeroParam = errors.New("zero parameter")
- var one = new(big.Int).SetInt64(1)
- var two = new(big.Int).SetInt64(2)
- // sign format = 30 + len(z) + 02 + len(r) + r + 02 + len(s) + s, z being what follows its size, ie 02+len(r)+r+02+len(s)+s
- func (priv *PrivateKey) Sign(random io.Reader, msg []byte, signer crypto.SignerOpts) ([]byte, error) {
- r, s, err := Sm2Sign(priv, msg, nil, random)
- if err != nil {
- return nil, err
- }
- return asn1.Marshal(sm2Signature{r, s})
- }
- func (pub *PublicKey) Verify(msg []byte, sign []byte) bool {
- var sm2Sign sm2Signature
- _, err := asn1.Unmarshal(sign, &sm2Sign)
- if err != nil {
- return false
- }
- return Sm2Verify(pub, msg, default_uid, sm2Sign.R, sm2Sign.S)
- }
- func (pub *PublicKey) Sm3Digest(msg, uid []byte) ([]byte, error) {
- if len(uid) == 0 {
- uid = default_uid
- }
- za, err := ZA(pub, uid)
- if err != nil {
- return nil, err
- }
- e, err := msgHash(za, msg)
- if err != nil {
- return nil, err
- }
- return e.Bytes(), nil
- }
- //****************************Encryption algorithm****************************//
- func (pub *PublicKey) EncryptAsn1(data []byte, random io.Reader) ([]byte, error) {
- return EncryptAsn1(pub, data, random)
- }
- func (priv *PrivateKey) DecryptAsn1(data []byte) ([]byte, error) {
- return DecryptAsn1(priv, data)
- }
- //**************************Key agreement algorithm**************************//
- // KeyExchangeB 协商第二部,用户B调用, 返回共享密钥k
- func KeyExchangeB(klen int, ida, idb []byte, priB *PrivateKey, pubA *PublicKey, rpri *PrivateKey, rpubA *PublicKey) (k, s1, s2 []byte, err error) {
- return keyExchange(klen, ida, idb, priB, pubA, rpri, rpubA, false)
- }
- // KeyExchangeA 协商第二部,用户A调用,返回共享密钥k
- func KeyExchangeA(klen int, ida, idb []byte, priA *PrivateKey, pubB *PublicKey, rpri *PrivateKey, rpubB *PublicKey) (k, s1, s2 []byte, err error) {
- return keyExchange(klen, ida, idb, priA, pubB, rpri, rpubB, true)
- }
- //****************************************************************************//
- func Sm2Sign(priv *PrivateKey, msg, uid []byte, random io.Reader) (r, s *big.Int, err error) {
- digest, err := priv.PublicKey.Sm3Digest(msg, uid)
- if err != nil {
- return nil, nil, err
- }
- e := new(big.Int).SetBytes(digest)
- c := priv.PublicKey.Curve
- N := c.Params().N
- if N.Sign() == 0 {
- return nil, nil, errZeroParam
- }
- var k *big.Int
- for { // 调整算法细节以实现SM2
- for {
- k, err = randFieldElement(c, random)
- if err != nil {
- r = nil
- return
- }
- r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
- r.Add(r, e)
- r.Mod(r, N)
- if r.Sign() != 0 {
- if t := new(big.Int).Add(r, k); t.Cmp(N) != 0 {
- break
- }
- }
- }
- rD := new(big.Int).Mul(priv.D, r)
- s = new(big.Int).Sub(k, rD)
- d1 := new(big.Int).Add(priv.D, one)
- d1Inv := new(big.Int).ModInverse(d1, N)
- s.Mul(s, d1Inv)
- s.Mod(s, N)
- if s.Sign() != 0 {
- break
- }
- }
- return
- }
- func Sm2Verify(pub *PublicKey, msg, uid []byte, r, s *big.Int) bool {
- c := pub.Curve
- N := c.Params().N
- one := new(big.Int).SetInt64(1)
- if r.Cmp(one) < 0 || s.Cmp(one) < 0 {
- return false
- }
- if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
- return false
- }
- if len(uid) == 0 {
- uid = default_uid
- }
- za, err := ZA(pub, uid)
- if err != nil {
- return false
- }
- e, err := msgHash(za, msg)
- if err != nil {
- return false
- }
- t := new(big.Int).Add(r, s)
- t.Mod(t, N)
- if t.Sign() == 0 {
- return false
- }
- var x *big.Int
- x1, y1 := c.ScalarBaseMult(s.Bytes())
- x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
- x, _ = c.Add(x1, y1, x2, y2)
- x.Add(x, e)
- x.Mod(x, N)
- return x.Cmp(r) == 0
- }
- /*
- za, err := ZA(pub, uid)
- if err != nil {
- return
- }
- e, err := msgHash(za, msg)
- hash=e.getBytes()
- */
- func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool {
- c := pub.Curve
- N := c.Params().N
- if r.Sign() <= 0 || s.Sign() <= 0 {
- return false
- }
- if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
- return false
- }
- // 调整算法细节以实现SM2
- t := new(big.Int).Add(r, s)
- t.Mod(t, N)
- if t.Sign() == 0 {
- return false
- }
- var x *big.Int
- x1, y1 := c.ScalarBaseMult(s.Bytes())
- x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
- x, _ = c.Add(x1, y1, x2, y2)
- e := new(big.Int).SetBytes(hash)
- x.Add(x, e)
- x.Mod(x, N)
- return x.Cmp(r) == 0
- }
- /*
- * sm2密文结构如下:
- * x
- * y
- * hash
- * CipherText
- */
- func Encrypt(pub *PublicKey, data []byte, random io.Reader,mode int) ([]byte, error) {
- length := len(data)
- for {
- c := []byte{}
- curve := pub.Curve
- k, err := randFieldElement(curve, random)
- if err != nil {
- return nil, err
- }
- x1, y1 := curve.ScalarBaseMult(k.Bytes())
- x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
- x1Buf := x1.Bytes()
- y1Buf := y1.Bytes()
- x2Buf := x2.Bytes()
- y2Buf := y2.Bytes()
- if n := len(x1Buf); n < 32 {
- x1Buf = append(zeroByteSlice()[:32-n], x1Buf...)
- }
- if n := len(y1Buf); n < 32 {
- y1Buf = append(zeroByteSlice()[:32-n], y1Buf...)
- }
- if n := len(x2Buf); n < 32 {
- x2Buf = append(zeroByteSlice()[:32-n], x2Buf...)
- }
- if n := len(y2Buf); n < 32 {
- y2Buf = append(zeroByteSlice()[:32-n], y2Buf...)
- }
- c = append(c, x1Buf...) // x分量
- c = append(c, y1Buf...) // y分量
- tm := []byte{}
- tm = append(tm, x2Buf...)
- tm = append(tm, data...)
- tm = append(tm, y2Buf...)
- h := sm3.Sm3Sum(tm)
- c = append(c, h...)
- ct, ok := kdf(length, x2Buf, y2Buf) // 密文
- if !ok {
- continue
- }
- c = append(c, ct...)
- for i := 0; i < length; i++ {
- c[96+i] ^= data[i]
- }
- switch mode{
-
- case C1C3C2:
- return append([]byte{0x04}, c...), nil
- case C1C2C3:
- c1 := make([]byte, 64)
- c2 := make([]byte, len(c) - 96)
- c3 := make([]byte, 32)
- copy(c1, c[:64])//x1,y1
- copy(c3, c[64:96])//hash
- copy(c2, c[96:])//密文
- ciphertext := []byte{}
- ciphertext = append(ciphertext, c1...)
- ciphertext = append(ciphertext, c2...)
- ciphertext = append(ciphertext, c3...)
- return append([]byte{0x04}, ciphertext...), nil
- default:
- return append([]byte{0x04}, c...), nil
- }
- }
- }
- func Decrypt(priv *PrivateKey, data []byte,mode int) ([]byte, error) {
- switch mode {
- case C1C3C2:
- data = data[1:]
- case C1C2C3:
- data = data[1:]
- c1 := make([]byte, 64)
- c2 := make([]byte, len(data) - 96)
- c3 := make([]byte, 32)
- copy(c1, data[:64])//x1,y1
- copy(c2, data[64:len(data) - 32])//密文
- copy(c3, data[len(data) - 32:])//hash
- c := []byte{}
- c = append(c, c1...)
- c = append(c, c3...)
- c = append(c, c2...)
- data = c
- default:
- data = data[1:]
- }
- length := len(data) - 96
- curve := priv.Curve
- x := new(big.Int).SetBytes(data[:32])
- y := new(big.Int).SetBytes(data[32:64])
- x2, y2 := curve.ScalarMult(x, y, priv.D.Bytes())
- x2Buf := x2.Bytes()
- y2Buf := y2.Bytes()
- if n := len(x2Buf); n < 32 {
- x2Buf = append(zeroByteSlice()[:32-n], x2Buf...)
- }
- if n := len(y2Buf); n < 32 {
- y2Buf = append(zeroByteSlice()[:32-n], y2Buf...)
- }
- c, ok := kdf(length, x2Buf, y2Buf)
- if !ok {
- return nil, errors.New("Decrypt: failed to decrypt")
- }
- for i := 0; i < length; i++ {
- c[i] ^= data[i+96]
- }
- tm := []byte{}
- tm = append(tm, x2Buf...)
- tm = append(tm, c...)
- tm = append(tm, y2Buf...)
- h := sm3.Sm3Sum(tm)
- if bytes.Compare(h, data[64:96]) != 0 {
- return c, errors.New("Decrypt: failed to decrypt")
- }
- return c, nil
- }
- // keyExchange 为SM2密钥交换算法的第二部和第三步复用部分,协商的双方均调用此函数计算共同的字节串
- // klen: 密钥长度
- // ida, idb: 协商双方的标识,ida为密钥协商算法发起方标识,idb为响应方标识
- // pri: 函数调用者的密钥
- // pub: 对方的公钥
- // rpri: 函数调用者生成的临时SM2密钥
- // rpub: 对方发来的临时SM2公钥
- // thisIsA: 如果是A调用,文档中的协商第三步,设置为true,否则设置为false
- // 返回 k 为klen长度的字节串
- func keyExchange(klen int, ida, idb []byte, pri *PrivateKey, pub *PublicKey, rpri *PrivateKey, rpub *PublicKey, thisISA bool) (k, s1, s2 []byte, err error) {
- curve := P256Sm2()
- N := curve.Params().N
- x2hat := keXHat(rpri.PublicKey.X)
- x2rb := new(big.Int).Mul(x2hat, rpri.D)
- tbt := new(big.Int).Add(pri.D, x2rb)
- tb := new(big.Int).Mod(tbt, N)
- if !curve.IsOnCurve(rpub.X, rpub.Y) {
- err = errors.New("Ra not on curve")
- return
- }
- x1hat := keXHat(rpub.X)
- ramx1, ramy1 := curve.ScalarMult(rpub.X, rpub.Y, x1hat.Bytes())
- vxt, vyt := curve.Add(pub.X, pub.Y, ramx1, ramy1)
- vx, vy := curve.ScalarMult(vxt, vyt, tb.Bytes())
- pza := pub
- if thisISA {
- pza = &pri.PublicKey
- }
- za, err := ZA(pza, ida)
- if err != nil {
- return
- }
- zero := new(big.Int)
- if vx.Cmp(zero) == 0 || vy.Cmp(zero) == 0 {
- err = errors.New("V is infinite")
- }
- pzb := pub
- if !thisISA {
- pzb = &pri.PublicKey
- }
- zb, err := ZA(pzb, idb)
- k, ok := kdf(klen, vx.Bytes(), vy.Bytes(), za, zb)
- if !ok {
- err = errors.New("kdf: zero key")
- return
- }
- h1 := BytesCombine(vx.Bytes(), za, zb, rpub.X.Bytes(), rpub.Y.Bytes(), rpri.X.Bytes(), rpri.Y.Bytes())
- if !thisISA {
- h1 = BytesCombine(vx.Bytes(), za, zb, rpri.X.Bytes(), rpri.Y.Bytes(), rpub.X.Bytes(), rpub.Y.Bytes())
- }
- hash := sm3.Sm3Sum(h1)
- h2 := BytesCombine([]byte{0x02}, vy.Bytes(), hash)
- S1 := sm3.Sm3Sum(h2)
- h3 := BytesCombine([]byte{0x03}, vy.Bytes(), hash)
- S2 := sm3.Sm3Sum(h3)
- return k, S1, S2, nil
- }
- func msgHash(za, msg []byte) (*big.Int, error) {
- e := sm3.New()
- e.Write(za)
- e.Write(msg)
- return new(big.Int).SetBytes(e.Sum(nil)[:32]), nil
- }
- // ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA)
- func ZA(pub *PublicKey, uid []byte) ([]byte, error) {
- za := sm3.New()
- uidLen := len(uid)
- if uidLen >= 8192 {
- return []byte{}, errors.New("SM2: uid too large")
- }
- Entla := uint16(8 * uidLen)
- za.Write([]byte{byte((Entla >> 8) & 0xFF)})
- za.Write([]byte{byte(Entla & 0xFF)})
- if uidLen > 0 {
- za.Write(uid)
- }
- za.Write(sm2P256ToBig(&sm2P256.a).Bytes())
- za.Write(sm2P256.B.Bytes())
- za.Write(sm2P256.Gx.Bytes())
- za.Write(sm2P256.Gy.Bytes())
- xBuf := pub.X.Bytes()
- yBuf := pub.Y.Bytes()
- if n := len(xBuf); n < 32 {
- xBuf = append(zeroByteSlice()[:32-n], xBuf...)
- }
- if n := len(yBuf); n < 32 {
- yBuf = append(zeroByteSlice()[:32-n], yBuf...)
- }
- za.Write(xBuf)
- za.Write(yBuf)
- return za.Sum(nil)[:32], nil
- }
- // 32byte
- func zeroByteSlice() []byte {
- return []byte{
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- }
- }
- /*
- sm2加密,返回asn.1编码格式的密文内容
- */
- func EncryptAsn1(pub *PublicKey, data []byte, rand io.Reader) ([]byte, error) {
- cipher, err := Encrypt(pub, data, rand,C1C3C2)
- if err != nil {
- return nil, err
- }
- return CipherMarshal(cipher)
- }
- /*
- sm2解密,解析asn.1编码格式的密文内容
- */
- func DecryptAsn1(pub *PrivateKey, data []byte) ([]byte, error) {
- cipher, err := CipherUnmarshal(data)
- if err != nil {
- return nil, err
- }
- return Decrypt(pub, cipher,C1C3C2)
- }
- /*
- *sm2密文转asn.1编码格式
- *sm2密文结构如下:
- * x
- * y
- * hash
- * CipherText
- */
- func CipherMarshal(data []byte) ([]byte, error) {
- data = data[1:]
- x := new(big.Int).SetBytes(data[:32])
- y := new(big.Int).SetBytes(data[32:64])
- hash := data[64:96]
- cipherText := data[96:]
- return asn1.Marshal(sm2Cipher{x, y, hash, cipherText})
- }
- /*
- sm2密文asn.1编码格式转C1|C3|C2拼接格式
- */
- func CipherUnmarshal(data []byte) ([]byte, error) {
- var cipher sm2Cipher
- _, err := asn1.Unmarshal(data, &cipher)
- if err != nil {
- return nil, err
- }
- x := cipher.XCoordinate.Bytes()
- y := cipher.YCoordinate.Bytes()
- hash := cipher.HASH
- if err != nil {
- return nil, err
- }
- cipherText := cipher.CipherText
- if err != nil {
- return nil, err
- }
- if n := len(x); n < 32 {
- x = append(zeroByteSlice()[:32-n], x...)
- }
- if n := len(y); n < 32 {
- y = append(zeroByteSlice()[:32-n], y...)
- }
- c := []byte{}
- c = append(c, x...) // x分量
- c = append(c, y...) // y分
- c = append(c, hash...) // x分量
- c = append(c, cipherText...) // y分
- return append([]byte{0x04}, c...), nil
- }
- // keXHat 计算 x = 2^w + (x & (2^w-1))
- // 密钥协商算法辅助函数
- func keXHat(x *big.Int) (xul *big.Int) {
- buf := x.Bytes()
- for i := 0; i < len(buf)-16; i++ {
- buf[i] = 0
- }
- if len(buf) >= 16 {
- c := buf[len(buf)-16]
- buf[len(buf)-16] = c & 0x7f
- }
- r := new(big.Int).SetBytes(buf)
- _2w := new(big.Int).SetBytes([]byte{
- 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
- return r.Add(r, _2w)
- }
- func BytesCombine(pBytes ...[]byte) []byte {
- len := len(pBytes)
- s := make([][]byte, len)
- for index := 0; index < len; index++ {
- s[index] = pBytes[index]
- }
- sep := []byte("")
- return bytes.Join(s, sep)
- }
- func intToBytes(x int) []byte {
- var buf = make([]byte, 4)
- binary.BigEndian.PutUint32(buf, uint32(x))
- return buf
- }
- func kdf(length int, x ...[]byte) ([]byte, bool) {
- var c []byte
- ct := 1
- h := sm3.New()
- for i, j := 0, (length+31)/32; i < j; i++ {
- h.Reset()
- for _, xx := range x {
- h.Write(xx)
- }
- h.Write(intToBytes(ct))
- hash := h.Sum(nil)
- if i+1 == j && length%32 != 0 {
- c = append(c, hash[:length%32]...)
- } else {
- c = append(c, hash...)
- }
- ct++
- }
- for i := 0; i < length; i++ {
- if c[i] != 0 {
- return c, true
- }
- }
- return c, false
- }
- func randFieldElement(c elliptic.Curve, random io.Reader) (k *big.Int, err error) {
- if random == nil {
- random = rand.Reader //If there is no external trusted random source,please use rand.Reader to instead of it.
- }
- params := c.Params()
- b := make([]byte, params.BitSize/8+8)
- _, err = io.ReadFull(random, b)
- if err != nil {
- return
- }
- k = new(big.Int).SetBytes(b)
- n := new(big.Int).Sub(params.N, one)
- k.Mod(k, n)
- k.Add(k, one)
- return
- }
- func GenerateKey(random io.Reader) (*PrivateKey, error) {
- c := P256Sm2()
- if random == nil {
- random = rand.Reader //If there is no external trusted random source,please use rand.Reader to instead of it.
- }
- params := c.Params()
- b := make([]byte, params.BitSize/8+8)
- _, err := io.ReadFull(random, b)
- if err != nil {
- return nil, err
- }
- k := new(big.Int).SetBytes(b)
- n := new(big.Int).Sub(params.N, two)
- k.Mod(k, n)
- k.Add(k, one)
- priv := new(PrivateKey)
- priv.PublicKey.Curve = c
- priv.D = k
- priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
- return priv, nil
- }
- type zr struct {
- io.Reader
- }
- func (z *zr) Read(dst []byte) (n int, err error) {
- for i := range dst {
- dst[i] = 0
- }
- return len(dst), nil
- }
- var zeroReader = &zr{}
- func getLastBit(a *big.Int) uint {
- return a.Bit(0)
- }
- // crypto.Decrypter
- func (priv *PrivateKey) Decrypt(_ io.Reader, msg []byte, _ crypto.DecrypterOpts) (plaintext []byte, err error) {
- return Decrypt(priv, msg,C1C3C2)
- }
|