Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/circt/Dialect/LLHD/LLHDOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def SigStructExtractOp : LLHDOp<"sig.struct_extract", [
```
}];

let arguments = (ins RefTypeOf<StructType>:$input, StrAttr:$field);
let arguments = (ins RefTypeOf<AnyTypeOf<[StructType, UnionType]>>:$input, StrAttr:$field);
let results = (outs RefTypeOf<AnyType>:$result);

let assemblyFormat = [{
Expand Down
95 changes: 92 additions & 3 deletions lib/Conversion/MooreToCore/MooreToCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ static hw::ModulePortInfo getModulePortInfo(const TypeConverter &typeConverter,

for (auto port : moduleTy.getPorts()) {
Type portTy = typeConverter.convertType(port.type);
if (!portTy)
return hw::ModulePortInfo({});
if (port.dir == hw::ModulePort::Direction::Output) {
ports.push_back(
hw::PortInfo({{port.name, portTy, port.dir}, resultNum++, {}}));
Expand Down Expand Up @@ -271,9 +273,9 @@ struct SVModuleOpConversion : public OpConversionPattern<SVModuleOp> {
rewriter.setInsertionPoint(op);

// Create the hw.module to replace moore.module
auto hwModuleOp =
hw::HWModuleOp::create(rewriter, op.getLoc(), op.getSymNameAttr(),
getModulePortInfo(*typeConverter, op));
auto portInfo = getModulePortInfo(*typeConverter, op);
auto hwModuleOp = hw::HWModuleOp::create(rewriter, op.getLoc(),
op.getSymNameAttr(), portInfo);
// Make hw.module have the same visibility as the moore.module.
// The entry/top level module is public, otherwise is private.
SymbolTable::setSymbolVisibility(hwModuleOp,
Expand Down Expand Up @@ -1357,6 +1359,44 @@ struct StructExtractRefOpConversion
}
};

struct UnionCreateOpConversion : public OpConversionPattern<UnionCreateOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(UnionCreateOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type resultType = typeConverter->convertType(op.getResult().getType());
rewriter.replaceOpWithNewOp<hw::UnionCreateOp>(
op, resultType, adaptor.getFieldNameAttr(), adaptor.getInput());
return success();
}
};

struct UnionExtractOpConversion : public OpConversionPattern<UnionExtractOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(UnionExtractOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<hw::UnionExtractOp>(op, adaptor.getInput(),
adaptor.getFieldNameAttr());
return success();
}
};

struct UnionExtractRefOpConversion
: public OpConversionPattern<UnionExtractRefOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(UnionExtractRefOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<llhd::SigStructExtractOp>(
op, adaptor.getInput(), adaptor.getFieldNameAttr());
return success();
}
};

struct ReduceAndOpConversion : public OpConversionPattern<ReduceAndOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
Expand Down Expand Up @@ -2353,6 +2393,38 @@ static void populateTypeConversion(TypeConverter &typeConverter) {
return hw::StructType::get(type.getContext(), fields);
});

// UnionType -> hw::UnionType
typeConverter.addConversion([&](UnionType type) -> std::optional<Type> {
SmallVector<hw::UnionType::FieldInfo> fields;
for (auto field : type.getMembers()) {
hw::UnionType::FieldInfo info;
info.type = typeConverter.convertType(field.type);
if (!info.type)
return {};
info.name = field.name;
info.offset = 0; // packed union, all fields start at bit 0
fields.push_back(info);
}
auto result = hw::UnionType::get(type.getContext(), fields);
return result;
});

// UnpackedUnionType -> hw::UnionType
typeConverter.addConversion(
[&](UnpackedUnionType type) -> std::optional<Type> {
SmallVector<hw::UnionType::FieldInfo> fields;
for (auto field : type.getMembers()) {
hw::UnionType::FieldInfo info;
info.type = typeConverter.convertType(field.type);
if (!info.type)
return {};
info.name = field.name;
info.offset = 0;
fields.push_back(info);
}
return hw::UnionType::get(type.getContext(), fields);
});

// Conversion of CHandle to LLVMPointerType
typeConverter.addConversion([&](ChandleType type) -> std::optional<Type> {
return LLVM::LLVMPointerType::get(type.getContext());
Expand Down Expand Up @@ -2407,6 +2479,20 @@ static void populateTypeConversion(TypeConverter &typeConverter) {
return hw::StructType::get(type.getContext(), fields);
});

typeConverter.addConversion([&](hw::UnionType type) -> std::optional<Type> {
SmallVector<hw::UnionType::FieldInfo> fields;
for (auto field : type.getElements()) {
hw::UnionType::FieldInfo info;
info.type = typeConverter.convertType(field.type);
if (!info.type)
return {};
info.name = field.name;
info.offset = field.offset;
fields.push_back(info);
}
return hw::UnionType::get(type.getContext(), fields);
});

typeConverter.addTargetMaterialization(
[&](mlir::OpBuilder &builder, mlir::Type resultType,
mlir::ValueRange inputs, mlir::Location loc) -> mlir::Value {
Expand Down Expand Up @@ -2475,6 +2561,9 @@ static void populateOpConversion(ConversionPatternSet &patterns,
StructExtractRefOpConversion,
ExtractRefOpConversion,
StructCreateOpConversion,
UnionCreateOpConversion,
UnionExtractOpConversion,
UnionExtractRefOpConversion,
ConditionalOpConversion,
ArrayCreateOpConversion,
YieldOpConversion,
Expand Down
37 changes: 29 additions & 8 deletions lib/Dialect/LLHD/IR/LLHDOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,16 +475,28 @@ LogicalResult llhd::SigStructExtractOp::inferReturnTypes(
mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
typename SigStructExtractOp::Adaptor adaptor(operands, attrs, properties,
regions);
Type type = cast<hw::StructType>(
cast<RefType>(adaptor.getInput().getType()).getNestedType())
.getFieldType(adaptor.getField());
if (!type) {
auto nestedType = cast<RefType>(adaptor.getInput().getType()).getNestedType();
Type fieldType;

// Support both StructType and UnionType
if (auto structType = dyn_cast<hw::StructType>(nestedType)) {
fieldType = structType.getFieldType(adaptor.getField());
} else if (auto unionType = dyn_cast<hw::UnionType>(nestedType)) {
fieldType = unionType.getFieldType(adaptor.getField());
} else {
context->getDiagEngine().emit(loc.value_or(UnknownLoc()),
DiagnosticSeverity::Error)
<< "expected struct or union type";
return failure();
}

if (!fieldType) {
context->getDiagEngine().emit(loc.value_or(UnknownLoc()),
DiagnosticSeverity::Error)
<< "invalid field name specified";
return failure();
}
results.push_back(RefType::get(type));
results.push_back(RefType::get(fieldType));
return success();
}

Expand All @@ -495,9 +507,18 @@ bool SigStructExtractOp::canRewire(
const DataLayout &dataLayout) {
if (slot.ptr != getInput())
return false;
auto index =
cast<hw::StructType>(cast<RefType>(getInput().getType()).getNestedType())
.getFieldIndex(getFieldAttr());

auto nestedType = cast<RefType>(getInput().getType()).getNestedType();
std::optional<uint32_t> index;

// Support both StructType and UnionType
if (auto structType = dyn_cast<hw::StructType>(nestedType))
index = structType.getFieldIndex(getFieldAttr());
else if (auto unionType = dyn_cast<hw::UnionType>(nestedType))
index = unionType.getFieldIndex(getFieldAttr());
else
return false;

if (!index)
return false;
auto indexAttr = IntegerAttr::get(IndexType::get(getContext()), *index);
Expand Down
14 changes: 11 additions & 3 deletions lib/Dialect/Moore/MooreOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1010,12 +1010,20 @@ LogicalResult UnionCreateOp::verify() {
return TypeSwitch<Type, LogicalResult>(getType())
.Case<UnionType, UnpackedUnionType>([this](auto &type) {
auto members = type.getMembers();
auto resultType = getType();
auto inputType = getInput().getType();
auto fieldName = getFieldName();
for (const auto &member : members)
if (member.name == fieldName && member.type == resultType)
if (member.name == fieldName && member.type == inputType)
return success();
emitOpError("input type must match the union field type");
for (const auto &member : members) {
if (member.name == fieldName) {
emitOpError() << "input type " << inputType
<< " does not match union field '" << fieldName
<< "' type " << member.type;
return failure();
}
}
emitOpError() << "field '" << fieldName << "' not found in union type";
return failure();
})
.Default([this](auto &) {
Expand Down
15 changes: 15 additions & 0 deletions test/Conversion/MooreToCore/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,21 @@ moore.module @Struct(in %a : !moore.i32, in %b : !moore.i32, in %arg0 : !moore.s
moore.output %0, %3, %4 : !moore.i32, !moore.struct<{exp_bits: i32, man_bits: i32}>, !moore.struct<{exp_bits: i32, man_bits: i32}>
}

// CHECK-LABEL: hw.module @Union
moore.module @Union(in %a : !moore.i32, in %arg0 : !moore.union<{x: i32, y: i32}>, in %arg1 : !moore.ref<union<{x: i32, y: i32}>>, out o : !moore.i32, out p : !moore.union<{x: i32, y: i32}>) {
// CHECK: hw.union_extract %arg0["x"] : !hw.union<x: i32, y: i32>
%0 = moore.union_extract %arg0, "x" : !moore.union<{x: i32, y: i32}> -> !moore.i32

// CHECK: llhd.sig.struct_extract %arg1["x"] : <!hw.union<x: i32, y: i32>>
%ref = moore.union_extract_ref %arg1, "x" : <union<{x: i32, y: i32}>> -> <i32>
moore.assign %ref, %0 : !moore.i32

// CHECK: hw.union_create "x", %a : !hw.union<x: i32, y: i32>
%1 = moore.union_create %a {fieldName = "x"} : !moore.i32 -> union<{x: i32, y: i32}>

moore.output %0, %1 : !moore.i32, !moore.union<{x: i32, y: i32}>
}

// CHECK-LABEL: func @ArrayCreate
// CHECK-SAME: () -> !hw.array<2xi8>
func.func @ArrayCreate() -> !moore.array<2x!moore.i8> {
Expand Down
8 changes: 8 additions & 0 deletions test/Dialect/LLHD/IR/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ hw.module @sigStructExtract(in %arg0 : !llhd.ref<!hw.struct<foo: i1, bar: i2, ba
%1 = llhd.sig.struct_extract %arg0["baz"] : <!hw.struct<foo: i1, bar: i2, baz: i3>>
}

// CHECK-LABEL: @sigUnionExtract
hw.module @sigUnionExtract(in %arg0 : !llhd.ref<!hw.union<foo: i1, bar: i2, baz: i3>>) {
// CHECK-NEXT: %{{.*}} = llhd.sig.struct_extract %arg0["foo"] : <!hw.union<foo: i1, bar: i2, baz: i3>>
%0 = llhd.sig.struct_extract %arg0["foo"] : <!hw.union<foo: i1, bar: i2, baz: i3>>
// CHECK-NEXT: %{{.*}} = llhd.sig.struct_extract %arg0["baz"] : <!hw.union<foo: i1, bar: i2, baz: i3>>
%1 = llhd.sig.struct_extract %arg0["baz"] : <!hw.union<foo: i1, bar: i2, baz: i3>>
}

// CHECK-LABEL: @checkSigInst
hw.module @checkSigInst() {
// CHECK: %[[CI1:.*]] = hw.constant
Expand Down
7 changes: 7 additions & 0 deletions test/Dialect/LLHD/IR/errors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ hw.module @extract_element_tuple_index_out_of_bounds(in %tup : !llhd.ref<!hw.str

// -----

hw.module @extract_element_union_index_out_of_bounds(in %union : !llhd.ref<!hw.union<foo: i1, bar: i2, baz: i3>>) {
// expected-error @+1 {{invalid field name specified}}
%0 = llhd.sig.struct_extract %union["foobar"] : <!hw.union<foo: i1, bar: i2, baz: i3>>
}

// -----

hw.module @YieldFromFinal(in %arg0: i42) {
llhd.final {
// expected-error @below {{'llhd.halt' op has 1 yield operands, but enclosing 'llhd.final' returns 0}}
Expand Down
14 changes: 14 additions & 0 deletions test/Dialect/Moore/errors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,17 @@ moore.global_variable @Foo : !moore.i9001
moore.global_variable @Foo : !moore.i42 init {
llvm.unreachable
}

// -----

// UnionCreateOp verifier: input type mismatch
%0 = moore.constant 42 : i16
// expected-error @below {{op input type '!moore.i16' does not match union field 'x' type '!moore.i32'}}
moore.union_create %0 {fieldName = "x"} : !moore.i16 -> union<{x: i32, y: i32}>

// -----

// UnionCreateOp verifier: field not found
%0 = moore.constant 42 : i32
// expected-error @below {{op field 'z' not found in union type}}
moore.union_create %0 {fieldName = "z"} : !moore.i32 -> union<{x: i32, y: i32}>