@@ -25,34 +25,6 @@ SmallVector<Value> flatten(TritonOpBuilder &builder,
2525}
2626} // namespace
2727
28- static SmallVector<Type>
29- aggregationTypes (TritonOpBuilder &builder,
30- const SmallVector<Type> &unconvertTypes,
31- const SmallVector<Type> &convertTypes) {
32- SmallVector<Type> resultTypes;
33- TypeRange tgts = convertTypes;
34- for (Type singletype : unconvertTypes) {
35- if (auto ptrType = dyn_cast<RankedTensorType>(singletype)) {
36- size_t rank = ptrType.getRank ();
37- Type allocatedPtrTy = tgts[0 ];
38- Type alignedPtrTy = tgts[1 ];
39- Type offsetTy = tgts[2 ];
40- Type sizeElemTy = tgts[3 ];
41- Type strideElemTy = tgts[3 + rank];
42- auto sizesArrayTy = LLVM::LLVMArrayType::get (sizeElemTy, rank);
43- auto stridesArrayTy = LLVM::LLVMArrayType::get (strideElemTy, rank);
44- SmallVector<Type> fieldTys = {
45- allocatedPtrTy, alignedPtrTy, offsetTy, sizesArrayTy, stridesArrayTy,
46- };
47- resultTypes.push_back (LLVM::LLVMStructType::getLiteral (
48- builder.getContext (), fieldTys, /* packed=*/ false ));
49- } else {
50- resultTypes.push_back (std::move (tgts.front ()));
51- tgts = tgts.drop_front ();
52- }
53- }
54- return resultTypes;
55- }
5628// Create a DSLRegionOp that wraps an LLVM function, performing type conversion
5729// from Triton IR types to LLVM types based on EDSL function declarations.
5830//
@@ -141,11 +113,11 @@ SmallVector<Value> createTLERawRegionByLLVMFunc(
141113 SmallVector<Value> operands =
142114 llvm::to_vector (llvm::concat<Value>(converted_outputs, converted_inputs));
143115
144- SmallVector<Type> dslOutputTys = llvm::map_to_vector (
145- converted_outputs, [](Value value) -> Type { return value. getType (); });
146- auto outStructTy = aggregationTypes (self, outputTys, dslOutputTys );
116+ SmallVector<Type> returnTys = llvm::filter_to_vector (
117+ func. getFunctionType (). getReturnTypes (),
118+ [](Type ty) -> bool { return !isa<LLVM::LLVMVoidType>(ty); } );
147119 tle::DSLRegionOp dslRegionOp =
148- self.create <tle::DSLRegionOp>(outStructTy , operands);
120+ self.create <tle::DSLRegionOp>(returnTys , operands);
149121 OpBuilder::InsertionGuard guard (builder);
150122 Region &body = dslRegionOp.getBody ();
151123 SmallVector<Type> operandTys = llvm::map_to_vector (
0 commit comments