Skip to content

Commit 0079ff7

Browse files
authored
[REFACTOR][TEST] Replace CompareBeforeAfter for pytest compact (#18711)
This PR refactors test infrastructure by removing the CompareBeforeAfter base class from tvm.testing and converting all dependent tests to use a simpler, more explicit pattern. We need this change as latest pytest do not allow calling fixture as inner patterns which the previous CompareBeforeAfter depend on.
1 parent efdf2b2 commit 0079ff7

File tree

45 files changed

+5562
-4712
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+5562
-4712
lines changed

python/tvm/testing/utils.py

Lines changed: 0 additions & 246 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ def test_something():
7373
import pickle
7474
import platform
7575
import sys
76-
import textwrap
7776
import time
7877
import shutil
7978

@@ -1939,248 +1938,3 @@ def strtobool(val):
19391938
def main():
19401939
test_file = inspect.getsourcefile(sys._getframe(1))
19411940
sys.exit(pytest.main([test_file] + sys.argv[1:]))
1942-
1943-
1944-
class CompareBeforeAfter:
1945-
"""Utility for comparing before/after of TIR transforms
1946-
1947-
A standard framework for writing tests that take a TIR PrimFunc as
1948-
input, apply a transformation, then either compare against an
1949-
expected output or assert that the transformation raised an error.
1950-
A test should subclass CompareBeforeAfter, defining class members
1951-
`before` / `Before`, `transform`, and `expected` / `Expected`. CompareBeforeAfter will
1952-
then use these members to define a test method and test fixture.
1953-
1954-
`transform` may be one of the following.
1955-
1956-
- An instance of `tvm.ir.transform.Pass`
1957-
1958-
- A method that takes no arguments and returns a `tvm.ir.transform.Pass`
1959-
1960-
- A pytest fixture that returns a `tvm.ir.transform.Pass`
1961-
1962-
`before` / `Before` may be any one of the following.
1963-
1964-
- An instance of `tvm.tir.PrimFunc`. This is allowed, but is not
1965-
the preferred method, as any errors in constructing the
1966-
`PrimFunc` occur while collecting the test, preventing any other
1967-
tests in the same file from being run.
1968-
1969-
- An TVMScript function, without the ``@T.prim_func`` decoration.
1970-
The ``@T.prim_func`` decoration will be applied when running the
1971-
test, rather than at module import.
1972-
1973-
- A method that takes no arguments and returns a `tvm.tir.PrimFunc`
1974-
1975-
- A pytest fixture that returns a `tvm.tir.PrimFunc`
1976-
1977-
`expected` / `Expected` may be any one of the following. The type of
1978-
`expected` / `Expected` defines the test being performed. If `expected`
1979-
provides a `tvm.tir.PrimFunc`, the result of the transformation
1980-
must match `expected`. If `expected` is an exception, then the
1981-
transformation must raise that exception type.
1982-
1983-
- Any option supported for `before` / `Before`.
1984-
1985-
- The `Exception` class object, or a class object that inherits
1986-
from `Exception`.
1987-
1988-
- A method that takes no arguments and returns `Exception` or a
1989-
class object that inherits from `Exception`.
1990-
1991-
- A pytest fixture that returns `Exception` or an class object
1992-
that inherits from `Exception`.
1993-
1994-
Examples
1995-
--------
1996-
1997-
.. python::
1998-
1999-
class TestRemoveIf(tvm.testing.CompareBeforeAfter):
2000-
transform = tvm.tir.transform.Simplify()
2001-
2002-
def before(A: T.Buffer(1, "int32")):
2003-
if True:
2004-
A[0] = 42
2005-
else:
2006-
A[0] = 5
2007-
2008-
def expected(A: T.Buffer(1, "int32")):
2009-
A[0] = 42
2010-
2011-
"""
2012-
2013-
check_well_formed: bool = True
2014-
2015-
def __init_subclass__(cls):
2016-
assert len([getattr(cls, name) for name in ["before", "Before"] if hasattr(cls, name)]) <= 1
2017-
assert (
2018-
len([getattr(cls, name) for name in ["expected", "Expected"] if hasattr(cls, name)])
2019-
<= 1
2020-
)
2021-
for name in ["before", "Before"]:
2022-
if hasattr(cls, name):
2023-
cls.before = cls._normalize_before(getattr(cls, name))
2024-
break
2025-
for name in ["expected", "Expected"]:
2026-
if hasattr(cls, name):
2027-
cls.expected = cls._normalize_expected(getattr(cls, name))
2028-
break
2029-
if hasattr(cls, "transform"):
2030-
cls.transform = cls._normalize_transform(cls.transform)
2031-
2032-
@classmethod
2033-
def _normalize_ir_module(cls, func):
2034-
if isinstance(func, (tvm.tir.PrimFunc, tvm.IRModule)):
2035-
2036-
def inner(self):
2037-
# pylint: disable=unused-argument
2038-
return func
2039-
2040-
elif cls._is_method(func):
2041-
2042-
def inner(self):
2043-
# pylint: disable=unused-argument
2044-
return func(self)
2045-
2046-
elif inspect.isclass(func):
2047-
2048-
def inner(self):
2049-
# pylint: disable=unused-argument
2050-
func_dict = {}
2051-
for name, method in func.__dict__.items():
2052-
if name.startswith("_"):
2053-
pass
2054-
elif isinstance(method, tvm.ir.function.BaseFunc):
2055-
func_dict[name] = method.with_attr("global_symbol", name)
2056-
else:
2057-
source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(method))
2058-
prim_func = tvm.script.from_source(
2059-
source_code, check_well_formed=self.check_well_formed
2060-
)
2061-
func_dict[name] = prim_func.with_attr("global_symbol", name)
2062-
return tvm.IRModule(func_dict)
2063-
2064-
else:
2065-
2066-
def inner(self):
2067-
# pylint: disable=unused-argument
2068-
source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(func))
2069-
return tvm.script.from_source(source_code, check_well_formed=self.check_well_formed)
2070-
2071-
return pytest.fixture(inner)
2072-
2073-
@classmethod
2074-
def _normalize_before(cls, func):
2075-
if hasattr(func, "_pytestfixturefunction"):
2076-
return func
2077-
else:
2078-
return cls._normalize_ir_module(func)
2079-
2080-
@classmethod
2081-
def _normalize_expected(cls, func):
2082-
if hasattr(func, "_pytestfixturefunction"):
2083-
return func
2084-
2085-
elif inspect.isclass(func) and issubclass(func, Exception):
2086-
2087-
def inner(self):
2088-
# pylint: disable=unused-argument
2089-
return func
2090-
2091-
return pytest.fixture(inner)
2092-
2093-
else:
2094-
return cls._normalize_ir_module(func)
2095-
2096-
@classmethod
2097-
def _normalize_transform(cls, transform):
2098-
def apply(module_transform):
2099-
def inner(obj):
2100-
if isinstance(obj, tvm.IRModule):
2101-
return module_transform(obj)
2102-
elif isinstance(obj, tvm.tir.PrimFunc):
2103-
mod = tvm.IRModule({"main": obj})
2104-
mod = module_transform(mod)
2105-
return mod["main"]
2106-
else:
2107-
raise TypeError(f"Expected IRModule or PrimFunc, but received {type(obj)}")
2108-
2109-
return inner
2110-
2111-
if hasattr(transform, "_pytestfixturefunction"):
2112-
if not hasattr(cls, "_transform_orig"):
2113-
cls._transform_orig = transform
2114-
2115-
def inner(self, _transform_orig):
2116-
# pylint: disable=unused-argument
2117-
return apply(_transform_orig)
2118-
2119-
elif isinstance(transform, tvm.ir.transform.Pass):
2120-
2121-
def inner(self):
2122-
# pylint: disable=unused-argument
2123-
return apply(transform)
2124-
2125-
elif cls._is_method(transform):
2126-
2127-
def inner(self):
2128-
# pylint: disable=unused-argument
2129-
return apply(transform(self))
2130-
2131-
else:
2132-
raise TypeError(
2133-
"Expected transform to be a tvm.ir.transform.Pass, or a method returning a Pass"
2134-
)
2135-
2136-
return pytest.fixture(inner)
2137-
2138-
@staticmethod
2139-
def _is_method(func):
2140-
return callable(func) and "self" in inspect.signature(func).parameters
2141-
2142-
def test_compare(self, before, expected, transform):
2143-
"""Unit test to compare the expected TIR PrimFunc to actual"""
2144-
2145-
if inspect.isclass(expected) and issubclass(expected, Exception):
2146-
with pytest.raises(expected):
2147-
after = transform(before)
2148-
2149-
# This portion through pytest.fail isn't strictly
2150-
# necessary, but gives a better error message that
2151-
# includes the before/after.
2152-
before_str = before.script(name="before")
2153-
after_str = after.script(name="after")
2154-
2155-
pytest.fail(
2156-
msg=(
2157-
f"Expected {expected.__name__} to be raised from transformation, "
2158-
f"instead received TIR\n:{before_str}\n{after_str}"
2159-
)
2160-
)
2161-
2162-
elif isinstance(expected, (tvm.tir.PrimFunc, tvm.ir.IRModule)):
2163-
after = transform(before)
2164-
2165-
try:
2166-
# overwrite global symbol so it doesn't come up in the comparison
2167-
if isinstance(after, tvm.tir.PrimFunc):
2168-
after = after.with_attr("global_symbol", "main")
2169-
expected = expected.with_attr("global_symbol", "main")
2170-
tvm.ir.assert_structural_equal(after, expected)
2171-
except ValueError as err:
2172-
before_str = before.script(name="before")
2173-
after_str = after.script(name="after")
2174-
expected_str = expected.script(name="expected")
2175-
raise ValueError(
2176-
f"TIR after transformation did not match expected:\n"
2177-
f"{before_str}\n{after_str}\n{expected_str}"
2178-
) from err
2179-
2180-
else:
2181-
raise TypeError(
2182-
f"tvm.testing.CompareBeforeAfter requires the `expected` fixture "
2183-
f"to return either `Exception`, an `Exception` subclass, "
2184-
f"or an instance of `tvm.tir.PrimFunc`. "
2185-
f"Instead, received {type(expected)}."
2186-
)

