@@ -162,9 +162,10 @@ static void _setTensorStack(std::vector<Tensor*>& result, const std::vector<Tens
162162 id <MTLComputePipelineState > mPipeline ;
163163 std::vector<Tensor*> mTensors ;
164164 bool mHasBias = false ;
165+ int mCmdIndex = 0 ;
165166
166167public:
167- MetalBatchMatMul (const LoopParam* loop, Backend *bn) : MetalExecution(bn) {
168+ MetalBatchMatMul (const LoopParam* loop, Backend *bn, int index = 0 ) : MetalExecution(bn) {
168169 mLoop = loop;
169170 auto mtbn = static_cast <MetalBackend *>(bn);
170171 auto context = (__bridge MNNMetalContext *)mtbn->context ();
@@ -180,7 +181,7 @@ static void _setTensorStack(std::vector<Tensor*>& result, const std::vector<Tens
180181 std::string ([T UTF8String ]),
181182 " matmulunit"
182183 };
183- auto cmd = loop->commands ()->GetAs <RegionCommand>(0 );
184+ auto cmd = loop->commands ()->GetAs <RegionCommand>(index );
184185 mHasBias = cmd->indexes ()->size () > 3 ;
185186 if (mHasBias ) {
186187 keys.emplace_back (" BIAS" );
@@ -200,17 +201,19 @@ static void _setTensorStack(std::vector<Tensor*>& result, const std::vector<Tens
200201 }
201202 pipeline = mtbn->makeComputePipelineWithSourceOption (gMatMulUnitTemplate , " loop_matmul" , compileOptions);
202203 mtbn->runtime ()->insertPipeline (keys, pipeline);
204+
203205 }
204206 if (nil == pipeline) {
205207 MNN_ERROR (" Create batch matmul pipeline error\n " );
206208 }
207209 mPipeline = pipeline;
208210 mTensors .resize (mLoop ->tensorNumber ());
211+ mCmdIndex = index;
209212 }
210213 virtual ~MetalBatchMatMul () = default ;
211214 virtual ErrorCode onResize (const std::vector<Tensor *>& inputs, const std::vector<Tensor *>& outputs) override {
212215 _setTensorStack (mTensors , inputs, outputs, mLoop );
213- auto cmd = mLoop ->commands ()->GetAs <RegionCommand>(0 );
216+ auto cmd = mLoop ->commands ()->GetAs <RegionCommand>(mCmdIndex );
214217 auto size = cmd->size ()->data ();
215218 auto AStride = cmd->view ()->GetAs <View>(1 )->stride ()->data ();
216219 auto BStride = cmd->view ()->GetAs <View>(2 )->stride ()->data ();
@@ -234,7 +237,7 @@ virtual ErrorCode onResize(const std::vector<Tensor *>& inputs, const std::vecto
234237 return NO_ERROR;
235238 }
236239 virtual void onEncode (const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id <MTLComputeCommandEncoder > encoder) override {
237- auto cmd = mLoop ->commands ()->GetAs <RegionCommand>(0 );
240+ auto cmd = mLoop ->commands ()->GetAs <RegionCommand>(mCmdIndex );
238241 auto size = cmd->size ()->data ();
239242 auto AStride = cmd->view ()->GetAs <View>(1 )->stride ()->data ();
240243 auto BStride = cmd->view ()->GetAs <View>(2 )->stride ()->data ();
@@ -252,6 +255,8 @@ virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Ten
252255 MetalBackend::setTensor (inputs[0 ], encoder, cmd->indexes ()->size () + i);
253256 }
254257 }
258+ // printf("loop_matmul out dequant BMNK: %d %d %d %d\n", mLoop->loopNumber(), size[0], size[2], size[1]);
259+
255260 [encoder setBuffer: mParam offset: 0 atIndex: cmd->indexes ()->size () * 2 ];
256261 [encoder dispatchThreadgroups: MTLSizeMake (UP_DIV (totalSize, 256 ), 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (256 , 1 , 1 )];
257262 }
@@ -394,23 +399,21 @@ kernel void set_copy(device T *out [[buffer(0)]],
394399class MetalGather : public MetalExecution {
395400private:
396401 const LoopParam* mLoop ;
397- bool mNeedInit = false ;
398- std::pair<MTLSize , MTLSize > mInitThreads ;
402+ int mCmdIndex = 0 ;
399403 id <MTLBuffer > mParam ;
400404 id <MTLComputePipelineState > mPipeline ;
401- id <MTLComputePipelineState > mInitPipeline ;
402- id <MTLBuffer > mInitParam ;
403405 std::vector<Tensor*> mTensors ;
404406public:
405- MetalGather (const LoopParam* loop, Backend *bn, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) : MetalExecution(bn) {
407+ MetalGather (const LoopParam* loop, Backend *bn, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, int index = 0 ) : MetalExecution(bn) {
406408 mLoop = loop;
407409 auto mtbn = static_cast <MetalBackend *>(bn);
408410 auto context = (__bridge MNNMetalContext *)mtbn->context ();
409411
410412 mParam = [context newDeviceBuffer: sizeof (GatherInfo) access: CPUWriteOnly];
411413 bool useFp16 = mtbn->useFp16InsteadFp32 ();
412414 mTensors .resize (mLoop ->tensorNumber ());
413- auto cmd = mLoop ->commands ()->GetAs <RegionCommand>(0 );
415+ mCmdIndex = index;
416+ auto cmd = mLoop ->commands ()->GetAs <RegionCommand>(mCmdIndex );
414417 _setTensorStack (mTensors , inputs, outputs, mLoop );
415418 auto dstTensor = mTensors [cmd->indexes ()->data ()[0 ]];
416419
@@ -435,40 +438,10 @@ kernel void set_copy(device T *out [[buffer(0)]],
435438 }
436439 mPipeline = pipeline;
437440 }
438-
439- // scatter need init command pipeline
440- if (mLoop ->initCommand () != nullptr ){
441- mNeedInit = true ;
442- std::string shader = " set_copy" ;
443- auto cmd = mLoop ->initCommand ()->GetAs <RegionCommand>(0 );
444- if (cmd->op () == nullptr ){
445- shader = " set_zero" ;
446- } else {
447- mInitParam = [context newDeviceBuffer: sizeof (InitInfo) access: CPUWriteOnly];
448- }
449- std::vector<std::string> keys = {
450- std::string ([T UTF8String ]),
451- " init_region" ,
452- shader
453- };
454- auto pipeline = mtbn->runtime ()->findPipeline (keys);
455- if (nil == pipeline) {
456- MTLCompileOptions *compileOptions = [[MTLCompileOptions alloc ] init ];
457- compileOptions.preprocessorMacros = @{
458- @" T" : T,
459- };
460- pipeline = mtbn->makeComputePipelineWithSourceOption (gInitRegion , shader.c_str (), compileOptions);
461- mtbn->runtime ()->insertPipeline (keys, pipeline);
462- }
463- if (nil == pipeline) {
464- MNN_ERROR (" Create gather init pipeline error\n " );
465- }
466- mInitPipeline = pipeline;
467- }
468441 }
469442 virtual ~MetalGather () = default ;
470443 virtual ErrorCode onResize (const std::vector<Tensor *>& inputs, const std::vector<Tensor *>& outputs) override {
471- auto cmd = mLoop ->commands ()->GetAs <RegionCommand>(0 );
444+ auto cmd = mLoop ->commands ()->GetAs <RegionCommand>(mCmdIndex );
472445 _setTensorStack (mTensors , inputs, outputs, mLoop );
473446
474447 auto srcStride = cmd->view ()->GetAs <View>(1 )->stride ()->data ();
@@ -504,51 +477,11 @@ virtual ErrorCode onResize(const std::vector<Tensor *>& inputs, const std::vecto
504477 param->totalSize [0 ] = inputSize;
505478 param->totalSize [1 ] = outputSize;
506479
507- if (mNeedInit ) {
508- auto initCmd = mLoop ->initCommand ()->GetAs <RegionCommand>(0 );
509- auto data = reinterpret_cast <InitInfo*>([mInitParam contents ]);
510-
511- auto srcStride = initCmd->view ()->GetAs <View>(1 )->stride ()->data ();
512- auto dstStride = initCmd->view ()->GetAs <View>(0 )->stride ()->data ();
513- auto dataSize = initCmd->size ()->data ();
514- for (int i = 0 ; i < 3 ; ++i) {
515- data->srcStride [i] = srcStride[i];
516- data->dstStride [i] = dstStride[i];
517- data->size [i] = dataSize[i];
518- }
519-
520- auto initDstTensor = mTensors [initCmd->indexes ()->data ()[0 ]];
521- auto initSrcTensor = mTensors [initCmd->indexes ()->data ()[1 ]];
522- auto initInputSize = initSrcTensor->usize () / initSrcTensor->buffer ().type .bytes ();
523- auto initOutputSize = initDstTensor->usize () / initDstTensor->buffer ().type .bytes ();
524- data->totalSize [0 ] = initInputSize;
525- data->totalSize [1 ] = initOutputSize;
526-
527- auto backend = static_cast <MetalBackend *>(this ->backend ());
528- auto context = (__bridge MNNMetalContext *)backend->context ();
529- mInitThreads = [context computeBestGroupAndLocal: mInitPipeline threads: MTLSizeMake (data->size[0 ], data->size[1 ], data->size[2 ])];
530- }
531480 return NO_ERROR;
532481 }
533482 virtual void onEncode (const std::vector<Tensor *>& inputs, const std::vector<Tensor *>& outputs,
534483 id <MTLComputeCommandEncoder > encoder) override {
535- if (mNeedInit ) {
536- auto initCmd = mLoop ->initCommand ()->GetAs <RegionCommand>(0 );
537- int x = initCmd->size ()->data ()[0 ];
538- int y = initCmd->size ()->data ()[1 ];
539- int z = initCmd->size ()->data ()[2 ];
540-
541- [encoder setComputePipelineState: mInitPipeline ];
542- auto dstTensor = mTensors [initCmd->indexes ()->data ()[0 ]];
543- auto srcTensor = mTensors [initCmd->indexes ()->data ()[1 ]];
544- MetalBackend::setTensor (dstTensor, encoder, 0 );
545- MetalBackend::setTensor (srcTensor, encoder, 1 );
546- [encoder setBuffer: mInitParam offset: 0 atIndex: 2 ];
547-
548- [encoder dispatchThreadgroups: mInitThreads .first threadsPerThreadgroup: mInitThreads .second];
549- }
550-
551- auto cmd = mLoop ->commands ()->GetAs <RegionCommand>(0 );
484+ auto cmd = mLoop ->commands ()->GetAs <RegionCommand>(mCmdIndex );
552485 auto size = cmd->size ()->data ();
553486 auto srcStride = cmd->view ()->GetAs <View>(1 )->stride ()->data ();
554487 auto dstStride = cmd->view ()->GetAs <View>(0 )->stride ()->data ();
@@ -631,13 +564,14 @@ kernel void loop_binary(device T1* uOutput [[buffer(0)]], const device T0* uInpu
631564
632565class MetalBinaryBroadCast : public MetalExecution {
633566public:
634- MetalBinaryBroadCast (const LoopParam* loop, Backend *bn, std::vector<Tensor*>&& tensors, NSString * CUSTOM) : MetalExecution(bn) {
567+ MetalBinaryBroadCast (const LoopParam* loop, Backend *bn, const std::vector<Tensor*>& tensors, NSString * CUSTOM, int index = 0 ) : MetalExecution(bn) {
635568 mLoop = loop;
636569 auto mtbn = static_cast <MetalBackend *>(bn);
637570 auto context = (__bridge MNNMetalContext *)mtbn->context ();
638571 mParam = mtbn->getConstBuffer (sizeof (BinaryBroadCastInfo));
639- mTensors = std::move (tensors);
640- auto cmd = mLoop ->commands ()->GetAs <RegionCommand>(0 );
572+ mTensors = tensors;
573+ mCmdIndex = index;
574+ auto cmd = mLoop ->commands ()->GetAs <RegionCommand>(mCmdIndex );
641575 auto dstTensor = mTensors [cmd->indexes ()->data ()[0 ]];
642576 auto srcTensor = mTensors [cmd->indexes ()->data ()[1 ]];
643577 auto srcTensor1 = mTensors [cmd->indexes ()->data ()[2 ]];
@@ -672,7 +606,7 @@ kernel void loop_binary(device T1* uOutput [[buffer(0)]], const device T0* uInpu
672606 }
673607 virtual ErrorCode onResize (const std::vector<Tensor *>& inputs, const std::vector<Tensor *>& outputs) override {
674608 _setTensorStack (mTensors , inputs, outputs, mLoop );
675- auto cmd = mLoop ->commands ()->GetAs <RegionCommand>(0 );
609+ auto cmd = mLoop ->commands ()->GetAs <RegionCommand>(mCmdIndex );
676610 auto size = cmd->size ()->data ();
677611 auto srcStride0 = cmd->view ()->GetAs <View>(1 )->stride ()->data ();
678612 auto srcStride1 = cmd->view ()->GetAs <View>(2 )->stride ()->data ();
@@ -694,7 +628,7 @@ virtual ErrorCode onResize(const std::vector<Tensor *>& inputs, const std::vecto
694628 virtual void onEncode (const std::vector<Tensor *>& inputs, const std::vector<Tensor *>& outputs,
695629 id <MTLComputeCommandEncoder > encoder) override {
696630
697- auto cmd = mLoop ->commands ()->GetAs <RegionCommand>(0 );
631+ auto cmd = mLoop ->commands ()->GetAs <RegionCommand>(mCmdIndex );
698632 auto dstTensor = mTensors [cmd->indexes ()->data ()[0 ]];
699633 auto srcTensor = mTensors [cmd->indexes ()->data ()[1 ]];
700634 auto srcTensor1 = mTensors [cmd->indexes ()->data ()[2 ]];
@@ -711,6 +645,139 @@ virtual void onEncode(const std::vector<Tensor *>& inputs, const std::vector<Ten
711645 id <MTLBuffer > mParam ;
712646 std::vector<Tensor*> mTensors ;
713647 int mTotalSize ;
648+ int mCmdIndex = 0 ;
649+ };
650+
651+ class MetalLoop : public MetalExecution {
652+ public:
653+ MetalLoop (const LoopParam* loop, Backend *bn, const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) : MetalExecution(bn) {
654+ mLoop = loop;
655+ auto mtbn = static_cast <MetalBackend *>(bn);
656+ auto context = (__bridge MNNMetalContext *)mtbn->context ();
657+ mTensors .resize (mLoop ->tensorNumber ());
658+ _setTensorStack (mTensors , inputs, outputs, mLoop );
659+
660+ // Init
661+ if (mLoop ->initCommand () != nullptr ) {
662+ mNeedInit = true ;
663+ std::string shader = " set_copy" ;
664+ auto dstTensor = mTensors [mLoop->initCommand ()->GetAs<RegionCommand>(0 )->indexes ()->data ()[0 ]];
665+ NSString * T = MetalCast::getScalarType (dstTensor->getType (), mtbn->useFp16InsteadFp32 ());
666+
667+ auto cmd = mLoop ->initCommand ()->GetAs <RegionCommand>(0 );
668+ if (cmd->op () == nullptr ){
669+ shader = " set_zero" ;
670+ } else {
671+ mInitParam = [context newDeviceBuffer: sizeof (InitInfo) access: CPUWriteOnly];
672+ }
673+ std::vector<std::string> keys = {
674+ std::string ([T UTF8String ]),
675+ " init_region" ,
676+ shader
677+ };
678+ auto pipeline = mtbn->runtime ()->findPipeline (keys);
679+ if (nil == pipeline) {
680+ MTLCompileOptions *compileOptions = [[MTLCompileOptions alloc ] init ];
681+ compileOptions.preprocessorMacros = @{
682+ @" T" : T,
683+ };
684+ pipeline = mtbn->makeComputePipelineWithSourceOption (gInitRegion , shader.c_str (), compileOptions);
685+ mtbn->runtime ()->insertPipeline (keys, pipeline);
686+ }
687+ if (nil == pipeline) {
688+ MNN_ERROR (" Create gather init pipeline error\n " );
689+ }
690+ mInitPipeline = pipeline;
691+ }
692+
693+ bool valid = true ;
694+ for (int i=0 ; i<loop->commands ()->size (); ++i) {
695+ auto cmd = loop->commands ()->GetAs <RegionCommand>(i);
696+ auto subop = cmd->op ();
697+ if (OpType_UnaryOp == subop->type () && nullptr == subop->main () && cmd->fuse () < 0 ) {
698+ mExecutions .emplace_back (std::make_shared<MetalGather>(loop, bn, inputs, outputs, i));
699+ } else if (OpType_MatMul == subop->type () && loop->parallel ()) {
700+ mExecutions .emplace_back (std::make_shared<MetalBatchMatMul>(loop, bn, i));
701+ } else if (OpType_BinaryOp == subop->type () && cmd->fuse () < 0 && 1 == loop->loopNumber ()) {
702+ mExecutions .emplace_back (std::make_shared<MetalBinaryBroadCast>(loop, bn, mTensors , MetalBinary::convert (cmd->op ()->main_as_BinaryOp ()->opType (), mTensors [cmd->indexes ()->data ()[1 ]]->getType ().code == halide_type_float), i));
703+ } else {
704+ valid = false ;
705+ break ;
706+ }
707+ }
708+ if (!valid) {
709+ mExecutions .clear ();
710+ }
711+ }
712+ virtual ~MetalLoop () = default ;
713+
714+ virtual ErrorCode onResize (const std::vector<Tensor *>& inputs, const std::vector<Tensor *>& outputs) override {
715+ // Init
716+ if (mNeedInit ) {
717+ _setTensorStack (mTensors , inputs, outputs, mLoop );
718+ auto initCmd = mLoop ->initCommand ()->GetAs <RegionCommand>(0 );
719+ auto data = reinterpret_cast <InitInfo*>([mInitParam contents ]);
720+
721+ auto srcStride = initCmd->view ()->GetAs <View>(1 )->stride ()->data ();
722+ auto dstStride = initCmd->view ()->GetAs <View>(0 )->stride ()->data ();
723+ auto dataSize = initCmd->size ()->data ();
724+ for (int i = 0 ; i < 3 ; ++i) {
725+ data->srcStride [i] = srcStride[i];
726+ data->dstStride [i] = dstStride[i];
727+ data->size [i] = dataSize[i];
728+ }
729+
730+ auto initDstTensor = mTensors [initCmd->indexes ()->data ()[0 ]];
731+ auto initSrcTensor = mTensors [initCmd->indexes ()->data ()[1 ]];
732+ auto initInputSize = initSrcTensor->usize () / initSrcTensor->buffer ().type .bytes ();
733+ auto initOutputSize = initDstTensor->usize () / initDstTensor->buffer ().type .bytes ();
734+ data->totalSize [0 ] = initInputSize;
735+ data->totalSize [1 ] = initOutputSize;
736+
737+ auto backend = static_cast <MetalBackend *>(this ->backend ());
738+ auto context = (__bridge MNNMetalContext *)backend->context ();
739+ mInitThreads = [context computeBestGroupAndLocal: mInitPipeline threads: MTLSizeMake (data->size[0 ], data->size[1 ], data->size[2 ])];
740+ }
741+
742+ // Loop commands
743+ for (auto & exe : mExecutions ) {
744+ auto code = exe->onResize (inputs, outputs);
745+ if (NO_ERROR != code) {
746+ return code;
747+ }
748+ }
749+ return NO_ERROR;
750+ }
751+ virtual void onEncode (const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id <MTLComputeCommandEncoder > encoder) override {
752+ // Init
753+ if (mNeedInit ) {
754+ auto initCmd = mLoop ->initCommand ()->GetAs <RegionCommand>(0 );
755+ [encoder setComputePipelineState: mInitPipeline ];
756+ auto dstTensor = mTensors [initCmd->indexes ()->data ()[0 ]];
757+ auto srcTensor = mTensors [initCmd->indexes ()->data ()[1 ]];
758+ MetalBackend::setTensor (dstTensor, encoder, 0 );
759+ MetalBackend::setTensor (srcTensor, encoder, 1 );
760+ [encoder setBuffer: mInitParam offset: 0 atIndex: 2 ];
761+ [encoder dispatchThreadgroups: mInitThreads .first threadsPerThreadgroup: mInitThreads .second];
762+ }
763+ // Loop commands
764+ for (auto & exe : mExecutions ) {
765+ exe->onEncode (inputs, outputs, encoder);
766+ }
767+ }
768+ bool isValid () {
769+ return !mExecutions .empty ();
770+ }
771+ private:
772+ const LoopParam* mLoop ;
773+ std::vector<std::shared_ptr<MetalExecution>> mExecutions ;
774+ std::vector<Tensor*> mTensors ;
775+ // For Init
776+ bool mNeedInit = false ;
777+ std::pair<MTLSize , MTLSize > mInitThreads ;
778+ id <MTLComputePipelineState > mInitPipeline ;
779+ id <MTLBuffer > mInitParam ;
780+
714781};
715782
716783class MetalLoopCreator : public MetalBackend ::Creator {
@@ -725,7 +792,7 @@ virtual void onEncode(const std::vector<Tensor *>& inputs, const std::vector<Ten
725792 if (1 == loop->commands ()->size ()) {
726793 auto cmd = loop->commands ()->GetAs <RegionCommand>(0 );
727794 auto subop = cmd->op ();
728- if (OpType_UnaryOp == subop->type () && nullptr == subop->main () && cmd->fuse () < 0 ) {
795+ if (OpType_UnaryOp == subop->type () && nullptr == subop->main () && cmd->fuse () < 0 && nullptr == loop-> initCommand () ) {
729796 return new MetalGather (loop, bn, inputs, outputs);
730797 }
731798 if (OpType_MatMul == subop->type () && loop->parallel () && nullptr == loop->initCommand ()) {
@@ -741,10 +808,17 @@ virtual void onEncode(const std::vector<Tensor *>& inputs, const std::vector<Ten
741808 MNN_ERROR (" Metal Don't support binary - %d \n " , cmd->op ()->main_as_BinaryOp ()->opType ());
742809 return nullptr ;
743810 }
744- return new MetalBinaryBroadCast (loop, bn, std::move ( tensors) , CUSTOM);
811+ return new MetalBinaryBroadCast (loop, bn, tensors, CUSTOM);
745812 }
746813 }
814+ // General Case
815+ auto exe = new MetalLoop (loop, bn, inputs, outputs);
816+ if (exe->isValid ()) {
817+ return exe;
818+ }
819+ delete exe;
747820 return nullptr ;
821+
748822 }
749823};
750824REGISTER_METAL_OP_CREATOR (MetalLoopCreator, OpType_While);
0 commit comments