Skip to content

Commit a54558e

Browse files
Enable shm_comm support for arm
1 parent d7bb6f3 commit a54558e

File tree

3 files changed

+120
-4
lines changed

3 files changed

+120
-4
lines changed

csrc/cpu/comm/arm64/shm.h

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
// DeepSpeed Team
5+
6+
#include <arm_neon.h>
7+
#include <stdint.h>
8+
#include <stddef.h>
9+
#include <cmath>
10+
11+
// 128 bits = 16 bytes -> fits 8 fp16/bf16 or 4 fp32 elements.
12+
static int vector_length_in_bytes = 16;
13+
// When widening fp16/bf16 -> fp32, 4 elements fit in one 128-bit register.
14+
// Using 8 would require two 128-bit registers, so limit to 4.
15+
static int full_precision_elements_in_fixed_vector = 4;
16+
17+
18+
static inline float32x4_t cvt_bf16_to_fp32(const uint16x4_t input) {
19+
// Zero-extend 16-bit to 32-bit and shift left by 16 bits
20+
// BF16 has the same exponent/sign bits as FP32, just missing lower mantissa bits
21+
uint32x4_t result_32 = vshll_n_u16(input, 16);
22+
return vreinterpretq_f32_u32(result_32);
23+
24+
}
25+
26+
static inline float32x4_t cvt_fp16_to_fp32(float16x4_t input) {
27+
// Converts 4 FP16 values to 4 FP32 values
28+
return vcvt_f32_f16(input);
29+
}
30+
31+
32+
// While converting fp32 to fp16, before truncating lsb, it should be rounded to nearest even and
33+
// Converts 4 float32 -> 4 bfloat16 with round-to-nearest-even (RNE) and NaN handling
34+
static inline uint16x4_t cvt_fp32_to_bf16(float32x4_t src) {
35+
// Reinterpret float32 bits as uint32
36+
uint32x4_t u32 = vreinterpretq_u32_f32(src);
37+
38+
const uint32x4_t ones = vdupq_n_u32(0x1);
39+
const uint32x4_t vec_bias = vdupq_n_u32(0x7FFF); // one less than half of the dropped bits range
40+
const uint16x4_t nan_bf16 = vdup_n_u16(0xFFFF);
41+
42+
// RNE: lsb = (input >> 16) & 1
43+
uint32x4_t lsb = vandq_u32(vshrq_n_u32(u32, 16), ones);
44+
45+
// rounding_bias = 0x7FFF + lsb, lsb can be 0 or 1.
46+
uint32x4_t bias = vaddq_u32(vec_bias, lsb);
47+
48+
// input += rounding_bias
49+
u32 = vaddq_u32(u32, bias);
50+
51+
// >> 16 to get bfloat16
52+
// vshrq_n_u32 - keeps 32 bit width after shift
53+
// vshrn_n_u32 - keeps 16 bits width after shift
54+
uint16x4_t bf16 = vshrn_n_u32(u32, 16);
55+
56+
// vmvnq_u32 is bitwise NOT
57+
// NaN mask: ~(src == src) -> 1 if NaN
58+
// for normal num, ~(src == src) -> 0
59+
uint32x4_t isnan = vmvnq_u32(vceqq_f32(src, src));
60+
61+
// Select nan_bf16 if isnan (use 16-bit mask)
62+
uint16x4_t mask = vreinterpret_u16_u32(vget_low_u32(isnan));
63+
return vbsl_u16(mask, nan_bf16, bf16);
64+
}
65+
66+
67+
// fp32 and fp16 are IEEE formats.
68+
// converting fp32 to fp16 is handled by vcvt_f16_f32 internally without arbitrarily truncating the lsb but rounds to nearest.
69+
static inline float16x4_t cvt_fp32_to_fp16(float32x4_t input) {
70+
// Converts 4 FP32 values to 4 FP16 values with rounding
71+
return vcvt_f16_f32(input);
72+
}
73+
74+
75+
76+
// Reduce functions down below use vectorized algorithm, the number of bytes processed each
77+
// iteration depends on vector length. 128bit vector ==> 16 bytes. sticking to NEON 128 bit
78+
79+
void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers);
80+
void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers);
81+
void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers);
82+
83+
void parallel_memcpy(void* to, void* from, size_t n_bytes);
84+
85+
#define VLOAD_U8(X) vld1q_u8((uint8_t*)(X))
86+
#define VLOAD_U16(X) vld1_u16((uint16_t*)(X))
87+
#define VLOAD_F16(X) vld1_f16((float16_t*)(X))
88+
#define VLOAD_F32(X) vld1q_f32((float32_t*)(X))
89+
90+
91+
#define VSTORE_U8(A, B) vst1q_u8((uint8_t*)(A), B)
92+
#define VSTORE_U16(A, B) vst1_u16((uint16_t*)(A), B)
93+
#define VSTORE_F16(A, B) vst1_f16((float16_t*)(A), B) // fp16 supported from armv8.2-a+fp16
94+
#define VSTORE_F32(A, B) vst1q_f32((float32_t*)(A), B)
95+
96+
#define VADD_F32(A, B) vaddq_f32(A, B)
97+
#define VADD_F32_2VL(A, B) vaddq_f32(A, B)
98+
99+
#define CVT_BF16_TO_FP32(X) cvt_bf16_to_fp32(X)
100+
#define CVT_FP16_TO_FP32(X) cvt_fp16_to_fp32(X)
101+
#define CVT_FP32_TO_BF16(X) cvt_fp32_to_bf16(X)
102+
#define CVT_FP32_TO_FP16(X) cvt_fp32_to_fp16(X)

