@@ -73,7 +73,6 @@ def test_something():
7373import pickle
7474import platform
7575import sys
76- import textwrap
7776import time
7877import shutil
7978
@@ -1939,248 +1938,3 @@ def strtobool(val):
19391938def main ():
19401939 test_file = inspect .getsourcefile (sys ._getframe (1 ))
19411940 sys .exit (pytest .main ([test_file ] + sys .argv [1 :]))
1942-
1943-
1944- class CompareBeforeAfter :
1945- """Utility for comparing before/after of TIR transforms
1946-
1947- A standard framework for writing tests that take a TIR PrimFunc as
1948- input, apply a transformation, then either compare against an
1949- expected output or assert that the transformation raised an error.
1950- A test should subclass CompareBeforeAfter, defining class members
1951- `before` / `Before`, `transform`, and `expected` / `Expected`. CompareBeforeAfter will
1952- then use these members to define a test method and test fixture.
1953-
1954- `transform` may be one of the following.
1955-
1956- - An instance of `tvm.ir.transform.Pass`
1957-
1958- - A method that takes no arguments and returns a `tvm.ir.transform.Pass`
1959-
1960- - A pytest fixture that returns a `tvm.ir.transform.Pass`
1961-
1962- `before` / `Before` may be any one of the following.
1963-
1964- - An instance of `tvm.tir.PrimFunc`. This is allowed, but is not
1965- the preferred method, as any errors in constructing the
1966- `PrimFunc` occur while collecting the test, preventing any other
1967- tests in the same file from being run.
1968-
1969- - An TVMScript function, without the ``@T.prim_func`` decoration.
1970- The ``@T.prim_func`` decoration will be applied when running the
1971- test, rather than at module import.
1972-
1973- - A method that takes no arguments and returns a `tvm.tir.PrimFunc`
1974-
1975- - A pytest fixture that returns a `tvm.tir.PrimFunc`
1976-
1977- `expected` / `Expected` may be any one of the following. The type of
1978- `expected` / `Expected` defines the test being performed. If `expected`
1979- provides a `tvm.tir.PrimFunc`, the result of the transformation
1980- must match `expected`. If `expected` is an exception, then the
1981- transformation must raise that exception type.
1982-
1983- - Any option supported for `before` / `Before`.
1984-
1985- - The `Exception` class object, or a class object that inherits
1986- from `Exception`.
1987-
1988- - A method that takes no arguments and returns `Exception` or a
1989- class object that inherits from `Exception`.
1990-
1991- - A pytest fixture that returns `Exception` or an class object
1992- that inherits from `Exception`.
1993-
1994- Examples
1995- --------
1996-
1997- .. python::
1998-
1999- class TestRemoveIf(tvm.testing.CompareBeforeAfter):
2000- transform = tvm.tir.transform.Simplify()
2001-
2002- def before(A: T.Buffer(1, "int32")):
2003- if True:
2004- A[0] = 42
2005- else:
2006- A[0] = 5
2007-
2008- def expected(A: T.Buffer(1, "int32")):
2009- A[0] = 42
2010-
2011- """
2012-
2013- check_well_formed : bool = True
2014-
2015- def __init_subclass__ (cls ):
2016- assert len ([getattr (cls , name ) for name in ["before" , "Before" ] if hasattr (cls , name )]) <= 1
2017- assert (
2018- len ([getattr (cls , name ) for name in ["expected" , "Expected" ] if hasattr (cls , name )])
2019- <= 1
2020- )
2021- for name in ["before" , "Before" ]:
2022- if hasattr (cls , name ):
2023- cls .before = cls ._normalize_before (getattr (cls , name ))
2024- break
2025- for name in ["expected" , "Expected" ]:
2026- if hasattr (cls , name ):
2027- cls .expected = cls ._normalize_expected (getattr (cls , name ))
2028- break
2029- if hasattr (cls , "transform" ):
2030- cls .transform = cls ._normalize_transform (cls .transform )
2031-
2032- @classmethod
2033- def _normalize_ir_module (cls , func ):
2034- if isinstance (func , (tvm .tir .PrimFunc , tvm .IRModule )):
2035-
2036- def inner (self ):
2037- # pylint: disable=unused-argument
2038- return func
2039-
2040- elif cls ._is_method (func ):
2041-
2042- def inner (self ):
2043- # pylint: disable=unused-argument
2044- return func (self )
2045-
2046- elif inspect .isclass (func ):
2047-
2048- def inner (self ):
2049- # pylint: disable=unused-argument
2050- func_dict = {}
2051- for name , method in func .__dict__ .items ():
2052- if name .startswith ("_" ):
2053- pass
2054- elif isinstance (method , tvm .ir .function .BaseFunc ):
2055- func_dict [name ] = method .with_attr ("global_symbol" , name )
2056- else :
2057- source_code = "@T.prim_func\n " + textwrap .dedent (inspect .getsource (method ))
2058- prim_func = tvm .script .from_source (
2059- source_code , check_well_formed = self .check_well_formed
2060- )
2061- func_dict [name ] = prim_func .with_attr ("global_symbol" , name )
2062- return tvm .IRModule (func_dict )
2063-
2064- else :
2065-
2066- def inner (self ):
2067- # pylint: disable=unused-argument
2068- source_code = "@T.prim_func\n " + textwrap .dedent (inspect .getsource (func ))
2069- return tvm .script .from_source (source_code , check_well_formed = self .check_well_formed )
2070-
2071- return pytest .fixture (inner )
2072-
2073- @classmethod
2074- def _normalize_before (cls , func ):
2075- if hasattr (func , "_pytestfixturefunction" ):
2076- return func
2077- else :
2078- return cls ._normalize_ir_module (func )
2079-
2080- @classmethod
2081- def _normalize_expected (cls , func ):
2082- if hasattr (func , "_pytestfixturefunction" ):
2083- return func
2084-
2085- elif inspect .isclass (func ) and issubclass (func , Exception ):
2086-
2087- def inner (self ):
2088- # pylint: disable=unused-argument
2089- return func
2090-
2091- return pytest .fixture (inner )
2092-
2093- else :
2094- return cls ._normalize_ir_module (func )
2095-
2096- @classmethod
2097- def _normalize_transform (cls , transform ):
2098- def apply (module_transform ):
2099- def inner (obj ):
2100- if isinstance (obj , tvm .IRModule ):
2101- return module_transform (obj )
2102- elif isinstance (obj , tvm .tir .PrimFunc ):
2103- mod = tvm .IRModule ({"main" : obj })
2104- mod = module_transform (mod )
2105- return mod ["main" ]
2106- else :
2107- raise TypeError (f"Expected IRModule or PrimFunc, but received { type (obj )} " )
2108-
2109- return inner
2110-
2111- if hasattr (transform , "_pytestfixturefunction" ):
2112- if not hasattr (cls , "_transform_orig" ):
2113- cls ._transform_orig = transform
2114-
2115- def inner (self , _transform_orig ):
2116- # pylint: disable=unused-argument
2117- return apply (_transform_orig )
2118-
2119- elif isinstance (transform , tvm .ir .transform .Pass ):
2120-
2121- def inner (self ):
2122- # pylint: disable=unused-argument
2123- return apply (transform )
2124-
2125- elif cls ._is_method (transform ):
2126-
2127- def inner (self ):
2128- # pylint: disable=unused-argument
2129- return apply (transform (self ))
2130-
2131- else :
2132- raise TypeError (
2133- "Expected transform to be a tvm.ir.transform.Pass, or a method returning a Pass"
2134- )
2135-
2136- return pytest .fixture (inner )
2137-
2138- @staticmethod
2139- def _is_method (func ):
2140- return callable (func ) and "self" in inspect .signature (func ).parameters
2141-
2142- def test_compare (self , before , expected , transform ):
2143- """Unit test to compare the expected TIR PrimFunc to actual"""
2144-
2145- if inspect .isclass (expected ) and issubclass (expected , Exception ):
2146- with pytest .raises (expected ):
2147- after = transform (before )
2148-
2149- # This portion through pytest.fail isn't strictly
2150- # necessary, but gives a better error message that
2151- # includes the before/after.
2152- before_str = before .script (name = "before" )
2153- after_str = after .script (name = "after" )
2154-
2155- pytest .fail (
2156- msg = (
2157- f"Expected { expected .__name__ } to be raised from transformation, "
2158- f"instead received TIR\n :{ before_str } \n { after_str } "
2159- )
2160- )
2161-
2162- elif isinstance (expected , (tvm .tir .PrimFunc , tvm .ir .IRModule )):
2163- after = transform (before )
2164-
2165- try :
2166- # overwrite global symbol so it doesn't come up in the comparison
2167- if isinstance (after , tvm .tir .PrimFunc ):
2168- after = after .with_attr ("global_symbol" , "main" )
2169- expected = expected .with_attr ("global_symbol" , "main" )
2170- tvm .ir .assert_structural_equal (after , expected )
2171- except ValueError as err :
2172- before_str = before .script (name = "before" )
2173- after_str = after .script (name = "after" )
2174- expected_str = expected .script (name = "expected" )
2175- raise ValueError (
2176- f"TIR after transformation did not match expected:\n "
2177- f"{ before_str } \n { after_str } \n { expected_str } "
2178- ) from err
2179-
2180- else :
2181- raise TypeError (
2182- f"tvm.testing.CompareBeforeAfter requires the `expected` fixture "
2183- f"to return either `Exception`, an `Exception` subclass, "
2184- f"or an instance of `tvm.tir.PrimFunc`. "
2185- f"Instead, received { type (expected )} ."
2186- )
0 commit comments