xref: /openssl/crypto/sm2/sm2_sign.c (revision e257d3e7)
1 /*
2  * Copyright 2017-2021 The OpenSSL Project Authors. All Rights Reserved.
3  * Copyright 2017 Ribose Inc. All Rights Reserved.
4  * Ported from Ribose contributions from Botan.
5  *
6  * Licensed under the Apache License 2.0 (the "License").  You may not use
7  * this file except in compliance with the License.  You can obtain a copy
8  * in the file LICENSE in the source distribution or at
9  * https://www.openssl.org/source/license.html
10  */
11 
12 #include "internal/deprecated.h"
13 
14 #include "crypto/sm2.h"
15 #include "crypto/sm2err.h"
16 #include "crypto/ec.h" /* ossl_ec_group_do_inverse_ord() */
17 #include "internal/numbers.h"
18 #include <openssl/err.h>
19 #include <openssl/evp.h>
20 #include <openssl/bn.h>
21 #include <string.h>
22 
ossl_sm2_compute_z_digest(uint8_t * out,const EVP_MD * digest,const uint8_t * id,const size_t id_len,const EC_KEY * key)23 int ossl_sm2_compute_z_digest(uint8_t *out,
24                               const EVP_MD *digest,
25                               const uint8_t *id,
26                               const size_t id_len,
27                               const EC_KEY *key)
28 {
29     int rc = 0;
30     const EC_GROUP *group = EC_KEY_get0_group(key);
31     BN_CTX *ctx = NULL;
32     EVP_MD_CTX *hash = NULL;
33     BIGNUM *p = NULL;
34     BIGNUM *a = NULL;
35     BIGNUM *b = NULL;
36     BIGNUM *xG = NULL;
37     BIGNUM *yG = NULL;
38     BIGNUM *xA = NULL;
39     BIGNUM *yA = NULL;
40     int p_bytes = 0;
41     uint8_t *buf = NULL;
42     uint16_t entl = 0;
43     uint8_t e_byte = 0;
44 
45     hash = EVP_MD_CTX_new();
46     ctx = BN_CTX_new_ex(ossl_ec_key_get_libctx(key));
47     if (hash == NULL || ctx == NULL) {
48         ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE);
49         goto done;
50     }
51 
52     p = BN_CTX_get(ctx);
53     a = BN_CTX_get(ctx);
54     b = BN_CTX_get(ctx);
55     xG = BN_CTX_get(ctx);
56     yG = BN_CTX_get(ctx);
57     xA = BN_CTX_get(ctx);
58     yA = BN_CTX_get(ctx);
59 
60     if (yA == NULL) {
61         ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE);
62         goto done;
63     }
64 
65     if (!EVP_DigestInit(hash, digest)) {
66         ERR_raise(ERR_LIB_SM2, ERR_R_EVP_LIB);
67         goto done;
68     }
69 
70     /* Z = h(ENTL || ID || a || b || xG || yG || xA || yA) */
71 
72     if (id_len >= (UINT16_MAX / 8)) {
73         /* too large */
74         ERR_raise(ERR_LIB_SM2, SM2_R_ID_TOO_LARGE);
75         goto done;
76     }
77 
78     entl = (uint16_t)(8 * id_len);
79 
80     e_byte = entl >> 8;
81     if (!EVP_DigestUpdate(hash, &e_byte, 1)) {
82         ERR_raise(ERR_LIB_SM2, ERR_R_EVP_LIB);
83         goto done;
84     }
85     e_byte = entl & 0xFF;
86     if (!EVP_DigestUpdate(hash, &e_byte, 1)) {
87         ERR_raise(ERR_LIB_SM2, ERR_R_EVP_LIB);
88         goto done;
89     }
90 
91     if (id_len > 0 && !EVP_DigestUpdate(hash, id, id_len)) {
92         ERR_raise(ERR_LIB_SM2, ERR_R_EVP_LIB);
93         goto done;
94     }
95 
96     if (!EC_GROUP_get_curve(group, p, a, b, ctx)) {
97         ERR_raise(ERR_LIB_SM2, ERR_R_EC_LIB);
98         goto done;
99     }
100 
101     p_bytes = BN_num_bytes(p);
102     buf = OPENSSL_zalloc(p_bytes);
103     if (buf == NULL) {
104         ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE);
105         goto done;
106     }
107 
108     if (BN_bn2binpad(a, buf, p_bytes) < 0
109             || !EVP_DigestUpdate(hash, buf, p_bytes)
110             || BN_bn2binpad(b, buf, p_bytes) < 0
111             || !EVP_DigestUpdate(hash, buf, p_bytes)
112             || !EC_POINT_get_affine_coordinates(group,
113                                                 EC_GROUP_get0_generator(group),
114                                                 xG, yG, ctx)
115             || BN_bn2binpad(xG, buf, p_bytes) < 0
116             || !EVP_DigestUpdate(hash, buf, p_bytes)
117             || BN_bn2binpad(yG, buf, p_bytes) < 0
118             || !EVP_DigestUpdate(hash, buf, p_bytes)
119             || !EC_POINT_get_affine_coordinates(group,
120                                                 EC_KEY_get0_public_key(key),
121                                                 xA, yA, ctx)
122             || BN_bn2binpad(xA, buf, p_bytes) < 0
123             || !EVP_DigestUpdate(hash, buf, p_bytes)
124             || BN_bn2binpad(yA, buf, p_bytes) < 0
125             || !EVP_DigestUpdate(hash, buf, p_bytes)
126             || !EVP_DigestFinal(hash, out, NULL)) {
127         ERR_raise(ERR_LIB_SM2, ERR_R_INTERNAL_ERROR);
128         goto done;
129     }
130 
131     rc = 1;
132 
133  done:
134     OPENSSL_free(buf);
135     BN_CTX_free(ctx);
136     EVP_MD_CTX_free(hash);
137     return rc;
138 }
139 
sm2_compute_msg_hash(const EVP_MD * digest,const EC_KEY * key,const uint8_t * id,const size_t id_len,const uint8_t * msg,size_t msg_len)140 static BIGNUM *sm2_compute_msg_hash(const EVP_MD *digest,
141                                     const EC_KEY *key,
142                                     const uint8_t *id,
143                                     const size_t id_len,
144                                     const uint8_t *msg, size_t msg_len)
145 {
146     EVP_MD_CTX *hash = EVP_MD_CTX_new();
147     const int md_size = EVP_MD_get_size(digest);
148     uint8_t *z = NULL;
149     BIGNUM *e = NULL;
150     EVP_MD *fetched_digest = NULL;
151     OSSL_LIB_CTX *libctx = ossl_ec_key_get_libctx(key);
152     const char *propq = ossl_ec_key_get0_propq(key);
153 
154     if (md_size < 0) {
155         ERR_raise(ERR_LIB_SM2, SM2_R_INVALID_DIGEST);
156         goto done;
157     }
158 
159     z = OPENSSL_zalloc(md_size);
160     if (hash == NULL || z == NULL) {
161         ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE);
162         goto done;
163     }
164 
165     fetched_digest = EVP_MD_fetch(libctx, EVP_MD_get0_name(digest), propq);
166     if (fetched_digest == NULL) {
167         ERR_raise(ERR_LIB_SM2, ERR_R_INTERNAL_ERROR);
168         goto done;
169     }
170 
171     if (!ossl_sm2_compute_z_digest(z, fetched_digest, id, id_len, key)) {
172         /* SM2err already called */
173         goto done;
174     }
175 
176     if (!EVP_DigestInit(hash, fetched_digest)
177             || !EVP_DigestUpdate(hash, z, md_size)
178             || !EVP_DigestUpdate(hash, msg, msg_len)
179                /* reuse z buffer to hold H(Z || M) */
180             || !EVP_DigestFinal(hash, z, NULL)) {
181         ERR_raise(ERR_LIB_SM2, ERR_R_EVP_LIB);
182         goto done;
183     }
184 
185     e = BN_bin2bn(z, md_size, NULL);
186     if (e == NULL)
187         ERR_raise(ERR_LIB_SM2, ERR_R_INTERNAL_ERROR);
188 
189  done:
190     EVP_MD_free(fetched_digest);
191     OPENSSL_free(z);
192     EVP_MD_CTX_free(hash);
193     return e;
194 }
195 
sm2_sig_gen(const EC_KEY * key,const BIGNUM * e)196 static ECDSA_SIG *sm2_sig_gen(const EC_KEY *key, const BIGNUM *e)
197 {
198     const BIGNUM *dA = EC_KEY_get0_private_key(key);
199     const EC_GROUP *group = EC_KEY_get0_group(key);
200     const BIGNUM *order = EC_GROUP_get0_order(group);
201     ECDSA_SIG *sig = NULL;
202     EC_POINT *kG = NULL;
203     BN_CTX *ctx = NULL;
204     BIGNUM *k = NULL;
205     BIGNUM *rk = NULL;
206     BIGNUM *r = NULL;
207     BIGNUM *s = NULL;
208     BIGNUM *x1 = NULL;
209     BIGNUM *tmp = NULL;
210     OSSL_LIB_CTX *libctx = ossl_ec_key_get_libctx(key);
211 
212     kG = EC_POINT_new(group);
213     ctx = BN_CTX_new_ex(libctx);
214     if (kG == NULL || ctx == NULL) {
215         ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE);
216         goto done;
217     }
218 
219     BN_CTX_start(ctx);
220     k = BN_CTX_get(ctx);
221     rk = BN_CTX_get(ctx);
222     x1 = BN_CTX_get(ctx);
223     tmp = BN_CTX_get(ctx);
224     if (tmp == NULL) {
225         ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE);
226         goto done;
227     }
228 
229     /*
230      * These values are returned and so should not be allocated out of the
231      * context
232      */
233     r = BN_new();
234     s = BN_new();
235 
236     if (r == NULL || s == NULL) {
237         ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE);
238         goto done;
239     }
240 
241     /*
242      * A3: Generate a random number k in [1,n-1] using random number generators;
243      * A4: Compute (x1,y1)=[k]G, and convert the type of data x1 to be integer
244      *     as specified in clause 4.2.8 of GM/T 0003.1-2012;
245      * A5: Compute r=(e+x1) mod n. If r=0 or r+k=n, then go to A3;
246      * A6: Compute s=(1/(1+dA)*(k-r*dA)) mod n. If s=0, then go to A3;
247      * A7: Convert the type of data (r,s) to be bit strings according to the details
248      *     in clause 4.2.2 of GM/T 0003.1-2012. Then the signature of message M is (r,s).
249      */
250     for (;;) {
251         if (!BN_priv_rand_range_ex(k, order, 0, ctx)) {
252             ERR_raise(ERR_LIB_SM2, ERR_R_INTERNAL_ERROR);
253             goto done;
254         }
255 
256         if (!EC_POINT_mul(group, kG, k, NULL, NULL, ctx)
257                 || !EC_POINT_get_affine_coordinates(group, kG, x1, NULL,
258                                                     ctx)
259                 || !BN_mod_add(r, e, x1, order, ctx)) {
260             ERR_raise(ERR_LIB_SM2, ERR_R_INTERNAL_ERROR);
261             goto done;
262         }
263 
264         /* try again if r == 0 or r+k == n */
265         if (BN_is_zero(r))
266             continue;
267 
268         if (!BN_add(rk, r, k)) {
269             ERR_raise(ERR_LIB_SM2, ERR_R_INTERNAL_ERROR);
270             goto done;
271         }
272 
273         if (BN_cmp(rk, order) == 0)
274             continue;
275 
276         if (!BN_add(s, dA, BN_value_one())
277                 || !ossl_ec_group_do_inverse_ord(group, s, s, ctx)
278                 || !BN_mod_mul(tmp, dA, r, order, ctx)
279                 || !BN_sub(tmp, k, tmp)
280                 || !BN_mod_mul(s, s, tmp, order, ctx)) {
281             ERR_raise(ERR_LIB_SM2, ERR_R_BN_LIB);
282             goto done;
283         }
284 
285         /* try again if s == 0 */
286         if (BN_is_zero(s))
287             continue;
288 
289         sig = ECDSA_SIG_new();
290         if (sig == NULL) {
291             ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE);
292             goto done;
293         }
294 
295          /* takes ownership of r and s */
296         ECDSA_SIG_set0(sig, r, s);
297         break;
298     }
299 
300  done:
301     if (sig == NULL) {
302         BN_free(r);
303         BN_free(s);
304     }
305 
306     BN_CTX_free(ctx);
307     EC_POINT_free(kG);
308     return sig;
309 }
310 
sm2_sig_verify(const EC_KEY * key,const ECDSA_SIG * sig,const BIGNUM * e)311 static int sm2_sig_verify(const EC_KEY *key, const ECDSA_SIG *sig,
312                           const BIGNUM *e)
313 {
314     int ret = 0;
315     const EC_GROUP *group = EC_KEY_get0_group(key);
316     const BIGNUM *order = EC_GROUP_get0_order(group);
317     BN_CTX *ctx = NULL;
318     EC_POINT *pt = NULL;
319     BIGNUM *t = NULL;
320     BIGNUM *x1 = NULL;
321     const BIGNUM *r = NULL;
322     const BIGNUM *s = NULL;
323     OSSL_LIB_CTX *libctx = ossl_ec_key_get_libctx(key);
324 
325     ctx = BN_CTX_new_ex(libctx);
326     pt = EC_POINT_new(group);
327     if (ctx == NULL || pt == NULL) {
328         ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE);
329         goto done;
330     }
331 
332     BN_CTX_start(ctx);
333     t = BN_CTX_get(ctx);
334     x1 = BN_CTX_get(ctx);
335     if (x1 == NULL) {
336         ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE);
337         goto done;
338     }
339 
340     /*
341      * B1: verify whether r' in [1,n-1], verification failed if not
342      * B2: verify whether s' in [1,n-1], verification failed if not
343      * B3: set M'~=ZA || M'
344      * B4: calculate e'=Hv(M'~)
345      * B5: calculate t = (r' + s') modn, verification failed if t=0
346      * B6: calculate the point (x1', y1')=[s']G + [t]PA
347      * B7: calculate R=(e'+x1') modn, verification pass if yes, otherwise failed
348      */
349 
350     ECDSA_SIG_get0(sig, &r, &s);
351 
352     if (BN_cmp(r, BN_value_one()) < 0
353             || BN_cmp(s, BN_value_one()) < 0
354             || BN_cmp(order, r) <= 0
355             || BN_cmp(order, s) <= 0) {
356         ERR_raise(ERR_LIB_SM2, SM2_R_BAD_SIGNATURE);
357         goto done;
358     }
359 
360     if (!BN_mod_add(t, r, s, order, ctx)) {
361         ERR_raise(ERR_LIB_SM2, ERR_R_BN_LIB);
362         goto done;
363     }
364 
365     if (BN_is_zero(t)) {
366         ERR_raise(ERR_LIB_SM2, SM2_R_BAD_SIGNATURE);
367         goto done;
368     }
369 
370     if (!EC_POINT_mul(group, pt, s, EC_KEY_get0_public_key(key), t, ctx)
371             || !EC_POINT_get_affine_coordinates(group, pt, x1, NULL, ctx)) {
372         ERR_raise(ERR_LIB_SM2, ERR_R_EC_LIB);
373         goto done;
374     }
375 
376     if (!BN_mod_add(t, e, x1, order, ctx)) {
377         ERR_raise(ERR_LIB_SM2, ERR_R_BN_LIB);
378         goto done;
379     }
380 
381     if (BN_cmp(r, t) == 0)
382         ret = 1;
383 
384  done:
385     EC_POINT_free(pt);
386     BN_CTX_free(ctx);
387     return ret;
388 }
389 
ossl_sm2_do_sign(const EC_KEY * key,const EVP_MD * digest,const uint8_t * id,const size_t id_len,const uint8_t * msg,size_t msg_len)390 ECDSA_SIG *ossl_sm2_do_sign(const EC_KEY *key,
391                             const EVP_MD *digest,
392                             const uint8_t *id,
393                             const size_t id_len,
394                             const uint8_t *msg, size_t msg_len)
395 {
396     BIGNUM *e = NULL;
397     ECDSA_SIG *sig = NULL;
398 
399     e = sm2_compute_msg_hash(digest, key, id, id_len, msg, msg_len);
400     if (e == NULL) {
401         /* SM2err already called */
402         goto done;
403     }
404 
405     sig = sm2_sig_gen(key, e);
406 
407  done:
408     BN_free(e);
409     return sig;
410 }
411 
ossl_sm2_do_verify(const EC_KEY * key,const EVP_MD * digest,const ECDSA_SIG * sig,const uint8_t * id,const size_t id_len,const uint8_t * msg,size_t msg_len)412 int ossl_sm2_do_verify(const EC_KEY *key,
413                        const EVP_MD *digest,
414                        const ECDSA_SIG *sig,
415                        const uint8_t *id,
416                        const size_t id_len,
417                        const uint8_t *msg, size_t msg_len)
418 {
419     BIGNUM *e = NULL;
420     int ret = 0;
421 
422     e = sm2_compute_msg_hash(digest, key, id, id_len, msg, msg_len);
423     if (e == NULL) {
424         /* SM2err already called */
425         goto done;
426     }
427 
428     ret = sm2_sig_verify(key, sig, e);
429 
430  done:
431     BN_free(e);
432     return ret;
433 }
434 
ossl_sm2_internal_sign(const unsigned char * dgst,int dgstlen,unsigned char * sig,unsigned int * siglen,EC_KEY * eckey)435 int ossl_sm2_internal_sign(const unsigned char *dgst, int dgstlen,
436                            unsigned char *sig, unsigned int *siglen,
437                            EC_KEY *eckey)
438 {
439     BIGNUM *e = NULL;
440     ECDSA_SIG *s = NULL;
441     int sigleni;
442     int ret = -1;
443 
444     e = BN_bin2bn(dgst, dgstlen, NULL);
445     if (e == NULL) {
446        ERR_raise(ERR_LIB_SM2, ERR_R_BN_LIB);
447        goto done;
448     }
449 
450     s = sm2_sig_gen(eckey, e);
451     if (s == NULL) {
452         ERR_raise(ERR_LIB_SM2, ERR_R_INTERNAL_ERROR);
453         goto done;
454     }
455 
456     sigleni = i2d_ECDSA_SIG(s, &sig);
457     if (sigleni < 0) {
458        ERR_raise(ERR_LIB_SM2, ERR_R_INTERNAL_ERROR);
459        goto done;
460     }
461     *siglen = (unsigned int)sigleni;
462 
463     ret = 1;
464 
465  done:
466     ECDSA_SIG_free(s);
467     BN_free(e);
468     return ret;
469 }
470 
ossl_sm2_internal_verify(const unsigned char * dgst,int dgstlen,const unsigned char * sig,int sig_len,EC_KEY * eckey)471 int ossl_sm2_internal_verify(const unsigned char *dgst, int dgstlen,
472                              const unsigned char *sig, int sig_len,
473                              EC_KEY *eckey)
474 {
475     ECDSA_SIG *s = NULL;
476     BIGNUM *e = NULL;
477     const unsigned char *p = sig;
478     unsigned char *der = NULL;
479     int derlen = -1;
480     int ret = -1;
481 
482     s = ECDSA_SIG_new();
483     if (s == NULL) {
484         ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE);
485         goto done;
486     }
487     if (d2i_ECDSA_SIG(&s, &p, sig_len) == NULL) {
488         ERR_raise(ERR_LIB_SM2, SM2_R_INVALID_ENCODING);
489         goto done;
490     }
491     /* Ensure signature uses DER and doesn't have trailing garbage */
492     derlen = i2d_ECDSA_SIG(s, &der);
493     if (derlen != sig_len || memcmp(sig, der, derlen) != 0) {
494         ERR_raise(ERR_LIB_SM2, SM2_R_INVALID_ENCODING);
495         goto done;
496     }
497 
498     e = BN_bin2bn(dgst, dgstlen, NULL);
499     if (e == NULL) {
500         ERR_raise(ERR_LIB_SM2, ERR_R_BN_LIB);
501         goto done;
502     }
503 
504     ret = sm2_sig_verify(eckey, s, e);
505 
506  done:
507     OPENSSL_free(der);
508     BN_free(e);
509     ECDSA_SIG_free(s);
510     return ret;
511 }
512