|
| 1 | +// Copyright (c) Microsoft Corporation. |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +// DeepSpeed Team |
| 5 | + |
| 6 | +#include <arm_neon.h> |
| 7 | +#include <stdint.h> |
| 8 | +#include <stddef.h> |
| 9 | +#include <cmath> |
| 10 | + |
| 11 | +// 128 bits = 16 bytes -> fits 8 fp16/bf16 or 4 fp32 elements. |
| 12 | +static int vector_length_in_bytes = 16; |
| 13 | +// When widening fp16/bf16 -> fp32, 4 elements fit in one 128-bit register. |
| 14 | +// Using 8 would require two 128-bit registers, so limit to 4. |
| 15 | +static int full_precision_elements_in_fixed_vector = 4; |
| 16 | + |
| 17 | + |
| 18 | +static inline float32x4_t cvt_bf16_to_fp32(const uint16x4_t input) { |
| 19 | + // Zero-extend 16-bit to 32-bit and shift left by 16 bits |
| 20 | + // BF16 has the same exponent/sign bits as FP32, just missing lower mantissa bits |
| 21 | + uint32x4_t result_32 = vshll_n_u16(input, 16); |
| 22 | + return vreinterpretq_f32_u32(result_32); |
| 23 | + |
| 24 | +} |
| 25 | + |
| 26 | +static inline float32x4_t cvt_fp16_to_fp32(float16x4_t input) { |
| 27 | + // Converts 4 FP16 values to 4 FP32 values |
| 28 | + return vcvt_f32_f16(input); |
| 29 | +} |
| 30 | + |
| 31 | + |
| 32 | +// While converting fp32 to fp16, before truncating lsb, it should be rounded to nearest even and |
| 33 | +// Converts 4 float32 -> 4 bfloat16 with round-to-nearest-even (RNE) and NaN handling |
| 34 | +static inline uint16x4_t cvt_fp32_to_bf16(float32x4_t src) { |
| 35 | + // Reinterpret float32 bits as uint32 |
| 36 | + uint32x4_t u32 = vreinterpretq_u32_f32(src); |
| 37 | + |
| 38 | + const uint32x4_t ones = vdupq_n_u32(0x1); |
| 39 | + const uint32x4_t vec_bias = vdupq_n_u32(0x7FFF); // one less than half of the dropped bits range |
| 40 | + const uint16x4_t nan_bf16 = vdup_n_u16(0xFFFF); |
| 41 | + |
| 42 | + // RNE: lsb = (input >> 16) & 1 |
| 43 | + uint32x4_t lsb = vandq_u32(vshrq_n_u32(u32, 16), ones); |
| 44 | + |
| 45 | + // rounding_bias = 0x7FFF + lsb, lsb can be 0 or 1. |
| 46 | + uint32x4_t bias = vaddq_u32(vec_bias, lsb); |
| 47 | + |
| 48 | + // input += rounding_bias |
| 49 | + u32 = vaddq_u32(u32, bias); |
| 50 | + |
| 51 | + // >> 16 to get bfloat16 |
| 52 | + // vshrq_n_u32 - keeps 32 bit width after shift |
| 53 | + // vshrn_n_u32 - keeps 16 bits width after shift |
| 54 | + uint16x4_t bf16 = vshrn_n_u32(u32, 16); |
| 55 | + |
| 56 | + // vmvnq_u32 is bitwise NOT |
| 57 | + // NaN mask: ~(src == src) -> 1 if NaN |
| 58 | + // for normal num, ~(src == src) -> 0 |
| 59 | + uint32x4_t isnan = vmvnq_u32(vceqq_f32(src, src)); |
| 60 | + |
| 61 | + // Select nan_bf16 if isnan (use 16-bit mask) |
| 62 | + uint16x4_t mask = vreinterpret_u16_u32(vget_low_u32(isnan)); |
| 63 | + return vbsl_u16(mask, nan_bf16, bf16); |
| 64 | +} |
| 65 | + |
| 66 | + |
| 67 | +// fp32 and fp16 are IEEE formats. |
| 68 | +// converting fp32 to fp16 is handled by vcvt_f16_f32 internally without arbitrarily truncating the lsb but rounds to nearest. |
| 69 | +static inline float16x4_t cvt_fp32_to_fp16(float32x4_t input) { |
| 70 | + // Converts 4 FP32 values to 4 FP16 values with rounding |
| 71 | + return vcvt_f16_f32(input); |
| 72 | +} |
| 73 | + |
| 74 | + |
| 75 | + |
| 76 | +// Reduce functions down below use vectorized algorithm, the number of bytes processed each |
| 77 | +// iteration depends on vector length. 128bit vector ==> 16 bytes. sticking to NEON 128 bit |
| 78 | + |
| 79 | +void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers); |
| 80 | +void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers); |
| 81 | +void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers); |
| 82 | + |
| 83 | +void parallel_memcpy(void* to, void* from, size_t n_bytes); |
| 84 | + |
| 85 | +#define VLOAD_U8(X) vld1q_u8((uint8_t*)(X)) |
| 86 | +#define VLOAD_U16(X) vld1_u16((uint16_t*)(X)) |
| 87 | +#define VLOAD_F16(X) vld1_f16((float16_t*)(X)) |
| 88 | +#define VLOAD_F32(X) vld1q_f32((float32_t*)(X)) |
| 89 | + |
| 90 | + |
| 91 | +#define VSTORE_U8(A, B) vst1q_u8((uint8_t*)(A), B) |
| 92 | +#define VSTORE_U16(A, B) vst1_u16((uint16_t*)(A), B) |
| 93 | +#define VSTORE_F16(A, B) vst1_f16((float16_t*)(A), B) // fp16 supported from armv8.2-a+fp16 |
| 94 | +#define VSTORE_F32(A, B) vst1q_f32((float32_t*)(A), B) |
| 95 | + |
| 96 | +#define VADD_F32(A, B) vaddq_f32(A, B) |
| 97 | +#define VADD_F32_2VL(A, B) vaddq_f32(A, B) |
| 98 | + |
| 99 | +#define CVT_BF16_TO_FP32(X) cvt_bf16_to_fp32(X) |
| 100 | +#define CVT_FP16_TO_FP32(X) cvt_fp16_to_fp32(X) |
| 101 | +#define CVT_FP32_TO_BF16(X) cvt_fp32_to_bf16(X) |
| 102 | +#define CVT_FP32_TO_FP16(X) cvt_fp32_to_fp16(X) |
0 commit comments