Skip to content

Commit ab246db

Browse files
patrick-rivosrbertran
authored andcommitted
Add preliminary RISC-V vector support (Assembly only)
Signed-off-by: Patrick O'Neill <patrick@rivosinc.com>
1 parent b3164c6 commit ab246db

File tree

11 files changed

+2853
-39
lines changed

11 files changed

+2853
-39
lines changed

src/microprobe/code/ins.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# Built-in modules
2222
import copy
2323
from itertools import product
24-
from typing import TYPE_CHECKING, Callable, List
24+
from typing import TYPE_CHECKING, Callable, Dict, List
2525

2626
# Third party modules
2727
import six
@@ -1604,7 +1604,8 @@ def __init__(self):
16041604
self._generic_type = None
16051605
self._label = None
16061606
self._mem_operands = []
1607-
self._operands = RejectingOrderedDict()
1607+
self._operands: Dict[str,
1608+
InstructionOperandValue] = RejectingOrderedDict()
16081609

16091610
def set_arch_type(self, instrtype):
16101611
"""
@@ -1613,7 +1614,8 @@ def set_arch_type(self, instrtype):
16131614
16141615
"""
16151616
self._arch_type = instrtype
1616-
self._operands = RejectingOrderedDict()
1617+
self._operands: Dict[str,
1618+
InstructionOperandValue] = RejectingOrderedDict()
16171619
self._mem_operands = []
16181620
self._allowed_regs = []
16191621
self._address = None

src/microprobe/passes/initialization/__init__.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
"""
1717

1818
# Futures
19-
from __future__ import absolute_import, print_function
19+
from __future__ import absolute_import, print_function, annotations
2020

2121
# Built-in modules
22+
from typing import TYPE_CHECKING
2223

2324
# Third party modules
2425
from six.moves import zip
@@ -36,6 +37,10 @@
3637

3738
# Local modules
3839

40+
# Type hinting
41+
if TYPE_CHECKING:
42+
from microprobe.code.benchmark import Benchmark
43+
from microprobe.target import Target
3944

4045
# Constants
4146
LOG = get_logger(__name__)
@@ -222,6 +227,7 @@ def __init__(self, *args, **kwargs):
222227
skip_unknown = kwargs.get("skip_unknown", False)
223228
warn_unknown = kwargs.get("warn_unknown", False)
224229
self._force_code = kwargs.get("force_code", False)
230+
self.lmul = kwargs.get("lmul", 1)
225231

226232
if len(args) == 1:
227233
self._reg_dict = dict([
@@ -250,7 +256,7 @@ def __init__(self, *args, **kwargs):
250256
self._fp_value,
251257
v_value)
252258

253-
def __call__(self, building_block, target):
259+
def __call__(self, building_block: Benchmark, target: Target):
254260
"""
255261
256262
:param building_block:
@@ -259,26 +265,26 @@ def __call__(self, building_block, target):
259265
"""
260266
if not self._skip_unknown:
261267
for register_name in self._reg_dict:
262-
if register_name not in list(target.registers.keys()):
268+
if register_name not in list(target.isa.registers.keys()):
263269
raise MicroprobeCodeGenerationError(
264270
"Unknown register name: '%s'. Unable to set it" %
265271
register_name)
266272

267273
if self._warn_unknown:
268274
for register_name in self._reg_dict:
269-
if register_name not in list(target.registers.keys()):
275+
if register_name not in list(target.isa.registers.keys()):
270276
print_warning(
271277
"Unknown register name: '%s'. Unable to set it" %
272278
register_name)
273279

274-
regs = sorted(target.registers.values(),
280+
regs = sorted(target.isa.registers.values(),
275281
key=lambda x: self._priolist.index(x.name)
276282
if x.name in self._priolist else 314159)
277283

278284
#
279285
# Make sure scratch registers are set last
280286
#
281-
for reg in target.scratch_registers:
287+
for reg in target.isa.scratch_registers:
282288
if reg in regs:
283289
regs.remove(reg)
284290
regs.append(reg)
@@ -294,25 +300,39 @@ def __call__(self, building_block, target):
294300
self._reg_dict.pop(reg.name)
295301
force_direct = True
296302

297-
if (reg in building_block.context.reserved_registers and
298-
not self._force_reserved):
303+
if reg.name == "LMUL":
304+
building_block.add_init(
305+
target.isa.set_register(reg, self.lmul,
306+
building_block.context))
307+
building_block.context.set_register_value(reg, self.lmul)
308+
continue
309+
310+
all_vec_regs = set([f"V{i}" for i in range(0, 32)])
311+
lmul_allowed_regs = set([f"V{i}" for i in range(0, 32, self.lmul)])
312+
313+
if reg.name in all_vec_regs - lmul_allowed_regs:
314+
# Skip vector registers ignored by lmul
315+
continue
316+
317+
if (reg in building_block.context.reserved_registers
318+
and not self._force_reserved):
299319
LOG.debug("Skip reserved - %s", reg)
300320
continue
301-
elif (reg in target.control_registers and
302-
(value is None or self._skip_control)):
321+
elif (reg in target.isa.control_registers
322+
and (value is None or self._skip_control)):
303323
LOG.debug("Skip control - %s", reg)
304324
continue
305325

306326
if value is None:
307-
if reg.used_for_vector_arithmetic:
327+
if reg.type.used_for_vector_arithmetic:
308328
if self._vect_value is not None:
309329
value = self._vect_value
310330
elemsize = self._vect_elemsize
311331
else:
312332
LOG.debug("Skip no vector default value provided - %s",
313333
reg)
314334
continue
315-
elif reg.used_for_float_arithmetic:
335+
elif reg.type.used_for_float_arithmetic:
316336
if self._fp_value is not None:
317337
value = self._fp_value
318338
else:
@@ -332,10 +352,10 @@ def __call__(self, building_block, target):
332352
if isinstance(value, int):
333353
value = value & ((2**reg.size)-1)
334354

335-
if reg.used_for_float_arithmetic:
355+
if reg.type.used_for_float_arithmetic:
336356
value = ieee_float_to_int64(float(value))
337357

338-
elif reg.used_for_vector_arithmetic:
358+
elif reg.type.used_for_vector_arithmetic:
339359
if isinstance(value, float):
340360
if elemsize != 64:
341361
raise MicroprobeCodeGenerationError(
@@ -360,13 +380,13 @@ def __call__(self, building_block, target):
360380
else:
361381
LOG.debug("Direct set of '%s' to '0x%x'", reg, value)
362382
except MicroprobeCodeGenerationError:
363-
building_block.add_init(target.set_register(
383+
building_block.add_init(target.isa.set_register(
364384
reg, value, building_block.context))
365385
LOG.debug("Set '%s' to '0x%x'", reg, value)
366386
except MicroprobeDuplicatedValueError:
367387
LOG.debug("Skip already set - %s", reg)
368388
else:
369-
building_block.add_init(target.set_register(
389+
building_block.add_init(target.isa.set_register(
370390
reg, value, building_block.context))
371391
building_block.context.set_register_value(reg, value)
372392

src/microprobe/target/isa/instruction.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1777,6 +1777,13 @@ def assembly(self, args, dissabled_fields=None):
17771777
"," + field.name + ")",
17781778
"," + next_operand_value().representation + ")", 1)
17791779

1780+
elif assembly_str.find(" " + field.name + ".t") >= 0:
1781+
assembly_str = assembly_str.replace(
1782+
", " + field.name + ".t",
1783+
", " + next_operand_value().representation + ".t",
1784+
1,
1785+
)
1786+
17801787
else:
17811788
LOG.debug(
17821789
"%s",

src/microprobe/target/isa/operand.py

Lines changed: 97 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616
"""
1717

