ftu/blst/multi_scalar.c

415 lines
17 KiB
C
Raw Normal View History

2022-09-09 06:47:49 +00:00
/*
* Copyright Supranational LLC
* Licensed under the Apache License, Version 2.0, see LICENSE for details.
* SPDX-License-Identifier: Apache-2.0
*/
#include "fields.h"
#include "point.h"
/*
* Infinite point among inputs would be devastating. Shall we change it?
*/
#define POINTS_TO_AFFINE_IMPL(prefix, ptype, bits, field) \
static void ptype##s_to_affine(ptype##_affine dst[], \
const ptype *const points[], size_t npoints) \
{ \
size_t i; \
vec##bits *acc, ZZ, ZZZ; \
const ptype *point = NULL; \
const size_t stride = sizeof(ptype)==sizeof(POINTonE1) ? 1536 : 768; \
\
while (npoints) { \
const ptype *p, *const *walkback; \
size_t delta = stride<npoints ? stride : npoints; \
\
point = *points ? *points++ : point+1; \
acc = (vec##bits *)dst; \
vec_copy(acc++, point->Z, sizeof(vec##bits)); \
for (i = 1; i < delta; i++, acc++) \
point = *points ? *points++ : point+1, \
mul_##field(acc[0], acc[-1], point->Z); \
\
--acc; reciprocal_##field(acc[0], acc[0]); \
\
walkback = points-1, p = point, --delta, dst += delta; \
for (i = 0; i < delta; i++, acc--, dst--) { \
mul_##field(acc[-1], acc[-1], acc[0]); /* 1/Z */\
sqr_##field(ZZ, acc[-1]); /* 1/Z^2 */\
mul_##field(ZZZ, ZZ, acc[-1]); /* 1/Z^3 */\
mul_##field(acc[-1], p->Z, acc[0]); \
mul_##field(dst->X, p->X, ZZ); /* X = X'/Z^2 */\
mul_##field(dst->Y, p->Y, ZZZ); /* Y = Y'/Z^3 */\
p = (p == *walkback) ? *--walkback : p-1; \
} \
sqr_##field(ZZ, acc[0]); /* 1/Z^2 */\
mul_##field(ZZZ, ZZ, acc[0]); /* 1/Z^3 */\
mul_##field(dst->X, p->X, ZZ); /* X = X'/Z^2 */\
mul_##field(dst->Y, p->Y, ZZZ); /* Y = Y'/Z^3 */\
++delta, dst += delta, npoints -= delta; \
} \
} \
\
void prefix##s_to_affine(ptype##_affine dst[], const ptype *const points[], \
size_t npoints) \
{ ptype##s_to_affine(dst, points, npoints); }
POINTS_TO_AFFINE_IMPL(blst_p1, POINTonE1, 384, fp)
POINTS_TO_AFFINE_IMPL(blst_p2, POINTonE2, 384x, fp2)
/*
* This is two-step multi-scalar multiplication procedure. First, given
* a set of points you pre-compute a table for chosen windowing factor
* [expressed in bits with value between 2 and 14], and then you pass
* this table to the actual multiplication procedure along with scalars.
* Idea is that the pre-computed table will be reused multiple times. In
* which case multiplication runs faster than below Pippenger algorithm
* implementation for up to ~16K points for wbits=8, naturally at the
* expense of multi-megabyte table. One can trade even more memory for
* performance, but each wbits increment doubles the memory requirement,
* so at some point it gets prohibively large... For reference, without
* reusing the table it's faster than Pippenger algorithm for up ~32
* points [with wbits=5]...
*/
#define SCRATCH_SZ(ptype) (sizeof(ptype)==sizeof(POINTonE1) ? 8192 : 4096)
#define PRECOMPUTE_WBITS_IMPL(prefix, ptype, bits, field, one) \
static void ptype##_precompute_row_wbits(ptype row[], size_t wbits, \
const ptype##_affine *point) \
{ \
size_t i, j, n = (size_t)1 << (wbits-1); \
/* row[-1] is implicit infinity */\
vec_copy(&row[0], point, sizeof(*point)); /* row[0]=p*1 */\
vec_copy(&row[0].Z, one, sizeof(row[0].Z)); \
ptype##_double(&row[1], &row[0]); /* row[1]=p*(1+1) */\
for (i = 2, j = 1; i < n; i += 2, j++) \
ptype##_add_affine(&row[i], &row[i-1], point), /* row[2]=p*(2+1) */\
ptype##_double(&row[i+1], &row[j]); /* row[3]=p*(2+2) */\
} /* row[4] ... */\
\
static void ptype##s_to_affine_row_wbits(ptype##_affine dst[], ptype src[], \
size_t wbits, size_t npoints) \
{ \
size_t total = npoints << (wbits-1); \
size_t nwin = (size_t)1 << (wbits-1); \
size_t i, j; \
vec##bits *acc, ZZ, ZZZ; \
\
src += total; \
acc = (vec##bits *)src; \
vec_copy(acc++, one, sizeof(vec##bits)); \
for (i = 0; i < npoints; i++) \
for (j = nwin; --src, --j; acc++) \
mul_##field(acc[0], acc[-1], src->Z); \
\
--acc; reciprocal_##field(acc[0], acc[0]); \
\
for (i = 0; i < npoints; i++) { \
vec_copy(dst++, src++, sizeof(ptype##_affine)); \
for (j = 1; j < nwin; j++, acc--, src++, dst++) { \
mul_##field(acc[-1], acc[-1], acc[0]); /* 1/Z */\
sqr_##field(ZZ, acc[-1]); /* 1/Z^2 */\
mul_##field(ZZZ, ZZ, acc[-1]); /* 1/Z^3 */\
mul_##field(acc[-1], src->Z, acc[0]); \
mul_##field(dst->X, src->X, ZZ); /* X = X'/Z^2 */\
mul_##field(dst->Y, src->Y, ZZZ); /* Y = Y'/Z^3 */\
} \
} \
} \
\
/* flat |points[n]| can be placed at the end of |table[n<<(wbits-1)]| */\
static void ptype##s_precompute_wbits(ptype##_affine table[], size_t wbits, \
const ptype##_affine *const points[], \
size_t npoints) \
{ \
size_t total = npoints << (wbits-1); \
size_t nwin = (size_t)1 << (wbits-1); \
size_t nmin = wbits>9 ? (size_t)1: (size_t)1 << (9-wbits); \
size_t i, top = 0; \
ptype *rows, *row; \
const ptype##_affine *point = NULL; \
size_t stride = ((512*1024)/sizeof(ptype##_affine)) >> wbits; \
if (stride == 0) stride = 1; \
\
while (npoints >= nmin) { \
size_t limit = total - npoints; \
\
if (top + (stride << wbits) > limit) { \
stride = (limit - top) >> wbits; \
if (stride == 0) break; \
} \
rows = row = (ptype *)(&table[top]); \
for (i = 0; i < stride; i++, row += nwin) \
point = *points ? *points++ : point+1, \
ptype##_precompute_row_wbits(row, wbits, point); \
ptype##s_to_affine_row_wbits(&table[top], rows, wbits, stride); \
top += stride << (wbits-1); \
npoints -= stride; \
} \
rows = row = alloca(2*sizeof(ptype##_affine) * npoints * nwin); \
for (i = 0; i < npoints; i++, row += nwin) \
point = *points ? *points++ : point+1, \
ptype##_precompute_row_wbits(row, wbits, point); \
ptype##s_to_affine_row_wbits(&table[top], rows, wbits, npoints); \
} \
\
size_t prefix##s_mult_wbits_precompute_sizeof(size_t wbits, size_t npoints) \
{ return (sizeof(ptype##_affine)*npoints) << (wbits-1); } \
void prefix##s_mult_wbits_precompute(ptype##_affine table[], size_t wbits, \
const ptype##_affine *const points[], \
size_t npoints) \
{ ptype##s_precompute_wbits(table, wbits, points, npoints); }
#define POINTS_MULT_WBITS_IMPL(prefix, ptype, bits, field, one) \
static void ptype##_gather_booth_wbits(ptype *p, const ptype##_affine row[], \
size_t wbits, limb_t booth_idx) \
{ \
bool_t booth_sign = (booth_idx >> wbits) & 1; \
bool_t idx_is_zero; \
static const ptype##_affine infinity = { 0 }; \
\
booth_idx &= ((limb_t)1 << wbits) - 1; \
idx_is_zero = is_zero(booth_idx); \
booth_idx -= 1 ^ idx_is_zero; \
vec_select(p, &infinity, &row[booth_idx], sizeof(row[0]), idx_is_zero); \
ptype##_cneg(p, booth_sign); \
} \
\
static void ptype##s_mult_wbits(ptype *ret, const ptype##_affine table[], \
size_t wbits, size_t npoints, \
const byte *const scalars[], size_t nbits, \
ptype scratch[]) \
{ \
limb_t wmask, wval; \
size_t i, j, z, nbytes, window, nwin = (size_t)1 << (wbits-1); \
const byte *scalar, *const *scalar_s = scalars; \
const ptype##_affine *row = table; \
\
size_t scratch_sz = SCRATCH_SZ(ptype); \
if (scratch == NULL) { \
scratch_sz /= 4; /* limit to 288K */ \
scratch_sz = scratch_sz < npoints ? scratch_sz : npoints; \
scratch = alloca(sizeof(ptype) * scratch_sz); \
} \
\
nbytes = (nbits + 7)/8; /* convert |nbits| to bytes */ \
scalar = *scalar_s++; \
\
/* top excess bits modulo target window size */ \
window = nbits % wbits; /* yes, it may be zero */ \
wmask = ((limb_t)1 << (window + 1)) - 1; \
\
nbits -= window; \
z = is_zero(nbits); \
wval = (get_wval_limb(scalar, nbits - (z^1), wbits + (z^1)) << z) & wmask; \
wval = booth_encode(wval, wbits); \
ptype##_gather_booth_wbits(&scratch[0], row, wbits, wval); \
row += nwin; \
\
i = 1; vec_zero(ret, sizeof(*ret)); \
while (nbits > 0) { \
for (j = i; i < npoints; i++, j++, row += nwin) { \
if (j == scratch_sz) \
ptype##s_accumulate(ret, scratch, j), j = 0; \
scalar = *scalar_s ? *scalar_s++ : scalar+nbytes; \
wval = get_wval_limb(scalar, nbits - 1, window + 1) & wmask; \
wval = booth_encode(wval, wbits); \
ptype##_gather_booth_wbits(&scratch[j], row, wbits, wval); \
} \
ptype##s_accumulate(ret, scratch, j); \
\
for (j = 0; j < wbits; j++) \
ptype##_double(ret, ret); \
\
window = wbits; \
wmask = ((limb_t)1 << (window + 1)) - 1; \
nbits -= window; \
i = 0; row = table; scalar_s = scalars; \
} \
\
for (j = i; i < npoints; i++, j++, row += nwin) { \
if (j == scratch_sz) \
ptype##s_accumulate(ret, scratch, j), j = 0; \
scalar = *scalar_s ? *scalar_s++ : scalar+nbytes; \
wval = (get_wval_limb(scalar, 0, wbits) << 1) & wmask; \
wval = booth_encode(wval, wbits); \
ptype##_gather_booth_wbits(&scratch[j], row, wbits, wval); \
} \
ptype##s_accumulate(ret, scratch, j); \
} \
\
size_t prefix##s_mult_wbits_scratch_sizeof(size_t npoints) \
{ \
const size_t scratch_sz = SCRATCH_SZ(ptype); \
return sizeof(ptype) * (npoints < scratch_sz ? npoints : scratch_sz); \
} \
void prefix##s_mult_wbits(ptype *ret, const ptype##_affine table[], \
size_t wbits, size_t npoints, \
const byte *const scalars[], size_t nbits, \
ptype scratch[]) \
{ ptype##s_mult_wbits(ret, table, wbits, npoints, scalars, nbits, scratch); }
PRECOMPUTE_WBITS_IMPL(blst_p1, POINTonE1, 384, fp, BLS12_381_Rx.p)
POINTS_MULT_WBITS_IMPL(blst_p1, POINTonE1, 384, fp, BLS12_381_Rx.p)
PRECOMPUTE_WBITS_IMPL(blst_p2, POINTonE2, 384x, fp2, BLS12_381_Rx.p2)
POINTS_MULT_WBITS_IMPL(blst_p2, POINTonE2, 384x, fp2, BLS12_381_Rx.p2)
/*
* Pippenger algorithm implementation, fastest option for larger amount
* of points...
*/
static size_t pippenger_window_size(size_t npoints)
{
size_t wbits;
for (wbits=0; npoints>>=1; wbits++) ;
return wbits>12 ? wbits-3 : (wbits>4 ? wbits-2 : (wbits ? 2 : 1));
}
#define DECLARE_PRIVATE_POINTXYZZ(ptype, bits) \
typedef struct { vec##bits X,Y,ZZZ,ZZ; } ptype##xyzz;
#define POINTS_MULT_PIPPENGER_IMPL(prefix, ptype) \
static void ptype##_integrate_buckets(ptype *out, ptype##xyzz buckets[], \
size_t wbits) \
{ \
ptype##xyzz ret[1], acc[1]; \
size_t n = (size_t)1 << wbits; \
\
/* Calculate sum of x[i-1]*i for i=1 through 1<<|wbits|. */\
vec_copy(acc, &buckets[--n], sizeof(acc)); \
vec_copy(ret, &buckets[n], sizeof(ret)); \
vec_zero(&buckets[n], sizeof(buckets[n])); \
while (n--) { \
ptype##xyzz_dadd(acc, acc, &buckets[n]); \
ptype##xyzz_dadd(ret, ret, acc); \
vec_zero(&buckets[n], sizeof(buckets[n])); \
} \
ptype##xyzz_to_Jacobian(out, ret); \
} \
\
static void ptype##_bucket(ptype##xyzz buckets[], limb_t booth_idx, \
size_t wbits, const ptype##_affine *p) \
{ \
bool_t booth_sign = (booth_idx >> wbits) & 1; \
\
booth_idx &= (1<<wbits) - 1; \
if (booth_idx--) \
ptype##xyzz_dadd_affine(&buckets[booth_idx], &buckets[booth_idx], \
p, booth_sign); \
} \
\
static void ptype##_prefetch(const ptype##xyzz buckets[], limb_t booth_idx, \
size_t wbits) \
{ \
booth_idx &= (1<<wbits) - 1; \
if (booth_idx--) \
vec_prefetch(&buckets[booth_idx], sizeof(buckets[booth_idx])); \
} \
\
static void ptype##s_tile_pippenger(ptype *ret, \
const ptype##_affine *const points[], \
size_t npoints, \
const byte *const scalars[], size_t nbits, \
ptype##xyzz buckets[], \
size_t bit0, size_t wbits, size_t cbits) \
{ \
limb_t wmask, wval, wnxt; \
size_t i, z, nbytes; \
const byte *scalar = *scalars++; \
const ptype##_affine *point = *points++; \
\
nbytes = (nbits + 7)/8; /* convert |nbits| to bytes */ \
wmask = ((limb_t)1 << (wbits+1)) - 1; \
z = is_zero(bit0); \
bit0 -= z^1; wbits += z^1; \
wval = (get_wval_limb(scalar, bit0, wbits) << z) & wmask; \
wval = booth_encode(wval, cbits); \
scalar = *scalars ? *scalars++ : scalar+nbytes; \
wnxt = (get_wval_limb(scalar, bit0, wbits) << z) & wmask; \
wnxt = booth_encode(wnxt, cbits); \
npoints--; /* account for prefetch */ \
\
ptype##_bucket(buckets, wval, cbits, point); \
for (i = 1; i < npoints; i++) { \
wval = wnxt; \
scalar = *scalars ? *scalars++ : scalar+nbytes; \
wnxt = (get_wval_limb(scalar, bit0, wbits) << z) & wmask; \
wnxt = booth_encode(wnxt, cbits); \
ptype##_prefetch(buckets, wnxt, cbits); \
point = *points ? *points++ : point+1; \
ptype##_bucket(buckets, wval, cbits, point); \
} \
point = *points ? *points++ : point+1; \
ptype##_bucket(buckets, wnxt, cbits, point); \
ptype##_integrate_buckets(ret, buckets, cbits - 1); \
} \
\
static void ptype##s_mult_pippenger(ptype *ret, \
const ptype##_affine *const points[], \
size_t npoints, \
const byte *const scalars[], size_t nbits, \
ptype##xyzz buckets[], size_t window) \
{ \
size_t i, wbits, cbits, bit0 = nbits; \
ptype tile[1]; \
\
window = window ? window : pippenger_window_size(npoints); \
vec_zero(buckets, sizeof(buckets[0]) << (window-1)); \
vec_zero(ret, sizeof(*ret)); \
\
/* top excess bits modulo target window size */ \
wbits = nbits % window; /* yes, it may be zero */ \
cbits = wbits + 1; \
while (bit0 -= wbits) { \
ptype##s_tile_pippenger(tile, points, npoints, scalars, nbits, \
buckets, bit0, wbits, cbits); \
ptype##_dadd(ret, ret, tile, NULL); \
for (i = 0; i < window; i++) \
ptype##_double(ret, ret); \
cbits = wbits = window; \
} \
ptype##s_tile_pippenger(tile, points, npoints, scalars, nbits, \
buckets, 0, wbits, cbits); \
ptype##_dadd(ret, ret, tile, NULL); \
} \
\
size_t prefix##s_mult_pippenger_scratch_sizeof(size_t npoints) \
{ return sizeof(ptype##xyzz) << (pippenger_window_size(npoints)-1); } \
void prefix##s_tile_pippenger(ptype *ret, \
const ptype##_affine *const points[], \
size_t npoints, \
const byte *const scalars[], size_t nbits, \
ptype##xyzz scratch[], \
size_t bit0, size_t window) \
{ \
size_t wbits, cbits; \
\
if (bit0 + window > nbits) wbits = nbits - bit0, cbits = wbits + 1; \
else wbits = cbits = window; \
ptype##s_tile_pippenger(ret, points, npoints, scalars, nbits, scratch, \
bit0, wbits, cbits); \
} \
void prefix##s_mult_pippenger(ptype *ret, \
const ptype##_affine *const points[], \
size_t npoints, \
const byte *const scalars[], size_t nbits, \
ptype##xyzz scratch[]) \
{ ptype##s_mult_pippenger(ret, points, npoints, scalars, nbits, scratch, 0); }
DECLARE_PRIVATE_POINTXYZZ(POINTonE1, 384)
POINTXYZZ_TO_JACOBIAN_IMPL(POINTonE1, 384, fp)
POINTXYZZ_DADD_IMPL(POINTonE1, 384, fp)
POINTXYZZ_DADD_AFFINE_IMPL(POINTonE1, 384, fp, BLS12_381_Rx.p)
POINTS_MULT_PIPPENGER_IMPL(blst_p1, POINTonE1)
DECLARE_PRIVATE_POINTXYZZ(POINTonE2, 384x)
POINTXYZZ_TO_JACOBIAN_IMPL(POINTonE2, 384x, fp2)
POINTXYZZ_DADD_IMPL(POINTonE2, 384x, fp2)
POINTXYZZ_DADD_AFFINE_IMPL(POINTonE2, 384x, fp2, BLS12_381_Rx.p2)
POINTS_MULT_PIPPENGER_IMPL(blst_p2, POINTonE2)