Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 162 additions & 0 deletions src/EmbeddingSpMDMAutovec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,168 @@ static bool ALWAYS_INLINE EmbeddingSpMDMNBit_autovec(
scale *= weight;
bias *= weight;
}
#if HAVE_SVE
constexpr size_t kNumElemsPerIter = 8;

float32x4_t bias_v = vdupq_n_f32(bias);
float32x4_t scale_v = vdupq_n_f32(scale);

float* bufPtr = reinterpret_cast<float*>(buf);

size_t output_columns_mod = block_size % kNumElemsPerIter;

svbool_t lastPredC = svwhilelt_b32_u64(0, output_columns_mod);
svbool_t lastPredD = svwhilelt_b32_u64(4, output_columns_mod);

if (input_bit_rate == 4) {
svuint64_t multiplier = svdup_n_u64((1ULL << 28) + 1);
svbool_t firstTwoPred = svwhilelt_b64_u64(0, 2);

constexpr size_t kNumBytesPerIter = 4;
constexpr size_t kNumElemsPerByte = 2;
size_t input_columns_mod = (output_columns_mod + 1) / kNumElemsPerByte;

svbool_t lastPredA = svwhilelt_b64_u64(0, input_columns_mod);
svbool_t lastPredB = svwhilelt_b64_u64(2, input_columns_mod);

for (size_t iters = block_size / kNumElemsPerIter;
__builtin_expect(iters > 0, 1);
--iters) {
svuint32_t in_v_0 =
svreinterpret_u32_u64(svld1ub_u64(firstTwoPred, input_row));
svuint32_t in_v_1 =
svreinterpret_u32_u64(svld1ub_u64(firstTwoPred, input_row + 2));

input_row += 4;

float32x4_t buf_v_0 = vld1q_f32(bufPtr);
float32x4_t buf_v_1 = vld1q_f32(bufPtr + 4);

in_v_0 =
svreinterpret_u32_u64(svreinterpret_u64_u32(in_v_0) * multiplier);
in_v_1 =
svreinterpret_u32_u64(svreinterpret_u64_u32(in_v_1) * multiplier);

buf_v_0 += bias_v;
buf_v_1 += bias_v;

in_v_0 &= 15;
in_v_1 &= 15;

float32x4_t in_v_0_f = vcvtq_f32_u32(svget_neonq(in_v_0));
float32x4_t in_v_1_f = vcvtq_f32_u32(svget_neonq(in_v_1));

buf_v_0 = vfmaq_f32(buf_v_0, scale_v, in_v_0_f);
buf_v_1 = vfmaq_f32(buf_v_1, scale_v, in_v_1_f);

vst1q_f32(bufPtr, buf_v_0);
vst1q_f32(bufPtr + 4, buf_v_1);

bufPtr += 8;
}

if (input_columns_mod != 0) {
svuint32_t in_v_0 =
svreinterpret_u32_u64(svld1ub_u64(lastPredA, input_row));
svuint32_t in_v_1 =
svreinterpret_u32_u64(svld1ub_u64(lastPredB, input_row + 2));

input_row += input_columns_mod;

float32x4_t buf_v_0 = svget_neonq(svld1_f32(lastPredC, bufPtr));
float32x4_t buf_v_1 = svget_neonq(svld1_f32(lastPredD, bufPtr + 4));

in_v_0 =
svreinterpret_u32_u64(svreinterpret_u64_u32(in_v_0) * multiplier);
in_v_1 =
svreinterpret_u32_u64(svreinterpret_u64_u32(in_v_1) * multiplier);

buf_v_0 += bias_v;
buf_v_1 += bias_v;

in_v_0 &= 15;
in_v_1 &= 15;

float32x4_t in_v_0_f = vcvtq_f32_u32(svget_neonq(in_v_0));
float32x4_t in_v_1_f = vcvtq_f32_u32(svget_neonq(in_v_1));

buf_v_0 = vfmaq_f32(buf_v_0, scale_v, in_v_0_f);
buf_v_1 = vfmaq_f32(buf_v_1, scale_v, in_v_1_f);

svst1_f32(lastPredC, bufPtr, svset_neonq(svundef_f32(), buf_v_0));
svst1_f32(lastPredD, bufPtr + 4, svset_neonq(svundef_f32(), buf_v_1));
}

} else if (input_bit_rate == 2) {
svuint32_t shift = svindex_u32(0, 2); // {0, 2, 4, 6};
constexpr size_t kNumBytesPerIter = 2;
constexpr size_t kNumElemsPerByte = 4;

size_t input_columns_mod = (output_columns_mod + 3) / kNumElemsPerByte;

for (size_t iters = block_size / kNumElemsPerIter;
__builtin_expect(iters > 0, 1);
--iters) {
svuint32_t in_v_0 = svreinterpret_u32_u8(svdup_n_u8(input_row[0]));
svuint32_t in_v_1 = svreinterpret_u32_u8(svdup_n_u8(input_row[1]));

input_row += 2;

float32x4_t buf_v_0 = vld1q_f32(bufPtr);
float32x4_t buf_v_1 = vld1q_f32(bufPtr + 4);

in_v_0 = in_v_0 >> shift;
in_v_1 = in_v_1 >> shift;

buf_v_0 += bias_v;
buf_v_1 += bias_v;

in_v_0 &= 3;
in_v_1 &= 3;

float32x4_t in_v_0_f = vcvtq_f32_u32(svget_neonq(in_v_0));
float32x4_t in_v_1_f = vcvtq_f32_u32(svget_neonq(in_v_1));

buf_v_0 = vfmaq_f32(buf_v_0, scale_v, in_v_0_f);
buf_v_1 = vfmaq_f32(buf_v_1, scale_v, in_v_1_f);

vst1q_f32(bufPtr, buf_v_0);
vst1q_f32(bufPtr + 4, buf_v_1);

bufPtr += 8;
}

if (input_columns_mod != 0) {
svuint32_t in_v_0 = svreinterpret_u32_u8(svdup_n_u8(input_row[0]));
svuint32_t in_v_1;
if (input_columns_mod == 2)
in_v_1 = svreinterpret_u32_u8(svdup_n_u8(input_row[1]));

input_row += input_columns_mod;

float32x4_t buf_v_0 = svget_neonq(svld1_f32(lastPredC, bufPtr));
float32x4_t buf_v_1 = svget_neonq(svld1_f32(lastPredD, bufPtr + 4));

in_v_0 = in_v_0 >> shift;
in_v_1 = in_v_1 >> shift;

buf_v_0 += bias_v;
buf_v_1 += bias_v;

in_v_0 &= 3;
in_v_1 &= 3;

float32x4_t in_v_0_f = vcvtq_f32_u32(svget_neonq(in_v_0));
float32x4_t in_v_1_f = vcvtq_f32_u32(svget_neonq(in_v_1));

buf_v_0 = vfmaq_f32(buf_v_0, scale_v, in_v_0_f);
buf_v_1 = vfmaq_f32(buf_v_1, scale_v, in_v_1_f);

svst1_f32(lastPredC, bufPtr, svset_neonq(svundef_f32(), buf_v_0));
svst1_f32(lastPredD, bufPtr + 4, svset_neonq(svundef_f32(), buf_v_1));
}
}
#else
if (input_bit_rate == 4) {
int64_t j = 0;
#ifdef FBGEMM_VECTOR_WIDTH
Expand Down Expand Up @@ -518,6 +679,7 @@ static bool ALWAYS_INLINE EmbeddingSpMDMNBit_autovec(
buf[j + 3] = std::fma(scale, quantized4, buf[j + 3] + bias);
}
}
#endif

const uint8_t* prefetch_addr = input + input_stride * prefetch_idx;
for (int64_t offset = 0; offset < input_stride;
Expand Down
Loading