csrc/cpu/comm/shm.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
#if defined(__riscv)
1515
#define TARGET_RISCV 1
1616
#include "riscv64/shm.h"
17+
#elif defined(__aarch64__)
18+
#define TARGET_ARM 1
19+
#include "arm64/shm.h"
1720
#else
1821
#include "x86_64/shm.h"
1922
#endif
@@ -154,6 +157,9 @@ void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer,
154157
#if TARGET_RISCV
155158
size_t vl = __riscv_vsetvl_e16m1(num_elements);
156159
vector_length_in_bytes = vl * element_size;
160+
#elif TARGET_ARM
161+
const int vl = full_precision_elements_in_fixed_vector;
162+
vector_length_in_bytes = vl * element_size;
157163
#else
158164
const int vl = vector_length_in_bytes / element_size;
159165
#endif
@@ -214,6 +220,9 @@ void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer,
214220
#if TARGET_RISCV
215221
size_t vl = __riscv_vsetvl_e16m1(num_elements);
216222
vector_length_in_bytes = vl * element_size;
223+
#elif TARGET_ARM
224+
const int vl = full_precision_elements_in_fixed_vector;
225+
vector_length_in_bytes = vl * element_size;
217226
#else
218227
const int vl = vector_length_in_bytes / element_size;
219228
#endif
@@ -274,6 +283,9 @@ void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer,
274283
#if TARGET_RISCV
275284
size_t vl = __riscv_vsetvl_e32m1(num_elements);
276285
vector_length_in_bytes = vl * element_size;
286+
#elif TARGET_ARM
287+
const int vl = full_precision_elements_in_fixed_vector;
288+
vector_length_in_bytes = vl * element_size;
277289
#else
278290
const int vl = vector_length_in_bytes / element_size;
279291
#endif

tests/unit/comm/test_dist.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,10 @@ class TestDistAllReduce(DistributedTest):
120120
world_size = [1]
121121

122122
def test(self):
123-
x = torch.ones(1, 3).to(get_accelerator().device_name()) * (dist.get_rank() + 1)
123+
num_elements = 128
124+
x = torch.ones(1, num_elements).to(get_accelerator().device_name()) * (dist.get_rank() + 1)
124125
sum_of_ranks = (dist.get_world_size() * (dist.get_world_size() + 1)) // 2
125-
result = torch.ones(1, 3).to(get_accelerator().device_name()) * sum_of_ranks
126+
result = torch.ones(1, num_elements).to(get_accelerator().device_name()) * sum_of_ranks
126127
dist.all_reduce(x)
127128
assert torch.all(x == result)
128129

@@ -138,9 +139,10 @@ class TestDistInferenceAllReduce(DistributedTest):
138139
world_size = [1]
139140

140141
def test(self, dtype):
141-
x = torch.ones(1, 3).to(get_accelerator().device_name()) * (dist.get_rank() + 1)
142+
num_elements = 128
143+
x = torch.ones(1, num_elements).to(get_accelerator().device_name()) * (dist.get_rank() + 1)
142144
sum_of_ranks = (dist.get_world_size() * (dist.get_world_size() + 1)) // 2
143-
result = torch.ones(1, 3).to(get_accelerator().device_name()) * sum_of_ranks
145+
result = torch.ones(1, num_elements).to(get_accelerator().device_name()) * sum_of_ranks
144146
result = result.to(dtype)
145147
x = x.to(dtype)
146148
dist.inference_all_reduce(x)

0 commit comments

Comments
 (0)