Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions python/torch_mlir/compiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .passmanager import PassManager
from .ir import StringAttr

ERROR_MSG_MAX_LENGTH = 10000

class TensorPlaceholder:
"""A class that represents a formal parameter of a given shape and dtype.
Expand Down Expand Up @@ -90,9 +91,10 @@ def run_pipeline_with_repro_report(
):
"""Runs `pipeline` on `module`, with a nice repro report if it fails."""
module_name = get_module_name_for_debug_dump(module)
original_stderr = sys.stderr
saved_stderr_fd = os.dup(2)
read_fd, write_fd = os.pipe()
os.dup2(write_fd, 2)
try:
sys.stderr = StringIO()
asm_for_error_report = module.operation.get_asm(
large_elements_limit=10, enable_debug_info=True
)
Expand Down Expand Up @@ -122,9 +124,10 @@ def run_pipeline_with_repro_report(
# Put something descriptive here even if description is empty.
description = description or f"{module_name} compile"

stderr_out = os.read(read_fd, ERROR_MSG_MAX_LENGTH).decode()
message = f"""\
{description} failed with the following diagnostics:
{sys.stderr.getvalue()}
{stderr_out}

python exception: {e}

Expand All @@ -135,7 +138,10 @@ def run_pipeline_with_repro_report(
trimmed_message = "\n".join([m.lstrip() for m in message.split("\n")])
raise TorchMlirCompilerError(trimmed_message) from None
finally:
sys.stderr = original_stderr
os.dup2(saved_stderr_fd, 2)
os.close(saved_stderr_fd)
os.close(write_fd)
os.close(read_fd)


class OutputType(Enum):
Expand Down