Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions projects/rccl/src/collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,15 @@ ncclResult_t ncclAllGather_impl(const void* sendbuff, void* recvbuff, size_t sen
ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream) {
NVTX3_FUNC_WITH_PARAMS(AllGather, NcclNvtxParamsAllGather,
NVTX3_PAYLOAD(comm ? comm->commHash : 0, sendcount * ncclTypeSize(datatype), datatype));

// RCCL update slice steps for AllGather if single node
const bool isGfx950 = IsArchMatch(comm->archName, "gfx950");
int chunkSteps = (isGfx950 && comm->rcclUseOneSlice)? 1 : ALLGATHER_CHUNKSTEPS;
int sliceSteps = comm->rcclUseOneSlice
? (isGfx950 ? 1 : ALLGATHER_SLICESTEPS_SINGLE_NODE)
: ALLGATHER_SLICESTEPS;
struct ncclInfo info = { ncclFuncAllGather, "AllGather",
sendbuff, recvbuff, sendcount, datatype, ncclSum, 0, comm, stream, /* Args */
ALLGATHER_CHUNKSTEPS, comm -> rcclUseOneSlice ? ALLGATHER_SLICESTEPS_SINGLE_NODE : ALLGATHER_SLICESTEPS, nullptr };

chunkSteps, sliceSteps, nullptr };
int nRanks, rank;
int in_place = 0;
const void* srcBuf;
Expand Down Expand Up @@ -388,10 +392,16 @@ ncclResult_t ncclReduceScatter_impl(const void* sendbuff, void* recvbuff, size_t
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream) {
NVTX3_FUNC_WITH_PARAMS(ReduceScatter, NcclNvtxParamsReduceScatter,
NVTX3_PAYLOAD(comm ? comm->commHash : 0, recvcount * ncclTypeSize(datatype), op, datatype));
// RCCL update slice steps for ReduceScatter if single node
const bool isGfx950 = IsArchMatch(comm->archName, "gfx950");
int chunkSteps = (isGfx950 && comm->rcclUseOneSlice)? 1 : REDUCESCATTER_CHUNKSTEPS;
int sliceSteps = comm->rcclUseOneSlice
? (isGfx950 ? 1 : REDUCESCATTER_SLICESTEPS_SINGLE_NODE)
: REDUCESCATTER_SLICESTEPS;

struct ncclInfo info = { ncclFuncReduceScatter, "ReduceScatter",
sendbuff, recvbuff, recvcount, datatype, op, 0, comm, stream, /* Args */
REDUCESCATTER_CHUNKSTEPS, comm -> rcclUseOneSlice ? REDUCESCATTER_SLICESTEPS_SINGLE_NODE : REDUCESCATTER_SLICESTEPS, nullptr };
chunkSteps, sliceSteps, nullptr };

int nRanks;
NCCLCHECK(ncclCommCount(comm, &nRanks));
Expand All @@ -409,7 +419,7 @@ ncclResult_t ncclReduceScatter_impl(const void* sendbuff, void* recvbuff, size_t
}

// Reset value forcing direct reduce scatter algorithm
comm->enableDirectReduceScatter = 0;
comm->enableDirectReduceScatter = 0;

if (rcclUseReduceScatterDirect(comm, msgSize)) {
INFO(NCCL_INIT, "RCCL DIRECT REDUCE-SCATTER recvcount=%zu msgSize=%zu rank=%d nRanks=%d nNodes=%d comm=%p stream=%p sendbuff=%p recvbuff=%p",
Expand Down
9 changes: 8 additions & 1 deletion projects/rccl/src/device/all_gather.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,20 @@ namespace {
}
}

#if defined(__gfx942__) || defined(__gfx950__) // Use a single slice per simple primitive for a single node on some GFX9 devices.
#if defined(__gfx942__) // Use a single slice per simple primitive for a single node on some GFX9 devices.
#define rcclAllGatherRunRingSimpleProtoImpl(tid, nthreads, work) \
if(work->rcclUseOneSlice){ \
runRing<T, RedOp, ProtoSimple<ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS_SINGLE_NODE, ALLGATHER_SLICESTEPS_SINGLE_NODE>, false>(tid, nthreads, work); \
} else{ \
runRing<T, RedOp, ProtoSimple<ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLGATHER_SLICESTEPS>, false>(tid, nthreads, work); \
}
#elif defined(__gfx950__)
#define rcclAllGatherRunRingSimpleProtoImpl(tid, nthreads, work) \
if(work->rcclUseOneSlice){ \
runRing<T, RedOp, ProtoSimple<1,1>, false>(tid, nthreads, work); \
} else{ \
runRing<T, RedOp, ProtoSimple<ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLGATHER_SLICESTEPS>, false>(tid, nthreads, work); \
}
#else
#define rcclAllGatherRunRingSimpleProtoImpl(tid, nthreads, work) \
runRing<T, RedOp, ProtoSimple<ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLGATHER_SLICESTEPS>, false>(tid, nthreads, work);
Expand Down
11 changes: 10 additions & 1 deletion projects/rccl/src/device/reduce_scatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ namespace {
}
}

