99c703d95f44d45ea5f28a416fdc6dde84b6e18c
[users/jgh/exim.git] / src / src / pdkim / rsa.c
1 /* $Cambridge: exim/src/src/pdkim/rsa.c,v 1.1.2.1 2009/02/24 13:13:47 tom Exp $ */
2 /*
3  *  The RSA public-key cryptosystem
4  *
5  *  Based on XySSL: Copyright (C) 2006-2008  Christophe Devine
6  *
7  *  Copyright (C) 2009  Paul Bakker <polarssl_maintainer at polarssl dot org>
8  *
9  *  This program is free software; you can redistribute it and/or modify
10  *  it under the terms of the GNU General Public License as published by
11  *  the Free Software Foundation; either version 2 of the License, or
12  *  (at your option) any later version.
13  *
14  *  This program is distributed in the hope that it will be useful,
15  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17  *  GNU General Public License for more details.
18  *
19  *  You should have received a copy of the GNU General Public License along
20  *  with this program; if not, write to the Free Software Foundation, Inc.,
21  *  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
22  */
23 /*
24  *  RSA was designed by Ron Rivest, Adi Shamir and Len Adleman.
25  *
26  *  http://theory.lcs.mit.edu/~rivest/rsapaper.pdf
27  *  http://www.cacr.math.uwaterloo.ca/hac/about/chap8.pdf
28  */
29
30 #include "rsa.h"
31 #include "base64.h"
32
33 #include <stdlib.h>
34 #include <string.h>
35 #include <stdio.h>
36
37 /*
38  * Initialize an RSA context
39  */
40 void rsa_init( rsa_context *ctx,
41                int padding,
42                int hash_id,
43                int (*f_rng)(void *),
44                void *p_rng )
45 {
46     memset( ctx, 0, sizeof( rsa_context ) );
47
48     ctx->padding = padding;
49     ctx->hash_id = hash_id;
50
51     ctx->f_rng = f_rng;
52     ctx->p_rng = p_rng;
53 }
54
55
56 /*
57  * Check a public RSA key
58  */
59 int rsa_check_pubkey( rsa_context *ctx )
60 {
61     if( ( ctx->N.p[0] & 1 ) == 0 ||
62         ( ctx->E.p[0] & 1 ) == 0 )
63         return( POLARSSL_ERR_RSA_KEY_CHECK_FAILED );
64
65     if( mpi_msb( &ctx->N ) < 128 ||
66         mpi_msb( &ctx->N ) > 4096 )
67         return( POLARSSL_ERR_RSA_KEY_CHECK_FAILED );
68
69     if( mpi_msb( &ctx->E ) < 2 ||
70         mpi_msb( &ctx->E ) > 64 )
71         return( POLARSSL_ERR_RSA_KEY_CHECK_FAILED );
72
73     return( 0 );
74 }
75
76 /*
77  * Check a private RSA key
78  */
79 int rsa_check_privkey( rsa_context *ctx )
80 {
81     int ret;
82     mpi PQ, DE, P1, Q1, H, I, G;
83
84     if( ( ret = rsa_check_pubkey( ctx ) ) != 0 )
85         return( ret );
86
87     mpi_init( &PQ, &DE, &P1, &Q1, &H, &I, &G, NULL );
88
89     MPI_CHK( mpi_mul_mpi( &PQ, &ctx->P, &ctx->Q ) );
90     MPI_CHK( mpi_mul_mpi( &DE, &ctx->D, &ctx->E ) );
91     MPI_CHK( mpi_sub_int( &P1, &ctx->P, 1 ) );
92     MPI_CHK( mpi_sub_int( &Q1, &ctx->Q, 1 ) );
93     MPI_CHK( mpi_mul_mpi( &H, &P1, &Q1 ) );
94     MPI_CHK( mpi_mod_mpi( &I, &DE, &H  ) );
95     MPI_CHK( mpi_gcd( &G, &ctx->E, &H  ) );
96
97     if( mpi_cmp_mpi( &PQ, &ctx->N ) == 0 &&
98         mpi_cmp_int( &I, 1 ) == 0 &&
99         mpi_cmp_int( &G, 1 ) == 0 )
100     {
101         mpi_free( &G, &I, &H, &Q1, &P1, &DE, &PQ, NULL );
102         return( 0 );
103     }
104
105 cleanup:
106
107     mpi_free( &G, &I, &H, &Q1, &P1, &DE, &PQ, NULL );
108     return( POLARSSL_ERR_RSA_KEY_CHECK_FAILED | ret );
109 }
110
111 /*
112  * Do an RSA public key operation
113  */
114 int rsa_public( rsa_context *ctx,
115                 unsigned char *input,
116                 unsigned char *output )
117 {
118     int ret, olen;
119     mpi T;
120
121     mpi_init( &T, NULL );
122
123     MPI_CHK( mpi_read_binary( &T, input, ctx->len ) );
124
125     if( mpi_cmp_mpi( &T, &ctx->N ) >= 0 )
126     {
127         mpi_free( &T, NULL );
128         return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
129     }
130
131     olen = ctx->len;
132     MPI_CHK( mpi_exp_mod( &T, &T, &ctx->E, &ctx->N, &ctx->RN ) );
133     MPI_CHK( mpi_write_binary( &T, output, olen ) );
134
135 cleanup:
136
137     mpi_free( &T, NULL );
138
139     if( ret != 0 )
140         return( POLARSSL_ERR_RSA_PUBLIC_FAILED | ret );
141
142     return( 0 );
143 }
144
145 /*
146  * Do an RSA private key operation
147  */
148 int rsa_private( rsa_context *ctx,
149                  unsigned char *input,
150                  unsigned char *output )
151 {
152     int ret, olen;
153     mpi T, T1, T2;
154
155     mpi_init( &T, &T1, &T2, NULL );
156
157     MPI_CHK( mpi_read_binary( &T, input, ctx->len ) );
158
159     if( mpi_cmp_mpi( &T, &ctx->N ) >= 0 )
160     {
161         mpi_free( &T, NULL );
162         return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
163     }
164
165 #if 0
166     MPI_CHK( mpi_exp_mod( &T, &T, &ctx->D, &ctx->N, &ctx->RN ) );
167 #else
168     /*
169      * faster decryption using the CRT
170      *
171      * T1 = input ^ dP mod P
172      * T2 = input ^ dQ mod Q
173      */
174     MPI_CHK( mpi_exp_mod( &T1, &T, &ctx->DP, &ctx->P, &ctx->RP ) );
175     MPI_CHK( mpi_exp_mod( &T2, &T, &ctx->DQ, &ctx->Q, &ctx->RQ ) );
176
177     /*
178      * T = (T1 - T2) * (Q^-1 mod P) mod P
179      */
180     MPI_CHK( mpi_sub_mpi( &T, &T1, &T2 ) );
181     MPI_CHK( mpi_mul_mpi( &T1, &T, &ctx->QP ) );
182     MPI_CHK( mpi_mod_mpi( &T, &T1, &ctx->P ) );
183
184     /*
185      * output = T2 + T * Q
186      */
187     MPI_CHK( mpi_mul_mpi( &T1, &T, &ctx->Q ) );
188     MPI_CHK( mpi_add_mpi( &T, &T2, &T1 ) );
189 #endif
190
191     olen = ctx->len;
192     MPI_CHK( mpi_write_binary( &T, output, olen ) );
193
194 cleanup:
195
196     mpi_free( &T, &T1, &T2, NULL );
197
198     if( ret != 0 )
199         return( POLARSSL_ERR_RSA_PRIVATE_FAILED | ret );
200
201     return( 0 );
202 }
203
204 /*
205  * Add the message padding, then do an RSA operation
206  */
207 int rsa_pkcs1_encrypt( rsa_context *ctx,
208                        int mode, int  ilen,
209                        unsigned char *input,
210                        unsigned char *output )
211 {
212     int nb_pad, olen;
213     unsigned char *p = output;
214
215     olen = ctx->len;
216
217     switch( ctx->padding )
218     {
219         case RSA_PKCS_V15:
220
221             if( ilen < 0 || olen < ilen + 11 )
222                 return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
223
224             nb_pad = olen - 3 - ilen;
225
226             *p++ = 0;
227             *p++ = RSA_CRYPT;
228
229             while( nb_pad-- > 0 )
230             {
231                 do {
232                     *p = (unsigned char) rand();
233                 } while( *p == 0 );
234                 p++;
235             }
236             *p++ = 0;
237             memcpy( p, input, ilen );
238             break;
239
240         default:
241
242             return( POLARSSL_ERR_RSA_INVALID_PADDING );
243     }
244
245     return( ( mode == RSA_PUBLIC )
246             ? rsa_public(  ctx, output, output )
247             : rsa_private( ctx, output, output ) );
248 }
249
250 /*
251  * Do an RSA operation, then remove the message padding
252  */
253 int rsa_pkcs1_decrypt( rsa_context *ctx,
254                        int mode, int *olen,
255                        unsigned char *input,
256                        unsigned char *output,
257                int output_max_len)
258 {
259     int ret, ilen;
260     unsigned char *p;
261     unsigned char buf[512];
262
263     ilen = ctx->len;
264
265     if( ilen < 16 || ilen > (int) sizeof( buf ) )
266         return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
267
268     ret = ( mode == RSA_PUBLIC )
269           ? rsa_public(  ctx, input, buf )
270           : rsa_private( ctx, input, buf );
271
272     if( ret != 0 )
273         return( ret );
274
275     p = buf;
276
277     switch( ctx->padding )
278     {
279         case RSA_PKCS_V15:
280
281             if( *p++ != 0 || *p++ != RSA_CRYPT )
282                 return( POLARSSL_ERR_RSA_INVALID_PADDING );
283
284             while( *p != 0 )
285             {
286                 if( p >= buf + ilen - 1 )
287                     return( POLARSSL_ERR_RSA_INVALID_PADDING );
288                 p++;
289             }
290             p++;
291             break;
292
293         default:
294
295             return( POLARSSL_ERR_RSA_INVALID_PADDING );
296     }
297
298     if (ilen - (int)(p - buf) > output_max_len)
299         return( POLARSSL_ERR_RSA_OUTPUT_TO_LARGE );
300
301     *olen = ilen - (int)(p - buf);
302     memcpy( output, p, *olen );
303
304     return( 0 );
305 }
306
307 /*
308  * Do an RSA operation to sign the message digest
309  */
310 int rsa_pkcs1_sign( rsa_context *ctx,
311                     int mode,
312                     int hash_id,
313                     int hashlen,
314                     unsigned char *hash,
315                     unsigned char *sig )
316 {
317     int nb_pad, olen;
318     unsigned char *p = sig;
319
320     olen = ctx->len;
321
322     switch( ctx->padding )
323     {
324         case RSA_PKCS_V15:
325
326             switch( hash_id )
327             {
328                 case RSA_RAW:
329                     nb_pad = olen - 3 - hashlen;
330                     break;
331
332                 case RSA_MD2:
333                 case RSA_MD4:
334                 case RSA_MD5:
335                     nb_pad = olen - 3 - 16 - 18;
336                     break;
337
338                 case RSA_SHA1:
339                     nb_pad = olen - 3 - 20 - 15;
340                     break;
341
342                 case RSA_SHA256:
343                     nb_pad = olen - 3 - 32 - 19;
344                     break;
345
346                 default:
347                     return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
348             }
349
350             if( nb_pad < 8 )
351                 return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
352
353             *p++ = 0;
354             *p++ = RSA_SIGN;
355             memset( p, 0xFF, nb_pad );
356             p += nb_pad;
357             *p++ = 0;
358             break;
359
360         default:
361
362             return( POLARSSL_ERR_RSA_INVALID_PADDING );
363     }
364
365     switch( hash_id )
366     {
367         case RSA_RAW:
368             memcpy( p, hash, hashlen );
369             break;
370
371         case RSA_MD2:
372             memcpy( p, ASN1_HASH_MDX, 18 );
373             memcpy( p + 18, hash, 16 );
374             p[13] = 2; break;
375
376         case RSA_MD4:
377             memcpy( p, ASN1_HASH_MDX, 18 );
378             memcpy( p + 18, hash, 16 );
379             p[13] = 4; break;
380
381         case RSA_MD5:
382             memcpy( p, ASN1_HASH_MDX, 18 );
383             memcpy( p + 18, hash, 16 );
384             p[13] = 5; break;
385
386         case RSA_SHA1:
387             memcpy( p, ASN1_HASH_SHA1, 15 );
388             memcpy( p + 15, hash, 20 );
389             break;
390
391         case RSA_SHA256:
392             memcpy( p, ASN1_HASH_SHA256, 19 );
393             memcpy( p + 19, hash, 32 );
394             break;
395
396         default:
397             return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
398     }
399
400     return( ( mode == RSA_PUBLIC )
401             ? rsa_public(  ctx, sig, sig )
402             : rsa_private( ctx, sig, sig ) );
403 }
404
405 /*
406  * Do an RSA operation and check the message digest
407  */
408 int rsa_pkcs1_verify( rsa_context *ctx,
409                       int mode,
410                       int hash_id,
411                       int hashlen,
412                       unsigned char *hash,
413                       unsigned char *sig )
414 {
415     int ret, len, siglen;
416     unsigned char *p, c;
417     unsigned char buf[512];
418
419     siglen = ctx->len;
420
421     if( siglen < 16 || siglen > (int) sizeof( buf ) )
422         return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
423
424     ret = ( mode == RSA_PUBLIC )
425           ? rsa_public(  ctx, sig, buf )
426           : rsa_private( ctx, sig, buf );
427
428     if( ret != 0 )
429         return( ret );
430
431     p = buf;
432
433     switch( ctx->padding )
434     {
435         case RSA_PKCS_V15:
436
437             if( *p++ != 0 || *p++ != RSA_SIGN )
438                 return( POLARSSL_ERR_RSA_INVALID_PADDING );
439
440             while( *p != 0 )
441             {
442                 if( p >= buf + siglen - 1 || *p != 0xFF )
443                     return( POLARSSL_ERR_RSA_INVALID_PADDING );
444                 p++;
445             }
446             p++;
447             break;
448
449         default:
450
451             return( POLARSSL_ERR_RSA_INVALID_PADDING );
452     }
453
454     len = siglen - (int)( p - buf );
455
456     if( len == 34 )
457     {
458         c = p[13];
459         p[13] = 0;
460
461         if( memcmp( p, ASN1_HASH_MDX, 18 ) != 0 )
462             return( POLARSSL_ERR_RSA_VERIFY_FAILED );
463
464         if( ( c == 2 && hash_id == RSA_MD2 ) ||
465             ( c == 4 && hash_id == RSA_MD4 ) ||
466             ( c == 5 && hash_id == RSA_MD5 ) )
467         {
468             if( memcmp( p + 18, hash, 16 ) == 0 )
469                 return( 0 );
470             else
471                 return( POLARSSL_ERR_RSA_VERIFY_FAILED );
472         }
473     }
474
475     if( len == 35 && hash_id == RSA_SHA1 )
476     {
477         if( memcmp( p, ASN1_HASH_SHA1, 15 ) == 0 &&
478             memcmp( p + 15, hash, 20 ) == 0 )
479             return( 0 );
480         else
481             return( POLARSSL_ERR_RSA_VERIFY_FAILED );
482     }
483
484     if( len == hashlen && hash_id == RSA_RAW )
485     {
486         if( memcmp( p, hash, hashlen ) == 0 )
487             return( 0 );
488         else
489             return( POLARSSL_ERR_RSA_VERIFY_FAILED );
490     }
491
492     return( POLARSSL_ERR_RSA_INVALID_PADDING );
493 }
494
495 /*
496  * Free the components of an RSA key
497  */
498 void rsa_free( rsa_context *ctx )
499 {
500     mpi_free( &ctx->RQ, &ctx->RP, &ctx->RN,
501               &ctx->QP, &ctx->DQ, &ctx->DP,
502               &ctx->Q,  &ctx->P,  &ctx->D,
503               &ctx->E,  &ctx->N,  NULL );
504 }
505
506 /*
507  * ASN.1 DER decoding routines
508  */
509 static int asn1_get_len( unsigned char **p,
510                          unsigned char *end,
511                          int *len )
512 {
513     if( ( end - *p ) < 1 )
514         return( POLARSSL_ERR_ASN1_OUT_OF_DATA );
515
516     if( ( **p & 0x80 ) == 0 )
517         *len = *(*p)++;
518     else
519     {
520         switch( **p & 0x7F )
521         {
522         case 1:
523             if( ( end - *p ) < 2 )
524                 return( POLARSSL_ERR_ASN1_OUT_OF_DATA );
525
526             *len = (*p)[1];
527             (*p) += 2;
528             break;
529
530         case 2:
531             if( ( end - *p ) < 3 )
532                 return( POLARSSL_ERR_ASN1_OUT_OF_DATA );
533
534             *len = ( (*p)[1] << 8 ) | (*p)[2];
535             (*p) += 3;
536             break;
537
538         default:
539             return( POLARSSL_ERR_ASN1_INVALID_LENGTH );
540             break;
541         }
542     }
543
544     if( *len > (int) ( end - *p ) )
545         return( POLARSSL_ERR_ASN1_OUT_OF_DATA );
546
547     return( 0 );
548 }
549
550 static int asn1_get_tag( unsigned char **p,
551                          unsigned char *end,
552                          int *len, int tag )
553 {
554     if( ( end - *p ) < 1 )
555         return( POLARSSL_ERR_ASN1_OUT_OF_DATA );
556
557     if( **p != tag )
558         return( POLARSSL_ERR_ASN1_UNEXPECTED_TAG );
559
560     (*p)++;
561
562     return( asn1_get_len( p, end, len ) );
563 }
564
565 static int asn1_get_int( unsigned char **p,
566                          unsigned char *end,
567                          int *val )
568 {
569     int ret, len;
570
571     if( ( ret = asn1_get_tag( p, end, &len, ASN1_INTEGER ) ) != 0 )
572         return( ret );
573
574     if( len > (int) sizeof( int ) || ( **p & 0x80 ) != 0 )
575         return( POLARSSL_ERR_ASN1_INVALID_LENGTH );
576
577     *val = 0;
578
579     while( len-- > 0 )
580     {
581         *val = ( *val << 8 ) | **p;
582         (*p)++;
583     }
584
585     return( 0 );
586 }
587
588 static int asn1_get_mpi( unsigned char **p,
589                          unsigned char *end,
590                          mpi *X )
591 {
592     int ret, len;
593
594     if( ( ret = asn1_get_tag( p, end, &len, ASN1_INTEGER ) ) != 0 )
595         return( ret );
596
597     ret = mpi_read_binary( X, *p, len );
598
599     *p += len;
600
601     return( ret );
602 }
603
604
605 /*
606  * Parse a private RSA key
607  */
608 int rsa_parse_key( rsa_context *rsa, unsigned char *buf, int buflen,
609                                      unsigned char *pwd, int pwdlen )
610 {
611     int ret, len, enc;
612     unsigned char *s1, *s2;
613     unsigned char *p, *end;
614
615     s1 = (unsigned char *) strstr( (char *) buf,
616         "-----BEGIN RSA PRIVATE KEY-----" );
617
618     if( s1 != NULL )
619     {
620         s2 = (unsigned char *) strstr( (char *) buf,
621             "-----END RSA PRIVATE KEY-----" );
622
623         if( s2 == NULL || s2 <= s1 )
624             return( POLARSSL_ERR_X509_KEY_INVALID_PEM );
625
626         s1 += 31;
627         if( *s1 == '\r' ) s1++;
628         if( *s1 == '\n' ) s1++;
629             else return( POLARSSL_ERR_X509_KEY_INVALID_PEM );
630
631         enc = 0;
632
633         if( memcmp( s1, "Proc-Type: 4,ENCRYPTED", 22 ) == 0 )
634         {
635             return( POLARSSL_ERR_X509_FEATURE_UNAVAILABLE );
636         }
637
638         len = 0;
639         ret = base64_decode( NULL, &len, s1, s2 - s1 );
640
641         if( ret == POLARSSL_ERR_BASE64_INVALID_CHARACTER )
642             return( ret | POLARSSL_ERR_X509_KEY_INVALID_PEM );
643
644         if( ( buf = (unsigned char *) malloc( len ) ) == NULL )
645             return( 1 );
646
647         if( ( ret = base64_decode( buf, &len, s1, s2 - s1 ) ) != 0 )
648         {
649             free( buf );
650             return( ret | POLARSSL_ERR_X509_KEY_INVALID_PEM );
651         }
652
653         buflen = len;
654
655         if( enc != 0 )
656         {
657             return( POLARSSL_ERR_X509_FEATURE_UNAVAILABLE );
658         }
659     }
660
661     memset( rsa, 0, sizeof( rsa_context ) );
662
663     p = buf;
664     end = buf + buflen;
665
666     /*
667      *  RSAPrivateKey ::= SEQUENCE {
668      *      version           Version,
669      *      modulus           INTEGER,  -- n
670      *      publicExponent    INTEGER,  -- e
671      *      privateExponent   INTEGER,  -- d
672      *      prime1            INTEGER,  -- p
673      *      prime2            INTEGER,  -- q
674      *      exponent1         INTEGER,  -- d mod (p-1)
675      *      exponent2         INTEGER,  -- d mod (q-1)
676      *      coefficient       INTEGER,  -- (inverse of q) mod p
677      *      otherPrimeInfos   OtherPrimeInfos OPTIONAL
678      *  }
679      */
680     if( ( ret = asn1_get_tag( &p, end, &len,
681             ASN1_CONSTRUCTED | ASN1_SEQUENCE ) ) != 0 )
682     {
683         if( s1 != NULL )
684             free( buf );
685
686         rsa_free( rsa );
687         return( POLARSSL_ERR_X509_KEY_INVALID_FORMAT | ret );
688     }
689
690     end = p + len;
691
692     if( ( ret = asn1_get_int( &p, end, &rsa->ver ) ) != 0 )
693     {
694         if( s1 != NULL )
695             free( buf );
696
697         rsa_free( rsa );
698         return( POLARSSL_ERR_X509_KEY_INVALID_FORMAT | ret );
699     }
700
701     if( rsa->ver != 0 )
702     {
703         if( s1 != NULL )
704             free( buf );
705
706         rsa_free( rsa );
707         return( ret | POLARSSL_ERR_X509_KEY_INVALID_VERSION );
708     }
709
710     if( ( ret = asn1_get_mpi( &p, end, &rsa->N  ) ) != 0 ||
711         ( ret = asn1_get_mpi( &p, end, &rsa->E  ) ) != 0 ||
712         ( ret = asn1_get_mpi( &p, end, &rsa->D  ) ) != 0 ||
713         ( ret = asn1_get_mpi( &p, end, &rsa->P  ) ) != 0 ||
714         ( ret = asn1_get_mpi( &p, end, &rsa->Q  ) ) != 0 ||
715         ( ret = asn1_get_mpi( &p, end, &rsa->DP ) ) != 0 ||
716         ( ret = asn1_get_mpi( &p, end, &rsa->DQ ) ) != 0 ||
717         ( ret = asn1_get_mpi( &p, end, &rsa->QP ) ) != 0 )
718     {
719         if( s1 != NULL )
720             free( buf );
721
722         rsa_free( rsa );
723         return( ret | POLARSSL_ERR_X509_KEY_INVALID_FORMAT );
724     }
725
726     rsa->len = mpi_size( &rsa->N );
727
728     if( p != end )
729     {
730         if( s1 != NULL )
731             free( buf );
732
733         rsa_free( rsa );
734         return( POLARSSL_ERR_X509_KEY_INVALID_FORMAT |
735                 POLARSSL_ERR_ASN1_LENGTH_MISMATCH );
736     }
737
738     if( ( ret = rsa_check_privkey( rsa ) ) != 0 )
739     {
740         if( s1 != NULL )
741             free( buf );
742
743         rsa_free( rsa );
744         return( ret );
745     }
746
747     if( s1 != NULL )
748         free( buf );
749
750     return( 0 );
751 }