Fix 64-bit arithmetic in sha2 variants.

SHA-256 has a 8 byte length field, and was being calculated
correctly but clumsily. SHA-384 and SHA-512 have 16 byte length
fields, which were not being calculated properly anyways for
inputs over half a GB (because 32-N bit shifts were used where
64-N bit shifts were required). Simplify by limiting the length
of the input we support to 2^64 bits for any of these, which
ought to be enough for anyone.

- Rename some macros for consistency and reorder.

Change-Id: I492d0d574177239519ee38c9cbbc2a9beab6967c
Reviewed-on: https://pdfium-review.googlesource.com/c/pdfium/+/61933
Reviewed-by: Lei Zhang <thestig@chromium.org>
Commit-Queue: Tom Sepez <tsepez@chromium.org>
diff --git a/core/fdrm/fx_crypt.h b/core/fdrm/fx_crypt.h
index e6c153f..cb3247a 100644
--- a/core/fdrm/fx_crypt.h
+++ b/core/fdrm/fx_crypt.h
@@ -42,7 +42,7 @@
 };
 
 struct CRYPT_sha2_context {
-  uint64_t total[2];
+  uint64_t total_bytes;
   uint64_t state[8];
   uint8_t buffer[128];
 };
diff --git a/core/fdrm/fx_crypt_sha.cpp b/core/fdrm/fx_crypt_sha.cpp
index d3bb9c4..0371685 100644
--- a/core/fdrm/fx_crypt_sha.cpp
+++ b/core/fdrm/fx_crypt_sha.cpp
@@ -6,8 +6,6 @@
 
 #include "core/fdrm/fx_crypt.h"
 
-#define rol(x, y) (((x) << (y)) | (((unsigned int)x) >> (32 - y)))
-
 #define SHA_GET_UINT32(n, b, i)                                         \
   {                                                                     \
     (n) = ((uint32_t)(b)[(i)] << 24) | ((uint32_t)(b)[(i) + 1] << 16) | \
@@ -20,6 +18,24 @@
     (b)[(i) + 2] = (uint8_t)((n) >> 8);  \
     (b)[(i) + 3] = (uint8_t)((n));       \
   }
+#define SHA_GET_UINT64(n, b, i)                                             \
+  {                                                                         \
+    (n) = ((uint64_t)(b)[(i)] << 56) | ((uint64_t)(b)[(i) + 1] << 48) |     \
+          ((uint64_t)(b)[(i) + 2] << 40) | ((uint64_t)(b)[(i) + 3] << 32) | \
+          ((uint64_t)(b)[(i) + 4] << 24) | ((uint64_t)(b)[(i) + 5] << 16) | \
+          ((uint64_t)(b)[(i) + 6] << 8) | ((uint64_t)(b)[(i) + 7]);         \
+  }
+#define SHA_PUT_UINT64(n, b, i)          \
+  {                                      \
+    (b)[(i)] = (uint8_t)((n) >> 56);     \
+    (b)[(i) + 1] = (uint8_t)((n) >> 48); \
+    (b)[(i) + 2] = (uint8_t)((n) >> 40); \
+    (b)[(i) + 3] = (uint8_t)((n) >> 32); \
+    (b)[(i) + 4] = (uint8_t)((n) >> 24); \
+    (b)[(i) + 5] = (uint8_t)((n) >> 16); \
+    (b)[(i) + 6] = (uint8_t)((n) >> 8);  \
+    (b)[(i) + 7] = (uint8_t)((n));       \
+  }
 
 #define SHA384_F0(x, y, z) ((x & y) | (z & (x | y)))
 #define SHA384_F1(x, y, z) (z ^ (x & (y ^ z)))
@@ -42,25 +58,7 @@
 #define SHA384_R(t) \
   (W[t] = SHA384_S1(W[t - 2]) + W[t - 7] + SHA384_S0(W[t - 15]) + W[t - 16])
 
