Fix endianness issues in X25519 code

This commit is contained in:
Sergey Lyubka 2024-05-15 21:51:07 +01:00
parent 3ae1a0f82a
commit c6ff8ab6dc
3 changed files with 48 additions and 12 deletions

View File

@ -14458,6 +14458,7 @@ void mg_uecc_point_mult(mg_uecc_word_t *result, const mg_uecc_word_t *point,
*/
const uint8_t X25519_BASE_POINT[X25519_BYTES] = {9};
#define X25519_WBITS 32
@ -14465,7 +14466,6 @@ const uint8_t X25519_BASE_POINT[X25519_BYTES] = {9};
typedef uint32_t limb_t;
typedef uint64_t dlimb_t;
typedef int64_t sdlimb_t;
#define LIMB(x) (uint32_t)(x##ull), (uint32_t) ((x##ull) >> 32)
#define NLIMBS (256 / X25519_WBITS)
typedef limb_t fe[NLIMBS];
@ -14634,11 +14634,15 @@ static void ladder_part2(fe xs[5], const fe x1) {
static void x25519_core(fe xs[5], const uint8_t scalar[X25519_BYTES],
const uint8_t *x1, int clamp) {
int i;
fe x1_limbs;
limb_t swap = 0;
limb_t *x2 = xs[0], *x3 = xs[2], *z3 = xs[3];
memset(xs, 0, 4 * sizeof(fe));
x2[0] = z3[0] = 1;
memcpy(x3, x1, sizeof(fe));
for (i = 0; i < NLIMBS; i++) {
x3[i] = x1_limbs[i] =
MG_U32(x1[i * 4 + 3], x1[i * 4 + 2], x1[i * 4 + 1], x1[i * 4]);
}
for (i = 255; i >= 0; i--) {
uint8_t bytei = scalar[i / 8];
@ -14656,7 +14660,7 @@ static void x25519_core(fe xs[5], const uint8_t scalar[X25519_BYTES],
swap = doswap;
ladder_part1(xs);
ladder_part2(xs, (const limb_t *) x1);
ladder_part2(xs, (const limb_t *) x1_limbs);
}
condswap(x2, x3, swap);
}
@ -14664,7 +14668,7 @@ static void x25519_core(fe xs[5], const uint8_t scalar[X25519_BYTES],
int mg_tls_x25519(uint8_t out[X25519_BYTES], const uint8_t scalar[X25519_BYTES],
const uint8_t x1[X25519_BYTES], int clamp) {
int i, ret;
fe xs[5];
fe xs[5], out_limbs;
limb_t *x2, *z2, *z3, *prev;
static const struct {
uint8_t a, c, n;
@ -14691,9 +14695,16 @@ int mg_tls_x25519(uint8_t out[X25519_BYTES], const uint8_t scalar[X25519_BYTES],
// Here prev = z3
// x2 /= z2
mul((limb_t *) out, x2, z3, NLIMBS);
ret = (int) canon((limb_t *) out);
mul(out_limbs, x2, z3, NLIMBS);
ret = (int) canon(out_limbs);
if (!clamp) ret = 0;
for (i = 0; i < NLIMBS; i++) {
uint32_t n = out_limbs[i];
out[i * 4] = (uint8_t) (n & 0xff);
out[i * 4 + 1] = (uint8_t) ((n >> 8) & 0xff);
out[i * 4 + 2] = (uint8_t) ((n >> 16) & 0xff);
out[i * 4 + 3] = (uint8_t) ((n >> 24) & 0xff);
}
return ret;
}

View File

@ -5,6 +5,7 @@
* License: MIT License
*/
#include "tls_x25519.h"
#include "util.h"
const uint8_t X25519_BASE_POINT[X25519_BYTES] = {9};
@ -13,7 +14,6 @@ const uint8_t X25519_BASE_POINT[X25519_BYTES] = {9};
typedef uint32_t limb_t;
typedef uint64_t dlimb_t;
typedef int64_t sdlimb_t;
#define LIMB(x) (uint32_t)(x##ull), (uint32_t) ((x##ull) >> 32)
#define NLIMBS (256 / X25519_WBITS)
typedef limb_t fe[NLIMBS];
@ -182,11 +182,15 @@ static void ladder_part2(fe xs[5], const fe x1) {
static void x25519_core(fe xs[5], const uint8_t scalar[X25519_BYTES],
const uint8_t *x1, int clamp) {
int i;
fe x1_limbs;
limb_t swap = 0;
limb_t *x2 = xs[0], *x3 = xs[2], *z3 = xs[3];
memset(xs, 0, 4 * sizeof(fe));
x2[0] = z3[0] = 1;
memcpy(x3, x1, sizeof(fe));
for (i = 0; i < NLIMBS; i++) {
x3[i] = x1_limbs[i] =
MG_U32(x1[i * 4 + 3], x1[i * 4 + 2], x1[i * 4 + 1], x1[i * 4]);
}
for (i = 255; i >= 0; i--) {
uint8_t bytei = scalar[i / 8];
@ -204,7 +208,7 @@ static void x25519_core(fe xs[5], const uint8_t scalar[X25519_BYTES],
swap = doswap;
ladder_part1(xs);
ladder_part2(xs, (const limb_t *) x1);
ladder_part2(xs, (const limb_t *) x1_limbs);
}
condswap(x2, x3, swap);
}
@ -212,7 +216,7 @@ static void x25519_core(fe xs[5], const uint8_t scalar[X25519_BYTES],
int mg_tls_x25519(uint8_t out[X25519_BYTES], const uint8_t scalar[X25519_BYTES],
const uint8_t x1[X25519_BYTES], int clamp) {
int i, ret;
fe xs[5];
fe xs[5], out_limbs;
limb_t *x2, *z2, *z3, *prev;
static const struct {
uint8_t a, c, n;
@ -239,8 +243,15 @@ int mg_tls_x25519(uint8_t out[X25519_BYTES], const uint8_t scalar[X25519_BYTES],
// Here prev = z3
// x2 /= z2
mul((limb_t *) out, x2, z3, NLIMBS);
ret = (int) canon((limb_t *) out);
mul(out_limbs, x2, z3, NLIMBS);
ret = (int) canon(out_limbs);
if (!clamp) ret = 0;
for (i = 0; i < NLIMBS; i++) {
uint32_t n = out_limbs[i];
out[i * 4] = (uint8_t) (n & 0xff);
out[i * 4 + 1] = (uint8_t) ((n >> 8) & 0xff);
out[i * 4 + 2] = (uint8_t) ((n >> 16) & 0xff);
out[i * 4 + 3] = (uint8_t) ((n >> 24) & 0xff);
}
return ret;
}

View File

@ -3349,11 +3349,25 @@ static void test_split(void) {
ASSERT(mg_strcmp(b, mg_str("")) == 0);
}
static void test_crypto(void) {
uint8_t key[X25519_BYTES];
uint8_t buf[X25519_BYTES];
char tmp[100];
size_t i;
for (i = 0; i < sizeof(key); i++) key[i] = (uint8_t) i;
for (i = 0; i < sizeof(buf); i++) buf[i] = 0;
mg_tls_x25519(buf, key, X25519_BASE_POINT, 1);
mg_snprintf(tmp, sizeof(tmp), "%M", mg_print_hex, sizeof(buf), buf);
MG_INFO(("%s", tmp));
ASSERT(mg_strcmp(mg_str("8f40c5adb6"), mg_str_n(tmp, 10)) == 0);
}
int main(void) {
const char *debug_level = getenv("V");
if (debug_level == NULL) debug_level = "3";
mg_log_set(atoi(debug_level));
test_crypto();
test_split();
test_json();
test_queue();