Skip to content

Commit dcf0c07

Browse files
authored
remove unnecessary funcs for dslop outputs (#355)
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
1 parent 7879d8b commit dcf0c07

File tree

1 file changed

+4
-32
lines changed

1 file changed

+4
-32
lines changed

third_party/tle/triton_tle_raw.cc

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -25,34 +25,6 @@ SmallVector<Value> flatten(TritonOpBuilder &builder,
2525
}
2626
} // namespace
2727

28-
static SmallVector<Type>
29-
aggregationTypes(TritonOpBuilder &builder,
30-
const SmallVector<Type> &unconvertTypes,
31-
const SmallVector<Type> &convertTypes) {
32-
SmallVector<Type> resultTypes;
33-
TypeRange tgts = convertTypes;
34-
for (Type singletype : unconvertTypes) {
35-
if (auto ptrType = dyn_cast<RankedTensorType>(singletype)) {
36-
size_t rank = ptrType.getRank();
37-
Type allocatedPtrTy = tgts[0];
38-
Type alignedPtrTy = tgts[1];
39-
Type offsetTy = tgts[2];
40-
Type sizeElemTy = tgts[3];
41-
Type strideElemTy = tgts[3 + rank];
42-
auto sizesArrayTy = LLVM::LLVMArrayType::get(sizeElemTy, rank);
43-
auto stridesArrayTy = LLVM::LLVMArrayType::get(strideElemTy, rank);
44-
SmallVector<Type> fieldTys = {
45-
allocatedPtrTy, alignedPtrTy, offsetTy, sizesArrayTy, stridesArrayTy,
46-
};
47-
resultTypes.push_back(LLVM::LLVMStructType::getLiteral(
48-
builder.getContext(), fieldTys, /*packed=*/false));
49-
} else {
50-
resultTypes.push_back(std::move(tgts.front()));
51-
tgts = tgts.drop_front();
52-
}
53-
}
54-
return resultTypes;
55-
}
5628
// Create a DSLRegionOp that wraps an LLVM function, performing type conversion
5729
// from Triton IR types to LLVM types based on EDSL function declarations.
5830
//
@@ -141,11 +113,11 @@ SmallVector<Value> createTLERawRegionByLLVMFunc(
141113
SmallVector<Value> operands =
142114
llvm::to_vector(llvm::concat<Value>(converted_outputs, converted_inputs));
143115

144-
SmallVector<Type> dslOutputTys = llvm::map_to_vector(
145-
converted_outputs, [](Value value) -> Type { return value.getType(); });
146-
auto outStructTy = aggregationTypes(self, outputTys, dslOutputTys);
116+
SmallVector<Type> returnTys = llvm::filter_to_vector(
117+
func.getFunctionType().getReturnTypes(),
118+
[](Type ty) -> bool { return !isa<LLVM::LLVMVoidType>(ty); });
147119
tle::DSLRegionOp dslRegionOp =
148-
self.create<tle::DSLRegionOp>(outStructTy, operands);
120+
self.create<tle::DSLRegionOp>(returnTys, operands);
149121
OpBuilder::InsertionGuard guard(builder);
150122
Region &body = dslRegionOp.getBody();
151123
SmallVector<Type> operandTys = llvm::map_to_vector(

0 commit comments

Comments
 (0)