#if defined(__gfx942__) || defined(__gfx950__) // Use a single slice per simple primitive for a single node on some GFX9 devices.
#if defined(__gfx942__) // Use a single slice per simple primitive for a single node on some GFX9 devices.
#define rcclReduceScatterRunRingSimpleProtoImpl(tid, nthreads, work) \
if(work->rcclUseOneSlice){ \
using Proto = ProtoSimple<REDUCESCATTER_CHUNKSTEPS/REDUCESCATTER_SLICESTEPS_SINGLE_NODE, REDUCESCATTER_SLICESTEPS_SINGLE_NODE>; \
Expand All @@ -186,6 +186,15 @@ namespace {
using Proto = ProtoSimple<REDUCESCATTER_CHUNKSTEPS/REDUCESCATTER_SLICESTEPS, REDUCESCATTER_SLICESTEPS>; \
runRing<T, RedOp, Proto>(tid, nthreads, work); \
}
#elif defined(__gfx950__)
#define rcclReduceScatterRunRingSimpleProtoImpl(tid, nthreads, work) \
if(work->rcclUseOneSlice){ \
using Proto = ProtoSimple<1,1>; \
runRing<T, RedOp, Proto>(tid, nthreads, work); \
} else{ \
using Proto = ProtoSimple<REDUCESCATTER_CHUNKSTEPS/REDUCESCATTER_SLICESTEPS, REDUCESCATTER_SLICESTEPS>; \
runRing<T, RedOp, Proto>(tid, nthreads, work); \
}
#else
#define rcclReduceScatterRunRingSimpleProtoImpl(tid, nthreads, work) \
using Proto = ProtoSimple<REDUCESCATTER_CHUNKSTEPS/REDUCESCATTER_SLICESTEPS, REDUCESCATTER_SLICESTEPS>; \
Expand Down
2 changes: 1 addition & 1 deletion projects/rccl/src/enqueue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2176,7 +2176,7 @@ static ncclResult_t topoGetAlgoInfo(
nc /= comm->warpSpeedChannelMultiplier;
// Temporary check as we reduce CU usage for all collectives
// TODO: Remove this condition after optimizing all collectives
if(IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950") && comm->nNodes == 1 && comm->nRanks == 8 && info->func != ncclFuncAllReduce && ncclParamMaxNchannels() < 0) {
if(IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950") && comm->nNodes == 1 && comm->nRanks == 8 && info->func != ncclFuncAllReduce && info->func != ncclFuncAllGather && info->func != ncclFuncReduceScatter && ncclParamMaxNchannels() < 0) {
nc *= 2;
}
}
Expand Down
36 changes: 26 additions & 10 deletions projects/rccl/src/rccl_wrap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ RCCL_PARAM(UnrollFactor, "UNROLL_FACTOR", -1);
RCCL_PARAM(WarpSpeedCuCount, "WARP_SPEED_CU_COUNT", 0);
RCCL_PARAM(WarpSpeedAutoMode, "WARP_SPEED_AUTO", 1);
RCCL_PARAM(WarpSpeedForceEnable, "WARP_SPEED_FORCE_ENABLE", 0);
RCCL_PARAM(WarpSpeedAGThreshold, "WARP_SPEED_AG_THRESHOLD", 134217728); // 128 MB for AllGather
RCCL_PARAM(WarpSpeedRSThreshold, "WARP_SPEED_RS_THRESHOLD", 2147483648); // 2 GB for ReduceScatter
RCCL_PARAM(WarpSpeedARThreshold, "WARP_SPEED_AR_THRESHOLD", 67108864); // 64 MB for AllReduce
#endif
#define RCCL_WARP_SPEED_MIN_BYTES (1ULL << 26) // 64 MB


void rcclRestrictMaxChannels(struct ncclComm* comm, int& nc ) {

Expand Down Expand Up @@ -124,7 +125,7 @@ void rcclUpdateCollectiveProtocol(struct ncclComm* comm, size_t const& nBytes, s
if (!userProtocolInput && IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950") && comm->nNodes == 1 && (info->func == ncclFuncAllGather) && sizePerRank <= 88448) {
// Change LL protocol threshold
info->protocol = NCCL_PROTO_LL;
} else if (!userProtocolInput && IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950") && comm->nNodes == 1 && (info->func == ncclFuncReduceScatter) && sizePerRank <= 1048576) {
} else if (!userProtocolInput && IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950") && comm->nNodes == 1 && (info->func == ncclFuncReduceScatter) && sizePerRank <= 131072) {
// Change LL protocol threshold
info->protocol = NCCL_PROTO_LL;
} else if (!userProtocolInput && IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx942") && comm->nNodes == 1 && (info->func == ncclFuncReduceScatter) && sizePerRank <= 352128) {
Expand Down Expand Up @@ -607,6 +608,25 @@ void rcclSetWarpSpeedSupportAndFinalCuCount(struct ncclComm* comm, struct ncclKe
cuCount = (support == 0)? nChannels : nChannels / warpsPerBlock + ((nChannels % warpsPerBlock) != 0 ? 1 : 0); // each CU can handle warpsPerBlock
}

bool rcclIsAboveWarpSpeedThreshold (struct ncclComm* comm, struct ncclTaskColl* info, size_t nBytes){
// Thresholds are currently set for single node, 8 ranks (full subscription)
if(comm->nRanks != 8 || comm->nNodes != 1){
return true;
}
//single node, full subscription thresholds for AllGather and ReduceScatter
if(info->func == ncclFuncAllReduce && nBytes >= rcclParamWarpSpeedARThreshold()) {
return true;
}
else if(info->func == ncclFuncAllGather && nBytes >= rcclParamWarpSpeedAGThreshold()) {
return true;
}
else if(info->func == ncclFuncReduceScatter && nBytes >= rcclParamWarpSpeedRSThreshold()) {
return true;
}
INFO(NCCL_TUNING, "RCCL WarpSpeed not enabled for %s at %zu bytes as it below the warpSpeed threshold", ncclFuncToString(info->func), nBytes);
return false;
}

bool rcclCanUseWarpSpeedAuto(struct ncclComm* comm, int nNodes) {
return IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950") && (nNodes == 1) && (rcclParamWarpSpeedAutoMode() != 0);
}
Expand All @@ -626,23 +646,19 @@ void rcclSetWarpSpeedAuto(struct ncclComm* comm, struct ncclTaskColl* info, size
if(!unrollFactorSet) comm->unroll = NCCL_UNROLL_2;
info->useWarpSpeed = true;
} else if(rcclCanUseWarpSpeedAuto(comm, comm->nNodes)) { // Auto performance mode
size_t minBytes = 0;
// No early return based on the algorithm at the start of the function
// to allow unroll factor to be reverted to default.
// This can be changed once per-task unroll factor setting is implemented.
if(info->algorithm != NCCL_ALGO_RING) {
return; // If Ring is not selected, assume it is suboptimal and return
}
if(info->func == ncclFuncAllReduce) {
if(info->func == ncclFuncAllReduce || info->func == ncclFuncAllGather || info->func == ncclFuncReduceScatter) {
// allReduce now benefits from unroll factor of 2 in all modes due to changing its slicing strategy
// TODO: Remove unroll update when all collectives are optimized
if(!unrollFactorSet) comm->unroll = NCCL_UNROLL_2;
minBytes = RCCL_WARP_SPEED_MIN_BYTES;
}
// temporarily disabling WarpSpeed for AllGather and ReduceScatter in auto mode
// if(info->func == ncclFuncAllReduce || info->func == ncclFuncAllGather) minBytes = RCCL_WARP_SPEED_MIN_BYTES;
// else if (info->func == ncclFuncReduceScatter) minBytes = RCCL_WARP_SPEED_MIN_BYTES << 2; // ReduceScatter requires higher message size to benefit from WarpSpeed
if(nBytes >= minBytes && minBytes > 0) {
if(rcclIsAboveWarpSpeedThreshold(comm, info, nBytes))
{
info->nWarps = 4;
info->useWarpSpeed = true;
}
Expand Down
Loading