-#define GET_FX_64WORD(n, b, i)                                              \
-  {                                                                         \
-    (n) = ((uint64_t)(b)[(i)] << 56) | ((uint64_t)(b)[(i) + 1] << 48) |     \
-          ((uint64_t)(b)[(i) + 2] << 40) | ((uint64_t)(b)[(i) + 3] << 32) | \
-          ((uint64_t)(b)[(i) + 4] << 24) | ((uint64_t)(b)[(i) + 5] << 16) | \
-          ((uint64_t)(b)[(i) + 6] << 8) | ((uint64_t)(b)[(i) + 7]);         \
-  }
-#define PUT_UINT64(n, b, i)              \
-  {                                      \
-    (b)[(i)] = (uint8_t)((n) >> 56);     \
-    (b)[(i) + 1] = (uint8_t)((n) >> 48); \
-    (b)[(i) + 2] = (uint8_t)((n) >> 40); \
-    (b)[(i) + 3] = (uint8_t)((n) >> 32); \
-    (b)[(i) + 4] = (uint8_t)((n) >> 24); \
-    (b)[(i) + 5] = (uint8_t)((n) >> 16); \
-    (b)[(i) + 6] = (uint8_t)((n) >> 8);  \
-    (b)[(i) + 7] = (uint8_t)((n));       \
-  }
-
+#define rol(x, y) (((x) << (y)) | (((unsigned int)x) >> (32 - y)))
 #define SHR(x, n) ((x & 0xFFFFFFFF) >> n)
 #define ROTR(x, n) (SHR(x, n) | (x << (32 - n)))
 #define S0(x) (ROTR(x, 7) ^ ROTR(x, 18) ^ SHR(x, 3))
@@ -295,22 +293,22 @@
   uint64_t temp1, temp2;
   uint64_t A, B, C, D, E, F, G, H;
   uint64_t W[80];
-  GET_FX_64WORD(W[0], data, 0);
-  GET_FX_64WORD(W[1], data, 8);
-  GET_FX_64WORD(W[2], data, 16);
-  GET_FX_64WORD(W[3], data, 24);
-  GET_FX_64WORD(W[4], data, 32);
-  GET_FX_64WORD(W[5], data, 40);
-  GET_FX_64WORD(W[6], data, 48);
-  GET_FX_64WORD(W[7], data, 56);
-  GET_FX_64WORD(W[8], data, 64);
-  GET_FX_64WORD(W[9], data, 72);
-  GET_FX_64WORD(W[10], data, 80);
-  GET_FX_64WORD(W[11], data, 88);
-  GET_FX_64WORD(W[12], data, 96);
-  GET_FX_64WORD(W[13], data, 104);
-  GET_FX_64WORD(W[14], data, 112);
-  GET_FX_64WORD(W[15], data, 120);
+  SHA_GET_UINT64(W[0], data, 0);
+  SHA_GET_UINT64(W[1], data, 8);
+  SHA_GET_UINT64(W[2], data, 16);
+  SHA_GET_UINT64(W[3], data, 24);
+  SHA_GET_UINT64(W[4], data, 32);
+  SHA_GET_UINT64(W[5], data, 40);
+  SHA_GET_UINT64(W[6], data, 48);
+  SHA_GET_UINT64(W[7], data, 56);
+  SHA_GET_UINT64(W[8], data, 64);
+  SHA_GET_UINT64(W[9], data, 72);
+  SHA_GET_UINT64(W[10], data, 80);
+  SHA_GET_UINT64(W[11], data, 88);
+  SHA_GET_UINT64(W[12], data, 96);
+  SHA_GET_UINT64(W[13], data, 104);
+  SHA_GET_UINT64(W[14], data, 112);
+  SHA_GET_UINT64(W[15], data, 120);
   A = ctx->state[0];
   B = ctx->state[1];
   C = ctx->state[2];
@@ -422,6 +420,7 @@
     digest[i * 4 + 3] = (context->h[i]) & 0xFF;
   }
 }
+
 void CRYPT_SHA1Generate(const uint8_t* data,
                         uint32_t size,
                         uint8_t digest[20]) {
@@ -430,9 +429,9 @@
   CRYPT_SHA1Update(&s, data, size);
   CRYPT_SHA1Finish(&s, digest);
 }
