Skip to content

Commit 1bde971

Browse files
committed
[Reranker Feature] Support Loop with multi-commands
Discussed-in: Merge-Request 25729546 , URL: https://code.alibaba-inc.com/AliNN/AliNNPrivate/codereview/25729546 GitOrigin-RevId: 7dd154bd8cfa99854201b9222d043f4cc0509975 ORIGINAL_AUTHOR=MNNSyncBot <hi@zhaode.wang>
1 parent 7ef8d15 commit 1bde971

File tree

1 file changed

+163
-89
lines changed

1 file changed

+163
-89
lines changed

source/backend/metal/MetalLoop.mm

Lines changed: 163 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,10 @@ static void _setTensorStack(std::vector<Tensor*>& result, const std::vector<Tens
162162
id<MTLComputePipelineState> mPipeline;
163163
std::vector<Tensor*> mTensors;
164164
bool mHasBias = false;
165+
int mCmdIndex = 0;
165166

166167
public:
167-
MetalBatchMatMul(const LoopParam* loop, Backend *bn) : MetalExecution(bn) {
168+
MetalBatchMatMul(const LoopParam* loop, Backend *bn, int index = 0) : MetalExecution(bn) {
168169
mLoop = loop;
169170
auto mtbn = static_cast<MetalBackend *>(bn);
170171
auto context = (__bridge MNNMetalContext *)mtbn->context();
@@ -180,7 +181,7 @@ static void _setTensorStack(std::vector<Tensor*>& result, const std::vector<Tens
180181
std::string([T UTF8String]),
181182
"matmulunit"
182183
};
183-
auto cmd = loop->commands()->GetAs<RegionCommand>(0);
184+
auto cmd = loop->commands()->GetAs<RegionCommand>(index);
184185
mHasBias = cmd->indexes()->size() > 3;
185186
if (mHasBias) {
186187
keys.emplace_back("BIAS");
@@ -200,17 +201,19 @@ static void _setTensorStack(std::vector<Tensor*>& result, const std::vector<Tens
200201
}
201202
pipeline = mtbn->makeComputePipelineWithSourceOption(gMatMulUnitTemplate, "loop_matmul", compileOptions);
202203
mtbn->runtime()->insertPipeline(keys, pipeline);
204+
203205
}
204206
if (nil == pipeline) {
205207
MNN_ERROR("Create batch matmul pipeline error\n");
206208
}
207209
mPipeline = pipeline;
208210
mTensors.resize(mLoop->tensorNumber());
211+
mCmdIndex = index;
209212
}
210213
virtual ~MetalBatchMatMul() = default;
211214
virtual ErrorCode onResize(const std::vector<Tensor *>& inputs, const std::vector<Tensor *>& outputs) override {
212215
_setTensorStack(mTensors, inputs, outputs, mLoop);
213-
auto cmd = mLoop->commands()->GetAs<RegionCommand>(0);
216+
auto cmd = mLoop->commands()->GetAs<RegionCommand>(mCmdIndex);
214217
auto size = cmd->size()->data();
215218
auto AStride = cmd->view()->GetAs<View>(1)->stride()->data();
216219
auto BStride = cmd->view()->GetAs<View>(2)->stride()->data();
@@ -234,7 +237,7 @@ virtual ErrorCode onResize(const std::vector<Tensor *>& inputs, const std::vecto
234237
return NO_ERROR;
235238
}
236239
virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override {
237-
auto cmd = mLoop->commands()->GetAs<RegionCommand>(0);
240+
auto cmd = mLoop->commands()->GetAs<RegionCommand>(mCmdIndex);
238241
auto size = cmd->size()->data();
239242
auto AStride = cmd->view()->GetAs<View>(1)->stride()->data();
240243
auto BStride = cmd->view()->GetAs<View>(2)->stride()->data();
@@ -252,6 +255,8 @@ virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Ten
252255
MetalBackend::setTensor(inputs[0], encoder, cmd->indexes()->size() + i);
253256
}
254257
}
258+
// printf("loop_matmul out dequant BMNK: %d %d %d %d\n", mLoop->loopNumber(), size[0], size[2], size[1]);
259+
255260
[encoder setBuffer:mParam offset:0 atIndex:cmd->indexes()->size() * 2];
256261
[encoder dispatchThreadgroups:MTLSizeMake(UP_DIV(totalSize, 256), 1, 1) threadsPerThreadgroup:MTLSizeMake(256, 1, 1)];
257262
}
@@ -394,23 +399,21 @@ kernel void set_copy(device T *out [[buffer(0)]],
394399
class MetalGather : public MetalExecution {
395400
private:
396401
const LoopParam* mLoop;
397-
bool mNeedInit = false;
398-
std::pair<MTLSize, MTLSize> mInitThreads;
402+
int mCmdIndex = 0;
399403
id<MTLBuffer> mParam;
400404
id<MTLComputePipelineState> mPipeline;
401-
id<MTLComputePipelineState> mInitPipeline;
402-
id<MTLBuffer> mInitParam;
403405
std::vector<Tensor*> mTensors;
404406
public:
405-
MetalGather(const LoopParam* loop, Backend *bn, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) : MetalExecution(bn) {
407+
MetalGather(const LoopParam* loop, Backend *bn, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, int index = 0) : MetalExecution(bn) {
406408
mLoop = loop;
407409
auto mtbn = static_cast<MetalBackend *>(bn);
408410
auto context = (__bridge MNNMetalContext *)mtbn->context();
409411

410412
mParam = [context newDeviceBuffer:sizeof(GatherInfo) access:CPUWriteOnly];
411413
bool useFp16 = mtbn->useFp16InsteadFp32();
412414
mTensors.resize(mLoop->tensorNumber());
413-
auto cmd = mLoop->commands()->GetAs<RegionCommand>(0);
415+
mCmdIndex = index;
416+
auto cmd = mLoop->commands()->GetAs<RegionCommand>(mCmdIndex);
414417
_setTensorStack(mTensors, inputs, outputs, mLoop);
415418
auto dstTensor = mTensors[cmd->indexes()->data()[0]];
416419

@@ -435,40 +438,10 @@ kernel void set_copy(device T *out [[buffer(0)]],
435438
}
436439
mPipeline = pipeline;
437440
}
438-
439-
// scatter need init command pipeline
440-
if(mLoop->initCommand() != nullptr){
441-
mNeedInit = true;
442-
std::string shader = "set_copy";
443-
auto cmd = mLoop->initCommand()->GetAs<RegionCommand>(0);
444-
if (cmd->op() == nullptr){
445-
shader = "set_zero";
446-
} else {
447-
mInitParam = [context newDeviceBuffer:sizeof(InitInfo) access:CPUWriteOnly];
448-
}
449-
std::vector<std::string> keys = {
450-
std::string([T UTF8String]),
451-
"init_region",
452-
shader
453-
};
454-
auto pipeline = mtbn->runtime()->findPipeline(keys);
455-
if (nil == pipeline) {
456-
MTLCompileOptions *compileOptions = [[MTLCompileOptions alloc] init];
457-
compileOptions.preprocessorMacros = @{
458-
@"T" : T,
459-
};
460-
pipeline = mtbn->makeComputePipelineWithSourceOption(gInitRegion, shader.c_str(), compileOptions);
461-
mtbn->runtime()->insertPipeline(keys, pipeline);
462-
}
463-
if (nil == pipeline) {
464-
MNN_ERROR("Create gather init pipeline error\n");
465-
}
466-
mInitPipeline = pipeline;
467-
}
468441
}
469442
virtual ~MetalGather() = default;
470443
virtual ErrorCode onResize(const std::vector<Tensor *>& inputs, const std::vector<Tensor *>& outputs) override {
471-
auto cmd = mLoop->commands()->GetAs<RegionCommand>(0);
444+
auto cmd = mLoop->commands()->GetAs<RegionCommand>(mCmdIndex);
472445
_setTensorStack(mTensors, inputs, outputs, mLoop);
473446

474447
auto srcStride = cmd->view()->GetAs<View>(1)->stride()->data();
@@ -504,51 +477,11 @@ virtual ErrorCode onResize(const std::vector<Tensor *>& inputs, const std::vecto
504477
param->totalSize[0] = inputSize;
505478
param->totalSize[1] = outputSize;
506479

507-
if(mNeedInit) {
508-
auto initCmd = mLoop->initCommand()->GetAs<RegionCommand>(0);
509-
auto data = reinterpret_cast<InitInfo*>([mInitParam contents]);
510-
511-
auto srcStride = initCmd->view()->GetAs<View>(1)->stride()->data();
512-
auto dstStride = initCmd->view()->GetAs<View>(0)->stride()->data();
513-
auto dataSize = initCmd->size()->data();
514-
for (int i = 0; i < 3; ++i) {
515-
data->srcStride[i] = srcStride[i];
516-
data->dstStride[i] = dstStride[i];
517-
data->size[i] = dataSize[i];
518-
}
519-
520-
auto initDstTensor = mTensors[initCmd->indexes()->data()[0]];
521-
auto initSrcTensor = mTensors[initCmd->indexes()->data()[1]];
522-
auto initInputSize = initSrcTensor->usize() / initSrcTensor->buffer().type.bytes();
523-
auto initOutputSize = initDstTensor->usize() / initDstTensor->buffer().type.bytes();
524-
data->totalSize[0] = initInputSize;
525-
data->totalSize[1] = initOutputSize;
526-
527-
auto backend = static_cast<MetalBackend *>(this->backend());
528-
auto context = (__bridge MNNMetalContext *)backend->context();
529-
mInitThreads = [context computeBestGroupAndLocal:mInitPipeline threads:MTLSizeMake(data->size[0], data->size[1], data->size[2])];
530-
}
531480
return NO_ERROR;
532481
}
533482
virtual void onEncode(const std::vector<Tensor *>& inputs, const std::vector<Tensor *>& outputs,
534483
id<MTLComputeCommandEncoder> encoder) override {
535-
if(mNeedInit) {
536-
auto initCmd = mLoop->initCommand()->GetAs<RegionCommand>(0);
537-
int x = initCmd->size()->data()[0];
538-
int y = initCmd->size()->data()[1];
539-
int z = initCmd->size()->data()[2];
540-
541-
[encoder setComputePipelineState:mInitPipeline];
542-
auto dstTensor = mTensors[initCmd->indexes()->data()[0]];
543-
auto srcTensor = mTensors[initCmd->indexes()->data()[1]];
544-
MetalBackend::setTensor(dstTensor, encoder, 0);
545-
MetalBackend::setTensor(srcTensor, encoder, 1);
546-
[encoder setBuffer:mInitParam offset:0 atIndex:2];
547-
548-
[encoder dispatchThreadgroups:mInitThreads.first threadsPerThreadgroup:mInitThreads.second];
549-
}
550-
551-
auto cmd = mLoop->commands()->GetAs<RegionCommand>(0);
484+
auto cmd = mLoop->commands()->GetAs<RegionCommand>(mCmdIndex);
552485
auto size = cmd->size()->data();
553486
auto srcStride = cmd->view()->GetAs<View>(1)->stride()->data();
554487
auto dstStride = cmd->view()->GetAs<View>(0)->stride()->data();
@@ -631,13 +564,14 @@ kernel void loop_binary(device T1* uOutput [[buffer(0)]], const device T0* uInpu
631564

632565
class MetalBinaryBroadCast : public MetalExecution {
633566
public:
634-
MetalBinaryBroadCast(const LoopParam* loop, Backend *bn, std::vector<Tensor*>&& tensors, NSString* CUSTOM) : MetalExecution(bn) {
567+
MetalBinaryBroadCast(const LoopParam* loop, Backend *bn, const std::vector<Tensor*>& tensors, NSString* CUSTOM, int index = 0) : MetalExecution(bn) {
635568
mLoop = loop;
636569
auto mtbn = static_cast<MetalBackend *>(bn);
637570
auto context = (__bridge MNNMetalContext *)mtbn->context();
638571
mParam = mtbn->getConstBuffer(sizeof(BinaryBroadCastInfo));
639-
mTensors = std::move(tensors);
640-
auto cmd = mLoop->commands()->GetAs<RegionCommand>(0);
572+
mTensors = tensors;
573+
mCmdIndex = index;
574+
auto cmd = mLoop->commands()->GetAs<RegionCommand>(mCmdIndex);
641575
auto dstTensor = mTensors[cmd->indexes()->data()[0]];
642576
auto srcTensor = mTensors[cmd->indexes()->data()[1]];
643577
auto srcTensor1 = mTensors[cmd->indexes()->data()[2]];
@@ -672,7 +606,7 @@ kernel void loop_binary(device T1* uOutput [[buffer(0)]], const device T0* uInpu
672606
}
673607
virtual ErrorCode onResize(const std::vector<Tensor *>& inputs, const std::vector<Tensor *>& outputs) override {
674608
_setTensorStack(mTensors, inputs, outputs, mLoop);
675-
auto cmd = mLoop->commands()->GetAs<RegionCommand>(0);
609+
auto cmd = mLoop->commands()->GetAs<RegionCommand>(mCmdIndex);
676610
auto size = cmd->size()->data();
677611
auto srcStride0 = cmd->view()->GetAs<View>(1)->stride()->data();
678612
auto srcStride1 = cmd->view()->GetAs<View>(2)->stride()->data();
@@ -694,7 +628,7 @@ virtual ErrorCode onResize(const std::vector<Tensor *>& inputs, const std::vecto
694628
virtual void onEncode(const std::vector<Tensor *>& inputs, const std::vector<Tensor *>& outputs,
695629
id<MTLComputeCommandEncoder> encoder) override {
696630

697-
auto cmd = mLoop->commands()->GetAs<RegionCommand>(0);
631+
auto cmd = mLoop->commands()->GetAs<RegionCommand>(mCmdIndex);
698632
auto dstTensor = mTensors[cmd->indexes()->data()[0]];
699633
auto srcTensor = mTensors[cmd->indexes()->data()[1]];
700634
auto srcTensor1 = mTensors[cmd->indexes()->data()[2]];
@@ -711,6 +645,139 @@ virtual void onEncode(const std::vector<Tensor *>& inputs, const std::vector<Ten
711645
id<MTLBuffer> mParam;
712646
std::vector<Tensor*> mTensors;
713647
int mTotalSize;
648+
int mCmdIndex = 0;
649+
};
650+
651+
class MetalLoop : public MetalExecution {
652+
public:
653+
MetalLoop(const LoopParam* loop, Backend *bn, const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) : MetalExecution(bn) {
654+
mLoop = loop;
655+
auto mtbn = static_cast<MetalBackend *>(bn);
656+
auto context = (__bridge MNNMetalContext *)mtbn->context();
657+
mTensors.resize(mLoop->tensorNumber());
658+
_setTensorStack(mTensors, inputs, outputs, mLoop);
659+
660+
// Init
661+
if(mLoop->initCommand() != nullptr) {
662+
mNeedInit = true;
663+
std::string shader = "set_copy";
664+
auto dstTensor = mTensors[mLoop->initCommand()->GetAs<RegionCommand>(0)->indexes()->data()[0]];
665+
NSString* T = MetalCast::getScalarType(dstTensor->getType(), mtbn->useFp16InsteadFp32());
666+
667+
auto cmd = mLoop->initCommand()->GetAs<RegionCommand>(0);
668+
if (cmd->op() == nullptr){
669+
shader = "set_zero";
670+
} else {
671+
mInitParam = [context newDeviceBuffer:sizeof(InitInfo) access:CPUWriteOnly];
672+
}
673+
std::vector<std::string> keys = {
674+
std::string([T UTF8String]),
675+
"init_region",
676+
shader
677+
};
678+
auto pipeline = mtbn->runtime()->findPipeline(keys);
679+
if (nil == pipeline) {
680+
MTLCompileOptions *compileOptions = [[MTLCompileOptions alloc] init];
681+
compileOptions.preprocessorMacros = @{
682+
@"T" : T,
683+
};
684+
pipeline = mtbn->makeComputePipelineWithSourceOption(gInitRegion, shader.c_str(), compileOptions);
685+
mtbn->runtime()->insertPipeline(keys, pipeline);
686+
}
687+
if (nil == pipeline) {
688+
MNN_ERROR("Create gather init pipeline error\n");
689+
}
690+
mInitPipeline = pipeline;
691+
}
692+
693+
bool valid = true;
694+
for (int i=0; i<loop->commands()->size(); ++i) {
695+
auto cmd = loop->commands()->GetAs<RegionCommand>(i);
696+
auto subop = cmd->op();
697+
if (OpType_UnaryOp == subop->type() && nullptr == subop->main() && cmd->fuse() < 0) {
698+
mExecutions.emplace_back(std::make_shared<MetalGather>(loop, bn, inputs, outputs, i));
699+
} else if (OpType_MatMul == subop->type() && loop->parallel()) {
700+
mExecutions.emplace_back(std::make_shared<MetalBatchMatMul>(loop, bn, i));
701+
} else if (OpType_BinaryOp == subop->type() && cmd->fuse() < 0 && 1 == loop->loopNumber()) {
702+
mExecutions.emplace_back(std::make_shared<MetalBinaryBroadCast>(loop, bn, mTensors, MetalBinary::convert(cmd->op()->main_as_BinaryOp()->opType(), mTensors[cmd->indexes()->data()[1]]->getType().code == halide_type_float), i));
703+
} else {
704+
valid = false;
705+
break;
706+
}
707+
}
708+
if (!valid) {
709+
mExecutions.clear();
710+
}
711+
}
712+
virtual ~MetalLoop() = default;
713+
714+
virtual ErrorCode onResize(const std::vector<Tensor *>& inputs, const std::vector<Tensor *>& outputs) override {
715+
// Init
716+
if(mNeedInit) {
717+
_setTensorStack(mTensors, inputs, outputs, mLoop);
718+
auto initCmd = mLoop->initCommand()->GetAs<RegionCommand>(0);
719+
auto data = reinterpret_cast<InitInfo*>([mInitParam contents]);
720+
721+
auto srcStride = initCmd->view()->GetAs<View>(1)->stride()->data();
722+
auto dstStride = initCmd->view()->GetAs<View>(0)->stride()->data();
723+
auto dataSize = initCmd->size()->data();
724+
for (int i = 0; i < 3; ++i) {
725+
data->srcStride[i] = srcStride[i];
726+
data->dstStride[i] = dstStride[i];
727+
data->size[i] = dataSize[i];
728+
}
729+
730+
auto initDstTensor = mTensors[initCmd->indexes()->data()[0]];
731+
auto initSrcTensor = mTensors[initCmd->indexes()->data()[1]];
732+
auto initInputSize = initSrcTensor->usize() / initSrcTensor->buffer().type.bytes();
733+
auto initOutputSize = initDstTensor->usize() / initDstTensor->buffer().type.bytes();
734+
data->totalSize[0] = initInputSize;
735+
data->totalSize[1] = initOutputSize;
736+
737+
auto backend = static_cast<MetalBackend *>(this->backend());
738+
auto context = (__bridge MNNMetalContext *)backend->context();
739+
mInitThreads = [context computeBestGroupAndLocal:mInitPipeline threads:MTLSizeMake(data->size[0], data->size[1], data->size[2])];
740+
}
741+
742+
// Loop commands
743+
for (auto& exe : mExecutions) {
744+
auto code = exe->onResize(inputs, outputs);
745+
if (NO_ERROR != code) {
746+
return code;
747+
}
748+
}
749+
return NO_ERROR;
750+
}
751+
virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override {
752+
// Init
753+
if(mNeedInit) {
754+
auto initCmd = mLoop->initCommand()->GetAs<RegionCommand>(0);
755+
[encoder setComputePipelineState:mInitPipeline];
756+
auto dstTensor = mTensors[initCmd->indexes()->data()[0]];
757+
auto srcTensor = mTensors[initCmd->indexes()->data()[1]];
758+
MetalBackend::setTensor(dstTensor, encoder, 0);
759+
MetalBackend::setTensor(srcTensor, encoder, 1);
760+
[encoder setBuffer:mInitParam offset:0 atIndex:2];
761+
[encoder dispatchThreadgroups:mInitThreads.first threadsPerThreadgroup:mInitThreads.second];
762+
}
763+
// Loop commands
764+
for (auto& exe : mExecutions) {
765+
exe->onEncode(inputs, outputs, encoder);
766+
}
767+
}
768+
bool isValid() {
769+
return !mExecutions.empty();
770+
}
771+
private:
772+
const LoopParam* mLoop;
773+
std::vector<std::shared_ptr<MetalExecution>> mExecutions;
774+
std::vector<Tensor*> mTensors;
775+
// For Init
776+
bool mNeedInit = false;
777+
std::pair<MTLSize, MTLSize> mInitThreads;
778+
id<MTLComputePipelineState> mInitPipeline;
779+
id<MTLBuffer> mInitParam;
780+
714781
};
715782

716783
class MetalLoopCreator : public MetalBackend::Creator {
@@ -725,7 +792,7 @@ virtual void onEncode(const std::vector<Tensor *>& inputs, const std::vector<Ten
725792
if (1 == loop->commands()->size()) {
726793
auto cmd = loop->commands()->GetAs<RegionCommand>(0);
727794
auto subop = cmd->op();
728-
if (OpType_UnaryOp == subop->type() && nullptr == subop->main() && cmd->fuse() < 0) {
795+
if (OpType_UnaryOp == subop->type() && nullptr == subop->main() && cmd->fuse() < 0 && nullptr == loop->initCommand()) {
729796
return new MetalGather(loop, bn, inputs, outputs);
730797
}
731798
if (OpType_MatMul == subop->type() && loop->parallel() && nullptr == loop->initCommand()) {
@@ -741,10 +808,17 @@ virtual void onEncode(const std::vector<Tensor *>& inputs, const std::vector<Ten
741808
MNN_ERROR("Metal Don't support binary - %d \n", cmd->op()->main_as_BinaryOp()->opType());
742809
return nullptr;
743810
}
744-
return new MetalBinaryBroadCast(loop, bn, std::move(tensors), CUSTOM);
811+
return new MetalBinaryBroadCast(loop, bn, tensors, CUSTOM);
745812
}
746813
}
814+
// General Case
815+
auto exe = new MetalLoop(loop, bn, inputs, outputs);
816+
if (exe->isValid()) {
817+
return exe;
818+
}
819+
delete exe;
747820
return nullptr;
821+
748822
}
749823
};
750824
REGISTER_METAL_OP_CREATOR(MetalLoopCreator, OpType_While);

0 commit comments

Comments
 (0)