3636import java .util .Arrays ;
3737import java .util .IdentityHashMap ;
3838import java .util .List ;
39+ import java .util .concurrent .atomic .AtomicReference ;
3940import java .util .stream .Collectors ;
4041import java .util .stream .Stream ;
4142import javax .annotation .Nullable ;
@@ -490,7 +491,7 @@ Procedure getRegisterAliasWriteProc(AliasDefinition definition,
490491 // If the [overwrite source:] annotation is set, we instead either zero or sign extend the
491492 // write value to overwrite the whole source register.
492493 writeValue = switch (overwriteMode ) {
493- case null -> sliceWriteValue (writeValue ,
494+ case null -> staticSliceWriteValue (writeValue ,
494495 new ReadRegTensorNode (reg , indices , sourceRegType , null ), List .of (slice ));
495496 case "zero" -> zeroExtend (writeValue , sourceRegType );
496497 case "sign" -> signExtend (writeValue , sourceRegType );
@@ -810,21 +811,28 @@ private ExpressionNode readTensorResourceConcatinated(RegisterResource resource,
810811 private List <WriteResourceNode > writeTensorResourceSliced (RegisterResource resource ,
811812 List <ExpressionNode > indices ,
812813 ExpressionNode value ,
813- List <Constant .BitSlice > slices ) {
814+ List <Constant .BitSlice > slices ,
815+ @ Nullable ExpressionNode dynamicIndex ) {
814816 // No multiple writes and slices needed
815817 if (resource .indexTypes ().size () <= indices .size ()) {
816818 switch (resource ) {
817819 case RegisterTensor register -> {
818- var slicedValue = sliceWriteValue (value ,
819- new ReadRegTensorNode (register , new NodeList <>(indices ), register .resultType (), null ),
820- slices );
820+ var resourceRead =
821+ new ReadRegTensorNode (register , new NodeList <>(indices ), register .resultType (), null );
822+ var slicedValue = dynamicIndexWriteValue (
823+ staticSliceWriteValue (value , resourceRead , slices ),
824+ resourceRead ,
825+ dynamicIndex );
821826 return List .of (
822827 new WriteRegTensorNode (register , new NodeList <>(indices ), slicedValue , null , null ));
823828 }
824829 case ArtificialResource register -> {
825- var slicedValue = sliceWriteValue (value ,
826- new ReadArtificialResNode (register , new NodeList <>(indices ), register .resultType ()),
827- slices );
830+ var resourceRead =
831+ new ReadArtificialResNode (register , new NodeList <>(indices ), register .resultType ());
832+ var slicedValue =
833+ dynamicIndexWriteValue (staticSliceWriteValue (value , resourceRead , slices ),
834+ resourceRead ,
835+ dynamicIndex );
828836 return List .of (
829837 new WriteArtificialResNode (register , new NodeList <>(indices ), slicedValue ));
830838 }
@@ -1591,7 +1599,8 @@ public SubgraphContext visit(AssignmentStatement statement) {
15911599
15921600 vadl .ast .Definition targetDef ;
15931601 List <CallIndexExpr .Arguments > argGroups = List .of ();
1594- List <Constant .BitSlice > slices = new ArrayList <>();
1602+ List <Constant .BitSlice > staticSlices = new ArrayList <>();
1603+ AtomicReference <ExpressionNode > dynamicIndexExpr = new AtomicReference <>();
15951604
15961605 // the MEM<xyz>(...) value
15971606 @ Nullable Integer callSize = null ;
@@ -1600,12 +1609,16 @@ public SubgraphContext visit(AssignmentStatement statement) {
16001609 targetDef = (vadl .ast .Definition ) callTarget .computedTarget ();
16011610 argGroups = callTarget .args ();
16021611 callTarget .slices ().forEach (s -> {
1603- slices .add (requireNonNull (s .computedstaticBitSlice ));
1612+ if (s .computedstaticBitSlice != null ) {
1613+ staticSlices .add (requireNonNull (s .computedstaticBitSlice ));
1614+ } else {
1615+ dynamicIndexExpr .set (fetch (s .values .getFirst ()));
1616+ }
16041617 });
16051618 // add all slices that come from format field accesses
16061619 callTarget .subCalls .forEach (s -> {
16071620 if (s .computedBitSlice != null ) {
1608- slices .add (s .computedBitSlice );
1621+ staticSlices .add (s .computedBitSlice );
16091622 }
16101623 });
16111624
@@ -1623,36 +1636,41 @@ public SubgraphContext visit(AssignmentStatement statement) {
16231636 .collect (Collectors .toCollection (NodeList ::new ));
16241637 var viamTargetDef = viamLowering .fetch (targetDef ).orElseThrow ();
16251638
1639+ var dynamicIndex = dynamicIndexExpr .get ();
1640+
16261641 // No need to call getViamType here as the viam definitions should already have that.
16271642 var writeNodes = switch (viamTargetDef ) {
16281643 case RegisterTensor regDef -> writeTensorResourceSliced (
1629- regDef , argExprs , value , slices
1644+ regDef , argExprs , value , staticSlices , dynamicIndex
16301645 );
16311646
16321647 case ArtificialResource aliasDef -> writeTensorResourceSliced (
1633- aliasDef , argExprs , value , slices
1648+ aliasDef , argExprs , value , staticSlices , dynamicIndex
16341649 );
16351650
16361651 case Memory memDef -> {
16371652 var words = callSize != null ? callSize : 1 ;
16381653 // slice the written value before writing it
1639- var slicedValue = sliceWriteValue (value ,
1640- new ReadMemNode (memDef , words , argExprs .getFirst (),
1641- ((BitsType ) memDef .resultType ()).scaleBy (words )), slices );
1654+ var resourceRead = new ReadMemNode (memDef , words , argExprs .getFirst (),
1655+ ((BitsType ) memDef .resultType ()).scaleBy (words ));
1656+ var slicedValue = dynamicIndexWriteValue (staticSliceWriteValue (value ,
1657+ resourceRead , staticSlices ), resourceRead , dynamicIndex );
16421658 yield List .of ((vadl .viam .graph .Node ) new WriteMemNode (
16431659 memDef , callSize != null ? callSize : 1 ,
16441660 argExprs .getFirst (), slicedValue
16451661 ));
16461662 }
16471663
16481664 // FIXME: Adjust value based on counter position
1649- case Counter counterDef ->
1650- List .of (new WriteRegTensorNode (counterDef .registerTensor (), argExprs ,
1651- // slice the written value before writing it
1652- sliceWriteValue (value ,
1653- new ReadRegTensorNode (counterDef .registerTensor (), argExprs ,
1654- counterDef .registerTensor ().resultType (), null ), slices ),
1655- null , null ));
1665+ case Counter counterDef -> {
1666+ var resourceRead = new ReadRegTensorNode (counterDef .registerTensor (), argExprs ,
1667+ counterDef .registerTensor ().resultType (), null );
1668+ yield List .of (new WriteRegTensorNode (counterDef .registerTensor (), argExprs ,
1669+ // slice the written value before writing it
1670+ dynamicIndexWriteValue (staticSliceWriteValue (value ,
1671+ resourceRead , staticSlices ), resourceRead , dynamicIndex ),
1672+ null , null ));
1673+ }
16561674
16571675 case StageOutput output -> List .of (
16581676 new WriteStageOutputNode (output , value )
@@ -1661,7 +1679,6 @@ public SubgraphContext visit(AssignmentStatement statement) {
16611679 default -> throw new IllegalStateException ("Unexpected target: " + viamTargetDef );
16621680 };
16631681
1664-
16651682 for (var writeNode : writeNodes ) {
16661683 writeNode .setSourceLocationIfNotSet (statement .target .location ());
16671684 }
@@ -1670,6 +1687,33 @@ public SubgraphContext visit(AssignmentStatement statement) {
16701687 writeNodes .stream ().map (n -> (vadl .viam .graph .Node ) n ).toList ());
16711688 }
16721689
1690+ /**
1691+ * Method that prepares the value so that it can be used for a dynamic write of a resource.
1692+ *
1693+ * @param value value that is being written (right side of assignment)
1694+ * @param entireRead resource value before value is written
1695+ * @param index the dynamic expression of the index.
1696+ * @return that incorporates the written value into the resource.
1697+ */
1698+ private ExpressionNode dynamicIndexWriteValue (ExpressionNode value , ReadResourceNode entireRead ,
1699+ @ Nullable ExpressionNode index ) {
1700+ if (index == null ) {
1701+ return value ;
1702+ }
1703+
1704+ ExpressionNode mask = Constant .Value .of (1 , entireRead .type ()).toNode ();
1705+ mask = BuiltInTable .LSL .call (mask , index );
1706+ mask = BuiltInTable .NOT .call (mask );
1707+
1708+ return BuiltInTable .OR .call (
1709+ BuiltInTable .AND .call (
1710+ entireRead ,
1711+ mask
1712+ ),
1713+ BuiltInTable .LSL .call (new ZeroExtendNode (value , entireRead .type ()), index )
1714+ );
1715+ }
1716+
16731717 /**
16741718 * Method that prepares the value so it can be written to a subset region of a resource.
16751719 * The entire resource before writing the value is given by the entireRead node.
@@ -1685,9 +1729,9 @@ public SubgraphContext visit(AssignmentStatement statement) {
16851729 * The example above has one bit-slice with two parts
16861730 * @return expression that incorporates the written value into the resource.
16871731 */
1688- private ExpressionNode sliceWriteValue (ExpressionNode value ,
1689- ReadResourceNode entireRead ,
1690- List <Constant .BitSlice > slices ) {
1732+ private ExpressionNode staticSliceWriteValue (ExpressionNode value ,
1733+ ReadResourceNode entireRead ,
1734+ List <Constant .BitSlice > slices ) {
16911735 if (slices .isEmpty ()) {
16921736 return value ;
16931737 }
0 commit comments