tests/python/dlight/test_cpu_gemv.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,9 @@
2323
from tvm.target import Target
2424

2525

26-
class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
27-
@pytest.fixture
28-
def transform(self):
29-
def transform(mod):
30-
with Target("llvm"):
31-
return dl.ApplyDefaultSchedule(dl.cpu.GEMV())(mod)
32-
33-
return transform
34-
35-
36-
class TestGEMV(BaseBeforeAfter):
26+
def test_gemv_basic():
3727
# fmt: off
38-
39-
@T.prim_func
28+
@T.prim_func(private=True)
4029
def before(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p_lv1614: T.handle, p_output0: T.handle):
4130
T.func_attr({"tir.noalias": True})
4231
n = T.int32()
@@ -81,7 +70,7 @@ def before(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p_l
8170
T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3])
8271
var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3])
8372

84-
@T.prim_func
73+
@T.prim_func(private=True)
8574
def expected(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p_lv1614: T.handle, p_output0: T.handle):
8675
T.func_attr({"tir.is_scheduled": True, "tir.noalias": True})
8776
n = T.int32()
@@ -114,6 +103,11 @@ def expected(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p
114103

115104
# fmt: on
116105

106+
mod = tvm.IRModule({"main": before})
107+
with Target("llvm"):
108+
mod = dl.ApplyDefaultSchedule(dl.cpu.GEMV())(mod)
109+
tvm.ir.assert_structural_equal(mod["main"], expected)
110+
117111

118112
def test_decode_gemv_256_threads():
119113
# fmt: off

tests/python/dlight/test_gpu_conv.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,16 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
# pylint: disable=missing-docstring
18-
import pytest
19-
18+
import tvm
2019
import tvm.testing
2120
from tvm import dlight as dl
2221
from tvm.script import tir as T
2322
from tvm.target import Target
2423

2524

26-
class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
27-
@pytest.fixture
28-
def transform(self):
29-
def transform(mod):
30-
with Target("nvidia/geforce-gtx-1080-ti"):
31-
# Use Matmul rule for Conv for now
32-
return dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod)
33-
34-
return transform
35-
36-
37-
class TestConv3d(BaseBeforeAfter):
25+
def test_conv3d():
3826
# fmt: off
39-
@T.prim_func
27+
@T.prim_func(private=True)
4028
def before(
4129
A: T.Buffer((14308, 3, 2, 14, 14), "float16"),
4230
W: T.Buffer((1280, 3, 2, 14, 14), "float16"),
@@ -54,7 +42,7 @@ def before(
5442
C[v_nn, v_ff, v_yy, v_xx, v_zz] = T.float16(0.0)
5543
C[v_nn, v_ff, v_yy, v_xx, v_zz] += pad_A[v_nn, v_rc, v_yy * 2 + v_ry, v_xx * 14 + v_rx, v_zz * 14 + v_rz]* W[v_ff, v_rc, v_ry, v_rx, v_rz]
5644

57-
@T.prim_func
45+
@T.prim_func(private=True)
5846
def expected(A: T.Buffer((14308, 3, 2, 14, 14), "float16"), W: T.Buffer((1280, 3, 2, 14, 14), "float16"), C: T.Buffer((14308, 1280, 1, 1, 1), "float16")):
5947
T.func_attr({"tir.is_scheduled": True})
6048
# with T.sblock("root"):
@@ -113,6 +101,11 @@ def expected(A: T.Buffer((14308, 3, 2, 14, 14), "float16"), W: T.Buffer((1280, 3
113101
C[v1, v2, 0, 0, 0] = C_reindex_pad_local[v0, v1, v2]
114102
# fmt: on
115103

104+
mod = tvm.IRModule({"main": before})
105+
with Target("nvidia/geforce-gtx-1080-ti"):
106+
mod = dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod)
107+
tvm.ir.assert_structural_equal(mod["main"], expected)
108+
116109

117110
if __name__ == "__main__":
118111
tvm.testing.main()

0 commit comments

Comments
 (0)