1818
# Futures
19-
from __future__ import absolute_import, print_function
19+
from __future__ import absolute_import, print_function, annotations
2020

2121
# Built-in modules
2222
import abc
2323
import os
2424
import random
25+
from typing import Dict, List, TYPE_CHECKING, cast
2526

2627
# Third party modules
2728
import six
@@ -39,6 +40,10 @@
3940
from microprobe.utils.typeguard_decorator import typeguard_testsuite
4041
from microprobe.utils.yaml import read_yaml
4142

43+
# Type hinting
44+
if TYPE_CHECKING:
45+
from microprobe.code.context import Context
46+
4247
# Constants
4348
SCHEMA = os.path.join(os.path.dirname(os.path.abspath(__file__)), "schemas",
4449
"operand.yaml")
@@ -285,7 +290,7 @@ class OperandDescriptor:
285290
286291
"""
287292

288-
def __init__(self, mtype, is_input, is_output):
293+
def __init__(self, mtype: Operand, is_input, is_output):
289294
"""
290295
291296
:param mtype:
@@ -312,7 +317,7 @@ def is_output(self):
312317
"""Is output flag (:class:`~.bool`) """
313318
return self._is_output
314319

315-
def set_type(self, new_type):
320+
def set_type(self, new_type: Operand):
316321
"""
317322
318323
:param new_type:
@@ -616,7 +621,14 @@ def copy(self):
616621
raise NotImplementedError
617622

618623
@abc.abstractmethod
619-
def values(self):
624+
def values(self) -> List[Register]:
625+
"""Return the possible value of the operand."""
626+
raise NotImplementedError
627+
628+
# TODO: Consider making filtered_values into values.
629+
def filtered_values(
630+
self, context: Context, fieldname: str
631+
) -> List[Register]:
620632
"""Return the possible value of the operand."""
621633
raise NotImplementedError
622634

@@ -767,8 +779,14 @@ class OperandReg(Operand):
767779
768780
"""
769781

