Fix missing 64bit carry handling in AES-NI CTR mode
authorJussi Kivilinna <jussi.kivilinna@mbnet.fi>
Fri, 23 Nov 2012 17:22:30 +0000 (19:22 +0200)
committerWerner Koch <wk@gnupg.org>
Mon, 26 Nov 2012 08:21:44 +0000 (09:21 +0100)
* cipher/rijndael.c [USE_AESNI] (do_aesni_ctr, do_aesni_ctr_4): Add
carry handling to 64-bit addition.
(selftest_ctr_128): New function for testing IV handling in bulk CTR
function.
(selftest): Add call to selftest_ctr_128.
--

Carry handling checks if lower 64-bit part of SSE register was overflowed and
if it was, increment upper parts since that point. Also add selftests to verify
correct operation.

Signed-off-by: Jussi Kivilinna <jussi.kivilinna@mbnet.fi>
cipher/rijndael.c

index 34a0f8c..860dcf8 100644 (file)
@@ -1011,16 +1011,33 @@ do_aesni_ctr (const RIJNDAEL_context *ctx,
   static unsigned char be_mask[16] __attribute__ ((aligned (16))) =
     { 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 };
 
-  asm volatile ("movdqa %[ctr], %%xmm0\n\t"     /* xmm0, xmm2 := CTR   */
+  asm volatile ("movdqa (%[ctr]), %%xmm0\n\t"   /* xmm0, xmm2 := CTR   */
                 "movaps %%xmm0, %%xmm2\n\t"
                 "mov    $1, %%esi\n\t"          /* xmm2++ (big-endian) */
                 "movd   %%esi, %%xmm1\n\t"
+
+                "movl   12(%[ctr]), %%esi\n\t"  /* load lower parts of CTR */
+                "bswapl %%esi\n\t"
+                "movl   8(%[ctr]), %%edi\n\t"
+                "bswapl %%edi\n\t"
+
                 "pshufb %[mask], %%xmm2\n\t"
                 "paddq  %%xmm1, %%xmm2\n\t"
+
+                "addl   $1, %%esi\n\t"
+                "adcl   $0, %%edi\n\t"          /* detect 64bit overflow */
+                "jnc    .Lno_carry%=\n\t"
+
+                /* swap upper and lower halfs */
+                "pshufd $0x4e, %%xmm1, %%xmm1\n\t"
+                "paddq   %%xmm1, %%xmm2\n\t"   /* add carry to upper 64bits */
+
+                ".Lno_carry%=:\n\t"
+
                 "pshufb %[mask], %%xmm2\n\t"
-                "movdqa %%xmm2, %[ctr]\n"       /* Update CTR.         */
+                "movdqa %%xmm2, (%[ctr])\n"     /* Update CTR.         */
 
-                "movdqa (%[key]), %%xmm1\n\t"    /* xmm1 := key[0]    */
+                "movdqa (%[key]), %%xmm1\n\t"   /* xmm1 := key[0]    */
                 "pxor   %%xmm1, %%xmm0\n\t"     /* xmm0 ^= key[0]    */
                 "movdqa 0x10(%[key]), %%xmm1\n\t"
                 aesenc_xmm1_xmm0
@@ -1060,12 +1077,13 @@ do_aesni_ctr (const RIJNDAEL_context *ctx,
                 "pxor %%xmm1, %%xmm0\n\t"        /* EncCTR ^= input  */
                 "movdqu %%xmm0, %[dst]"          /* Store EncCTR.    */
 
-                : [ctr] "+m" (*ctr), [dst] "=m" (*b)
+                : [dst] "=m" (*b)
                 : [src] "m" (*a),
+                  [ctr] "r" (ctr),
                   [key] "r" (ctx->keyschenc),
                   [rounds] "g" (ctx->rounds),
                   [mask] "m" (*be_mask)
-                : "%esi", "cc", "memory");
+                : "%esi", "%edi", "cc", "memory");
 #undef aesenc_xmm1_xmm0
 #undef aesenclast_xmm1_xmm0
 }
@@ -1098,10 +1116,16 @@ do_aesni_ctr_4 (const RIJNDAEL_context *ctx,
       xmm5  temp
    */
 
-  asm volatile ("movdqa %[ctr], %%xmm0\n\t"     /* xmm0, xmm2 := CTR   */
+  asm volatile ("movdqa (%[ctr]), %%xmm0\n\t"   /* xmm0, xmm2 := CTR   */
                 "movaps %%xmm0, %%xmm2\n\t"
                 "mov    $1, %%esi\n\t"          /* xmm1 := 1 */
                 "movd   %%esi, %%xmm1\n\t"
+
+                "movl   12(%[ctr]), %%esi\n\t"  /* load lower parts of CTR */
+                "bswapl %%esi\n\t"
+                "movl   8(%[ctr]), %%edi\n\t"
+                "bswapl %%edi\n\t"
+
                 "pshufb %[mask], %%xmm2\n\t"    /* xmm2 := le(xmm2) */
                 "paddq  %%xmm1, %%xmm2\n\t"     /* xmm2++           */
                 "movaps %%xmm2, %%xmm3\n\t"     /* xmm3 := xmm2     */
@@ -1110,11 +1134,39 @@ do_aesni_ctr_4 (const RIJNDAEL_context *ctx,
                 "paddq  %%xmm1, %%xmm4\n\t"     /* xmm4++           */
                 "movaps %%xmm4, %%xmm5\n\t"     /* xmm5 := xmm4     */
                 "paddq  %%xmm1, %%xmm5\n\t"     /* xmm5++           */
+
+                /* swap upper and lower halfs */
+                "pshufd $0x4e, %%xmm1, %%xmm1\n\t"
+
+                "addl   $1, %%esi\n\t"
+                "adcl   $0, %%edi\n\t"          /* detect 64bit overflow */
+                "jc     .Lcarry_xmm2%=\n\t"
+                "addl   $1, %%esi\n\t"
+                "adcl   $0, %%edi\n\t"          /* detect 64bit overflow */
+                "jc     .Lcarry_xmm3%=\n\t"
+                "addl   $1, %%esi\n\t"
+                "adcl   $0, %%edi\n\t"          /* detect 64bit overflow */
+                "jc     .Lcarry_xmm4%=\n\t"
+                "addl   $1, %%esi\n\t"
+                "adcl   $0, %%edi\n\t"          /* detect 64bit overflow */
+                "jc     .Lcarry_xmm5%=\n\t"
+                "jmp    .Lno_carry%=\n\t"
+
+                ".Lcarry_xmm2%=:\n\t"
+                "paddq   %%xmm1, %%xmm2\n\t"
+                ".Lcarry_xmm3%=:\n\t"
+                "paddq   %%xmm1, %%xmm3\n\t"
+                ".Lcarry_xmm4%=:\n\t"
+                "paddq   %%xmm1, %%xmm4\n\t"
+                ".Lcarry_xmm5%=:\n\t"
+                "paddq   %%xmm1, %%xmm5\n\t"
+
+                ".Lno_carry%=:\n\t"
                 "pshufb %[mask], %%xmm2\n\t"    /* xmm2 := be(xmm2) */
                 "pshufb %[mask], %%xmm3\n\t"    /* xmm3 := be(xmm3) */
                 "pshufb %[mask], %%xmm4\n\t"    /* xmm4 := be(xmm4) */
                 "pshufb %[mask], %%xmm5\n\t"    /* xmm5 := be(xmm5) */
-                "movdqa %%xmm5, %[ctr]\n"       /* Update CTR.      */
+                "movdqa %%xmm5, (%[ctr])\n"     /* Update CTR.      */
 
                 "movdqa (%[key]), %%xmm1\n\t"    /* xmm1 := key[0]    */
                 "pxor   %%xmm1, %%xmm0\n\t"     /* xmm0 ^= key[0]    */
@@ -1198,28 +1250,30 @@ do_aesni_ctr_4 (const RIJNDAEL_context *ctx,
                 aesenclast_xmm1_xmm3
                 aesenclast_xmm1_xmm4
 
-                "movdqu %[src], %%xmm1\n\t"      /* Get block 1.      */
+                "movdqu (%[src]), %%xmm1\n\t"    /* Get block 1.      */
                 "pxor %%xmm1, %%xmm0\n\t"        /* EncCTR-1 ^= input */
-                "movdqu %%xmm0, %[dst]\n\t"      /* Store block 1     */
+                "movdqu %%xmm0, (%[dst])\n\t"    /* Store block 1     */
 
-                "movdqu (16)%[src], %%xmm1\n\t"  /* Get block 2.      */
+                "movdqu 16(%[src]), %%xmm1\n\t"  /* Get block 2.      */
                 "pxor %%xmm1, %%xmm2\n\t"        /* EncCTR-2 ^= input */
-                "movdqu %%xmm2, (16)%[dst]\n\t"  /* Store block 2.    */
+                "movdqu %%xmm2, 16(%[dst])\n\t"  /* Store block 2.    */
 
-                "movdqu (32)%[src], %%xmm1\n\t"  /* Get block 3.      */
+                "movdqu 32(%[src]), %%xmm1\n\t"  /* Get block 3.      */
                 "pxor %%xmm1, %%xmm3\n\t"        /* EncCTR-3 ^= input */
-                "movdqu %%xmm3, (32)%[dst]\n\t"  /* Store block 3.    */
+                "movdqu %%xmm3, 32(%[dst])\n\t"  /* Store block 3.    */
 
-                "movdqu (48)%[src], %%xmm1\n\t"  /* Get block 4.      */
+                "movdqu 48(%[src]), %%xmm1\n\t"  /* Get block 4.      */
                 "pxor %%xmm1, %%xmm4\n\t"        /* EncCTR-4 ^= input */
-                "movdqu %%xmm4, (48)%[dst]"      /* Store block 4.   */
+                "movdqu %%xmm4, 48(%[dst])"      /* Store block 4.   */
 
-                : [ctr] "+m" (*ctr), [dst] "=m" (*b)
-                : [src] "m" (*a),
+                :
+                : [ctr] "r" (ctr),
+                  [src] "r" (a),
+                  [dst] "r" (b),
                   [key] "r" (ctx->keyschenc),
                   [rounds] "g" (ctx->rounds),
                   [mask] "m" (*be_mask)
-                : "%esi", "cc", "memory");
+                : "%esi", "%edi", "cc", "memory");
 #undef aesenc_xmm1_xmm0
 #undef aesenc_xmm1_xmm2
 #undef aesenc_xmm1_xmm3
@@ -1970,6 +2024,102 @@ selftest_basic_256 (void)
   return NULL;
 }
 
+
+/* Run the self-tests for AES-CTR-128, tests IV increment of bulk CTR
+   encryption.  Returns NULL on success. */
+static const char*
+selftest_ctr_128 (void)
+{
+  RIJNDAEL_context ctx ATTR_ALIGNED_16;
+  unsigned char plaintext[7*16] ATTR_ALIGNED_16;
+  unsigned char ciphertext[7*16] ATTR_ALIGNED_16;
+  unsigned char plaintext2[7*16] ATTR_ALIGNED_16;
+  unsigned char iv[16] ATTR_ALIGNED_16;
+  unsigned char iv2[16] ATTR_ALIGNED_16;
+  int i, j, diff;
+
+  static const unsigned char key[16] ATTR_ALIGNED_16 = {
+      0x06,0x9A,0x00,0x7F,0xC7,0x6A,0x45,0x9F,
+      0x98,0xBA,0xF9,0x17,0xFE,0xDF,0x95,0x21
+    };
+  static char error_str[128];
+
+  rijndael_setkey (&ctx, key, sizeof (key));
+
+  /* Test single block code path */
+  memset(iv, 0xff, sizeof(iv));
+  for (i = 0; i < 16; i++)
+    plaintext[i] = i;
+
+  /* CTR manually.  */
+  rijndael_encrypt (&ctx, ciphertext, iv);
+  for (i = 0; i < 16; i++)
+    ciphertext[i] ^= plaintext[i];
+  for (i = 16; i > 0; i--)
+    {
+      iv[i-1]++;
+      if (iv[i-1])
+        break;
+    }
+
+  memset(iv2, 0xff, sizeof(iv2));
+  _gcry_aes_ctr_enc (&ctx, iv2, plaintext2, ciphertext, 1);
+
+  if (memcmp(plaintext2, plaintext, 16))
+    return "AES-128-CTR test failed (plaintext mismatch)";
+
+  if (memcmp(iv2, iv, 16))
+    return "AES-128-CTR test failed (IV mismatch)";
+
+  /* Test parallelized code paths */
+  for (diff = 0; diff < 7; diff++) {
+    memset(iv, 0xff, sizeof(iv));
+    iv[15] -= diff;
+
+    for (i = 0; i < sizeof(plaintext); i++)
+      plaintext[i] = i;
+
+    /* Create CTR ciphertext manually.  */
+    for (i = 0; i < sizeof(plaintext); i+=16)
+      {
+        rijndael_encrypt (&ctx, &ciphertext[i], iv);
+        for (j = 0; j < 16; j++)
+          ciphertext[i+j] ^= plaintext[i+j];
+        for (j = 16; j > 0; j--)
+          {
+            iv[j-1]++;
+            if (iv[j-1])
+              break;
+          }
+      }
+
+    /* Decrypt using bulk CTR and compare result.  */
+    memset(iv2, 0xff, sizeof(iv2));
+    iv2[15] -= diff;
+
+    _gcry_aes_ctr_enc (&ctx, iv2, plaintext2, ciphertext,
+                       sizeof(ciphertext) / BLOCKSIZE);
+
+    if (memcmp(plaintext2, plaintext, sizeof(plaintext)))
+      {
+        snprintf(error_str, sizeof(error_str),
+                 "AES-128-CTR test failed (plaintext mismatch, diff: %d)",
+                 diff);
+        return error_str;
+      }
+    if (memcmp(iv2, iv, sizeof(iv)))
+      {
+        snprintf(error_str, sizeof(error_str),
+                 "AES-128-CTR test failed (IV mismatch, diff: %d)",
+                 diff);
+        return error_str;
+      }
+  }
+
+  return NULL;
+}
+
+
 /* Run all the self-tests and return NULL on success.  This function
    is used for the on-the-fly self-tests. */
 static const char *
@@ -1982,6 +2132,9 @@ selftest (void)
        || (r = selftest_basic_256 ()) )
     return r;
 
+  if ( (r = selftest_ctr_128 ()) )
+    return r;
+
   return r;
 }