+
 void CRYPT_SHA256Start(CRYPT_sha2_context* context) {
-  context->total[0] = 0;
-  context->total[1] = 0;
+  context->total_bytes = 0;
   context->state[0] = 0x6A09E667;
   context->state[1] = 0xBB67AE85;
   context->state[2] = 0x3C6EF372;
@@ -441,6 +440,7 @@
   context->state[5] = 0x9B05688C;
   context->state[6] = 0x1F83D9AB;
   context->state[7] = 0x5BE0CD19;
+  memset(context->buffer, 0, sizeof(context->buffer));
 }
 
 void CRYPT_SHA256Update(CRYPT_sha2_context* context,
@@ -449,13 +449,9 @@
   if (!size)
     return;
 
-  uint32_t left = context->total[0] & 0x3F;
+  uint32_t left = context->total_bytes & 0x3F;
   uint32_t fill = 64 - left;
-  context->total[0] += size;
-  context->total[0] &= 0xFFFFFFFF;
-  if (context->total[0] < size)
-    context->total[1]++;
-
+  context->total_bytes += size;
   if (left && size >= fill) {
     memcpy(context->buffer + left, data, fill);
     sha256_process(context, context->buffer);
@@ -474,11 +470,9 @@
 
 void CRYPT_SHA256Finish(CRYPT_sha2_context* context, uint8_t digest[32]) {
   uint8_t msglen[8];
-  uint32_t high = (context->total[0] >> 29) | (context->total[1] << 3);
-  uint32_t low = (context->total[0] << 3);
-  SHA_PUT_UINT32(high, msglen, 0);
-  SHA_PUT_UINT32(low, msglen, 4);
-  uint32_t last = context->total[0] & 0x3F;
+  uint64_t total_bits = 8 * context->total_bytes;  // Prior to padding.
+  SHA_PUT_UINT64(total_bits, msglen, 0);
+  uint32_t last = context->total_bytes & 0x3F;
   uint32_t padn = (last < 56) ? (56 - last) : (120 - last);
   CRYPT_SHA256Update(context, sha256_padding, padn);
   CRYPT_SHA256Update(context, msglen, 8);
@@ -502,7 +496,7 @@
 }
 
 void CRYPT_SHA384Start(CRYPT_sha2_context* context) {
-  memset(context, 0, sizeof(CRYPT_sha2_context));
+  context->total_bytes = 0;
   context->state[0] = 0xcbbb9d5dc1059ed8ULL;
   context->state[1] = 0x629a292a367cd507ULL;
   context->state[2] = 0x9159015a3070dd17ULL;
@@ -511,6 +505,7 @@
   context->state[5] = 0x8eb44a8768581511ULL;
   context->state[6] = 0xdb0c2e0d64f98fa7ULL;
   context->state[7] = 0x47b5481dbefa4fa4ULL;
+  memset(context->buffer, 0, sizeof(context->buffer));
 }
 
 void CRYPT_SHA384Update(CRYPT_sha2_context* context,
@@ -519,12 +514,9 @@
   if (!size)
     return;
 
-  uint32_t left = static_cast<uint32_t>(context->total[0]) & 0x7F;
+  uint32_t left = context->total_bytes & 0x7F;
   uint32_t fill = 128 - left;
-  context->total[0] += size;
-  if (context->total[0] < size)
-    context->total[1]++;
-
+  context->total_bytes += size;
   if (left && size >= fill) {
     memcpy(context->buffer + left, data, fill);
     sha384_process(context, context->buffer);
@@ -542,24 +534,20 @@
 }
 
 void CRYPT_SHA384Finish(CRYPT_sha2_context* context, uint8_t digest[48]) {
-  uint32_t last, padn;
   uint8_t msglen[16];
-  memset(msglen, 0, 16);
-  uint64_t high, low;
-  high = (context->total[0] >> 29) | (context->total[1] << 3);
-  low = (context->total[0] << 3);
-  PUT_UINT64(high, msglen, 0);
-  PUT_UINT64(low, msglen, 8);
-  last = (uint32_t)context->total[0] & 0x7F;
-  padn = (last < 112) ? (112 - last) : (240 - last);
+  uint64_t total_bits = 8 * context->total_bytes;  // Prior to padding.
+  SHA_PUT_UINT64(0ULL, msglen, 0);
+  SHA_PUT_UINT64(total_bits, msglen, 8);
+  uint32_t last = context->total_bytes & 0x7F;
+  uint32_t padn = (last < 112) ? (112 - last) : (240 - last);
   CRYPT_SHA384Update(context, sha384_padding, padn);
   CRYPT_SHA384Update(context, msglen, 16);
-  PUT_UINT64(context->state[0], digest, 0);
-  PUT_UINT64(context->state[1], digest, 8);
-  PUT_UINT64(context->state[2], digest, 16);
-  PUT_UINT64(context->state[3], digest, 24);
-  PUT_UINT64(context->state[4], digest, 32);
-  PUT_UINT64(context->state[5], digest, 40);
+  SHA_PUT_UINT64(context->state[0], digest, 0);
+  SHA_PUT_UINT64(context->state[1], digest, 8);
+  SHA_PUT_UINT64(context->state[2], digest, 16);
+  SHA_PUT_UINT64(context->state[3], digest, 24);
+  SHA_PUT_UINT64(context->state[4], digest, 32);
+  SHA_PUT_UINT64(context->state[5], digest, 40);
 }
 
 void CRYPT_SHA384Generate(const uint8_t* data,
@@ -572,7 +560,7 @@
 }
 
 void CRYPT_SHA512Start(CRYPT_sha2_context* context) {
-  memset(context, 0, sizeof(CRYPT_sha2_context));
+  context->total_bytes = 0;
   context->state[0] = 0x6a09e667f3bcc908ULL;
   context->state[1] = 0xbb67ae8584caa73bULL;
   context->state[2] = 0x3c6ef372fe94f82bULL;
@@ -581,6 +569,7 @@
   context->state[5] = 0x9b05688c2b3e6c1fULL;
   context->state[6] = 0x1f83d9abfb41bd6bULL;
   context->state[7] = 0x5be0cd19137e2179ULL;
+  memset(context->buffer, 0, sizeof(context->buffer));
 }
 
 void CRYPT_SHA512Update(CRYPT_sha2_context* context,
@@ -590,26 +579,22 @@
 }
 
 void CRYPT_SHA512Finish(CRYPT_sha2_context* context, uint8_t digest[64]) {
-  uint32_t last, padn;
   uint8_t msglen[16];
-  memset(msglen, 0, 16);
-  uint64_t high, low;
-  high = (context->total[0] >> 29) | (context->total[1] << 3);
-  low = (context->total[0] << 3);
-  PUT_UINT64(high, msglen, 0);
-  PUT_UINT64(low, msglen, 8);
-  last = (uint32_t)context->total[0] & 0x7F;
-  padn = (last < 112) ? (112 - last) : (240 - last);
+  uint64_t total_bits = 8 * context->total_bytes;
+  SHA_PUT_UINT64(0ULL, msglen, 0);
+  SHA_PUT_UINT64(total_bits, msglen, 8);
+  uint32_t last = context->total_bytes & 0x7F;
+  uint32_t padn = (last < 112) ? (112 - last) : (240 - last);
   CRYPT_SHA512Update(context, sha384_padding, padn);
   CRYPT_SHA512Update(context, msglen, 16);
-  PUT_UINT64(context->state[0], digest, 0);
-  PUT_UINT64(context->state[1], digest, 8);
-  PUT_UINT64(context->state[2], digest, 16);
-  PUT_UINT64(context->state[3], digest, 24);
-  PUT_UINT64(context->state[4], digest, 32);
-  PUT_UINT64(context->state[5], digest, 40);
-  PUT_UINT64(context->state[6], digest, 48);
-  PUT_UINT64(context->state[7], digest, 56);
+  SHA_PUT_UINT64(context->state[0], digest, 0);
+  SHA_PUT_UINT64(context->state[1], digest, 8);
+  SHA_PUT_UINT64(context->state[2], digest, 16);
+  SHA_PUT_UINT64(context->state[3], digest, 24);
+  SHA_PUT_UINT64(context->state[4], digest, 32);
+  SHA_PUT_UINT64(context->state[5], digest, 40);
+  SHA_PUT_UINT64(context->state[6], digest, 48);
+  SHA_PUT_UINT64(context->state[7], digest, 56);
 }
 
 void CRYPT_SHA512Generate(const uint8_t* data,