770-
def __init__(self, name, descr, regs, address_base, address_index,
771-
floating_point, vector):
782+
def __init__(self,
783+
name: str,
784+
descr: str,
785+
regs: List[Register] | Dict[Register, List[Register]],
786+
address_base,
787+
address_index: int,
788+
floating_point: bool | None,
789+
vector: bool | None):
772790
"""
773791
774792
:param name:
@@ -783,7 +801,7 @@ def __init__(self, name, descr, regs, address_base, address_index,
783801
super(OperandReg, self).__init__(name, descr)
784802

785803
if isinstance(regs, list):
786-
self._regs = OrderedDict()
804+
self._regs: Dict[Register, List[Register]] = OrderedDict()
787805
for reg in regs:
788806
self._regs[reg] = [reg]
789807
else:
@@ -809,6 +827,53 @@ def values(self):
809827
"""
810828
return list(self._regs.keys())
811829

830+
def filtered_values(self, context: Context, fieldname: str):
831+
lmul = cast(int | None, context.get_registername_value("LMUL"))
832+
833+
if lmul is None or not fieldname.startswith("v"):
834+
return self.values()
835+
elif fieldname in ["vd", "vmd", "vrs1", "vrs2", "vmask"]:
836+
lmul *= 1
837+
elif fieldname in ["vdd", "vdmd", "vdrs1", "vdrs2", "vnd", "vnmd"]:
838+
lmul *= 2
839+
elif fieldname in []:
840+
lmul *= 4
841+
elif fieldname in []:
842+
lmul *= 8
843+
else:
844+
raise ValueError(f"Unhandled LMUL operand name: {fieldname}")
845+
846+
regs = list(self._regs.keys())
847+
848+
class LMULRegs:
849+
lmul1 = regs
850+
lmul2 = [
851+
reg
852+
for reg in self._regs.keys()
853+
if reg.name in set([f"V{i}" for i in range(0, 32, 2)])
854+
]
855+
lmul4 = [
856+
reg
857+
for reg in self._regs.keys()
858+
if reg.name in set([f"V{i}" for i in range(0, 32, 4)])
859+
]
860+
lmul8 = [
861+
reg
862+
for reg in self._regs.keys()
863+
if reg.name in set([f"V{i}" for i in range(0, 32, 8)])
864+
]
865+
866+
if lmul == 1:
867+
return LMULRegs.lmul1
868+
elif lmul == 2:
869+
return LMULRegs.lmul2
870+
elif lmul == 4:
871+
return LMULRegs.lmul4
872+
elif lmul == 8:
873+
return LMULRegs.lmul8
874+
else:
875+
raise ValueError(f"Unhandled LMUL value: {lmul}")
876+
812877
def representation(self, value):
813878
"""
814879
@@ -927,6 +992,11 @@ def values(self):
927992
]
928993
return self._computed_values
929994

995+
def filtered_values(
996+
self, context: Context, fieldname: str
997+
):
998+
return super().filtered_values(context, fieldname)
999+
9301000
def set_valid_values(self, values):
9311001
"""
9321002
@@ -1083,6 +1153,11 @@ def values(self):
10831153
"""
10841154
return self._values
10851155

1156+
def filtered_values(
1157+
self, context: Context, fieldname: str
1158+
):
1159+
return super().filtered_values(context, fieldname)
1160+
10861161
def representation(self, value):
10871162
"""
10881163
@@ -1177,6 +1252,11 @@ def values(self):
11771252
"""
11781253
return [self._value]
11791254

1255+
def filtered_values(
1256+
self, context: Context, fieldname: str
1257+
):
1258+
return super().filtered_values(context, fieldname)
1259+
11801260
def representation(self, value):
11811261
"""
11821262
@@ -1285,6 +1365,11 @@ def values(self):
12851365
"""
12861366
return [self._reg]
12871367

1368+
def filtered_values(
1369+
self, context: Context, fieldname: str
1370+
):
1371+
return super().filtered_values(context, fieldname)
1372+
12881373
def random_value(self):
12891374
"""Return a random possible value for the operand.
12901375
@@ -1393,6 +1478,11 @@ def values(self):
13931478
"""
13941479
return [self._mindispl << self._shift]
13951480

1481+
def filtered_values(
1482+
self, context: Context, fieldname: str
1483+
):
1484+
return super().filtered_values(context, fieldname)
1485+
13961486
def random_value(self):
13971487
"""Return a random possible value for the operand.
13981488

0 commit comments

Comments
 (0)