Remove most non-TLS-1.2 stuff and most configure options.
[ntbtls.git] / src / ssl_tls.c
1 /*
2  *  SSLv3/TLSv1 shared functions
3  *
4  *  Copyright (C) 2006-2014, Brainspark B.V.
5  *
6  *  This file is part of PolarSSL (http://www.polarssl.org)
7  *  Lead Maintainer: Paul Bakker <polarssl_maintainer at polarssl.org>
8  *
9  *  All rights reserved.
10  *
11  *  This program is free software; you can redistribute it and/or modify
12  *  it under the terms of the GNU General Public License as published by
13  *  the Free Software Foundation; either version 2 of the License, or
14  *  (at your option) any later version.
15  *
16  *  This program is distributed in the hope that it will be useful,
17  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
19  *  GNU General Public License for more details.
20  *
21  *  You should have received a copy of the GNU General Public License along
22  *  with this program; if not, write to the Free Software Foundation, Inc.,
23  *  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24  */
25
26 #include <config.h>
27
28 #include "polarssl/debug.h"
29 #include "polarssl/ssl.h"
30
31 #include "polarssl/oid.h"
32
33 #include "polarssl/platform.h"
34
35 #include <stdlib.h>
36
37
38 /* Implementation that should never be optimized out by the compiler */
39 static void
40 polarssl_zeroize (void *v, size_t n)
41 {
42   volatile unsigned char *p = v;
43   while (n--)
44     *p++ = 0;
45 }
46
47
48 /*
49  * Convert max_fragment_length codes to length.
50  * RFC 6066 says:
51  *    enum{
52  *        2^9(1), 2^10(2), 2^11(3), 2^12(4), (255)
53  *    } MaxFragmentLength;
54  * and we add 0 -> extension unused
55  */
56 static unsigned int mfl_code_to_length[SSL_MAX_FRAG_LEN_INVALID] = {
57   SSL_MAX_CONTENT_LEN,          /* SSL_MAX_FRAG_LEN_NONE */
58   512,                          /* SSL_MAX_FRAG_LEN_512  */
59   1024,                         /* SSL_MAX_FRAG_LEN_1024 */
60   2048,                         /* SSL_MAX_FRAG_LEN_2048 */
61   4096,                         /* SSL_MAX_FRAG_LEN_4096 */
62 };
63
64
65 static int
66 ssl_session_copy (ssl_session * dst, const ssl_session * src)
67 {
68   ssl_session_free (dst);
69   memcpy (dst, src, sizeof (ssl_session));
70
71   if (src->peer_cert != NULL)
72     {
73       int ret;
74
75       dst->peer_cert = (x509_crt *) polarssl_malloc (sizeof (x509_crt));
76       if (dst->peer_cert == NULL)
77         return (POLARSSL_ERR_SSL_MALLOC_FAILED);
78
79       x509_crt_init (dst->peer_cert);
80
81       if ((ret = x509_crt_parse_der (dst->peer_cert, src->peer_cert->raw.p,
82                                      src->peer_cert->raw.len)) != 0)
83         {
84           polarssl_free (dst->peer_cert);
85           dst->peer_cert = NULL;
86           return (ret);
87         }
88     }
89
90   if (src->ticket != NULL)
91     {
92       dst->ticket = (unsigned char *) polarssl_malloc (src->ticket_len);
93       if (dst->ticket == NULL)
94         return (POLARSSL_ERR_SSL_MALLOC_FAILED);
95
96       memcpy (dst->ticket, src->ticket, src->ticket_len);
97     }
98
99   return (0);
100 }
101
102
103 /*
104  * Key material generation
105  */
106
107 static int
108 tls_prf_sha256 (const unsigned char *secret, size_t slen,
109                 const char *label,
110                 const unsigned char *random, size_t rlen,
111                 unsigned char *dstbuf, size_t dlen)
112 {
113   size_t nb;
114   size_t i, j, k;
115   unsigned char tmp[128];
116   unsigned char h_i[32];
117
118   if (sizeof (tmp) < 32 + strlen (label) + rlen)
119     return (POLARSSL_ERR_SSL_BAD_INPUT_DATA);
120
121   nb = strlen (label);
122   memcpy (tmp + 32, label, nb);
123   memcpy (tmp + 32 + nb, random, rlen);
124   nb += rlen;
125
126   /*
127    * Compute P_<hash>(secret, label + random)[0..dlen]
128    */
129   sha256_hmac (secret, slen, tmp + 32, nb, tmp, 0);
130
131   for (i = 0; i < dlen; i += 32)
132     {
133       sha256_hmac (secret, slen, tmp, 32 + nb, h_i, 0);
134       sha256_hmac (secret, slen, tmp, 32, tmp, 0);
135
136       k = (i + 32 > dlen) ? dlen % 32 : 32;
137
138       for (j = 0; j < k; j++)
139         dstbuf[i + j] = h_i[j];
140     }
141
142   polarssl_zeroize (tmp, sizeof (tmp));
143   polarssl_zeroize (h_i, sizeof (h_i));
144
145   return (0);
146 }
147
148
149 static int
150 tls_prf_sha384 (const unsigned char *secret, size_t slen,
151                 const char *label,
152                 const unsigned char *random, size_t rlen,
153                 unsigned char *dstbuf, size_t dlen)
154 {
155   size_t nb;
156   size_t i, j, k;
157   unsigned char tmp[128];
158   unsigned char h_i[48];
159
160   if (sizeof (tmp) < 48 + strlen (label) + rlen)
161     return (POLARSSL_ERR_SSL_BAD_INPUT_DATA);
162
163   nb = strlen (label);
164   memcpy (tmp + 48, label, nb);
165   memcpy (tmp + 48 + nb, random, rlen);
166   nb += rlen;
167
168   /*
169    * Compute P_<hash>(secret, label + random)[0..dlen]
170    */
171   sha512_hmac (secret, slen, tmp + 48, nb, tmp, 1);
172
173   for (i = 0; i < dlen; i += 48)
174     {
175       sha512_hmac (secret, slen, tmp, 48 + nb, h_i, 1);
176       sha512_hmac (secret, slen, tmp, 48, tmp, 1);
177
178       k = (i + 48 > dlen) ? dlen % 48 : 48;
179
180       for (j = 0; j < k; j++)
181         dstbuf[i + j] = h_i[j];
182     }
183
184   polarssl_zeroize (tmp, sizeof (tmp));
185   polarssl_zeroize (h_i, sizeof (h_i));
186
187   return (0);
188 }
189
190
191 static void ssl_update_checksum_start (ssl_context *, const unsigned char *,
192                                        size_t);
193 static void ssl_update_checksum_sha256 (ssl_context *, const unsigned char *,
194                                         size_t);
195 static void ssl_calc_verify_tls_sha256 (ssl_context *, unsigned char *);
196 static void ssl_calc_finished_tls_sha256 (ssl_context *, unsigned char *,
197                                           int);
198 static void ssl_update_checksum_sha384 (ssl_context *, const unsigned char *,
199                                         size_t);
200 static void ssl_calc_verify_tls_sha384 (ssl_context *, unsigned char *);
201 static void ssl_calc_finished_tls_sha384 (ssl_context *, unsigned char *,
202                                           int);
203
204
205 int
206 ssl_derive_keys (ssl_context * ssl)
207 {
208   int ret = 0;
209   unsigned char tmp[64];
210   unsigned char keyblk[256];
211   unsigned char *key1;
212   unsigned char *key2;
213   unsigned char *mac_enc;
214   unsigned char *mac_dec;
215   size_t iv_copy_len;
216   const cipher_info_t *cipher_info;
217   const md_info_t *md_info;
218
219   ssl_session *session = ssl->session_negotiate;
220   ssl_transform *transform = ssl->transform_negotiate;
221   ssl_handshake_params *handshake = ssl->handshake;
222
223   SSL_DEBUG_MSG (2, ("=> derive keys"));
224
225   cipher_info = cipher_info_from_type (transform->ciphersuite_info->cipher);
226   if (cipher_info == NULL)
227     {
228       SSL_DEBUG_MSG (1, ("cipher info for %d not found",
229                          transform->ciphersuite_info->cipher));
230       return (POLARSSL_ERR_SSL_BAD_INPUT_DATA);
231     }
232
233   md_info = md_info_from_type (transform->ciphersuite_info->mac);
234   if (md_info == NULL)
235     {
236       SSL_DEBUG_MSG (1, ("md info for %d not found",
237                          transform->ciphersuite_info->mac));
238       return (POLARSSL_ERR_SSL_BAD_INPUT_DATA);
239     }
240
241   /*
242    * Set appropriate PRF function and other TLS functions
243    */
244   if (ssl->minor_ver == SSL_MINOR_VERSION_3 &&
245         transform->ciphersuite_info->mac == POLARSSL_MD_SHA384)
246     {
247       handshake->tls_prf = tls_prf_sha384;
248       handshake->calc_verify = ssl_calc_verify_tls_sha384;
249       handshake->calc_finished = ssl_calc_finished_tls_sha384;
250     }
251   else if (ssl->minor_ver == SSL_MINOR_VERSION_3)
252     {
253       handshake->tls_prf = tls_prf_sha256;
254       handshake->calc_verify = ssl_calc_verify_tls_sha256;
255       handshake->calc_finished = ssl_calc_finished_tls_sha256;
256     }
257   else
258     {
259       SSL_DEBUG_MSG (1, ("should never happen"));
260       return (POLARSSL_ERR_SSL_INTERNAL_ERROR);
261     }
262
263   /*
264    * TLSv1+:
265    *   master = PRF( premaster, "master secret", randbytes )[0..47]
266    */
267   if (handshake->resume == 0)
268     {
269       SSL_DEBUG_BUF (3, "premaster secret", handshake->premaster,
270                      handshake->pmslen);
271
272       handshake->tls_prf (handshake->premaster, handshake->pmslen,
273                           "master secret",
274                           handshake->randbytes, 64, session->master, 48);
275
276       polarssl_zeroize (handshake->premaster, sizeof (handshake->premaster));
277     }
278   else
279     SSL_DEBUG_MSG (3, ("no premaster (session resumed)"));
280
281   /*
282    * Swap the client and server random values.
283    */
284   memcpy (tmp, handshake->randbytes, 64);
285   memcpy (handshake->randbytes, tmp + 32, 32);
286   memcpy (handshake->randbytes + 32, tmp, 32);
287   polarssl_zeroize (tmp, sizeof (tmp));
288
289   /*
290    *  TLSv1:
291    *    key block = PRF( master, "key expansion", randbytes )
292    */
293   handshake->tls_prf (session->master, 48, "key expansion",
294                       handshake->randbytes, 64, keyblk, 256);
295
296   SSL_DEBUG_MSG (3, ("ciphersuite = %s",
297                      ssl_get_ciphersuite_name (session->ciphersuite)));
298   SSL_DEBUG_BUF (3, "master secret", session->master, 48);
299   SSL_DEBUG_BUF (4, "random bytes", handshake->randbytes, 64);
300   SSL_DEBUG_BUF (4, "key block", keyblk, 256);
301
302   polarssl_zeroize (handshake->randbytes, sizeof (handshake->randbytes));
303
304   /*
305    * Determine the appropriate key, IV and MAC length.
306    */
307
308   transform->keylen = cipher_info->key_length / 8;
309
310   if (cipher_info->mode == POLARSSL_MODE_GCM ||
311       cipher_info->mode == POLARSSL_MODE_CCM)
312     {
313       transform->maclen = 0;
314
315       transform->ivlen = 12;
316       transform->fixed_ivlen = 4;
317
318       /* Minimum length is expicit IV + tag */
319       transform->minlen = transform->ivlen - transform->fixed_ivlen
320         + (transform->ciphersuite_info->flags &
321            POLARSSL_CIPHERSUITE_SHORT_TAG ? 8 : 16);
322     }
323   else
324     {
325       int ret;
326
327       /* Initialize HMAC contexts */
328       if ((ret = md_init_ctx (&transform->md_ctx_enc, md_info)) != 0 ||
329           (ret = md_init_ctx (&transform->md_ctx_dec, md_info)) != 0)
330         {
331           SSL_DEBUG_RET (1, "md_init_ctx", ret);
332           return (ret);
333         }
334
335       /* Get MAC length */
336       transform->maclen = md_get_size (md_info);
337
338       /*
339        * If HMAC is to be truncated, we shall keep the leftmost bytes,
340        * (rfc 6066 page 13 or rfc 2104 section 4),
341        * so we only need to adjust the length here.
342        */
343       if (session->trunc_hmac == SSL_TRUNC_HMAC_ENABLED)
344         transform->maclen = SSL_TRUNCATED_HMAC_LEN;
345
346       /* IV length */
347       transform->ivlen = cipher_info->iv_size;
348
349       /* Minimum length */
350       if (cipher_info->mode == POLARSSL_MODE_STREAM)
351         transform->minlen = transform->maclen;
352       else
353         {
354           /*
355            * GenericBlockCipher:
356            * first multiple of blocklen greater than maclen
357            * + IV except for SSL3 and TLS 1.0
358            */
359           transform->minlen = (transform->maclen
360                                + cipher_info->block_size
361                                - transform->maclen % cipher_info->block_size);
362
363           if (ssl->minor_ver == SSL_MINOR_VERSION_2
364               || ssl->minor_ver == SSL_MINOR_VERSION_3)
365             {
366               transform->minlen += transform->ivlen;
367             }
368           else
369             {
370               SSL_DEBUG_MSG (1, ("should never happen"));
371               return (POLARSSL_ERR_SSL_INTERNAL_ERROR);
372             }
373         }
374     }
375
376   SSL_DEBUG_MSG (3, ("keylen: %d, minlen: %d, ivlen: %d, maclen: %d",
377                      transform->keylen, transform->minlen, transform->ivlen,
378                      transform->maclen));
379
380   /*
381    * Finally setup the cipher contexts, IVs and MAC secrets.
382    */
383   if (ssl->endpoint == SSL_IS_CLIENT)
384     {
385       key1 = keyblk + transform->maclen * 2;
386       key2 = keyblk + transform->maclen * 2 + transform->keylen;
387
388       mac_enc = keyblk;
389       mac_dec = keyblk + transform->maclen;
390
391       /*
392        * This is not used in TLS v1.1.
393        */
394       iv_copy_len = (transform->fixed_ivlen) ?
395         transform->fixed_ivlen : transform->ivlen;
396       memcpy (transform->iv_enc, key2 + transform->keylen, iv_copy_len);
397       memcpy (transform->iv_dec, key2 + transform->keylen + iv_copy_len,
398               iv_copy_len);
399     }
400   else
401     {
402       key1 = keyblk + transform->maclen * 2 + transform->keylen;
403       key2 = keyblk + transform->maclen * 2;
404
405       mac_enc = keyblk + transform->maclen;
406       mac_dec = keyblk;
407
408       /*
409        * This is not used in TLS v1.1.
410        */
411       iv_copy_len = (transform->fixed_ivlen) ?
412         transform->fixed_ivlen : transform->ivlen;
413       memcpy (transform->iv_dec, key1 + transform->keylen, iv_copy_len);
414       memcpy (transform->iv_enc, key1 + transform->keylen + iv_copy_len,
415               iv_copy_len);
416     }
417
418   if (ssl->minor_ver >= SSL_MINOR_VERSION_1)
419     {
420       md_hmac_starts (&transform->md_ctx_enc, mac_enc, transform->maclen);
421       md_hmac_starts (&transform->md_ctx_dec, mac_dec, transform->maclen);
422     }
423   else
424     {
425       SSL_DEBUG_MSG (1, ("should never happen"));
426       return (POLARSSL_ERR_SSL_INTERNAL_ERROR);
427     }
428
429   if ((ret = cipher_init_ctx (&transform->cipher_ctx_enc, cipher_info)) != 0)
430     {
431       SSL_DEBUG_RET (1, "cipher_init_ctx", ret);
432       return (ret);
433     }
434
435   if ((ret = cipher_init_ctx (&transform->cipher_ctx_dec, cipher_info)) != 0)
436     {
437       SSL_DEBUG_RET (1, "cipher_init_ctx", ret);
438       return (ret);
439     }
440
441   if ((ret = cipher_setkey (&transform->cipher_ctx_enc, key1,
442                             cipher_info->key_length, POLARSSL_ENCRYPT)) != 0)
443     {
444       SSL_DEBUG_RET (1, "cipher_setkey", ret);
445       return (ret);
446     }
447
448   if ((ret = cipher_setkey (&transform->cipher_ctx_dec, key2,
449                             cipher_info->key_length, POLARSSL_DECRYPT)) != 0)
450     {
451       SSL_DEBUG_RET (1, "cipher_setkey", ret);
452       return (ret);
453     }
454
455   if (cipher_info->mode == POLARSSL_MODE_CBC)
456     {
457       if ((ret = cipher_set_padding_mode (&transform->cipher_ctx_enc,
458                                           POLARSSL_PADDING_NONE)) != 0)
459         {
460           SSL_DEBUG_RET (1, "cipher_set_padding_mode", ret);
461           return (ret);
462         }
463
464       if ((ret = cipher_set_padding_mode (&transform->cipher_ctx_dec,
465                                           POLARSSL_PADDING_NONE)) != 0)
466         {
467           SSL_DEBUG_RET (1, "cipher_set_padding_mode", ret);
468           return (ret);
469         }
470     }
471
472   polarssl_zeroize (keyblk, sizeof (keyblk));
473
474   /* Initialize compression.  */
475   if (session->compression == SSL_COMPRESS_DEFLATE)
476     {
477       if (ssl->compress_buf == NULL)
478         {
479           SSL_DEBUG_MSG (3, ("Allocating compression buffer"));
480           ssl->compress_buf = polarssl_malloc (SSL_BUFFER_LEN);
481           if (ssl->compress_buf == NULL)
482             {
483               SSL_DEBUG_MSG (1, ("malloc(%d bytes) failed", SSL_BUFFER_LEN));
484               return (POLARSSL_ERR_SSL_MALLOC_FAILED);
485             }
486         }
487
488       SSL_DEBUG_MSG (3, ("Initializing zlib states"));
489
490       memset (&transform->ctx_deflate, 0, sizeof (transform->ctx_deflate));
491       memset (&transform->ctx_inflate, 0, sizeof (transform->ctx_inflate));
492
493       if (deflateInit (&transform->ctx_deflate,
494                        Z_DEFAULT_COMPRESSION) != Z_OK ||
495           inflateInit (&transform->ctx_inflate) != Z_OK)
496         {
497           SSL_DEBUG_MSG (1, ("Failed to initialize compression"));
498           return (POLARSSL_ERR_SSL_COMPRESSION_FAILED);
499         }
500     }
501
502   SSL_DEBUG_MSG (2, ("<= derive keys"));
503
504   return (0);
505 }
506
507
508 void
509 ssl_calc_verify_tls_sha256 (ssl_context * ssl, unsigned char hash[32])
510 {
511   sha256_context sha256;
512
513   SSL_DEBUG_MSG (2, ("=> calc verify sha256"));
514
515   memcpy (&sha256, &ssl->handshake->fin_sha256, sizeof (sha256_context));
516   sha256_finish (&sha256, hash);
517
518   SSL_DEBUG_BUF (3, "calculated verify result", hash, 32);
519   SSL_DEBUG_MSG (2, ("<= calc verify"));
520
521   sha256_free (&sha256);
522
523   return;
524 }
525
526
527 void
528 ssl_calc_verify_tls_sha384 (ssl_context * ssl, unsigned char hash[48])
529 {
530   sha512_context sha512;
531
532   SSL_DEBUG_MSG (2, ("=> calc verify sha384"));
533
534   memcpy (&sha512, &ssl->handshake->fin_sha512, sizeof (sha512_context));
535   sha512_finish (&sha512, hash);
536
537   SSL_DEBUG_BUF (3, "calculated verify result", hash, 48);
538   SSL_DEBUG_MSG (2, ("<= calc verify"));
539
540   sha512_free (&sha512);
541
542   return;
543 }
544
545
546 int
547 ssl_psk_derive_premaster (ssl_context * ssl, key_exchange_type_t key_ex)
548 {
549   unsigned char *p = ssl->handshake->premaster;
550   unsigned char *end = p + sizeof (ssl->handshake->premaster);
551
552   /*
553    * PMS = struct {
554    *     opaque other_secret<0..2^16-1>;
555    *     opaque psk<0..2^16-1>;
556    * };
557    * with "other_secret" depending on the particular key exchange
558    */
559   if (key_ex == POLARSSL_KEY_EXCHANGE_PSK)
560     {
561       if (end - p < 2 + (int) ssl->psk_len)
562         return (POLARSSL_ERR_SSL_BAD_INPUT_DATA);
563
564       *(p++) = (unsigned char) (ssl->psk_len >> 8);
565       *(p++) = (unsigned char) (ssl->psk_len);
566       p += ssl->psk_len;
567     }
568   else if (key_ex == POLARSSL_KEY_EXCHANGE_RSA_PSK)
569     {
570       /*
571        * other_secret already set by the ClientKeyExchange message,
572        * and is 48 bytes long
573        */
574       *p++ = 0;
575       *p++ = 48;
576       p += 48;
577     }
578   else if (key_ex == POLARSSL_KEY_EXCHANGE_DHE_PSK)
579     {
580       int ret;
581       size_t len = end - (p + 2);
582
583       /* Write length only when we know the actual value */
584       if ((ret = dhm_calc_secret (&ssl->handshake->dhm_ctx,
585                                   p + 2, &len, ssl->f_rng, ssl->p_rng)) != 0)
586         {
587           SSL_DEBUG_RET (1, "dhm_calc_secret", ret);
588           return (ret);
589         }
590       *(p++) = (unsigned char) (len >> 8);
591       *(p++) = (unsigned char) (len);
592       p += len;
593
594       SSL_DEBUG_MPI (3, "DHM: K ", &ssl->handshake->dhm_ctx.K);
595     }
596   else if (key_ex == POLARSSL_KEY_EXCHANGE_ECDHE_PSK)
597     {
598       int ret;
599       size_t zlen;
600
601       if ((ret = ecdh_calc_secret (&ssl->handshake->ecdh_ctx, &zlen,
602                                    p + 2, end - (p + 2),
603                                    ssl->f_rng, ssl->p_rng)) != 0)
604         {
605           SSL_DEBUG_RET (1, "ecdh_calc_secret", ret);
606           return (ret);
607         }
608
609       *(p++) = (unsigned char) (zlen >> 8);
610       *(p++) = (unsigned char) (zlen);
611       p += zlen;
612
613       SSL_DEBUG_MPI (3, "ECDH: z", &ssl->handshake->ecdh_ctx.z);
614     }
615   else
616     {
617       SSL_DEBUG_MSG (1, ("should never happen"));
618       return (POLARSSL_ERR_SSL_INTERNAL_ERROR);
619     }
620
621   /* opaque psk<0..2^16-1>; */
622   if (end - p < 2 + (int) ssl->psk_len)
623     return (POLARSSL_ERR_SSL_BAD_INPUT_DATA);
624
625   *(p++) = (unsigned char) (ssl->psk_len >> 8);
626   *(p++) = (unsigned char) (ssl->psk_len);
627   memcpy (p, ssl->psk, ssl->psk_len);
628   p += ssl->psk_len;
629
630   ssl->handshake->pmslen = p - ssl->handshake->premaster;
631
632   return (0);
633 }
634
635
636 /*
637  * Encryption/decryption functions
638  */
639 static int
640 ssl_encrypt_buf (ssl_context * ssl)
641 {
642   size_t i;
643   const cipher_mode_t mode =
644     cipher_get_cipher_mode (&ssl->transform_out->cipher_ctx_enc);
645
646   SSL_DEBUG_MSG (2, ("=> encrypt buf"));
647
648   /*
649    * Add MAC before encrypt, except for AEAD modes
650    */
651   if (mode != POLARSSL_MODE_GCM && mode != POLARSSL_MODE_CCM)
652     {
653       if (ssl->minor_ver >= SSL_MINOR_VERSION_1)
654         {
655           md_hmac_update (&ssl->transform_out->md_ctx_enc, ssl->out_ctr, 13);
656           md_hmac_update (&ssl->transform_out->md_ctx_enc,
657                           ssl->out_msg, ssl->out_msglen);
658           md_hmac_finish (&ssl->transform_out->md_ctx_enc,
659                           ssl->out_msg + ssl->out_msglen);
660           md_hmac_reset (&ssl->transform_out->md_ctx_enc);
661         }
662       else
663         {
664           SSL_DEBUG_MSG (1, ("should never happen"));
665           return (POLARSSL_ERR_SSL_INTERNAL_ERROR);
666         }
667
668       SSL_DEBUG_BUF (4, "computed mac",
669                      ssl->out_msg + ssl->out_msglen,
670                      ssl->transform_out->maclen);
671
672       ssl->out_msglen += ssl->transform_out->maclen;
673     }
674
675   /*
676    * Encrypt
677    */
678   if (mode == POLARSSL_MODE_STREAM)
679     {
680       int ret;
681       size_t olen = 0;
682
683       SSL_DEBUG_MSG (3, ("before encrypt: msglen = %d, "
684                          "including %d bytes of padding",
685                          ssl->out_msglen, 0));
686
687       SSL_DEBUG_BUF (4, "before encrypt: output payload",
688                      ssl->out_msg, ssl->out_msglen);
689
690       if ((ret = cipher_crypt (&ssl->transform_out->cipher_ctx_enc,
691                                ssl->transform_out->iv_enc,
692                                ssl->transform_out->ivlen,
693                                ssl->out_msg, ssl->out_msglen,
694                                ssl->out_msg, &olen)) != 0)
695         {
696           SSL_DEBUG_RET (1, "cipher_crypt", ret);
697           return (ret);
698         }
699
700       if (ssl->out_msglen != olen)
701         {
702           SSL_DEBUG_MSG (1, ("should never happen"));
703           return (POLARSSL_ERR_SSL_INTERNAL_ERROR);
704         }
705     }
706   else if (mode == POLARSSL_MODE_GCM || mode == POLARSSL_MODE_CCM)
707     {
708       int ret;
709       size_t enc_msglen, olen;
710       unsigned char *enc_msg;
711       unsigned char add_data[13];
712       unsigned char taglen = ssl->transform_out->ciphersuite_info->flags &
713         POLARSSL_CIPHERSUITE_SHORT_TAG ? 8 : 16;
714
715       memcpy (add_data, ssl->out_ctr, 8);
716       add_data[8] = ssl->out_msgtype;
717       add_data[9] = ssl->major_ver;
718       add_data[10] = ssl->minor_ver;
719       add_data[11] = (ssl->out_msglen >> 8) & 0xFF;
720       add_data[12] = ssl->out_msglen & 0xFF;
721
722       SSL_DEBUG_BUF (4, "additional data used for AEAD", add_data, 13);
723
724       /*
725        * Generate IV
726        */
727       ret = ssl->f_rng (ssl->p_rng,
728                         ssl->transform_out->iv_enc +
729                         ssl->transform_out->fixed_ivlen,
730                         ssl->transform_out->ivlen -
731                         ssl->transform_out->fixed_ivlen);
732       if (ret != 0)
733         return (ret);
734
735       memcpy (ssl->out_iv,
736               ssl->transform_out->iv_enc + ssl->transform_out->fixed_ivlen,
737               ssl->transform_out->ivlen - ssl->transform_out->fixed_ivlen);
738
739       SSL_DEBUG_BUF (4, "IV used", ssl->out_iv,
740                      ssl->transform_out->ivlen -
741                      ssl->transform_out->fixed_ivlen);
742
743       /*
744        * Fix pointer positions and message length with added IV
745        */
746       enc_msg = ssl->out_msg;
747       enc_msglen = ssl->out_msglen;
748       ssl->out_msglen += ssl->transform_out->ivlen -
749         ssl->transform_out->fixed_ivlen;
750
751       SSL_DEBUG_MSG (3, ("before encrypt: msglen = %d, "
752                          "including %d bytes of padding",
753                          ssl->out_msglen, 0));
754
755       SSL_DEBUG_BUF (4, "before encrypt: output payload",
756                      ssl->out_msg, ssl->out_msglen);
757
758       /*
759        * Encrypt and authenticate
760        */
761       if ((ret = cipher_auth_encrypt (&ssl->transform_out->cipher_ctx_enc,
762                                       ssl->transform_out->iv_enc,
763                                       ssl->transform_out->ivlen,
764                                       add_data, 13,
765                                       enc_msg, enc_msglen,
766                                       enc_msg, &olen,
767                                       enc_msg + enc_msglen, taglen)) != 0)
768         {
769           SSL_DEBUG_RET (1, "cipher_auth_encrypt", ret);
770           return (ret);
771         }
772
773       if (olen != enc_msglen)
774         {
775           SSL_DEBUG_MSG (1, ("should never happen"));
776           return (POLARSSL_ERR_SSL_INTERNAL_ERROR);
777         }
778
779       ssl->out_msglen += taglen;
780
781       SSL_DEBUG_BUF (4, "after encrypt: tag", enc_msg + enc_msglen, taglen);
782     }
783   else if (mode == POLARSSL_MODE_CBC)
784     {
785       int ret;
786       unsigned char *enc_msg;
787       size_t enc_msglen, padlen, olen = 0;
788
789       padlen = ssl->transform_out->ivlen - (ssl->out_msglen + 1) %
790         ssl->transform_out->ivlen;
791       if (padlen == ssl->transform_out->ivlen)
792         padlen = 0;
793
794       for (i = 0; i <= padlen; i++)
795         ssl->out_msg[ssl->out_msglen + i] = (unsigned char) padlen;
796
797       ssl->out_msglen += padlen + 1;
798
799       enc_msglen = ssl->out_msglen;
800       enc_msg = ssl->out_msg;
801
802       /*
803        * Prepend per-record IV for block cipher in TLS v1.1 and up as per
804        * Method 1 (6.2.3.2. in RFC4346 and RFC5246)
805        */
806       if (ssl->minor_ver >= SSL_MINOR_VERSION_2)
807         {
808           /*
809            * Generate IV
810            */
811           int ret = ssl->f_rng (ssl->p_rng, ssl->transform_out->iv_enc,
812                                 ssl->transform_out->ivlen);
813           if (ret != 0)
814             return (ret);
815
816           memcpy (ssl->out_iv, ssl->transform_out->iv_enc,
817                   ssl->transform_out->ivlen);
818
819           /*
820            * Fix pointer positions and message length with added IV
821            */
822           enc_msg = ssl->out_msg;
823           enc_msglen = ssl->out_msglen;
824           ssl->out_msglen += ssl->transform_out->ivlen;
825         }
826
827       SSL_DEBUG_MSG (3, ("before encrypt: msglen = %d, "
828                          "including %d bytes of IV and %d bytes of padding",
829                          ssl->out_msglen, ssl->transform_out->ivlen,
830                          padlen + 1));
831
832       SSL_DEBUG_BUF (4, "before encrypt: output payload",
833                      ssl->out_iv, ssl->out_msglen);
834
835       if ((ret = cipher_crypt (&ssl->transform_out->cipher_ctx_enc,
836                                ssl->transform_out->iv_enc,
837                                ssl->transform_out->ivlen,
838                                enc_msg, enc_msglen, enc_msg, &olen)) != 0)
839         {
840           SSL_DEBUG_RET (1, "cipher_crypt", ret);
841           return (ret);
842         }
843
844       if (enc_msglen != olen)
845         {
846           SSL_DEBUG_MSG (1, ("should never happen"));
847           return (POLARSSL_ERR_SSL_INTERNAL_ERROR);
848         }
849
850     }
851   else
852     {
853       SSL_DEBUG_MSG (1, ("should never happen"));
854       return (POLARSSL_ERR_SSL_INTERNAL_ERROR);
855     }
856
857   for (i = 8; i > 0; i--)
858     if (++ssl->out_ctr[i - 1] != 0)
859       break;
860
861   /* The loops goes to its end iff the counter is wrapping */
862   if (i == 0)
863     {
864       SSL_DEBUG_MSG (1, ("outgoing message counter would wrap"));
865       return (POLARSSL_ERR_SSL_COUNTER_WRAPPING);
866     }
867
868   SSL_DEBUG_MSG (2, ("<= encrypt buf"));
869
870   return (0);
871 }
872
873 #define POLARSSL_SSL_MAX_MAC_SIZE   48
874
875 static int
876 ssl_decrypt_buf (ssl_context * ssl)
877 {
878   size_t i;
879   const cipher_mode_t mode =
880     cipher_get_cipher_mode (&ssl->transform_in->cipher_ctx_dec);
881   size_t padlen = 0, correct = 1;
882
883   SSL_DEBUG_MSG (2, ("=> decrypt buf"));
884
885   if (ssl->in_msglen < ssl->transform_in->minlen)
886     {
887       SSL_DEBUG_MSG (1, ("in_msglen (%d) < minlen (%d)",
888                          ssl->in_msglen, ssl->transform_in->minlen));
889       return (POLARSSL_ERR_SSL_INVALID_MAC);
890     }
891
892   if (mode == POLARSSL_MODE_STREAM)
893     {
894       int ret;
895       size_t olen = 0;
896
897       padlen = 0;
898
899       if ((ret = cipher_crypt (&ssl->transform_in->cipher_ctx_dec,
900                                ssl->transform_in->iv_dec,
901                                ssl->transform_in->ivlen,
902                                ssl->in_msg, ssl->in_msglen,
903                                ssl->in_msg, &olen)) != 0)
904         {
905           SSL_DEBUG_RET (1, "cipher_crypt", ret);
906           return (ret);
907         }
908
909       if (ssl->in_msglen != olen)
910         {
911           SSL_DEBUG_MSG (1, ("should never happen"));
912           return (POLARSSL_ERR_SSL_INTERNAL_ERROR);
913         }
914     }
915   else if (mode == POLARSSL_MODE_GCM || mode == POLARSSL_MODE_CCM)
916     {
917       int ret;
918       size_t dec_msglen, olen;
919       unsigned char *dec_msg;
920       unsigned char *dec_msg_result;
921       unsigned char add_data[13];
922       unsigned char taglen = ssl->transform_in->ciphersuite_info->flags &
923         POLARSSL_CIPHERSUITE_SHORT_TAG ? 8 : 16;
924       unsigned char explicit_iv_len = ssl->transform_in->ivlen -
925         ssl->transform_in->fixed_ivlen;
926
927       if (ssl->in_msglen < explicit_iv_len + taglen)
928         {
929           SSL_DEBUG_MSG (1, ("msglen (%d) < explicit_iv_len (%d) "
930                              "+ taglen (%d)", ssl->in_msglen,
931                              explicit_iv_len, taglen));
932           return (POLARSSL_ERR_SSL_INVALID_MAC);
933         }
934       dec_msglen = ssl->in_msglen - explicit_iv_len - taglen;
935
936       dec_msg = ssl->in_msg;
937       dec_msg_result = ssl->in_msg;
938       ssl->in_msglen = dec_msglen;
939
940       memcpy (add_data, ssl->in_ctr, 8);
941       add_data[8] = ssl->in_msgtype;
942       add_data[9] = ssl->major_ver;
943       add_data[10] = ssl->minor_ver;
944       add_data[11] = (ssl->in_msglen >> 8) & 0xFF;
945       add_data[12] = ssl->in_msglen & 0xFF;
946
947       SSL_DEBUG_BUF (4, "additional data used for AEAD", add_data, 13);
948
949       memcpy (ssl->transform_in->iv_dec + ssl->transform_in->fixed_ivlen,
950               ssl->in_iv,
951               ssl->transform_in->ivlen - ssl->transform_in->fixed_ivlen);
952
953       SSL_DEBUG_BUF (4, "IV used", ssl->transform_in->iv_dec,
954                      ssl->transform_in->ivlen);
955       SSL_DEBUG_BUF (4, "TAG used", dec_msg + dec_msglen, taglen);
956
957       /*
958        * Decrypt and authenticate
959        */
960       if ((ret = cipher_auth_decrypt (&ssl->transform_in->cipher_ctx_dec,
961                                       ssl->transform_in->iv_dec,
962                                       ssl->transform_in->ivlen,
963                                       add_data, 13,
964                                       dec_msg, dec_msglen,
965                                       dec_msg_result, &olen,
966                                       dec_msg + dec_msglen, taglen)) != 0)
967         {
968           SSL_DEBUG_RET (1, "cipher_auth_decrypt", ret);
969
970           if (ret == POLARSSL_ERR_CIPHER_AUTH_FAILED)
971             return (POLARSSL_ERR_SSL_INVALID_MAC);
972
973           return (ret);
974         }
975
976       if (olen != dec_msglen)
977         {
978           SSL_DEBUG_MSG (1, ("should never happen"));
979           return (POLARSSL_ERR_SSL_INTERNAL_ERROR);
980         }
981     }
982   else if (mode == POLARSSL_MODE_CBC)
983     {
984       /*
985        * Decrypt and check the padding
986        */
987       int ret;
988       unsigned char *dec_msg;
989       unsigned char *dec_msg_result;
990       size_t dec_msglen;
991       size_t minlen = 0;
992       size_t olen = 0;
993
994       /*
995        * Check immediate ciphertext sanity
996        */
997       if (ssl->in_msglen % ssl->transform_in->ivlen != 0)
998         {
999           SSL_DEBUG_MSG (1, ("msglen (%d) %% ivlen (%d) != 0",
1000                              ssl->in_msglen, ssl->transform_in->ivlen));
1001           return (POLARSSL_ERR_SSL_INVALID_MAC);
1002         }
1003
1004       if (ssl->minor_ver >= SSL_MINOR_VERSION_2)
1005         minlen += ssl->transform_in->ivlen;
1006
1007       if (ssl->in_msglen < minlen + ssl->transform_in->ivlen ||
1008           ssl->in_msglen < minlen + ssl->transform_in->maclen + 1)
1009         {
1010           SSL_DEBUG_MSG (1, ("msglen (%d) < max( ivlen(%d), maclen (%d) "
1011                              "+ 1 ) ( + expl IV )", ssl->in_msglen,
1012                              ssl->transform_in->ivlen,
1013                              ssl->transform_in->maclen));
1014           return (POLARSSL_ERR_SSL_INVALID_MAC);
1015         }
1016
1017       dec_msglen = ssl->in_msglen;
1018       dec_msg = ssl->in_msg;
1019       dec_msg_result = ssl->in_msg;
1020
1021       /*
1022        * Initialize for prepended IV for block cipher in TLS v1.1 and up
1023        */
1024       if (ssl->minor_ver >= SSL_MINOR_VERSION_2)
1025         {
1026           dec_msglen -= ssl->transform_in->ivlen;
1027           ssl->in_msglen -= ssl->transform_in->ivlen;
1028
1029           for (i = 0; i < ssl->transform_in->ivlen; i++)
1030             ssl->transform_in->iv_dec[i] = ssl->in_iv[i];
1031         }
1032
1033       if ((ret = cipher_crypt (&ssl->transform_in->cipher_ctx_dec,
1034                                ssl->transform_in->iv_dec,
1035                                ssl->transform_in->ivlen,
1036                                dec_msg, dec_msglen,
1037                                dec_msg_result, &olen)) != 0)
1038         {
1039           SSL_DEBUG_RET (1, "cipher_crypt", ret);
1040           return (ret);
1041         }
1042
1043       if (dec_msglen != olen)
1044         {
1045           SSL_DEBUG_MSG (1, ("should never happen"));
1046           return (POLARSSL_ERR_SSL_INTERNAL_ERROR);
1047         }
1048
1049       padlen = 1 + ssl->in_msg[ssl->in_msglen - 1];
1050
1051       if (ssl->in_msglen < ssl->transform_in->maclen + padlen)
1052         {
1053           SSL_DEBUG_MSG (1, ("msglen (%d) < maclen (%d) + padlen (%d)",
1054                              ssl->in_msglen, ssl->transform_in->maclen,
1055                              padlen));
1056           padlen = 0;
1057           correct = 0;
1058         }
1059
1060       if (ssl->minor_ver > SSL_MINOR_VERSION_0)
1061         {
1062           /*
1063            * TLSv1+: always check the padding up to the first failure
1064            * and fake check up to 256 bytes of padding
1065            */
1066           size_t pad_count = 0, real_count = 1;
1067           size_t padding_idx = ssl->in_msglen - padlen - 1;
1068
1069           /*
1070            * Padding is guaranteed to be incorrect if:
1071            *   1. padlen >= ssl->in_msglen
1072            *
1073            *   2. padding_idx >= SSL_MAX_CONTENT_LEN +
1074            *                     ssl->transform_in->maclen
1075            *
1076            * In both cases we reset padding_idx to a safe value (0) to
1077            * prevent out-of-buffer reads.
1078            */
1079           correct &= (ssl->in_msglen >= padlen + 1);
1080           correct &= (padding_idx < SSL_MAX_CONTENT_LEN +
1081                       ssl->transform_in->maclen);
1082
1083           padding_idx *= correct;
1084
1085           for (i = 1; i <= 256; i++)
1086             {
1087               real_count &= (i <= padlen);
1088               pad_count += real_count *
1089                 (ssl->in_msg[padding_idx + i] == padlen - 1);
1090             }
1091
1092           correct &= (pad_count == padlen);     /* Only 1 on correct padding */
1093
1094           if (padlen > 0 && correct == 0)
1095             SSL_DEBUG_MSG (1, ("bad padding byte detected"));
1096
1097           padlen &= correct * 0x1FF;
1098         }
1099       else
1100         {
1101           SSL_DEBUG_MSG (1, ("should never happen"));
1102           return (POLARSSL_ERR_SSL_INTERNAL_ERROR);
1103         }
1104     }
1105   else
1106     {
1107       SSL_DEBUG_MSG (1, ("should never happen"));
1108       return (POLARSSL_ERR_SSL_INTERNAL_ERROR);
1109     }
1110
1111   SSL_DEBUG_BUF (4, "raw buffer after decryption",
1112                  ssl->in_msg, ssl->in_msglen);
1113
1114   /*
1115    * Always compute the MAC (RFC4346, CBCTIME), except for AEAD of course
1116    */
1117   if (mode != POLARSSL_MODE_GCM && mode != POLARSSL_MODE_CCM)
1118     {
1119       unsigned char tmp[POLARSSL_SSL_MAX_MAC_SIZE];
1120
1121       ssl->in_msglen -= (ssl->transform_in->maclen + padlen);
1122
1123       ssl->in_hdr[3] = (unsigned char) (ssl->in_msglen >> 8);
1124       ssl->in_hdr[4] = (unsigned char) (ssl->in_msglen);
1125
1126       memcpy (tmp, ssl->in_msg + ssl->in_msglen, ssl->transform_in->maclen);
1127
1128       if (ssl->minor_ver > SSL_MINOR_VERSION_0)
1129         {
1130           /*
1131            * Process MAC and always update for padlen afterwards to make
1132            * total time independent of padlen
1133            *
1134            * extra_run compensates MAC check for padlen
1135            *
1136            * Known timing attacks:
1137            *  - Lucky Thirteen (http://www.isg.rhul.ac.uk/tls/TLStiming.pdf)
1138            *
1139            * We use ( ( Lx + 8 ) / 64 ) to handle 'negative Lx' values
1140            * correctly. (We round down instead of up, so -56 is the correct
1141            * value for our calculations instead of -55)
1142            */
1143           size_t j, extra_run = 0;
1144           extra_run = (13 + ssl->in_msglen + padlen + 8) / 64 -
1145             (13 + ssl->in_msglen + 8) / 64;
1146
1147           extra_run &= correct * 0xFF;
1148
1149           md_hmac_update (&ssl->transform_in->md_ctx_dec, ssl->in_ctr, 13);
1150           md_hmac_update (&ssl->transform_in->md_ctx_dec, ssl->in_msg,
1151                           ssl->in_msglen);
1152           md_hmac_finish (&ssl->transform_in->md_ctx_dec,
1153                           ssl->in_msg + ssl->in_msglen);
1154           for (j = 0; j < extra_run; j++)
1155             md_process (&ssl->transform_in->md_ctx_dec, ssl->in_msg);
1156
1157           md_hmac_reset (&ssl->transform_in->md_ctx_dec);
1158         }
1159       else
1160         {
1161           SSL_DEBUG_MSG (1, ("should never happen"));
1162           return (POLARSSL_ERR_SSL_INTERNAL_ERROR);
1163         }
1164
1165       SSL_DEBUG_BUF (4, "message  mac", tmp, ssl->transform_in->maclen);
1166       SSL_DEBUG_BUF (4, "computed mac", ssl->in_msg + ssl->in_msglen,
1167                      ssl->transform_in->maclen);
1168
1169       if (safer_memcmp (tmp, ssl->in_msg + ssl->in_msglen,
1170                         ssl->transform_in->maclen) != 0)
1171         {
1172           SSL_DEBUG_MSG (1, ("message mac does not match"));
1173           correct = 0;
1174         }
1175
1176       /*
1177        * Finally check the correct flag
1178        */
1179       if (correct == 0)
1180         return (POLARSSL_ERR_SSL_INVALID_MAC);
1181     }
1182
1183   if (ssl->in_msglen == 0)
1184     {
1185       ssl->nb_zero++;
1186
1187       /*
1188        * Three or more empty messages may be a DoS attack
1189        * (excessive CPU consumption).
1190        */
1191       if (ssl->nb_zero > 3)
1192         {
1193           SSL_DEBUG_MSG (1, ("received four consecutive empty "
1194                              "messages, possible DoS attack"));
1195           return (POLARSSL_ERR_SSL_INVALID_MAC);
1196         }
1197     }
1198   else
1199     ssl->nb_zero = 0;
1200
1201   for (i = 8; i > 0; i--)
1202     if (++ssl->in_ctr[i - 1] != 0)
1203       break;
1204
1205   /* The loops goes to its end iff the counter is wrapping */
1206   if (i == 0)
1207     {
1208       SSL_DEBUG_MSG (1, ("incoming message counter would wrap"));
1209       return (POLARSSL_ERR_SSL_COUNTER_WRAPPING);
1210     }
1211
1212   SSL_DEBUG_MSG (2, ("<= decrypt buf"));
1213
1214   return (0);
1215 }
1216
1217
1218 /*
1219  * Compression/decompression functions
1220  */
1221 static int
1222 ssl_compress_buf (ssl_context * ssl)
1223 {
1224   int ret;
1225   unsigned char *msg_post = ssl->out_msg;
1226   size_t len_pre = ssl->out_msglen;
1227   unsigned char *msg_pre = ssl->compress_buf;
1228
1229   SSL_DEBUG_MSG (2, ("=> compress buf"));
1230
1231   if (len_pre == 0)
1232     return (0);
1233
1234   memcpy (msg_pre, ssl->out_msg, len_pre);
1235
1236   SSL_DEBUG_MSG (3, ("before compression: msglen = %d, ", ssl->out_msglen));
1237
1238   SSL_DEBUG_BUF (4, "before compression: output payload",
1239                  ssl->out_msg, ssl->out_msglen);
1240
1241   ssl->transform_out->ctx_deflate.next_in = msg_pre;
1242   ssl->transform_out->ctx_deflate.avail_in = len_pre;
1243   ssl->transform_out->ctx_deflate.next_out = msg_post;
1244   ssl->transform_out->ctx_deflate.avail_out = SSL_BUFFER_LEN;
1245
1246   ret = deflate (&ssl->transform_out->ctx_deflate, Z_SYNC_FLUSH);
1247   if (ret != Z_OK)
1248     {
1249       SSL_DEBUG_MSG (1, ("failed to perform compression (%d)", ret));
1250       return (POLARSSL_ERR_SSL_COMPRESSION_FAILED);
1251     }
1252
1253   ssl->out_msglen = SSL_BUFFER_LEN -
1254     ssl->transform_out->ctx_deflate.avail_out;
1255
1256   SSL_DEBUG_MSG (3, ("after compression: msglen = %d, ", ssl->out_msglen));
1257
1258   SSL_DEBUG_BUF (4, "after compression: output payload",
1259                  ssl->out_msg, ssl->out_msglen);
1260
1261   SSL_DEBUG_MSG (2, ("<= compress buf"));
1262
1263   return (0);
1264 }
1265
1266 static int
1267 ssl_decompress_buf (ssl_context * ssl)
1268 {
1269   int ret;
1270   unsigned char *msg_post = ssl->in_msg;
1271   size_t len_pre = ssl->in_msglen;
1272   unsigned char *msg_pre = ssl->compress_buf;
1273
1274   SSL_DEBUG_MSG (2, ("=> decompress buf"));
1275
1276   if (len_pre == 0)
1277     return (0);
1278
1279   memcpy (msg_pre, ssl->in_msg, len_pre);
1280
1281   SSL_DEBUG_MSG (3, ("before decompression: msglen = %d, ", ssl->in_msglen));
1282
1283   SSL_DEBUG_BUF (4, "before decompression: input payload",
1284                  ssl->in_msg, ssl->in_msglen);
1285
1286   ssl->transform_in->ctx_inflate.next_in = msg_pre;
1287   ssl->transform_in->ctx_inflate.avail_in = len_pre;
1288   ssl->transform_in->ctx_inflate.next_out = msg_post;
1289   ssl->transform_in->ctx_inflate.avail_out = SSL_MAX_CONTENT_LEN;
1290
1291   ret = inflate (&ssl->transform_in->ctx_inflate, Z_SYNC_FLUSH);
1292   if (ret != Z_OK)
1293     {
1294       SSL_DEBUG_MSG (1, ("failed to perform decompression (%d)", ret));
1295       return (POLARSSL_ERR_SSL_COMPRESSION_FAILED);
1296     }
1297
1298   ssl->in_msglen = SSL_MAX_CONTENT_LEN -
1299     ssl->transform_in->ctx_inflate.avail_out;
1300
1301   SSL_DEBUG_MSG (3, ("after decompression: msglen = %d, ", ssl->in_msglen));
1302
1303   SSL_DEBUG_BUF (4, "after decompression: input payload",
1304                  ssl->in_msg, ssl->in_msglen);
1305
1306   SSL_DEBUG_MSG (2, ("<= decompress buf"));
1307
1308   return (0);
1309 }
1310
1311
1312 /*
1313  * Fill the input message buffer
1314  */
1315 int
1316 ssl_fetch_input (ssl_context * ssl, size_t nb_want)
1317 {
1318   int ret;
1319   size_t len;
1320
1321   SSL_DEBUG_MSG (2, ("=> fetch input"));
1322
1323   if (nb_want > SSL_BUFFER_LEN - 8)
1324     {
1325       SSL_DEBUG_MSG (1, ("requesting more data than fits"));
1326       return (POLARSSL_ERR_SSL_BAD_INPUT_DATA);
1327     }
1328
1329   while (ssl->in_left < nb_want)
1330     {
1331       len = nb_want - ssl->in_left;
1332       ret = ssl->f_recv (ssl->p_recv, ssl->in_hdr + ssl->in_left, len);
1333
1334       SSL_DEBUG_MSG (2, ("in_left: %d, nb_want: %d", ssl->in_left, nb_want));
1335       SSL_DEBUG_RET (2, "ssl->f_recv", ret);
1336
1337       if (ret == 0)
1338         return (POLARSSL_ERR_SSL_CONN_EOF);
1339
1340       if (ret < 0)
1341         return (ret);
1342
1343       ssl->in_left += ret;
1344     }
1345
1346   SSL_DEBUG_MSG (2, ("<= fetch input"));
1347
1348   return (0);
1349 }
1350
1351 /*
1352  * Flush any data not yet written
1353  */
1354 int
1355 ssl_flush_output (ssl_context * ssl)
1356 {
1357   int ret;
1358   unsigned char *buf;
1359
1360   SSL_DEBUG_MSG (2, ("=> flush output"));
1361
1362   while (ssl->out_left > 0)
1363     {
1364       SSL_DEBUG_MSG (2, ("message length: %d, out_left: %d",
1365                          5 + ssl->out_msglen, ssl->out_left));
1366
1367       buf = ssl->out_hdr + 5 + ssl->out_msglen - ssl->out_left;
1368       ret = ssl->f_send (ssl->p_send, buf, ssl->out_left);
1369
1370       SSL_DEBUG_RET (2, "ssl->f_send", ret);
1371
1372       if (ret <= 0)
1373         return (ret);
1374
1375       ssl->out_left -= ret;
1376     }
1377
1378   SSL_DEBUG_MSG (2, ("<= flush output"));
1379
1380   return (0);
1381 }
1382
1383
1384 /*
1385  * Record layer functions
1386  */
1387 int
1388 ssl_write_record (ssl_context * ssl)
1389 {
1390   int ret, done = 0;
1391   size_t len = ssl->out_msglen;
1392
1393   SSL_DEBUG_MSG (2, ("=> write record"));
1394
1395   if (ssl->out_msgtype == SSL_MSG_HANDSHAKE)
1396     {
1397       ssl->out_msg[1] = (unsigned char) ((len - 4) >> 16);
1398       ssl->out_msg[2] = (unsigned char) ((len - 4) >> 8);
1399       ssl->out_msg[3] = (unsigned char) ((len - 4));
1400
1401       if (ssl->out_msg[0] != SSL_HS_HELLO_REQUEST)
1402         ssl->handshake->update_checksum (ssl, ssl->out_msg, len);
1403     }
1404
1405   if (ssl->transform_out != NULL &&
1406       ssl->session_out->compression == SSL_COMPRESS_DEFLATE)
1407     {
1408       if ((ret = ssl_compress_buf (ssl)) != 0)
1409         {
1410           SSL_DEBUG_RET (1, "ssl_compress_buf", ret);
1411           return (ret);
1412         }
1413
1414       len = ssl->out_msglen;
1415     }
1416
1417   if (!done)
1418     {
1419       ssl->out_hdr[0] = (unsigned char) ssl->out_msgtype;
1420       ssl->out_hdr[1] = (unsigned char) ssl->major_ver;
1421       ssl->out_hdr[2] = (unsigned char) ssl->minor_ver;
1422       ssl->out_hdr[3] = (unsigned char) (len >> 8);
1423       ssl->out_hdr[4] = (unsigned char) (len);
1424
1425       if (ssl->transform_out != NULL)
1426         {
1427           if ((ret = ssl_encrypt_buf (ssl)) != 0)
1428             {
1429               SSL_DEBUG_RET (1, "ssl_encrypt_buf", ret);
1430               return (ret);
1431             }
1432
1433           len = ssl->out_msglen;
1434           ssl->out_hdr[3] = (unsigned char) (len >> 8);
1435           ssl->out_hdr[4] = (unsigned char) (len);
1436         }
1437
1438       ssl->out_left = 5 + ssl->out_msglen;
1439
1440       SSL_DEBUG_MSG (3, ("output record: msgtype = %d, "
1441                          "version = [%d:%d], msglen = %d",
1442                          ssl->out_hdr[0], ssl->out_hdr[1], ssl->out_hdr[2],
1443                          (ssl->out_hdr[3] << 8) | ssl->out_hdr[4]));
1444
1445       SSL_DEBUG_BUF (4, "output record sent to network",
1446                      ssl->out_hdr, 5 + ssl->out_msglen);
1447     }
1448
1449   if ((ret = ssl_flush_output (ssl)) != 0)
1450     {
1451       SSL_DEBUG_RET (1, "ssl_flush_output", ret);
1452       return (ret);
1453     }
1454
1455   SSL_DEBUG_MSG (2, ("<= write record"));
1456
1457   return (0);
1458 }
1459
1460
1461 int
1462 ssl_read_record (ssl_context * ssl)
1463 {
1464   int ret, done = 0;
1465
1466   SSL_DEBUG_MSG (2, ("=> read record"));
1467
1468   if (ssl->in_hslen != 0 && ssl->in_hslen < ssl->in_msglen)
1469     {
1470       /*
1471        * Get next Handshake message in the current record
1472        */
1473       ssl->in_msglen -= ssl->in_hslen;
1474
1475       memmove (ssl->in_msg, ssl->in_msg + ssl->in_hslen, ssl->in_msglen);
1476
1477       ssl->in_hslen = 4;
1478       ssl->in_hslen += (ssl->in_msg[2] << 8) | ssl->in_msg[3];
1479
1480       SSL_DEBUG_MSG (3, ("handshake message: msglen ="
1481                          " %d, type = %d, hslen = %d",
1482                          ssl->in_msglen, ssl->in_msg[0], ssl->in_hslen));
1483
1484       if (ssl->in_msglen < 4 || ssl->in_msg[1] != 0)
1485         {
1486           SSL_DEBUG_MSG (1, ("bad handshake length"));
1487           return (POLARSSL_ERR_SSL_INVALID_RECORD);
1488         }
1489
1490       if (ssl->in_msglen < ssl->in_hslen)
1491         {
1492           SSL_DEBUG_MSG (1, ("bad handshake length"));
1493           return (POLARSSL_ERR_SSL_INVALID_RECORD);
1494         }
1495
1496       if (ssl->state != SSL_HANDSHAKE_OVER)
1497         ssl->handshake->update_checksum (ssl, ssl->in_msg, ssl->in_hslen);
1498
1499       return (0);
1500     }
1501
1502   ssl->in_hslen = 0;
1503
1504   /*
1505    * Read the record header and validate it
1506    */
1507   if ((ret = ssl_fetch_input (ssl, 5)) != 0)
1508     {
1509       SSL_DEBUG_RET (1, "ssl_fetch_input", ret);
1510       return (ret);
1511     }
1512
1513   ssl->in_msgtype = ssl->in_hdr[0];
1514   ssl->in_msglen = (ssl->in_hdr[3] << 8) | ssl->in_hdr[4];
1515
1516   SSL_DEBUG_MSG (3, ("input record: msgtype = %d, "
1517                      "version = [%d:%d], msglen = %d",
1518                      ssl->in_hdr[0], ssl->in_hdr[1], ssl->in_hdr[2],
1519                      (ssl->in_hdr[3] << 8) | ssl->in_hdr[4]));
1520
1521   if (ssl->in_hdr[1] != ssl->major_ver)
1522     {
1523       SSL_DEBUG_MSG (1, ("major version mismatch"));
1524       return (POLARSSL_ERR_SSL_INVALID_RECORD);
1525     }
1526
1527   if (ssl->in_hdr[2] > ssl->max_minor_ver)
1528     {
1529       SSL_DEBUG_MSG (1, ("minor version mismatch"));
1530       return (POLARSSL_ERR_SSL_INVALID_RECORD);
1531     }
1532
1533   /* Sanity check (outer boundaries) */
1534   if (ssl->in_msglen < 1 || ssl->in_msglen > SSL_BUFFER_LEN - 13)
1535     {
1536       SSL_DEBUG_MSG (1, ("bad message length"));
1537       return (POLARSSL_ERR_SSL_INVALID_RECORD);
1538     }
1539
1540   /*
1541    * Make sure the message length is acceptable for the current transform
1542    * and protocol version.
1543    */
1544   if (ssl->transform_in == NULL)
1545     {
1546       if (ssl->in_msglen > SSL_MAX_CONTENT_LEN)
1547         {
1548           SSL_DEBUG_MSG (1, ("bad message length"));
1549           return (POLARSSL_ERR_SSL_INVALID_RECORD);
1550         }
1551     }
1552   else
1553     {
1554       if (ssl->in_msglen < ssl->transform_in->minlen)
1555         {
1556           SSL_DEBUG_MSG (1, ("bad message length"));
1557           return (POLARSSL_ERR_SSL_INVALID_RECORD);
1558         }
1559
1560       /*
1561        * TLS encrypted messages can have up to 256 bytes of padding
1562        */
1563       if (ssl->minor_ver >= SSL_MINOR_VERSION_1 &&
1564           ssl->in_msglen > ssl->transform_in->minlen +
1565           SSL_MAX_CONTENT_LEN + 256)
1566         {
1567           SSL_DEBUG_MSG (1, ("bad message length"));
1568           return (POLARSSL_ERR_SSL_INVALID_RECORD);
1569         }
1570     }
1571
1572   /*
1573    * Read and optionally decrypt the message contents
1574    */
1575   if ((ret = ssl_fetch_input (ssl, 5 + ssl->in_msglen)) != 0)
1576     {
1577       SSL_DEBUG_RET (1, "ssl_fetch_input", ret);
1578       return (ret);
1579     }
1580
1581   SSL_DEBUG_BUF (4, "input record from network",
1582                  ssl->in_hdr, 5 + ssl->in_msglen);
1583
1584   if (!done && ssl->transform_in != NULL)
1585     {
1586       if ((ret = ssl_decrypt_buf (ssl)) != 0)
1587         {
1588           if (ret == POLARSSL_ERR_SSL_INVALID_MAC)
1589             {
1590               ssl_send_alert_message (ssl,
1591                                       SSL_ALERT_LEVEL_FATAL,
1592                                       SSL_ALERT_MSG_BAD_RECORD_MAC);
1593             }
1594           SSL_DEBUG_RET (1, "ssl_decrypt_buf", ret);
1595           return (ret);
1596         }
1597
1598       SSL_DEBUG_BUF (4, "input payload after decrypt",
1599                      ssl->in_msg, ssl->in_msglen);
1600
1601       if (ssl->in_msglen > SSL_MAX_CONTENT_LEN)
1602         {
1603           SSL_DEBUG_MSG (1, ("bad message length"));
1604           return (POLARSSL_ERR_SSL_INVALID_RECORD);
1605         }
1606     }
1607
1608   if (ssl->transform_in != NULL &&
1609       ssl->session_in->compression == SSL_COMPRESS_DEFLATE)
1610     {
1611       if ((ret = ssl_decompress_buf (ssl)) != 0)
1612         {
1613           SSL_DEBUG_RET (1, "ssl_decompress_buf", ret);
1614           return (ret);
1615         }
1616
1617       ssl->in_hdr[3] = (unsigned char) (ssl->in_msglen >> 8);
1618       ssl->in_hdr[4] = (unsigned char) (ssl->in_msglen);
1619     }
1620
1621   if (ssl->in_msgtype != SSL_MSG_HANDSHAKE &&
1622       ssl->in_msgtype != SSL_MSG_ALERT &&
1623       ssl->in_msgtype != SSL_MSG_CHANGE_CIPHER_SPEC &&
1624       ssl->in_msgtype != SSL_MSG_APPLICATION_DATA)
1625     {
1626       SSL_DEBUG_MSG (1, ("unknown record type"));
1627
1628       if ((ret = ssl_send_alert_message (ssl,
1629                                          SSL_ALERT_LEVEL_FATAL,
1630                                          SSL_ALERT_MSG_UNEXPECTED_MESSAGE)) !=
1631           0)
1632         {
1633           return (ret);
1634         }
1635
1636       return (POLARSSL_ERR_SSL_INVALID_RECORD);
1637     }
1638
1639   if (ssl->in_msgtype == SSL_MSG_HANDSHAKE)
1640     {
1641       ssl->in_hslen = 4;
1642       ssl->in_hslen += (ssl->in_msg[2] << 8) | ssl->in_msg[3];
1643
1644       SSL_DEBUG_MSG (3, ("handshake message: msglen ="
1645                          " %d, type = %d, hslen = %d",
1646                          ssl->in_msglen, ssl->in_msg[0], ssl->in_hslen));
1647
1648       /*
1649        * Additional checks to validate the handshake header
1650        */
1651       if (ssl->in_msglen < 4 || ssl->in_msg[1] != 0)
1652         {
1653           SSL_DEBUG_MSG (1, ("bad handshake length"));
1654           return (POLARSSL_ERR_SSL_INVALID_RECORD);
1655         }
1656
1657       if (ssl->in_msglen < ssl->in_hslen)
1658         {
1659           SSL_DEBUG_MSG (1, ("bad handshake length"));
1660           return (POLARSSL_ERR_SSL_INVALID_RECORD);
1661         }
1662
1663       if (ssl->state != SSL_HANDSHAKE_OVER)
1664         ssl->handshake->update_checksum (ssl, ssl->in_msg, ssl->in_hslen);
1665     }
1666
1667   if (ssl->in_msgtype == SSL_MSG_ALERT)
1668     {
1669       SSL_DEBUG_MSG (2, ("got an alert message, type: [%d:%d]",
1670                          ssl->in_msg[0], ssl->in_msg[1]));
1671
1672       /*
1673        * Ignore non-fatal alerts, except close_notify
1674        */
1675       if (ssl->in_msg[0] == SSL_ALERT_LEVEL_FATAL)
1676         {
1677           SSL_DEBUG_MSG (1, ("is a fatal alert message (msg %d)",
1678                              ssl->in_msg[1]));
1679           /**
1680            * Subtract from error code as ssl->in_msg[1] is 7-bit positive
1681            * error identifier.
1682            */
1683           return (POLARSSL_ERR_SSL_FATAL_ALERT_MESSAGE);
1684         }
1685
1686       if (ssl->in_msg[0] == SSL_ALERT_LEVEL_WARNING &&
1687           ssl->in_msg[1] == SSL_ALERT_MSG_CLOSE_NOTIFY)
1688         {
1689           SSL_DEBUG_MSG (2, ("is a close notify message"));
1690           return (POLARSSL_ERR_SSL_PEER_CLOSE_NOTIFY);
1691         }
1692     }
1693
1694   ssl->in_left = 0;
1695
1696   SSL_DEBUG_MSG (2, ("<= read record"));
1697
1698   return (0);
1699 }
1700
1701
1702 int
1703 ssl_send_fatal_handshake_failure (ssl_context * ssl)
1704 {
1705   int ret;
1706
1707   if ((ret = ssl_send_alert_message (ssl,
1708                                      SSL_ALERT_LEVEL_FATAL,
1709                                      SSL_ALERT_MSG_HANDSHAKE_FAILURE)) != 0)
1710     {
1711       return (ret);
1712     }
1713
1714   return (0);
1715 }
1716
1717
1718 int
1719 ssl_send_alert_message (ssl_context * ssl,
1720                         unsigned char level, unsigned char message)
1721 {
1722   int ret;
1723
1724   SSL_DEBUG_MSG (2, ("=> send alert message"));
1725
1726   ssl->out_msgtype = SSL_MSG_ALERT;
1727   ssl->out_msglen = 2;
1728   ssl->out_msg[0] = level;
1729   ssl->out_msg[1] = message;
1730
1731   if ((ret = ssl_write_record (ssl)) != 0)
1732     {
1733       SSL_DEBUG_RET (1, "ssl_write_record", ret);
1734       return (ret);
1735     }
1736
1737   SSL_DEBUG_MSG (2, ("<= send alert message"));
1738
1739   return (0);
1740 }
1741
1742
1743 /*
1744  * Handshake functions
1745  */
1746 int
1747 ssl_write_certificate (ssl_context * ssl)
1748 {
1749   int ret = POLARSSL_ERR_SSL_FEATURE_UNAVAILABLE;
1750   size_t i, n;
1751   const x509_crt *crt;
1752   const ssl_ciphersuite_t *ciphersuite_info =
1753     ssl->transform_negotiate->ciphersuite_info;
1754
1755   SSL_DEBUG_MSG (2, ("=> write certificate"));
1756
1757   if (ciphersuite_info->key_exchange == POLARSSL_KEY_EXCHANGE_PSK ||
1758       ciphersuite_info->key_exchange == POLARSSL_KEY_EXCHANGE_DHE_PSK ||
1759       ciphersuite_info->key_exchange == POLARSSL_KEY_EXCHANGE_ECDHE_PSK)
1760     {
1761       SSL_DEBUG_MSG (2, ("<= skip write certificate"));
1762       ssl->state++;
1763       return (0);
1764     }
1765
1766   if (ssl->endpoint == SSL_IS_CLIENT)
1767     {
1768       if (ssl->client_auth == 0)
1769         {
1770           SSL_DEBUG_MSG (2, ("<= skip write certificate"));
1771           ssl->state++;
1772           return (0);
1773         }
1774
1775     }
1776   else /* SSL_IS_SERVER */
1777     {
1778       if (ssl_own_cert (ssl) == NULL)
1779         {
1780           SSL_DEBUG_MSG (1, ("got no certificate to send"));
1781           return (POLARSSL_ERR_SSL_CERTIFICATE_REQUIRED);
1782         }
1783     }
1784
1785   SSL_DEBUG_CRT (3, "own certificate", ssl_own_cert (ssl));
1786
1787   /*
1788    *     0  .  0    handshake type
1789    *     1  .  3    handshake length
1790    *     4  .  6    length of all certs
1791    *     7  .  9    length of cert. 1
1792    *    10  . n-1   peer certificate
1793    *     n  . n+2   length of cert. 2
1794    *    n+3 . ...   upper level cert, etc.
1795    */
1796   i = 7;
1797   crt = ssl_own_cert (ssl);
1798
1799   while (crt != NULL)
1800     {
1801       n = crt->raw.len;
1802       if (n > SSL_MAX_CONTENT_LEN - 3 - i)
1803         {
1804           SSL_DEBUG_MSG (1, ("certificate too large, %d > %d",
1805                              i + 3 + n, SSL_MAX_CONTENT_LEN));
1806           return (POLARSSL_ERR_SSL_CERTIFICATE_TOO_LARGE);
1807         }
1808
1809       ssl->out_msg[i] = (unsigned char) (n >> 16);
1810       ssl->out_msg[i + 1] = (unsigned char) (n >> 8);
1811       ssl->out_msg[i + 2] = (unsigned char) (n);
1812
1813       i += 3;
1814       memcpy (ssl->out_msg + i, crt->raw.p, n);
1815       i += n;
1816       crt = crt->next;
1817     }
1818
1819   ssl->out_msg[4] = (unsigned char) ((i - 7) >> 16);
1820   ssl->out_msg[5] = (unsigned char) ((i - 7) >> 8);
1821   ssl->out_msg[6] = (unsigned char) ((i - 7));
1822
1823   ssl->out_msglen = i;
1824   ssl->out_msgtype = SSL_MSG_HANDSHAKE;
1825   ssl->out_msg[0] = SSL_HS_CERTIFICATE;
1826
1827   ssl->state++;
1828
1829   if ((ret = ssl_write_record (ssl)) != 0)
1830     {
1831       SSL_DEBUG_RET (1, "ssl_write_record", ret);
1832       return (ret);
1833     }
1834
1835   SSL_DEBUG_MSG (2, ("<= write certificate"));
1836
1837   return (ret);
1838 }
1839
1840
1841 int
1842 ssl_parse_certificate (ssl_context * ssl)
1843 {
1844   int ret = POLARSSL_ERR_SSL_FEATURE_UNAVAILABLE;
1845   size_t i, n;
1846   const ssl_ciphersuite_t *ciphersuite_info =
1847     ssl->transform_negotiate->ciphersuite_info;
1848
1849   SSL_DEBUG_MSG (2, ("=> parse certificate"));
1850
1851   if (ciphersuite_info->key_exchange == POLARSSL_KEY_EXCHANGE_PSK ||
1852       ciphersuite_info->key_exchange == POLARSSL_KEY_EXCHANGE_DHE_PSK ||
1853       ciphersuite_info->key_exchange == POLARSSL_KEY_EXCHANGE_ECDHE_PSK)
1854     {
1855       SSL_DEBUG_MSG (2, ("<= skip parse certificate"));
1856       ssl->state++;
1857       return (0);
1858     }
1859
1860   if (ssl->endpoint == SSL_IS_SERVER &&
1861       (ssl->authmode == SSL_VERIFY_NONE ||
1862        ciphersuite_info->key_exchange == POLARSSL_KEY_EXCHANGE_RSA_PSK))
1863     {
1864       ssl->session_negotiate->verify_result = BADCERT_SKIP_VERIFY;
1865       SSL_DEBUG_MSG (2, ("<= skip parse certificate"));
1866       ssl->state++;
1867       return (0);
1868     }
1869
1870   if ((ret = ssl_read_record (ssl)) != 0)
1871     {
1872       SSL_DEBUG_RET (1, "ssl_read_record", ret);
1873       return (ret);
1874     }
1875
1876   ssl->state++;
1877
1878
1879   if (ssl->endpoint == SSL_IS_SERVER && ssl->minor_ver != SSL_MINOR_VERSION_0)
1880     {
1881       if (ssl->in_hslen == 7 &&
1882           ssl->in_msgtype == SSL_MSG_HANDSHAKE &&
1883           ssl->in_msg[0] == SSL_HS_CERTIFICATE &&
1884           memcmp (ssl->in_msg + 4, "\0\0\0", 3) == 0)
1885         {
1886           SSL_DEBUG_MSG (1, ("TLSv1 client has no certificate"));
1887
1888           ssl->session_negotiate->verify_result = BADCERT_MISSING;
1889           if (ssl->authmode == SSL_VERIFY_REQUIRED)
1890             return (POLARSSL_ERR_SSL_NO_CLIENT_CERTIFICATE);
1891           else
1892             return (0);
1893         }
1894     }
1895
1896   if (ssl->in_msgtype != SSL_MSG_HANDSHAKE)
1897     {
1898       SSL_DEBUG_MSG (1, ("bad certificate message"));
1899       return (POLARSSL_ERR_SSL_UNEXPECTED_MESSAGE);
1900     }
1901
1902   if (ssl->in_msg[0] != SSL_HS_CERTIFICATE || ssl->in_hslen < 10)
1903     {
1904       SSL_DEBUG_MSG (1, ("bad certificate message"));
1905       return (POLARSSL_ERR_SSL_BAD_HS_CERTIFICATE);
1906     }
1907
1908   /*
1909    * Same message structure as in ssl_write_certificate()
1910    */
1911   n = (ssl->in_msg[5] << 8) | ssl->in_msg[6];
1912
1913   if (ssl->in_msg[4] != 0 || ssl->in_hslen != 7 + n)
1914     {
1915       SSL_DEBUG_MSG (1, ("bad certificate message"));
1916       return (POLARSSL_ERR_SSL_BAD_HS_CERTIFICATE);
1917     }
1918
1919   /* In case we tried to reuse a session but it failed */
1920   if (ssl->session_negotiate->peer_cert != NULL)
1921     {
1922       x509_crt_free (ssl->session_negotiate->peer_cert);
1923       polarssl_free (ssl->session_negotiate->peer_cert);
1924     }
1925
1926   if ((ssl->session_negotiate->peer_cert =
1927        (x509_crt *) polarssl_malloc (sizeof (x509_crt))) == NULL)
1928     {
1929       SSL_DEBUG_MSG (1, ("malloc(%d bytes) failed", sizeof (x509_crt)));
1930       return (POLARSSL_ERR_SSL_MALLOC_FAILED);
1931     }
1932
1933   x509_crt_init (ssl->session_negotiate->peer_cert);
1934
1935   i = 7;
1936
1937   while (i < ssl->in_hslen)
1938     {
1939       if (ssl->in_msg[i] != 0)
1940         {
1941           SSL_DEBUG_MSG (1, ("bad certificate message"));
1942           return (POLARSSL_ERR_SSL_BAD_HS_CERTIFICATE);
1943         }
1944
1945       n = ((unsigned int) ssl->in_msg[i + 1] << 8)
1946         | (unsigned int) ssl->in_msg[i + 2];
1947       i += 3;
1948
1949       if (n < 128 || i + n > ssl->in_hslen)
1950         {
1951           SSL_DEBUG_MSG (1, ("bad certificate message"));
1952           return (POLARSSL_ERR_SSL_BAD_HS_CERTIFICATE);
1953         }
1954
1955       ret = x509_crt_parse_der (ssl->session_negotiate->peer_cert,
1956                                 ssl->in_msg + i, n);
1957       if (ret != 0)
1958         {
1959           SSL_DEBUG_RET (1, " x509_crt_parse_der", ret);
1960           return (ret);
1961         }
1962
1963       i += n;
1964     }
1965
1966   SSL_DEBUG_CRT (3, "peer certificate", ssl->session_negotiate->peer_cert);
1967
1968   /*
1969    * On client, make sure the server cert doesn't change during renego to
1970    * avoid "triple handshake" attack: https://secure-resumption.com/
1971    */
1972   if (ssl->endpoint == SSL_IS_CLIENT &&
1973       ssl->renegotiation == SSL_RENEGOTIATION)
1974     {
1975       if (ssl->session->peer_cert == NULL)
1976         {
1977           SSL_DEBUG_MSG (1, ("new server cert during renegotiation"));
1978           return (POLARSSL_ERR_SSL_BAD_HS_CERTIFICATE);
1979         }
1980
1981       if (ssl->session->peer_cert->raw.len !=
1982           ssl->session_negotiate->peer_cert->raw.len ||
1983           memcmp (ssl->session->peer_cert->raw.p,
1984                   ssl->session_negotiate->peer_cert->raw.p,
1985                   ssl->session->peer_cert->raw.len) != 0)
1986         {
1987           SSL_DEBUG_MSG (1, ("server cert changed during renegotiation"));
1988           return (POLARSSL_ERR_SSL_BAD_HS_CERTIFICATE);
1989         }
1990     }
1991
1992   if (ssl->authmode != SSL_VERIFY_NONE)
1993     {
1994       if (ssl->ca_chain == NULL)
1995         {
1996           SSL_DEBUG_MSG (1, ("got no CA chain"));
1997           return (POLARSSL_ERR_SSL_CA_CHAIN_REQUIRED);
1998         }
1999
2000       /*
2001        * Main check: verify certificate
2002        */
2003       ret = x509_crt_verify (ssl->session_negotiate->peer_cert,
2004                              ssl->ca_chain, ssl->ca_crl, ssl->peer_cn,
2005                              &ssl->session_negotiate->verify_result,
2006                              ssl->f_vrfy, ssl->p_vrfy);
2007
2008       if (ret != 0)
2009         {
2010           SSL_DEBUG_RET (1, "x509_verify_cert", ret);
2011         }
2012
2013       /*
2014        * Secondary checks: always done, but change 'ret' only if it was 0
2015        */
2016       {
2017         pk_context *pk = &ssl->session_negotiate->peer_cert->pk;
2018
2019         /* If certificate uses an EC key, make sure the curve is OK */
2020         if (pk_can_do (pk, POLARSSL_PK_ECKEY) &&
2021             !ssl_curve_is_acceptable (ssl, pk_ec (*pk)->grp.id))
2022           {
2023             SSL_DEBUG_MSG (1, ("bad certificate (EC key curve)"));
2024             if (ret == 0)
2025               ret = POLARSSL_ERR_SSL_BAD_HS_CERTIFICATE;
2026           }
2027       }
2028
2029       if (ssl_check_cert_usage (ssl->session_negotiate->peer_cert,
2030                                 ciphersuite_info, !ssl->endpoint) != 0)
2031         {
2032           SSL_DEBUG_MSG (1, ("bad certificate (usage extensions)"));
2033           if (ret == 0)
2034             ret = POLARSSL_ERR_SSL_BAD_HS_CERTIFICATE;
2035         }
2036
2037       if (ssl->authmode != SSL_VERIFY_REQUIRED)
2038         ret = 0;
2039     }
2040
2041   SSL_DEBUG_MSG (2, ("<= parse certificate"));
2042
2043   return (ret);
2044 }
2045
2046
2047 int
2048 ssl_write_change_cipher_spec (ssl_context * ssl)
2049 {
2050   int ret;
2051
2052   SSL_DEBUG_MSG (2, ("=> write change cipher spec"));
2053
2054   ssl->out_msgtype = SSL_MSG_CHANGE_CIPHER_SPEC;
2055   ssl->out_msglen = 1;
2056   ssl->out_msg[0] = 1;
2057
2058   ssl->state++;
2059
2060   if ((ret = ssl_write_record (ssl)) != 0)
2061     {
2062       SSL_DEBUG_RET (1, "ssl_write_record", ret);
2063       return (ret);
2064     }
2065
2066   SSL_DEBUG_MSG (2, ("<= write change cipher spec"));
2067
2068   return (0);
2069 }
2070
2071 int
2072 ssl_parse_change_cipher_spec (ssl_context * ssl)
2073 {
2074   int ret;
2075
2076   SSL_DEBUG_MSG (2, ("=> parse change cipher spec"));
2077
2078   if ((ret = ssl_read_record (ssl)) != 0)
2079     {
2080       SSL_DEBUG_RET (1, "ssl_read_record", ret);
2081       return (ret);
2082     }
2083
2084   if (ssl->in_msgtype != SSL_MSG_CHANGE_CIPHER_SPEC)
2085     {
2086       SSL_DEBUG_MSG (1, ("bad change cipher spec message"));
2087       return (POLARSSL_ERR_SSL_UNEXPECTED_MESSAGE);
2088     }
2089
2090   if (ssl->in_msglen != 1 || ssl->in_msg[0] != 1)
2091     {
2092       SSL_DEBUG_MSG (1, ("bad change cipher spec message"));
2093       return (POLARSSL_ERR_SSL_BAD_HS_CHANGE_CIPHER_SPEC);
2094     }
2095
2096   ssl->state++;
2097
2098   SSL_DEBUG_MSG (2, ("<= parse change cipher spec"));
2099
2100   return (0);
2101 }
2102
2103
2104 void
2105 ssl_optimize_checksum (ssl_context * ssl,
2106                        const ssl_ciphersuite_t * ciphersuite_info)
2107 {
2108   ((void) ciphersuite_info);
2109
2110   if (ciphersuite_info->mac == POLARSSL_MD_SHA384)
2111     ssl->handshake->update_checksum = ssl_update_checksum_sha384;
2112   else if (ciphersuite_info->mac != POLARSSL_MD_SHA384)
2113     ssl->handshake->update_checksum = ssl_update_checksum_sha256;
2114   else
2115     {
2116       SSL_DEBUG_MSG (1, ("should never happen"));
2117       return;
2118     }
2119 }
2120
2121
2122 static void
2123 ssl_update_checksum_start (ssl_context * ssl,
2124                            const unsigned char *buf, size_t len)
2125 {
2126   sha256_update (&ssl->handshake->fin_sha256, buf, len);
2127   sha512_update (&ssl->handshake->fin_sha512, buf, len);
2128 }
2129
2130
2131
2132 static void
2133 ssl_update_checksum_sha256 (ssl_context * ssl,
2134                             const unsigned char *buf, size_t len)
2135 {
2136   sha256_update (&ssl->handshake->fin_sha256, buf, len);
2137 }
2138
2139
2140 static void
2141 ssl_update_checksum_sha384 (ssl_context * ssl,
2142                             const unsigned char *buf, size_t len)
2143 {
2144   sha512_update (&ssl->handshake->fin_sha512, buf, len);
2145 }
2146
2147
2148 static void
2149 ssl_calc_finished_tls_sha256 (ssl_context * ssl, unsigned char *buf, int from)
2150 {
2151   int len = 12;
2152   const char *sender;
2153   sha256_context sha256;
2154   unsigned char padbuf[32];
2155
2156   ssl_session *session = ssl->session_negotiate;
2157   if (!session)
2158     session = ssl->session;
2159
2160   SSL_DEBUG_MSG (2, ("=> calc  finished tls sha256"));
2161
2162   memcpy (&sha256, &ssl->handshake->fin_sha256, sizeof (sha256_context));
2163
2164   /*
2165    * TLSv1.2:
2166    *   hash = PRF( master, finished_label,
2167    *               Hash( handshake ) )[0.11]
2168    */
2169
2170 #if !defined(POLARSSL_SHA256_ALT)
2171   SSL_DEBUG_BUF (4, "finished sha2 state", (unsigned char *)
2172                  sha256.state, sizeof (sha256.state));
2173 #endif
2174
2175   sender = (from == SSL_IS_CLIENT) ? "client finished" : "server finished";
2176
2177   sha256_finish (&sha256, padbuf);
2178
2179   ssl->handshake->tls_prf (session->master, 48, sender, padbuf, 32, buf, len);
2180
2181   SSL_DEBUG_BUF (3, "calc finished result", buf, len);
2182
2183   sha256_free (&sha256);
2184
2185   polarssl_zeroize (padbuf, sizeof (padbuf));
2186
2187   SSL_DEBUG_MSG (2, ("<= calc  finished"));
2188 }
2189
2190
2191 static void
2192 ssl_calc_finished_tls_sha384 (ssl_context * ssl, unsigned char *buf, int from)
2193 {
2194   int len = 12;
2195   const char *sender;
2196   sha512_context sha512;
2197   unsigned char padbuf[48];
2198
2199   ssl_session *session = ssl->session_negotiate;
2200   if (!session)
2201     session = ssl->session;
2202
2203   SSL_DEBUG_MSG (2, ("=> calc  finished tls sha384"));
2204
2205   memcpy (&sha512, &ssl->handshake->fin_sha512, sizeof (sha512_context));
2206
2207   /*
2208    * TLSv1.2:
2209    *   hash = PRF( master, finished_label,
2210    *               Hash( handshake ) )[0.11]
2211    */
2212
2213 #if !defined(POLARSSL_SHA512_ALT)
2214   SSL_DEBUG_BUF (4, "finished sha512 state", (unsigned char *)
2215                  sha512.state, sizeof (sha512.state));
2216 #endif
2217
2218   sender = (from == SSL_IS_CLIENT) ? "client finished" : "server finished";
2219
2220   sha512_finish (&sha512, padbuf);
2221
2222   ssl->handshake->tls_prf (session->master, 48, sender, padbuf, 48, buf, len);
2223
2224   SSL_DEBUG_BUF (3, "calc finished result", buf, len);
2225
2226   sha512_free (&sha512);
2227
2228   polarssl_zeroize (padbuf, sizeof (padbuf));
2229
2230   SSL_DEBUG_MSG (2, ("<= calc  finished"));
2231 }
2232
2233
2234 void
2235 ssl_handshake_wrapup (ssl_context * ssl)
2236 {
2237   int resume = ssl->handshake->resume;
2238
2239   SSL_DEBUG_MSG (3, ("=> handshake wrapup"));
2240
2241   /*
2242    * Free our handshake params
2243    */
2244   ssl_handshake_free (ssl->handshake);
2245   polarssl_free (ssl->handshake);
2246   ssl->handshake = NULL;
2247
2248   if (ssl->renegotiation == SSL_RENEGOTIATION)
2249     {
2250       ssl->renegotiation = SSL_RENEGOTIATION_DONE;
2251       ssl->renego_records_seen = 0;
2252     }
2253
2254   /*
2255    * Switch in our now active transform context
2256    */
2257   if (ssl->transform)
2258     {
2259       ssl_transform_free (ssl->transform);
2260       polarssl_free (ssl->transform);
2261     }
2262   ssl->transform = ssl->transform_negotiate;
2263   ssl->transform_negotiate = NULL;
2264
2265   if (ssl->session)
2266     {
2267       ssl_session_free (ssl->session);
2268       polarssl_free (ssl->session);
2269     }
2270   ssl->session = ssl->session_negotiate;
2271   ssl->session_negotiate = NULL;
2272
2273   /*
2274    * Add cache entry
2275    */
2276   if (ssl->f_set_cache != NULL && ssl->session->length != 0 && resume == 0)
2277     {
2278       if (ssl->f_set_cache (ssl->p_set_cache, ssl->session) != 0)
2279         SSL_DEBUG_MSG (1, ("cache did not store session"));
2280     }
2281
2282   ssl->state++;
2283
2284   SSL_DEBUG_MSG (3, ("<= handshake wrapup"));
2285 }
2286
2287
2288 int
2289 ssl_write_finished (ssl_context * ssl)
2290 {
2291   int ret, hash_len;
2292
2293   SSL_DEBUG_MSG (2, ("=> write finished"));
2294
2295   /*
2296    * Set the out_msg pointer to the correct location based on IV length
2297    */
2298   if (ssl->minor_ver >= SSL_MINOR_VERSION_2)
2299     {
2300       ssl->out_msg = ssl->out_iv + ssl->transform_negotiate->ivlen -
2301         ssl->transform_negotiate->fixed_ivlen;
2302     }
2303   else
2304     ssl->out_msg = ssl->out_iv;
2305
2306   ssl->handshake->calc_finished (ssl, ssl->out_msg + 4, ssl->endpoint);
2307
2308   // TODO TLS/1.2 Hash length is determined by cipher suite (Page 63)
2309   hash_len = (ssl->minor_ver == SSL_MINOR_VERSION_0) ? 36 : 12;
2310
2311   ssl->verify_data_len = hash_len;
2312   memcpy (ssl->own_verify_data, ssl->out_msg + 4, hash_len);
2313
2314   ssl->out_msglen = 4 + hash_len;
2315   ssl->out_msgtype = SSL_MSG_HANDSHAKE;
2316   ssl->out_msg[0] = SSL_HS_FINISHED;
2317
2318   /*
2319    * In case of session resuming, invert the client and server
2320    * ChangeCipherSpec messages order.
2321    */
2322   if (ssl->handshake->resume != 0)
2323     {
2324       if (ssl->endpoint == SSL_IS_CLIENT)
2325         ssl->state = SSL_HANDSHAKE_WRAPUP;
2326       else
2327         ssl->state = SSL_CLIENT_CHANGE_CIPHER_SPEC;
2328     }
2329   else
2330     ssl->state++;
2331
2332   /*
2333    * Switch to our negotiated transform and session parameters for outbound
2334    * data.
2335    */
2336   SSL_DEBUG_MSG (3, ("switching to new transform spec for outbound data"));
2337   ssl->transform_out = ssl->transform_negotiate;
2338   ssl->session_out = ssl->session_negotiate;
2339   memset (ssl->out_ctr, 0, 8);
2340
2341   if ((ret = ssl_write_record (ssl)) != 0)
2342     {
2343       SSL_DEBUG_RET (1, "ssl_write_record", ret);
2344       return (ret);
2345     }
2346
2347   SSL_DEBUG_MSG (2, ("<= write finished"));
2348
2349   return (0);
2350 }
2351
2352
2353 int
2354 ssl_parse_finished (ssl_context * ssl)
2355 {
2356   int ret;
2357   unsigned int hash_len;
2358   unsigned char buf[36];
2359
2360   SSL_DEBUG_MSG (2, ("=> parse finished"));
2361
2362   ssl->handshake->calc_finished (ssl, buf, ssl->endpoint ^ 1);
2363
2364   /*
2365    * Switch to our negotiated transform and session parameters for inbound
2366    * data.
2367    */
2368   SSL_DEBUG_MSG (3, ("switching to new transform spec for inbound data"));
2369   ssl->transform_in = ssl->transform_negotiate;
2370   ssl->session_in = ssl->session_negotiate;
2371   memset (ssl->in_ctr, 0, 8);
2372
2373   /*
2374    * Set the in_msg pointer to the correct location based on IV length
2375    */
2376   if (ssl->minor_ver >= SSL_MINOR_VERSION_2)
2377     {
2378       ssl->in_msg = ssl->in_iv + ssl->transform_negotiate->ivlen -
2379         ssl->transform_negotiate->fixed_ivlen;
2380     }
2381   else
2382     ssl->in_msg = ssl->in_iv;
2383
2384   if ((ret = ssl_read_record (ssl)) != 0)
2385     {
2386       SSL_DEBUG_RET (1, "ssl_read_record", ret);
2387       return (ret);
2388     }
2389
2390   if (ssl->in_msgtype != SSL_MSG_HANDSHAKE)
2391     {
2392       SSL_DEBUG_MSG (1, ("bad finished message"));
2393       return (POLARSSL_ERR_SSL_UNEXPECTED_MESSAGE);
2394     }
2395
2396   // TODO TLS/1.2 Hash length is determined by cipher suite (Page 63)
2397   hash_len = (ssl->minor_ver == SSL_MINOR_VERSION_0) ? 36 : 12;
2398
2399   if (ssl->in_msg[0] != SSL_HS_FINISHED || ssl->in_hslen != 4 + hash_len)
2400     {
2401       SSL_DEBUG_MSG (1, ("bad finished message"));
2402       return (POLARSSL_ERR_SSL_BAD_HS_FINISHED);
2403     }
2404
2405   if (safer_memcmp (ssl->in_msg + 4, buf, hash_len) != 0)
2406     {
2407       SSL_DEBUG_MSG (1, ("bad finished message"));
2408       return (POLARSSL_ERR_SSL_BAD_HS_FINISHED);
2409     }
2410
2411   ssl->verify_data_len = hash_len;
2412   memcpy (ssl->peer_verify_data, buf, hash_len);
2413
2414   if (ssl->handshake->resume != 0)
2415     {
2416       if (ssl->endpoint == SSL_IS_CLIENT)
2417         ssl->state = SSL_CLIENT_CHANGE_CIPHER_SPEC;
2418
2419       if (ssl->endpoint == SSL_IS_SERVER)
2420         ssl->state = SSL_HANDSHAKE_WRAPUP;
2421     }
2422   else
2423     ssl->state++;
2424
2425   SSL_DEBUG_MSG (2, ("<= parse finished"));
2426
2427   return (0);
2428 }
2429
2430
2431 static void
2432 ssl_handshake_params_init (ssl_handshake_params * handshake)
2433 {
2434   memset (handshake, 0, sizeof (ssl_handshake_params));
2435
2436   sha256_init (&handshake->fin_sha256);
2437   sha256_starts (&handshake->fin_sha256, 0);
2438   sha512_init (&handshake->fin_sha512);
2439   sha512_starts (&handshake->fin_sha512, 1);
2440
2441   handshake->update_checksum = ssl_update_checksum_start;
2442   handshake->sig_alg = SSL_HASH_SHA1;
2443
2444   dhm_init (&handshake->dhm_ctx);
2445   ecdh_init (&handshake->ecdh_ctx);
2446 }
2447
2448
2449 static void
2450 ssl_transform_init (ssl_transform * transform)
2451 {
2452   memset (transform, 0, sizeof (ssl_transform));
2453
2454   cipher_init (&transform->cipher_ctx_enc);
2455   cipher_init (&transform->cipher_ctx_dec);
2456
2457   md_init (&transform->md_ctx_enc);
2458   md_init (&transform->md_ctx_dec);
2459 }
2460
2461
2462 void
2463 ssl_session_init (ssl_session * session)
2464 {
2465   memset (session, 0, sizeof (ssl_session));
2466 }
2467
2468
2469 static int
2470 ssl_handshake_init (ssl_context * ssl)
2471 {
2472   /* Clear old handshake information if present */
2473   if (ssl->transform_negotiate)
2474     ssl_transform_free (ssl->transform_negotiate);
2475   if (ssl->session_negotiate)
2476     ssl_session_free (ssl->session_negotiate);
2477   if (ssl->handshake)
2478     ssl_handshake_free (ssl->handshake);
2479
2480   /*
2481    * Either the pointers are now NULL or cleared properly and can be freed.
2482    * Now allocate missing structures.
2483    */
2484   if (ssl->transform_negotiate == NULL)
2485     {
2486       ssl->transform_negotiate =
2487         (ssl_transform *) polarssl_malloc (sizeof (ssl_transform));
2488     }
2489
2490   if (ssl->session_negotiate == NULL)
2491     {
2492       ssl->session_negotiate =
2493         (ssl_session *) polarssl_malloc (sizeof (ssl_session));
2494     }
2495
2496   if (ssl->handshake == NULL)
2497     {
2498       ssl->handshake = (ssl_handshake_params *)
2499         polarssl_malloc (sizeof (ssl_handshake_params));
2500     }
2501
2502   /* All pointers should exist and can be directly freed without issue */
2503   if (ssl->handshake == NULL ||
2504       ssl->transform_negotiate == NULL || ssl->session_negotiate == NULL)
2505     {
2506       SSL_DEBUG_MSG (1, ("malloc() of ssl sub-contexts failed"));
2507
2508       polarssl_free (ssl->handshake);
2509       polarssl_free (ssl->transform_negotiate);
2510       polarssl_free (ssl->session_negotiate);
2511
2512       ssl->handshake = NULL;
2513       ssl->transform_negotiate = NULL;
2514       ssl->session_negotiate = NULL;
2515
2516       return (POLARSSL_ERR_SSL_MALLOC_FAILED);
2517     }
2518
2519   /* Initialize structures */
2520   ssl_session_init (ssl->session_negotiate);
2521   ssl_transform_init (ssl->transform_negotiate);
2522   ssl_handshake_params_init (ssl->handshake);
2523
2524   ssl->handshake->key_cert = ssl->key_cert;
2525
2526   return (0);
2527 }
2528
2529 /*
2530  * Initialize an SSL context
2531  */
2532 int
2533 ssl_init (ssl_context * ssl)
2534 {
2535   int ret;
2536   int len = SSL_BUFFER_LEN;
2537
2538   memset (ssl, 0, sizeof (ssl_context));
2539
2540   /*
2541    * Sane defaults
2542    */
2543   ssl->min_major_ver = SSL_MIN_MAJOR_VERSION;
2544   ssl->min_minor_ver = SSL_MIN_MINOR_VERSION;
2545   ssl->max_major_ver = SSL_MAX_MAJOR_VERSION;
2546   ssl->max_minor_ver = SSL_MAX_MINOR_VERSION;
2547
2548   ssl_set_ciphersuites (ssl, ssl_list_ciphersuites ());
2549
2550   ssl->renego_max_records = SSL_RENEGO_MAX_RECORDS_DEFAULT;
2551
2552   if ((ret = mpi_read_string (&ssl->dhm_P, 16,
2553                               POLARSSL_DHM_RFC5114_MODP_1024_P)) != 0 ||
2554       (ret = mpi_read_string (&ssl->dhm_G, 16,
2555                               POLARSSL_DHM_RFC5114_MODP_1024_G)) != 0)
2556     {
2557       SSL_DEBUG_RET (1, "mpi_read_string", ret);
2558       return (ret);
2559     }
2560
2561   /*
2562    * Prepare base structures
2563    */
2564   ssl->in_ctr = (unsigned char *) polarssl_malloc (len);
2565   ssl->in_hdr = ssl->in_ctr + 8;
2566   ssl->in_iv = ssl->in_ctr + 13;
2567   ssl->in_msg = ssl->in_ctr + 13;
2568
2569   if (ssl->in_ctr == NULL)
2570     {
2571       SSL_DEBUG_MSG (1, ("malloc(%d bytes) failed", len));
2572       return (POLARSSL_ERR_SSL_MALLOC_FAILED);
2573     }
2574
2575   ssl->out_ctr = (unsigned char *) polarssl_malloc (len);
2576   ssl->out_hdr = ssl->out_ctr + 8;
2577   ssl->out_iv = ssl->out_ctr + 13;
2578   ssl->out_msg = ssl->out_ctr + 13;
2579
2580   if (ssl->out_ctr == NULL)
2581     {
2582       SSL_DEBUG_MSG (1, ("malloc(%d bytes) failed", len));
2583       polarssl_free (ssl->in_ctr);
2584       ssl->in_ctr = NULL;
2585       return (POLARSSL_ERR_SSL_MALLOC_FAILED);
2586     }
2587
2588   memset (ssl->in_ctr, 0, SSL_BUFFER_LEN);
2589   memset (ssl->out_ctr, 0, SSL_BUFFER_LEN);
2590
2591   ssl->ticket_lifetime = SSL_DEFAULT_TICKET_LIFETIME;
2592
2593   ssl->curve_list = ecp_grp_id_list ();
2594
2595   if ((ret = ssl_handshake_init (ssl)) != 0)
2596     return (ret);
2597
2598   return (0);
2599 }
2600
2601 /*
2602  * Reset an initialized and used SSL context for re-use while retaining
2603  * all application-set variables, function pointers and data.
2604  */
2605 int
2606 ssl_session_reset (ssl_context * ssl)
2607 {
2608   int ret;
2609
2610   ssl->state = SSL_HELLO_REQUEST;
2611   ssl->renegotiation = SSL_INITIAL_HANDSHAKE;
2612   ssl->secure_renegotiation = SSL_LEGACY_RENEGOTIATION;
2613
2614   ssl->verify_data_len = 0;
2615   memset (ssl->own_verify_data, 0, 36);
2616   memset (ssl->peer_verify_data, 0, 36);
2617
2618   ssl->in_offt = NULL;
2619
2620   ssl->in_msg = ssl->in_ctr + 13;
2621   ssl->in_msgtype = 0;
2622   ssl->in_msglen = 0;
2623   ssl->in_left = 0;
2624
2625   ssl->in_hslen = 0;
2626   ssl->nb_zero = 0;
2627   ssl->record_read = 0;
2628
2629   ssl->out_msg = ssl->out_ctr + 13;
2630   ssl->out_msgtype = 0;
2631   ssl->out_msglen = 0;
2632   ssl->out_left = 0;
2633
2634   ssl->transform_in = NULL;
2635   ssl->transform_out = NULL;
2636
2637   ssl->renego_records_seen = 0;
2638
2639   memset (ssl->out_ctr, 0, SSL_BUFFER_LEN);
2640   memset (ssl->in_ctr, 0, SSL_BUFFER_LEN);
2641
2642   if (ssl->transform)
2643     {
2644       ssl_transform_free (ssl->transform);
2645       polarssl_free (ssl->transform);
2646       ssl->transform = NULL;
2647     }
2648
2649   if (ssl->session)
2650     {
2651       ssl_session_free (ssl->session);
2652       polarssl_free (ssl->session);
2653       ssl->session = NULL;
2654     }
2655
2656   ssl->alpn_chosen = NULL;
2657
2658   if ((ret = ssl_handshake_init (ssl)) != 0)
2659     return (ret);
2660
2661   return (0);
2662 }
2663
2664
2665 static void
2666 ssl_ticket_keys_free (ssl_ticket_keys * tkeys)
2667 {
2668   aes_free (&tkeys->enc);
2669   aes_free (&tkeys->dec);
2670
2671   polarssl_zeroize (tkeys, sizeof (ssl_ticket_keys));
2672 }
2673
2674
2675 /*
2676  * Allocate and initialize ticket keys
2677  */
2678 static int
2679 ssl_ticket_keys_init (ssl_context * ssl)
2680 {
2681   int ret;
2682   ssl_ticket_keys *tkeys;
2683   unsigned char buf[16];
2684
2685   if (ssl->ticket_keys != NULL)
2686     return (0);
2687
2688   tkeys = (ssl_ticket_keys *) polarssl_malloc (sizeof (ssl_ticket_keys));
2689   if (tkeys == NULL)
2690     return (POLARSSL_ERR_SSL_MALLOC_FAILED);
2691
2692   aes_init (&tkeys->enc);
2693   aes_init (&tkeys->dec);
2694
2695   if ((ret = ssl->f_rng (ssl->p_rng, tkeys->key_name, 16)) != 0)
2696     {
2697       ssl_ticket_keys_free (tkeys);
2698       polarssl_free (tkeys);
2699       return (ret);
2700     }
2701
2702   if ((ret = ssl->f_rng (ssl->p_rng, buf, 16)) != 0 ||
2703       (ret = aes_setkey_enc (&tkeys->enc, buf, 128)) != 0 ||
2704       (ret = aes_setkey_dec (&tkeys->dec, buf, 128)) != 0)
2705     {
2706       ssl_ticket_keys_free (tkeys);
2707       polarssl_free (tkeys);
2708       return (ret);
2709     }
2710
2711   if ((ret = ssl->f_rng (ssl->p_rng, tkeys->mac_key, 16)) != 0)
2712     {
2713       ssl_ticket_keys_free (tkeys);
2714       polarssl_free (tkeys);
2715       return (ret);
2716     }
2717
2718   ssl->ticket_keys = tkeys;
2719
2720   return (0);
2721 }
2722
2723
2724 /*
2725  * SSL set accessors
2726  */
2727 void
2728 ssl_set_endpoint (ssl_context * ssl, int endpoint)
2729 {
2730   ssl->endpoint = endpoint;
2731
2732   if (endpoint == SSL_IS_CLIENT)
2733     ssl->session_tickets = SSL_SESSION_TICKETS_ENABLED;
2734 }
2735
2736
2737 void
2738 ssl_set_authmode (ssl_context * ssl, int authmode)
2739 {
2740   ssl->authmode = authmode;
2741 }
2742
2743 #if defined(POLARSSL_X509_CRT_PARSE_C)
2744 void
2745 ssl_set_verify (ssl_context * ssl,
2746                 int (*f_vrfy) (void *, x509_crt *, int, int *), void *p_vrfy)
2747 {
2748   ssl->f_vrfy = f_vrfy;
2749   ssl->p_vrfy = p_vrfy;
2750 }
2751 #endif /* POLARSSL_X509_CRT_PARSE_C */
2752
2753 void
2754 ssl_set_rng (ssl_context * ssl,
2755              int (*f_rng) (void *, unsigned char *, size_t), void *p_rng)
2756 {
2757   ssl->f_rng = f_rng;
2758   ssl->p_rng = p_rng;
2759 }
2760
2761 void
2762 ssl_set_dbg (ssl_context * ssl,
2763              void (*f_dbg) (void *, int, const char *), void *p_dbg)
2764 {
2765   ssl->f_dbg = f_dbg;
2766   ssl->p_dbg = p_dbg;
2767 }
2768
2769 void
2770 ssl_set_bio (ssl_context * ssl,
2771              int (*f_recv) (void *, unsigned char *, size_t), void *p_recv,
2772              int (*f_send) (void *, const unsigned char *, size_t),
2773              void *p_send)
2774 {
2775   ssl->f_recv = f_recv;
2776   ssl->f_send = f_send;
2777   ssl->p_recv = p_recv;
2778   ssl->p_send = p_send;
2779 }
2780
2781 void
2782 ssl_set_session_cache (ssl_context * ssl,
2783                        int (*f_get_cache) (void *, ssl_session *),
2784                        void *p_get_cache, int (*f_set_cache) (void *,
2785                                                               const
2786                                                               ssl_session *),
2787                        void *p_set_cache)
2788 {
2789   ssl->f_get_cache = f_get_cache;
2790   ssl->p_get_cache = p_get_cache;
2791   ssl->f_set_cache = f_set_cache;
2792   ssl->p_set_cache = p_set_cache;
2793 }
2794
2795 int
2796 ssl_set_session (ssl_context * ssl, const ssl_session * session)
2797 {
2798   int ret;
2799
2800   if (ssl == NULL ||
2801       session == NULL ||
2802       ssl->session_negotiate == NULL || ssl->endpoint != SSL_IS_CLIENT)
2803     {
2804       return (POLARSSL_ERR_SSL_BAD_INPUT_DATA);
2805     }
2806
2807   if ((ret = ssl_session_copy (ssl->session_negotiate, session)) != 0)
2808     return (ret);
2809
2810   ssl->handshake->resume = 1;
2811
2812   return (0);
2813 }
2814
2815 void
2816 ssl_set_ciphersuites (ssl_context * ssl, const int *ciphersuites)
2817 {
2818   ssl->ciphersuite_list[SSL_MINOR_VERSION_0] = ciphersuites;
2819   ssl->ciphersuite_list[SSL_MINOR_VERSION_1] = ciphersuites;
2820   ssl->ciphersuite_list[SSL_MINOR_VERSION_2] = ciphersuites;
2821   ssl->ciphersuite_list[SSL_MINOR_VERSION_3] = ciphersuites;
2822 }
2823
2824 void
2825 ssl_set_ciphersuites_for_version (ssl_context * ssl,
2826                                   const int *ciphersuites,
2827                                   int major, int minor)
2828 {
2829   if (major != SSL_MAJOR_VERSION_3)
2830     return;
2831
2832   if (minor < SSL_MINOR_VERSION_0 || minor > SSL_MINOR_VERSION_3)
2833     return;
2834
2835   ssl->ciphersuite_list[minor] = ciphersuites;
2836 }
2837
2838
2839 /* Add a new (empty) key_cert entry an return a pointer to it */
2840 static ssl_key_cert *
2841 ssl_add_key_cert (ssl_context * ssl)
2842 {
2843   ssl_key_cert *key_cert, *last;
2844
2845   key_cert = (ssl_key_cert *) polarssl_malloc (sizeof (ssl_key_cert));
2846   if (key_cert == NULL)
2847     return (NULL);
2848
2849   memset (key_cert, 0, sizeof (ssl_key_cert));
2850
2851   /* Append the new key_cert to the (possibly empty) current list */
2852   if (ssl->key_cert == NULL)
2853     {
2854       ssl->key_cert = key_cert;
2855       if (ssl->handshake != NULL)
2856         ssl->handshake->key_cert = key_cert;
2857     }
2858   else
2859     {
2860       last = ssl->key_cert;
2861       while (last->next != NULL)
2862         last = last->next;
2863       last->next = key_cert;
2864     }
2865
2866   return (key_cert);
2867 }
2868
2869
2870 void
2871 ssl_set_ca_chain (ssl_context * ssl, x509_crt * ca_chain,
2872                   x509_crl * ca_crl, const char *peer_cn)
2873 {
2874   ssl->ca_chain = ca_chain;
2875   ssl->ca_crl = ca_crl;
2876   ssl->peer_cn = peer_cn;
2877 }
2878
2879
2880 int
2881 ssl_set_own_cert (ssl_context * ssl, x509_crt * own_cert, pk_context * pk_key)
2882 {
2883   ssl_key_cert *key_cert = ssl_add_key_cert (ssl);
2884
2885   if (key_cert == NULL)
2886     return (POLARSSL_ERR_SSL_MALLOC_FAILED);
2887
2888   key_cert->cert = own_cert;
2889   key_cert->key = pk_key;
2890
2891   return (0);
2892 }
2893
2894
2895 int
2896 ssl_set_own_cert_rsa (ssl_context * ssl, x509_crt * own_cert,
2897                       rsa_context * rsa_key)
2898 {
2899   int ret;
2900   ssl_key_cert *key_cert = ssl_add_key_cert (ssl);
2901
2902   if (key_cert == NULL)
2903     return (POLARSSL_ERR_SSL_MALLOC_FAILED);
2904
2905   key_cert->key = (pk_context *) polarssl_malloc (sizeof (pk_context));
2906   if (key_cert->key == NULL)
2907     return (POLARSSL_ERR_SSL_MALLOC_FAILED);
2908
2909   pk_init (key_cert->key);
2910
2911   ret = pk_init_ctx (key_cert->key, pk_info_from_type (POLARSSL_PK_RSA));
2912   if (ret != 0)
2913     return (ret);
2914
2915   if ((ret = rsa_copy (pk_rsa (*key_cert->key), rsa_key)) != 0)
2916     return (ret);
2917
2918   key_cert->cert = own_cert;
2919   key_cert->key_own_alloc = 1;
2920
2921   return (0);
2922 }
2923
2924
2925 int
2926 ssl_set_own_cert_alt (ssl_context * ssl, x509_crt * own_cert,
2927                       void *rsa_key,
2928                       rsa_decrypt_func rsa_decrypt,
2929                       rsa_sign_func rsa_sign, rsa_key_len_func rsa_key_len)
2930 {
2931   int ret;
2932   ssl_key_cert *key_cert = ssl_add_key_cert (ssl);
2933
2934   if (key_cert == NULL)
2935     return (POLARSSL_ERR_SSL_MALLOC_FAILED);
2936
2937   key_cert->key = (pk_context *) polarssl_malloc (sizeof (pk_context));
2938   if (key_cert->key == NULL)
2939     return (POLARSSL_ERR_SSL_MALLOC_FAILED);
2940
2941   pk_init (key_cert->key);
2942
2943   if ((ret = pk_init_ctx_rsa_alt (key_cert->key, rsa_key,
2944                                   rsa_decrypt, rsa_sign, rsa_key_len)) != 0)
2945     return (ret);
2946
2947   key_cert->cert = own_cert;
2948   key_cert->key_own_alloc = 1;
2949
2950   return (0);
2951 }
2952
2953
2954 int
2955 ssl_set_psk (ssl_context * ssl, const unsigned char *psk, size_t psk_len,
2956              const unsigned char *psk_identity, size_t psk_identity_len)
2957 {
2958   if (psk == NULL || psk_identity == NULL)
2959     return (POLARSSL_ERR_SSL_BAD_INPUT_DATA);
2960
2961   if (psk_len > POLARSSL_PSK_MAX_LEN)
2962     return (POLARSSL_ERR_SSL_BAD_INPUT_DATA);
2963
2964   if (ssl->psk != NULL)
2965     {
2966       polarssl_free (ssl->psk);
2967       polarssl_free (ssl->psk_identity);
2968     }
2969
2970   ssl->psk_len = psk_len;
2971   ssl->psk_identity_len = psk_identity_len;
2972
2973   ssl->psk = (unsigned char *) polarssl_malloc (ssl->psk_len);
2974   ssl->psk_identity = (unsigned char *)
2975     polarssl_malloc (ssl->psk_identity_len);
2976
2977   if (ssl->psk == NULL || ssl->psk_identity == NULL)
2978     return (POLARSSL_ERR_SSL_MALLOC_FAILED);
2979
2980   memcpy (ssl->psk, psk, ssl->psk_len);
2981   memcpy (ssl->psk_identity, psk_identity, ssl->psk_identity_len);
2982
2983   return (0);
2984 }
2985
2986 void
2987 ssl_set_psk_cb (ssl_context * ssl,
2988                 int (*f_psk) (void *, ssl_context *, const unsigned char *,
2989                               size_t), void *p_psk)
2990 {
2991   ssl->f_psk = f_psk;
2992   ssl->p_psk = p_psk;
2993 }
2994
2995
2996 int
2997 ssl_set_dh_param (ssl_context * ssl, const char *dhm_P, const char *dhm_G)
2998 {
2999   int ret;
3000
3001   if ((ret = mpi_read_string (&ssl->dhm_P, 16, dhm_P)) != 0)
3002     {
3003       SSL_DEBUG_RET (1, "mpi_read_string", ret);
3004       return (ret);
3005     }
3006
3007   if ((ret = mpi_read_string (&ssl->dhm_G, 16, dhm_G)) != 0)
3008     {
3009       SSL_DEBUG_RET (1, "mpi_read_string", ret);
3010       return (ret);
3011     }
3012
3013   return (0);
3014 }
3015
3016 int
3017 ssl_set_dh_param_ctx (ssl_context * ssl, dhm_context * dhm_ctx)
3018 {
3019   int ret;
3020
3021   if ((ret = mpi_copy (&ssl->dhm_P, &dhm_ctx->P)) != 0)
3022     {
3023       SSL_DEBUG_RET (1, "mpi_copy", ret);
3024       return (ret);
3025     }
3026
3027   if ((ret = mpi_copy (&ssl->dhm_G, &dhm_ctx->G)) != 0)
3028     {
3029       SSL_DEBUG_RET (1, "mpi_copy", ret);
3030       return (ret);
3031     }
3032
3033   return (0);
3034 }
3035
3036
3037 /*
3038  * Set the allowed elliptic curves
3039  */
3040 void
3041 ssl_set_curves (ssl_context * ssl, const ecp_group_id * curve_list)
3042 {
3043   ssl->curve_list = curve_list;
3044 }
3045
3046
3047 int
3048 ssl_set_hostname (ssl_context * ssl, const char *hostname)
3049 {
3050   if (hostname == NULL)
3051     return (POLARSSL_ERR_SSL_BAD_INPUT_DATA);
3052
3053   ssl->hostname_len = strlen (hostname);
3054
3055   if (ssl->hostname_len + 1 == 0)
3056     return (POLARSSL_ERR_SSL_BAD_INPUT_DATA);
3057
3058   ssl->hostname = (unsigned char *) polarssl_malloc (ssl->hostname_len + 1);
3059
3060   if (ssl->hostname == NULL)
3061     return (POLARSSL_ERR_SSL_MALLOC_FAILED);
3062
3063   memcpy (ssl->hostname, (const unsigned char *) hostname, ssl->hostname_len);
3064
3065   ssl->hostname[ssl->hostname_len] = '\0';
3066
3067   return (0);
3068 }
3069
3070
3071 void
3072 ssl_set_sni (ssl_context * ssl,
3073              int (*f_sni) (void *, ssl_context *,
3074                            const unsigned char *, size_t), void *p_sni)
3075 {
3076   ssl->f_sni = f_sni;
3077   ssl->p_sni = p_sni;
3078 }
3079
3080
3081 int
3082 ssl_set_alpn_protocols (ssl_context * ssl, const char **protos)
3083 {
3084   size_t cur_len, tot_len;
3085   const char **p;
3086
3087   /*
3088    * "Empty strings MUST NOT be included and byte strings MUST NOT be
3089    * truncated". Check lengths now rather than later.
3090    */
3091   tot_len = 0;
3092   for (p = protos; *p != NULL; p++)
3093     {
3094       cur_len = strlen (*p);
3095       tot_len += cur_len;
3096
3097       if (cur_len == 0 || cur_len > 255 || tot_len > 65535)
3098         return (POLARSSL_ERR_SSL_BAD_INPUT_DATA);
3099     }
3100
3101   ssl->alpn_list = protos;
3102
3103   return (0);
3104 }
3105
3106 const char *
3107 ssl_get_alpn_protocol (const ssl_context * ssl)
3108 {
3109   return (ssl->alpn_chosen);
3110 }
3111
3112
3113 void
3114 ssl_set_max_version (ssl_context * ssl, int major, int minor)
3115 {
3116   if (major >= SSL_MIN_MAJOR_VERSION && major <= SSL_MAX_MAJOR_VERSION &&
3117       minor >= SSL_MIN_MINOR_VERSION && minor <= SSL_MAX_MINOR_VERSION)
3118     {
3119       ssl->max_major_ver = major;
3120       ssl->max_minor_ver = minor;
3121     }
3122 }
3123
3124
3125 void
3126 ssl_set_min_version (ssl_context * ssl, int major, int minor)
3127 {
3128   if (major >= SSL_MIN_MAJOR_VERSION && major <= SSL_MAX_MAJOR_VERSION &&
3129       minor >= SSL_MIN_MINOR_VERSION && minor <= SSL_MAX_MINOR_VERSION)
3130     {
3131       ssl->min_major_ver = major;
3132       ssl->min_minor_ver = minor;
3133     }
3134 }
3135
3136
3137 int
3138 ssl_set_max_frag_len (ssl_context * ssl, unsigned char mfl_code)
3139 {
3140   if (mfl_code >= SSL_MAX_FRAG_LEN_INVALID ||
3141       mfl_code_to_length[mfl_code] > SSL_MAX_CONTENT_LEN)
3142     {
3143       return (POLARSSL_ERR_SSL_BAD_INPUT_DATA);
3144     }
3145
3146   ssl->mfl_code = mfl_code;
3147
3148   return (0);
3149 }
3150
3151
3152 int
3153 ssl_set_truncated_hmac (ssl_context * ssl, int truncate)
3154 {
3155   if (ssl->endpoint != SSL_IS_CLIENT)
3156     return (POLARSSL_ERR_SSL_BAD_INPUT_DATA);
3157
3158   ssl->trunc_hmac = truncate;
3159
3160   return 0;
3161 }
3162
3163
3164 void
3165 ssl_set_renegotiation (ssl_context * ssl, int renegotiation)
3166 {
3167   ssl->disable_renegotiation = renegotiation;
3168 }
3169
3170 void
3171 ssl_legacy_renegotiation (ssl_context * ssl, int allow_legacy)
3172 {
3173   ssl->allow_legacy_renegotiation = allow_legacy;
3174 }
3175
3176 void
3177 ssl_set_renegotiation_enforced (ssl_context * ssl, int max_records)
3178 {
3179   ssl->renego_max_records = max_records;
3180 }
3181
3182
3183 int
3184 ssl_set_session_tickets (ssl_context * ssl, int use_tickets)
3185 {
3186   ssl->session_tickets = use_tickets;
3187
3188   if (ssl->endpoint == SSL_IS_CLIENT)
3189     return (0);
3190
3191   if (ssl->f_rng == NULL)
3192     return (POLARSSL_ERR_SSL_BAD_INPUT_DATA);
3193
3194   return (ssl_ticket_keys_init (ssl));
3195 }
3196
3197 void
3198 ssl_set_session_ticket_lifetime (ssl_context * ssl, int lifetime)
3199 {
3200   ssl->ticket_lifetime = lifetime;
3201 }
3202
3203
3204 /*
3205  * SSL get accessors
3206  */
3207 size_t
3208 ssl_get_bytes_avail (const ssl_context * ssl)
3209 {
3210   return (ssl->in_offt == NULL ? 0 : ssl->in_msglen);
3211 }
3212
3213 int
3214 ssl_get_verify_result (const ssl_context * ssl)
3215 {
3216   return (ssl->session->verify_result);
3217 }
3218
3219 const char *
3220 ssl_get_ciphersuite (const ssl_context * ssl)
3221 {
3222   if (ssl == NULL || ssl->session == NULL)
3223     return (NULL);
3224
3225   return ssl_get_ciphersuite_name (ssl->session->ciphersuite);
3226 }
3227
3228 const char *
3229 ssl_get_version (const ssl_context * ssl)
3230 {
3231   switch (ssl->minor_ver)
3232     {
3233     case SSL_MINOR_VERSION_0:
3234       return ("SSLv3.0");
3235
3236     case SSL_MINOR_VERSION_1:
3237       return ("TLSv1.0");
3238
3239     case SSL_MINOR_VERSION_2:
3240       return ("TLSv1.1");
3241
3242     case SSL_MINOR_VERSION_3:
3243       return ("TLSv1.2");
3244
3245     default:
3246       break;
3247     }
3248   return ("unknown");
3249 }
3250
3251
3252 const x509_crt *
3253 ssl_get_peer_cert (const ssl_context * ssl)
3254 {
3255   if (ssl == NULL || ssl->session == NULL)
3256     return (NULL);
3257
3258   return (ssl->session->peer_cert);
3259 }
3260
3261
3262 int
3263 ssl_get_session (const ssl_context * ssl, ssl_session * dst)
3264 {
3265   if (ssl == NULL ||
3266       dst == NULL || ssl->session == NULL || ssl->endpoint != SSL_IS_CLIENT)
3267     {
3268       return (POLARSSL_ERR_SSL_BAD_INPUT_DATA);
3269     }
3270
3271   return (ssl_session_copy (dst, ssl->session));
3272 }
3273
3274
3275 /*
3276  * Perform a single step of the SSL handshake
3277  */
3278 int
3279 ssl_handshake_step (ssl_context * ssl)
3280 {
3281   int ret = POLARSSL_ERR_SSL_FEATURE_UNAVAILABLE;
3282
3283   if (ssl->endpoint == SSL_IS_CLIENT)
3284     ret = ssl_handshake_client_step (ssl);
3285
3286   if (ssl->endpoint == SSL_IS_SERVER)
3287     ret = ssl_handshake_server_step (ssl);
3288
3289   return (ret);
3290 }
3291
3292
3293 /*
3294  * Perform the SSL handshake
3295  */
3296 int
3297 ssl_handshake (ssl_context * ssl)
3298 {
3299   int ret = 0;
3300
3301   SSL_DEBUG_MSG (2, ("=> handshake"));
3302
3303   while (ssl->state != SSL_HANDSHAKE_OVER)
3304     {
3305       ret = ssl_handshake_step (ssl);
3306
3307       if (ret != 0)
3308         break;
3309     }
3310
3311   SSL_DEBUG_MSG (2, ("<= handshake"));
3312
3313   return (ret);
3314 }
3315
3316
3317 /*
3318  * Write HelloRequest to request renegotiation on server
3319  */
3320 static int
3321 ssl_write_hello_request (ssl_context * ssl)
3322 {
3323   int ret;
3324
3325   SSL_DEBUG_MSG (2, ("=> write hello request"));
3326
3327   ssl->out_msglen = 4;
3328   ssl->out_msgtype = SSL_MSG_HANDSHAKE;
3329   ssl->out_msg[0] = SSL_HS_HELLO_REQUEST;
3330
3331   if ((ret = ssl_write_record (ssl)) != 0)
3332     {
3333       SSL_DEBUG_RET (1, "ssl_write_record", ret);
3334       return (ret);
3335     }
3336
3337   ssl->renegotiation = SSL_RENEGOTIATION_PENDING;
3338
3339   SSL_DEBUG_MSG (2, ("<= write hello request"));
3340
3341   return (0);
3342 }
3343
3344
3345 /*
3346  * Actually renegotiate current connection, triggered by either:
3347  * - calling ssl_renegotiate() on client,
3348  * - receiving a HelloRequest on client during ssl_read(),
3349  * - receiving any handshake message on server during ssl_read() after the
3350  *   initial handshake is completed
3351  * If the handshake doesn't complete due to waiting for I/O, it will continue
3352  * during the next calls to ssl_renegotiate() or ssl_read() respectively.
3353  */
3354 static int
3355 ssl_start_renegotiation (ssl_context * ssl)
3356 {
3357   int ret;
3358
3359   SSL_DEBUG_MSG (2, ("=> renegotiate"));
3360
3361   if ((ret = ssl_handshake_init (ssl)) != 0)
3362     return (ret);
3363
3364   ssl->state = SSL_HELLO_REQUEST;
3365   ssl->renegotiation = SSL_RENEGOTIATION;
3366
3367   if ((ret = ssl_handshake (ssl)) != 0)
3368     {
3369       SSL_DEBUG_RET (1, "ssl_handshake", ret);
3370       return (ret);
3371     }
3372
3373   SSL_DEBUG_MSG (2, ("<= renegotiate"));
3374
3375   return (0);
3376 }
3377
3378
3379 /*
3380  * Renegotiate current connection on client,
3381  * or request renegotiation on server
3382  */
3383 int
3384 ssl_renegotiate (ssl_context * ssl)
3385 {
3386   int ret = POLARSSL_ERR_SSL_FEATURE_UNAVAILABLE;
3387
3388   /* On server, just send the request */
3389   if (ssl->endpoint == SSL_IS_SERVER)
3390     {
3391       if (ssl->state != SSL_HANDSHAKE_OVER)
3392         return (POLARSSL_ERR_SSL_BAD_INPUT_DATA);
3393
3394       return (ssl_write_hello_request (ssl));
3395     }
3396
3397   /*
3398    * On client, either start the renegotiation process or,
3399    * if already in progress, continue the handshake
3400    */
3401   if (ssl->renegotiation != SSL_RENEGOTIATION)
3402     {
3403       if (ssl->state != SSL_HANDSHAKE_OVER)
3404         return (POLARSSL_ERR_SSL_BAD_INPUT_DATA);
3405
3406       if ((ret = ssl_start_renegotiation (ssl)) != 0)
3407         {
3408           SSL_DEBUG_RET (1, "ssl_start_renegotiation", ret);
3409           return (ret);
3410         }
3411     }
3412   else
3413     {
3414       if ((ret = ssl_handshake (ssl)) != 0)
3415         {
3416           SSL_DEBUG_RET (1, "ssl_handshake", ret);
3417           return (ret);
3418         }
3419     }
3420
3421   return (ret);
3422 }
3423
3424
3425 /*
3426  * Receive application data decrypted from the SSL layer
3427  */
3428 int
3429 ssl_read (ssl_context * ssl, unsigned char *buf, size_t len)
3430 {
3431   int ret;
3432   size_t n;
3433
3434   SSL_DEBUG_MSG (2, ("=> read"));
3435
3436   if (ssl->state != SSL_HANDSHAKE_OVER)
3437     {
3438       if ((ret = ssl_handshake (ssl)) != 0)
3439         {
3440           SSL_DEBUG_RET (1, "ssl_handshake", ret);
3441           return (ret);
3442         }
3443     }
3444
3445   if (ssl->in_offt == NULL)
3446     {
3447       if ((ret = ssl_read_record (ssl)) != 0)
3448         {
3449           if (ret == POLARSSL_ERR_SSL_CONN_EOF)
3450             return (0);
3451
3452           SSL_DEBUG_RET (1, "ssl_read_record", ret);
3453           return (ret);
3454         }
3455
3456       if (ssl->in_msglen == 0 && ssl->in_msgtype == SSL_MSG_APPLICATION_DATA)
3457         {
3458           /*
3459            * OpenSSL sends empty messages to randomize the IV
3460            */
3461           if ((ret = ssl_read_record (ssl)) != 0)
3462             {
3463               if (ret == POLARSSL_ERR_SSL_CONN_EOF)
3464                 return (0);
3465
3466               SSL_DEBUG_RET (1, "ssl_read_record", ret);
3467               return (ret);
3468             }
3469         }
3470
3471       if (ssl->in_msgtype == SSL_MSG_HANDSHAKE)
3472         {
3473           SSL_DEBUG_MSG (1, ("received handshake message"));
3474
3475           if (ssl->endpoint == SSL_IS_CLIENT &&
3476               (ssl->in_msg[0] != SSL_HS_HELLO_REQUEST || ssl->in_hslen != 4))
3477             {
3478               SSL_DEBUG_MSG (1, ("handshake received (not HelloRequest)"));
3479               return (POLARSSL_ERR_SSL_UNEXPECTED_MESSAGE);
3480             }
3481
3482           if (ssl->disable_renegotiation == SSL_RENEGOTIATION_DISABLED ||
3483               (ssl->secure_renegotiation == SSL_LEGACY_RENEGOTIATION &&
3484                ssl->allow_legacy_renegotiation ==
3485                SSL_LEGACY_NO_RENEGOTIATION))
3486             {
3487               SSL_DEBUG_MSG (3, ("ignoring renegotiation, sending alert"));
3488
3489               if (ssl->minor_ver >= SSL_MINOR_VERSION_1)
3490                 {
3491                   if ((ret = ssl_send_alert_message (ssl,
3492                                                      SSL_ALERT_LEVEL_WARNING,
3493                                                      SSL_ALERT_MSG_NO_RENEGOTIATION))
3494                       != 0)
3495                     {
3496                       return (ret);
3497                     }
3498                 }
3499               else
3500                 {
3501                   SSL_DEBUG_MSG (1, ("should never happen"));
3502                   return (POLARSSL_ERR_SSL_INTERNAL_ERROR);
3503                 }
3504             }
3505           else
3506             {
3507               if ((ret = ssl_start_renegotiation (ssl)) != 0)
3508                 {
3509                   SSL_DEBUG_RET (1, "ssl_start_renegotiation", ret);
3510                   return (ret);
3511                 }
3512
3513               return (POLARSSL_ERR_NET_WANT_READ);
3514             }
3515         }
3516       else if (ssl->renegotiation == SSL_RENEGOTIATION_PENDING)
3517         {
3518           ssl->renego_records_seen++;
3519
3520           if (ssl->renego_max_records >= 0 &&
3521               ssl->renego_records_seen > ssl->renego_max_records)
3522             {
3523               SSL_DEBUG_MSG (1, ("renegotiation requested, "
3524                                  "but not honored by client"));
3525               return (POLARSSL_ERR_SSL_UNEXPECTED_MESSAGE);
3526             }
3527         }
3528       else if (ssl->in_msgtype != SSL_MSG_APPLICATION_DATA)
3529         {
3530           SSL_DEBUG_MSG (1, ("bad application data message"));
3531           return (POLARSSL_ERR_SSL_UNEXPECTED_MESSAGE);
3532         }
3533
3534       ssl->in_offt = ssl->in_msg;
3535     }
3536
3537   n = (len < ssl->in_msglen) ? len : ssl->in_msglen;
3538
3539   memcpy (buf, ssl->in_offt, n);
3540   ssl->in_msglen -= n;
3541
3542   if (ssl->in_msglen == 0)
3543     /* all bytes consumed  */
3544     ssl->in_offt = NULL;
3545   else
3546     /* more data available */
3547     ssl->in_offt += n;
3548
3549   SSL_DEBUG_MSG (2, ("<= read"));
3550
3551   return ((int) n);
3552 }
3553
3554
3555 /*
3556  * Send application data to be encrypted by the SSL layer
3557  */
3558 int
3559 ssl_write (ssl_context * ssl, const unsigned char *buf, size_t len)
3560 {
3561   int ret;
3562   size_t n;
3563   unsigned int max_len = SSL_MAX_CONTENT_LEN;
3564
3565   SSL_DEBUG_MSG (2, ("=> write"));
3566
3567   if (ssl->state != SSL_HANDSHAKE_OVER)
3568     {
3569       if ((ret = ssl_handshake (ssl)) != 0)
3570         {
3571           SSL_DEBUG_RET (1, "ssl_handshake", ret);
3572           return (ret);
3573         }
3574     }
3575
3576   /*
3577    * Assume mfl_code is correct since it was checked when set
3578    */
3579   max_len = mfl_code_to_length[ssl->mfl_code];
3580
3581   /*
3582    * Check if a smaller max length was negotiated
3583    */
3584   if (ssl->session_out != NULL &&
3585       mfl_code_to_length[ssl->session_out->mfl_code] < max_len)
3586     {
3587       max_len = mfl_code_to_length[ssl->session_out->mfl_code];
3588     }
3589
3590   n = (len < max_len) ? len : max_len;
3591
3592   if (ssl->out_left != 0)
3593     {
3594       if ((ret = ssl_flush_output (ssl)) != 0)
3595         {
3596           SSL_DEBUG_RET (1, "ssl_flush_output", ret);
3597           return (ret);
3598         }
3599     }
3600   else
3601     {
3602       ssl->out_msglen = n;
3603       ssl->out_msgtype = SSL_MSG_APPLICATION_DATA;
3604       memcpy (ssl->out_msg, buf, n);
3605
3606       if ((ret = ssl_write_record (ssl)) != 0)
3607         {
3608           SSL_DEBUG_RET (1, "ssl_write_record", ret);
3609           return (ret);
3610         }
3611     }
3612
3613   SSL_DEBUG_MSG (2, ("<= write"));
3614
3615   return ((int) n);
3616 }
3617
3618
3619 /*
3620  * Notify the peer that the connection is being closed
3621  */
3622 int
3623 ssl_close_notify (ssl_context * ssl)
3624 {
3625   int ret;
3626
3627   SSL_DEBUG_MSG (2, ("=> write close notify"));
3628
3629   if ((ret = ssl_flush_output (ssl)) != 0)
3630     {
3631       SSL_DEBUG_RET (1, "ssl_flush_output", ret);
3632       return (ret);
3633     }
3634
3635   if (ssl->state == SSL_HANDSHAKE_OVER)
3636     {
3637       if ((ret = ssl_send_alert_message (ssl,
3638                                          SSL_ALERT_LEVEL_WARNING,
3639                                          SSL_ALERT_MSG_CLOSE_NOTIFY)) != 0)
3640         {
3641           return (ret);
3642         }
3643     }
3644
3645   SSL_DEBUG_MSG (2, ("<= write close notify"));
3646
3647   return (ret);
3648 }
3649
3650
3651 void
3652 ssl_transform_free (ssl_transform * transform)
3653 {
3654   if (transform == NULL)
3655     return;
3656
3657   deflateEnd (&transform->ctx_deflate);
3658   inflateEnd (&transform->ctx_inflate);
3659
3660   cipher_free (&transform->cipher_ctx_enc);
3661   cipher_free (&transform->cipher_ctx_dec);
3662
3663   md_free (&transform->md_ctx_enc);
3664   md_free (&transform->md_ctx_dec);
3665
3666   polarssl_zeroize (transform, sizeof (ssl_transform));
3667 }
3668
3669
3670 static void
3671 ssl_key_cert_free (ssl_key_cert * key_cert)
3672 {
3673   ssl_key_cert *cur = key_cert, *next;
3674
3675   while (cur != NULL)
3676     {
3677       next = cur->next;
3678
3679       if (cur->key_own_alloc)
3680         {
3681           pk_free (cur->key);
3682           polarssl_free (cur->key);
3683         }
3684       polarssl_free (cur);
3685
3686       cur = next;
3687     }
3688 }
3689
3690
3691 void
3692 ssl_handshake_free (ssl_handshake_params * handshake)
3693 {
3694   if (handshake == NULL)
3695     return;
3696
3697   dhm_free (&handshake->dhm_ctx);
3698   ecdh_free (&handshake->ecdh_ctx);
3699
3700   /* explicit void pointer cast for buggy MS compiler */
3701   polarssl_free ((void *) handshake->curves);
3702
3703   /*
3704    * Free only the linked list wrapper, not the keys themselves
3705    * since the belong to the SNI callback
3706    */
3707   if (handshake->sni_key_cert != NULL)
3708     {
3709       ssl_key_cert *cur = handshake->sni_key_cert, *next;
3710
3711       while (cur != NULL)
3712         {
3713           next = cur->next;
3714           polarssl_free (cur);
3715           cur = next;
3716         }
3717     }
3718
3719   polarssl_zeroize (handshake, sizeof (ssl_handshake_params));
3720 }
3721
3722 void
3723 ssl_session_free (ssl_session * session)
3724 {
3725   if (session == NULL)
3726     return;
3727
3728   if (session->peer_cert != NULL)
3729     {
3730       x509_crt_free (session->peer_cert);
3731       polarssl_free (session->peer_cert);
3732     }
3733
3734   polarssl_free (session->ticket);
3735   polarssl_zeroize (session, sizeof (ssl_session));
3736 }
3737
3738
3739 /*
3740  * Free an SSL context
3741  */
3742 void
3743 ssl_free (ssl_context * ssl)
3744 {
3745   if (ssl == NULL)
3746     return;
3747
3748   SSL_DEBUG_MSG (2, ("=> free"));
3749
3750   if (ssl->out_ctr != NULL)
3751     {
3752       polarssl_zeroize (ssl->out_ctr, SSL_BUFFER_LEN);
3753       polarssl_free (ssl->out_ctr);
3754     }
3755
3756   if (ssl->in_ctr != NULL)
3757     {
3758       polarssl_zeroize (ssl->in_ctr, SSL_BUFFER_LEN);
3759       polarssl_free (ssl->in_ctr);
3760     }
3761
3762   if (ssl->compress_buf != NULL)
3763     {
3764       polarssl_zeroize (ssl->compress_buf, SSL_BUFFER_LEN);
3765       polarssl_free (ssl->compress_buf);
3766     }
3767
3768   mpi_free (&ssl->dhm_P);
3769   mpi_free (&ssl->dhm_G);
3770
3771   if (ssl->transform)
3772     {
3773       ssl_transform_free (ssl->transform);
3774       polarssl_free (ssl->transform);
3775     }
3776
3777   if (ssl->handshake)
3778     {