Skip to content

Commit b167b33

Browse files
authored
Merge pull request #708 from OpenVADL/feature/add-dynamic-indexing
frontend: Add dynamic indexing
2 parents e113ece + 4c48969 commit b167b33

File tree

5 files changed

+164
-66
lines changed

5 files changed

+164
-66
lines changed

sys/aarch64/sve.vadl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -848,23 +848,23 @@ instruction set architecture AArch64SVEandSME extending AArch64Base = {
848848

849849
// common sve and sme instructions *********************************************
850850

851-
// $InstrBHSD (SVEArithmeticPredInstr ; (SVEpredADD ; "add" ; 0b0000'0000 ; add ))
852-
// $InstrBHSD (SVEArithmeticPredInstr ; (SVEpredSUB ; "sub" ; 0b0000'1000 ; sub ))
853-
// $InstrBHSD (SVEArithmeticPredInstr ; (SVEpredMUL ; "mul" ; 0b1000'0000 ; mul ))
854-
// $InstrSD (SVEArithmeticPredInstr ; (SVEpredSDIV ; "sdiv" ; 0b1010'0000 ; sdiv ))
855-
// $InstrSD (SVEArithmeticPredInstr ; (SVEpredUDIV ; "udiv" ; 0b1010'1000 ; udiv ))
856-
//
857-
// $InstrBHSD (SVEArithmeticUnprInstr ; (SVEunprADD ; "add" ; 0b000'000 ; add ))
858-
// $InstrBHSD (SVEArithmeticUnprInstr ; (SVEunprSUB ; "sub" ; 0b000'001 ; sub ))
859-
// $InstrBHSD (SVEArithmeticUnprInstr ; (SVEunprMUL ; "mul" ; 0b011'000 ; mul ))
860-
//
861-
// $InstrBHSD (SVEAddExtReductionInstr ; (SVEfoldSADD ; "saddv" ; 0b0'0000'0001 ; SIntX ))
862-
// $InstrBHSD (SVEAddExtReductionInstr ; (SVEfoldUADD ; "uaddv" ; 0b0'0000'1001 ; UIntX ))
863-
// $InstrBHSD (SVEReductionInstr ; (SVEfoldAND ; "andv" ; 0b0'1101'0001 ; and ))
864-
// $InstrBHSD (SVEReductionInstr ; (SVEfoldOR ; "orv" ; 0b0'1100'0001 ; or ))
865-
// $InstrBHSD (SVEReductionInstr ; (SVEfoldXOR ; "eorv" ; 0b0'1100'1001 ; xor ))
866-
//
867-
// $SIMDMoveInstr ( (SIMDUMOV ; "umov" ; 0b00'1111 ))
851+
$InstrBHSD (SVEArithmeticPredInstr ; (SVEpredADD ; "add" ; 0b0000'0000 ; add ))
852+
$InstrBHSD (SVEArithmeticPredInstr ; (SVEpredSUB ; "sub" ; 0b0000'1000 ; sub ))
853+
$InstrBHSD (SVEArithmeticPredInstr ; (SVEpredMUL ; "mul" ; 0b1000'0000 ; mul ))
854+
$InstrSD (SVEArithmeticPredInstr ; (SVEpredSDIV ; "sdiv" ; 0b1010'0000 ; sdiv ))
855+
$InstrSD (SVEArithmeticPredInstr ; (SVEpredUDIV ; "udiv" ; 0b1010'1000 ; udiv ))
856+
857+
$InstrBHSD (SVEArithmeticUnprInstr ; (SVEunprADD ; "add" ; 0b000'000 ; add ))
858+
$InstrBHSD (SVEArithmeticUnprInstr ; (SVEunprSUB ; "sub" ; 0b000'001 ; sub ))
859+
$InstrBHSD (SVEArithmeticUnprInstr ; (SVEunprMUL ; "mul" ; 0b011'000 ; mul ))
860+
861+
$InstrBHSD (SVEAddExtReductionInstr ; (SVEfoldSADD ; "saddv" ; 0b0'0000'0001 ; SIntX ))
862+
$InstrBHSD (SVEAddExtReductionInstr ; (SVEfoldUADD ; "uaddv" ; 0b0'0000'1001 ; UIntX ))
863+
$InstrBHSD (SVEReductionInstr ; (SVEfoldAND ; "andv" ; 0b0'1101'0001 ; and ))
864+
$InstrBHSD (SVEReductionInstr ; (SVEfoldOR ; "orv" ; 0b0'1100'0001 ; or ))
865+
$InstrBHSD (SVEReductionInstr ; (SVEfoldXOR ; "eorv" ; 0b0'1100'1001 ; xor ))
866+
867+
$SIMDMoveInstr ( (SIMDUMOV ; "umov" ; 0b00'1111 ))
868868

869869
$SVEAddVLInstr ( (SVEADDVL ; "addvl" ; 0b0'1010 ))
870870
$SVEIncrInstr ( (SVEINC ; "inc" ; 0b111'000 ; inc ))

vadl/main/vadl/ast/BehaviorLowering.java

Lines changed: 71 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import java.util.Arrays;
3737
import java.util.IdentityHashMap;
3838
import java.util.List;
39+
import java.util.concurrent.atomic.AtomicReference;
3940
import java.util.stream.Collectors;
4041
import java.util.stream.Stream;
4142
import javax.annotation.Nullable;
@@ -490,7 +491,7 @@ Procedure getRegisterAliasWriteProc(AliasDefinition definition,
490491
// If the [overwrite source:] annotation is set, we instead either zero or sign extend the
491492
// write value to overwrite the whole source register.
492493
writeValue = switch (overwriteMode) {
493-
case null -> sliceWriteValue(writeValue,
494+
case null -> staticSliceWriteValue(writeValue,
494495
new ReadRegTensorNode(reg, indices, sourceRegType, null), List.of(slice));
495496
case "zero" -> zeroExtend(writeValue, sourceRegType);
496497
case "sign" -> signExtend(writeValue, sourceRegType);
@@ -810,21 +811,28 @@ private ExpressionNode readTensorResourceConcatinated(RegisterResource resource,
810811
private List<WriteResourceNode> writeTensorResourceSliced(RegisterResource resource,
811812
List<ExpressionNode> indices,
812813
ExpressionNode value,
813-
List<Constant.BitSlice> slices) {
814+
List<Constant.BitSlice> slices,
815+
@Nullable ExpressionNode dynamicIndex) {
814816
// No multiple writes and slices needed
815817
if (resource.indexTypes().size() <= indices.size()) {
816818
switch (resource) {
817819
case RegisterTensor register -> {
818-
var slicedValue = sliceWriteValue(value,
819-
new ReadRegTensorNode(register, new NodeList<>(indices), register.resultType(), null),
820-
slices);
820+
var resourceRead =
821+
new ReadRegTensorNode(register, new NodeList<>(indices), register.resultType(), null);
822+
var slicedValue = dynamicIndexWriteValue(
823+
staticSliceWriteValue(value, resourceRead, slices),
824+
resourceRead,
825+
dynamicIndex);
821826
return List.of(
822827
new WriteRegTensorNode(register, new NodeList<>(indices), slicedValue, null, null));
823828
}
824829
case ArtificialResource register -> {
825-
var slicedValue = sliceWriteValue(value,
826-
new ReadArtificialResNode(register, new NodeList<>(indices), register.resultType()),
827-
slices);
830+
var resourceRead =
831+
new ReadArtificialResNode(register, new NodeList<>(indices), register.resultType());
832+
var slicedValue =
833+
dynamicIndexWriteValue(staticSliceWriteValue(value, resourceRead, slices),
834+
resourceRead,
835+
dynamicIndex);
828836
return List.of(
829837
new WriteArtificialResNode(register, new NodeList<>(indices), slicedValue));
830838
}
@@ -1591,7 +1599,8 @@ public SubgraphContext visit(AssignmentStatement statement) {
15911599

15921600
vadl.ast.Definition targetDef;
15931601
List<CallIndexExpr.Arguments> argGroups = List.of();
1594-
List<Constant.BitSlice> slices = new ArrayList<>();
1602+
List<Constant.BitSlice> staticSlices = new ArrayList<>();
1603+
AtomicReference<ExpressionNode> dynamicIndexExpr = new AtomicReference<>();
15951604

15961605
// the MEM<xyz>(...) value
15971606
@Nullable Integer callSize = null;
@@ -1600,12 +1609,16 @@ public SubgraphContext visit(AssignmentStatement statement) {
16001609
targetDef = (vadl.ast.Definition) callTarget.computedTarget();
16011610
argGroups = callTarget.args();
16021611
callTarget.slices().forEach(s -> {
1603-
slices.add(requireNonNull(s.computedstaticBitSlice));
1612+
if (s.computedstaticBitSlice != null) {
1613+
staticSlices.add(requireNonNull(s.computedstaticBitSlice));
1614+
} else {
1615+
dynamicIndexExpr.set(fetch(s.values.getFirst()));
1616+
}
16041617
});
16051618
// add all slices that come from format field accesses
16061619
callTarget.subCalls.forEach(s -> {
16071620
if (s.computedBitSlice != null) {
1608-
slices.add(s.computedBitSlice);
1621+
staticSlices.add(s.computedBitSlice);
16091622
}
16101623
});
16111624

@@ -1623,36 +1636,41 @@ public SubgraphContext visit(AssignmentStatement statement) {
16231636
.collect(Collectors.toCollection(NodeList::new));
16241637
var viamTargetDef = viamLowering.fetch(targetDef).orElseThrow();
16251638

1639+
var dynamicIndex = dynamicIndexExpr.get();
1640+
16261641
// No need to call getViamType here as the viam definitions should already have that.
16271642
var writeNodes = switch (viamTargetDef) {
16281643
case RegisterTensor regDef -> writeTensorResourceSliced(
1629-
regDef, argExprs, value, slices
1644+
regDef, argExprs, value, staticSlices, dynamicIndex
16301645
);
16311646

16321647
case ArtificialResource aliasDef -> writeTensorResourceSliced(
1633-
aliasDef, argExprs, value, slices
1648+
aliasDef, argExprs, value, staticSlices, dynamicIndex
16341649
);
16351650

16361651
case Memory memDef -> {
16371652
var words = callSize != null ? callSize : 1;
16381653
// slice the written value before writing it
1639-
var slicedValue = sliceWriteValue(value,
1640-
new ReadMemNode(memDef, words, argExprs.getFirst(),
1641-
((BitsType) memDef.resultType()).scaleBy(words)), slices);
1654+
var resourceRead = new ReadMemNode(memDef, words, argExprs.getFirst(),
1655+
((BitsType) memDef.resultType()).scaleBy(words));
1656+
var slicedValue = dynamicIndexWriteValue(staticSliceWriteValue(value,
1657+
resourceRead, staticSlices), resourceRead, dynamicIndex);
16421658
yield List.of((vadl.viam.graph.Node) new WriteMemNode(
16431659
memDef, callSize != null ? callSize : 1,
16441660
argExprs.getFirst(), slicedValue
16451661
));
16461662
}
16471663

16481664
// FIXME: Adjust value based on counter position
1649-
case Counter counterDef ->
1650-
List.of(new WriteRegTensorNode(counterDef.registerTensor(), argExprs,
1651-
// slice the written value before writing it
1652-
sliceWriteValue(value,
1653-
new ReadRegTensorNode(counterDef.registerTensor(), argExprs,
1654-
counterDef.registerTensor().resultType(), null), slices),
1655-
null, null));
1665+
case Counter counterDef -> {
1666+
var resourceRead = new ReadRegTensorNode(counterDef.registerTensor(), argExprs,
1667+
counterDef.registerTensor().resultType(), null);
1668+
yield List.of(new WriteRegTensorNode(counterDef.registerTensor(), argExprs,
1669+
// slice the written value before writing it
1670+
dynamicIndexWriteValue(staticSliceWriteValue(value,
1671+
resourceRead, staticSlices), resourceRead, dynamicIndex),
1672+
null, null));
1673+
}
16561674

16571675
case StageOutput output -> List.of(
16581676
new WriteStageOutputNode(output, value)
@@ -1661,7 +1679,6 @@ public SubgraphContext visit(AssignmentStatement statement) {
16611679
default -> throw new IllegalStateException("Unexpected target: " + viamTargetDef);
16621680
};
16631681

1664-
16651682
for (var writeNode : writeNodes) {
16661683
writeNode.setSourceLocationIfNotSet(statement.target.location());
16671684
}
@@ -1670,6 +1687,33 @@ public SubgraphContext visit(AssignmentStatement statement) {
16701687
writeNodes.stream().map(n -> (vadl.viam.graph.Node) n).toList());
16711688
}
16721689

1690+
/**
1691+
* Method that prepares the value so that it can be used for a dynamic write of a resource.
1692+
*
1693+
* @param value value that is being written (right side of assignment)
1694+
* @param entireRead resource value before value is written
1695+
* @param index the dynamic expression of the index.
1696+
* @return that incorporates the written value into the resource.
1697+
*/
1698+
private ExpressionNode dynamicIndexWriteValue(ExpressionNode value, ReadResourceNode entireRead,
1699+
@Nullable ExpressionNode index) {
1700+
if (index == null) {
1701+
return value;
1702+
}
1703+
1704+
ExpressionNode mask = Constant.Value.of(1, entireRead.type()).toNode();
1705+
mask = BuiltInTable.LSL.call(mask, index);
1706+
mask = BuiltInTable.NOT.call(mask);
1707+
1708+
return BuiltInTable.OR.call(
1709+
BuiltInTable.AND.call(
1710+
entireRead,
1711+
mask
1712+
),
1713+
BuiltInTable.LSL.call(new ZeroExtendNode(value, entireRead.type()), index)
1714+
);
1715+
}
1716+
16731717
/**
16741718
* Method that prepares the value so it can be written to a subset region of a resource.
16751719
* The entire resource before writing the value is given by the entireRead node.
@@ -1685,9 +1729,9 @@ public SubgraphContext visit(AssignmentStatement statement) {
16851729
* The example above has one bit-slice with two parts
16861730
* @return expression that incorporates the written value into the resource.
16871731
*/
1688-
private ExpressionNode sliceWriteValue(ExpressionNode value,
1689-
ReadResourceNode entireRead,
1690-
List<Constant.BitSlice> slices) {
1732+
private ExpressionNode staticSliceWriteValue(ExpressionNode value,
1733+
ReadResourceNode entireRead,
1734+
List<Constant.BitSlice> slices) {
16911735
if (slices.isEmpty()) {
16921736
return value;
16931737
}

vadl/main/vadl/ast/TypeChecker.java

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3106,6 +3106,8 @@ public Void visit(UnaryExpr expr) {
31063106
}
31073107

31083108
private Constant.BitSlice.Part checkSliceRange(RangeExpr range, BitsType valueType) {
3109+
check(range.from);
3110+
check(range.to);
31093111
int from = constantEvaluator.eval(range.from).value().intValueExact();
31103112
int to = constantEvaluator.eval(range.to).value().intValueExact();
31113113

@@ -3131,8 +3133,22 @@ private Constant.BitSlice.Part checkSliceRange(RangeExpr range, BitsType valueTy
31313133
return new Constant.BitSlice.Part(from, to);
31323134
}
31333135

3136+
/**
3137+
* Checks if the index is valid for the given bits type. If the index is static a bits-slice part
3138+
* is returned.
3139+
*
3140+
* @param indexExpr The index expression to check.
3141+
* @param valueType The bits type to check against.
3142+
* @return The index slice if static, null otherwise.
3143+
*/
3144+
@Nullable
31343145
private Constant.BitSlice.Part checkIndexSlice(Expr indexExpr, BitsType valueType) {
31353146
check(indexExpr);
3147+
if (!constantEvaluator.isConstant(indexExpr)) {
3148+
// The index can also be dynamic computed in which we don't assign anything.
3149+
return null;
3150+
}
3151+
31363152
int sliceIndex = constantEvaluator.eval(indexExpr).value().intValueExact();
31373153
if (sliceIndex >= valueType.bitWidth()) {
31383154
addErrorAndStopChecking(error("Invalid Index", indexExpr)
@@ -3188,23 +3204,40 @@ private void visitSliceIndexCall(CallIndexExpr expr, Type typeBeforeSlice,
31883204
parts.add(part);
31893205
}
31903206

3191-
var bitSlice = new Constant.BitSlice(parts.toArray(new Constant.BitSlice.Part[0]));
3192-
if (bitSlice.hasOverlappingParts()) {
3193-
// FIXME: Currently, we don't allow overlapping slices for both slices on read values
3194-
// and write targets.
3195-
// In the future we might want to allow overlapping slices on read values.
3196-
// For written values (`X(1, 1) := 2`) this must not be allowed, as the same value
3197-
// position is written twice.
3198-
addErrorAndStopChecking(error("Overlapping slice parts", slice.location)
3199-
.locationDescription(slice.location, "Some parts of the slice are overlapping.")
3200-
.note("Slices must have distinct, non-overlapping parts.")
3201-
.build());
3202-
}
3207+
var hasDynamicSlice = parts.stream().anyMatch(p -> p == null);
3208+
if (hasDynamicSlice) {
3209+
// FIXME: Implement this
3210+
// Dynamic slices cannot be stacked because of a VIAM constraint
3211+
if (parts.size() > 1) {
3212+
addErrorAndStopChecking(error("Invalid Slice", expr)
3213+
.description("Dynamic slices cannot be stacked.")
3214+
.build());
3215+
}
32033216

3204-
currType = Type.bits(bitSlice.bitSize());
3205-
slice.computedstaticBitSlice = bitSlice;
3206-
slice.type = currType;
3207-
expr.type = currType;
3217+
// Dynamic slices can only result in a single bit for now.
3218+
currType = Type.bits(1);
3219+
slice.type = currType;
3220+
expr.type = currType;
3221+
3222+
} else {
3223+
var bitSlice = new Constant.BitSlice(parts.toArray(new Constant.BitSlice.Part[0]));
3224+
if (bitSlice.hasOverlappingParts()) {
3225+
// FIXME: Currently, we don't allow overlapping slices for both slices on read values
3226+
// and write targets.
3227+
// In the future we might want to allow overlapping slices on read values.
3228+
// For written values (`X(1, 1) := 2`) this must not be allowed, as the same value
3229+
// position is written twice.
3230+
addErrorAndStopChecking(error("Overlapping slice parts", slice.location)
3231+
.locationDescription(slice.location, "Some parts of the slice are overlapping.")
3232+
.note("Slices must have distinct, non-overlapping parts.")
3233+
.build());
3234+
}
3235+
3236+
currType = Type.bits(bitSlice.bitSize());
3237+
slice.computedstaticBitSlice = bitSlice;
3238+
slice.type = currType;
3239+
expr.type = currType;
3240+
}
32083241
}
32093242
if (currType instanceof TensorType currTensoType) {
32103243
if (slice.values.size() != 1) {
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
instruction set architecture Tensor = {
2+
using Index = Bits<4>
3+
register X : Bits<32>
4+
format F : Bits<16> = {opcode : Bits<4>, rs2: Index, rs1: Index, rd: Index}
5+
6+
instruction ReadIndex : F = forall i: Bits<4> in 0 .. 3 do X := X(i) as Bits<32>
7+
encoding ReadIndex = {opcode = 0b1100}
8+
assembly ReadIndex = (mnemonic, " ", register(rd), ",", register(rs1), ",", register(rs2))
9+
10+
instruction WriteIndex : F = forall i: Bits<4> in 0 .. 3 do X(i) := 0
11+
encoding WriteIndex = {opcode = 0b1000}
12+
assembly WriteIndex = (mnemonic, " ", register(rd), ",", register(rs1), ",", register(rs2))
13+
}
14+
15+
16+
// Reported Diagnostics:
17+
//
18+
// No diagnostics were reported, the input was correctly parsed, typechecked and lowered.
19+
//
20+
//
21+
// Part of the class vadl.ast.DiagnosticsTest

vadl/test/resources/diagnostics/typechecker/tensorValuesTest.vadl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,8 @@ constant j = i as Dim_2_a
199199
// . : ' Identifier name: "x" type: Bits<64>
200200
// . : ' ArgsIndices
201201
// . : ' | RangeExpr type: null
202-
// . : ' | . IntegerLiteral literal: 15 (15) type: null
203-
// . : ' | . IntegerLiteral literal: 0 (0) type: null
202+
// . : ' | . IntegerLiteral literal: 15 (15) type: Const<15>
203+
// . : ' | . IntegerLiteral literal: 0 (0) type: Const<0>
204204
// ConstantDefinition name: "g"
205205
// . LetExpr type: Bits<16>
206206
// . : Identifier name: "x" type: Bits<64>
@@ -216,8 +216,8 @@ constant j = i as Dim_2_a
216216
// . : ' Identifier name: "x" type: Bits<64>
217217
// . : ' ArgsIndices
218218
// . : ' | RangeExpr type: null
219-
// . : ' | . IntegerLiteral literal: 63 (63) type: null
220-
// . : ' | . IntegerLiteral literal: 48 (48) type: null
219+
// . : ' | . IntegerLiteral literal: 63 (63) type: Const<63>
220+
// . : ' | . IntegerLiteral literal: 48 (48) type: Const<48>
221221
// ConstantDefinition name: "h"
222222
// . LetExpr type: Bits<16>
223223
// . : Identifier name: "x" type: Bits<64>
@@ -233,8 +233,8 @@ constant j = i as Dim_2_a
233233
// . : ' Identifier name: "x" type: Bits<64>
234234
// . : ' ArgsIndices
235235
// . : ' | RangeExpr type: null
236-
// . : ' | . IntegerLiteral literal: 63 (63) type: null
237-
// . : ' | . IntegerLiteral literal: 48 (48) type: null
236+
// . : ' | . IntegerLiteral literal: 63 (63) type: Const<63>
237+
// . : ' | . IntegerLiteral literal: 48 (48) type: Const<48>
238238
// ConstantDefinition name: "i"
239239
// . CastExpr type: Bits<64>
240240
// . : BinaryLiteral literal: 0xfedc'da98'7654'3210 (18364793728865153552) type: Const<18364793728865153552>

0 commit comments

Comments
 (0)