torchonnx is a compiler-based tool that converts ONNX models (.onnx files) into native PyTorch models (.pth files for parameters and .py files for model structure).
Extensively tested on VNNCOMP 2024 benchmarks including Vision Transformers, CNNs, and complex neural network architectures.
While PyTorch provides the torch.onnx module to convert PyTorch models to ONNX, the reverse process—converting ONNX models back to PyTorch—is not officially supported. This tool addresses this gap for several key reasons:
-
Version Fragmentation: ONNX model format evolves across versions, with different versions supporting different operations. This creates significant compatibility challenges when working with models from various sources.
-
Framework Inconsistencies: There are numerous inconsistencies between ONNX and PyTorch models in terms of naming conventions, parameter handling, and operational semantics. PyTorch does not officially support reverse conversion, likely considering it unnecessary for their ecosystem.
-
Neural Network Verification Requirements: For the Neural Network Verification (NNV) community, ONNX has become the unified model format. Being able to work with these models natively in PyTorch is essential for research and verification tasks.
-
Code Quality and Maintainability: ONNX's computational graph representation does not always align with logical groupings that make sense in PyTorch. We need a tool that generates clean, maintainable PyTorch code.
-
Dynamic Batch Dimension: Many ONNX models are exported with hardcoded batch size (typically
batch_size=1), using operations likereshape(1, ...). TorchONNX generates code that supports dynamic batch dimensions, making models compatible with arbitrary batch sizes. -
Vectorization Support: For neural network verification and adversarial robustness testing, efficient batch processing via
torch.vmapis essential. TorchONNX provides avmap_modethat generates vectorization-compatible code, enabling efficient parallel evaluation across multiple inputs.
While other tools exist for ONNX-to-PyTorch conversion, most fall short in performance and code quality. The most well-known tool, onnx2pytorch, serves as a runtime wrapper rather than a true compiler. Its forward method iterates over ONNX nodes at runtime instead of generating static PyTorch code, and parameter conversion is inefficient.
torchonnx takes a different approach: it is a true compiler that generates clean, efficient PyTorch code. The tool converts ONNX models into two separate files:
- A .py file defining the neural network structure as native PyTorch code
- A .pth file containing the model parameters as a state dictionary
This design eliminates runtime overhead and produces code that is readable, maintainable, and performs identically to hand-written PyTorch models.
- Zero Runtime Overhead: Generated PyTorch code has no ONNX dependencies and runs at native PyTorch speed
- Static Code Generation: All operations are compiled to clean Python code, not interpreted at runtime
- Optimized Parameter Handling: Intelligent tracking eliminates unused parameters, reducing model size
- Cached Constants: Constant tensors are registered as buffers for efficient device management
- Idiomatic PyTorch: Uses native PyTorch operations, type conversions, and best practices throughout
- Complete Type Hints: All generated code includes full type annotations for Python 3.10+
- Clean Structure: Human-readable modules with proper naming, documentation, and organization
- No Dead Code: Automatic elimination of unused operations, parameters, and buffers
- Code Optimization: Post-processing removes default arguments and converts to positional arguments
- Formatted Output: All code formatted with
blackfor consistency
- Pure Python Implementation: No compiled dependencies, easy to inspect and modify
- Modular Architecture: Clean 6-stage compiler pipeline with separation of concerns
- Easy to Extend: Add new operations or modify existing ones without breaking the codebase
- Well-Documented: reStructuredText docstrings with
:param:and:return:annotations
- VNNCOMP 2024 Benchmarks: Extensively tested on official neural network verification competition benchmarks
- Diverse Model Coverage: Successfully converts Vision Transformers, CNNs, MLPs, and complex architectures
- Validated Output: Generated models produce numerically identical results to original ONNX models
TorchONNX supports generating torch.vmap-compatible code for efficient batched evaluation. This is particularly important for neural network verification and adversarial robustness testing where many inputs need to be evaluated in parallel.
Standard ONNX models often contain operations that are incompatible with torch.vmap:
- In-place operations: Operations like
index_put_break vmap's functional requirements - Dynamic
.item()calls: Converting tensor values to Python scalars is not vmap-compatible - Input-dependent control flow: Conditional behavior based on input values
When vmap_mode=True is specified during conversion:
- Functional helpers: Uses
torch.scatterinstead of in-placeindex_put_ - Tensor-based indexing: Uses
torch.gatherwith pre-computed slice lengths instead of.item()calls - Validity flag propagation: Tracks empty/out-of-bounds slices and propagates validity to downstream operations
The cctsdb_yolo_2023 benchmark from VNNCOMP presents a particularly challenging case for vectorization:
- Input-dependent dynamic slicing: Slice indices are computed from input values
- Out-of-bounds handling: When slice indices exceed array bounds, standard mode returns empty tensors while vmap mode must return fixed-shape tensors
TorchONNX solves this with validity flag propagation:
dynamic_slicereturns(result, valid_flag)wherevalid_flag=0indicates out-of-bounds- Validity flags are accumulated across multiple slice operations
scatter_ndreceives the validity flag and returns original data unchanged whenvalid=0
This ensures identical outputs between standard and vmap modes for all inputs, including edge cases.
from torchonnx import TorchONNX
converter = TorchONNX(verbose=True)
# Default conversion (vmap_mode=True by default)
converter.convert("model.onnx", target_py_path="model.py")
# Explicitly disable vmap mode if needed (for legacy compatibility)
converter.convert("model.onnx", target_py_path="model_legacy.py", vmap_mode=False)import torch
# Load vmap-compatible model
from model_vmap import Model
model = Model()
model.load_state_dict(torch.load("model.pth"))
# Batch of inputs
batch_inputs = torch.randn(100, 3, 64, 64)
# Vectorized evaluation
vmapped_model = torch.vmap(model)
batch_outputs = vmapped_model(batch_inputs)TorchONNX implements a 6-stage compiler pipeline that transforms ONNX models into optimized PyTorch code:
Loads and normalizes ONNX models to a consistent format:
- Model validation using ONNX checker
- Opset version conversion (target: opset 20)
- Shape inference using ONNX shape inference or shapeonnx
- Metadata cleanup
Key Files: normalize/normalize.py, normalize/utils.py
Extracts pure structural information from ONNX graph:
- Builds
ModelIRcontaining list ofNodeIRinstances - Captures graph topology, tensor shapes, and initializers
- No semantic interpretation at this stage (pure structural representation)
Key Files: build/builder.py, build/types.py
Transforms structural IR into semantic IR with PyTorch types:
- Classifies initializers into parameters (trainable), constants (buffers), and arguments (literals)
- Maps ONNX operations to PyTorch types (layers, functions, operators)
- Resolves tensor data types and shapes
- Builds
SemanticModelIRwith typed inputs (VariableInfo,ParameterInfo,ConstantInfo,ArgumentInfo)
Key Files:
analyze/builder.py- Main semantic IR builderanalyze/types.py- Semantic type definitionsanalyze/tensor_classifier.py- Tensor classification logicanalyze/type_mapping/- ONNX to PyTorch type mappingsanalyze/attr_extractor.py- ONNX attribute extraction
Optimizes semantic IR before code generation:
- Constant folding (future)
- Dead code elimination (future)
- Operation fusion (future)
Key Files: optimize/optimizer.py
Generates PyTorch module code from semantic IR:
__init__method: Parameter/constant registration and layer constructionforwardmethod: Operation-by-operation code generation using handlers- State dict: Parameter and constant tensors
- Import statements and module structure
Key Files:
generate/code_generator.py- Main orchestratorgenerate/_init_gen.py-__init__method generationgenerate/_forward_gen.py-forwardmethod generationgenerate/_state_dict_gen.py- State dict buildinggenerate/_templates.py- Code templatesgenerate/_handlers/- Operation-specific code generators
Operation Handlers:
_layers.py- Layer handlers (nn.Conv2d, nn.Linear, etc.)_operators.py- Operator handlers (torch.add, torch.matmul, etc.)_operations.py- Function handlers (reshape, concat, slice, etc.)_registry.py- Handler registration system
Post-processes generated code for cleanliness:
- Removes unused buffer registrations using regex parsing
- Removes default arguments from layer constructors (e.g.,
bias=True→ removed) - Removes default arguments from functions (e.g.,
F.relu(x, inplace=False)→F.relu(x)) - Converts named arguments to positional where appropriate (e.g.,
nn.Conv2d(in_channels=3, out_channels=64)→nn.Conv2d(3, 64)) - Filters state dict to exclude removed buffers
Key Files:
simplify/_optimizer.py- Main optimizer orchestratorsimplify/_line_optimizer.py- Line-by-line optimizationsimplify/_rules.py- Optimization rules and patterns
torchonnx/
├── torchonnx/
│ ├── __init__.py # Exports TorchONNX class
│ ├── _torchonnx.py # TorchONNX class (main API)
│ ├── normalize/ # Stage 1: ONNX normalization
│ │ ├── __init__.py
│ │ ├── normalize.py # Model preprocessing
│ │ └── utils.py # ONNX utilities
│ ├── build/ # Stage 2: Structural IR
│ │ ├── __init__.py
│ │ ├── builder.py # IR builder
│ │ └── types.py # NodeIR, ModelIR types
│ ├── analyze/ # Stage 3: Semantic IR
│ │ ├── __init__.py
│ │ ├── builder.py # Semantic IR builder
│ │ ├── types.py # Semantic type definitions
│ │ ├── tensor_classifier.py # Tensor classification
│ │ ├── attr_extractor.py # Attribute extraction
│ │ └── type_mapping/ # ONNX → PyTorch mappings
│ │ ├── _layers.py # Layer type mappings
│ │ └── _operations.py # Operation type mappings
│ ├── optimize/ # Stage 4: IR optimization
│ │ ├── __init__.py
│ │ └── optimizer.py # IR-level optimizations
│ ├── generate/ # Stage 5: Code generation
│ │ ├── __init__.py
│ │ ├── code_generator.py # Main code generator
│ │ ├── _init_gen.py # __init__ generation
│ │ ├── _forward_gen.py # forward() generation
│ │ ├── _state_dict_gen.py # State dict building
│ │ ├── _templates.py # Code templates
│ │ ├── _utils.py # Helper utilities
│ │ └── _handlers/ # Operation-specific handlers
│ │ ├── __init__.py
│ │ ├── _registry.py # Handler registry
│ │ ├── _layers.py # Layer handlers
│ │ ├── _operators.py # Operator handlers
│ │ └── _operations.py # Function handlers
│ └── simplify/ # Stage 6: Code optimization
│ ├── __init__.py
│ ├── _optimizer.py # Main optimizer
│ ├── _line_optimizer.py # Line optimizations
│ └── _rules.py # Optimization rules
├── tests/ # Testing infrastructure
│ ├── benchmarks/ # Original ONNX files
│ ├── baselines/ # Expected outputs
│ ├── results/ # Generated outputs
│ ├── analyze_model_nodes.py # Model node analyzer
│ ├── build_benchmarks.py # Benchmark builder
│ ├── test_benchmarks.py # VNNCOMP 2024 tests
│ └── utils.py # Test utilities
└── README.md
Note: TorchONNX is not published on PyPI. Install locally from source.
Clone the repository and install in development mode:
git clone https://github.com/ZhongkuiMa/torchonnx.git
cd torchonnx
pip install -e .- Python >= 3.11
- PyTorch 2.3.1
- ONNX 1.16.0
- ONNXRuntime 1.22.0
- NumPy 1.26.4
For development with linting and testing tools:
pip install -e ".[dev]"This installs additional dependencies:
- pytest >= 7.0
- pytest-cov >= 4.0
- ruff >= 0.14.0
- mypy >= 1.0
from torchonnx import TorchONNX
if __name__ == "__main__":
# Create converter instance
converter = TorchONNX(verbose=True)
# Convert ONNX model to PyTorch
converter.convert(
onnx_path="model.onnx",
benchmark_name="mymodel", # Optional: for module naming
target_py_path="model.py", # Optional: defaults to model.py
target_pth_path="model.pth" # Optional: defaults to model.pth
)The following example demonstrates conversion of a Vision Transformer (ViT) model from VNNCOMP 2023. Note that you should use slimonnx to simplify the model first, as the original may contain unsupported operations.
You can visualize the ONNX computational graph using netron.app.
from torchonnx import TorchONNX
if __name__ == "__main__":
file_path = "../nets/ibp_3_3_8_v22_simplified.onnx"
converter = TorchONNX(verbose=True)
converter.convert(file_path)The following shows generated PyTorch code for the ViT model. Note the clean structure, proper parameter registration, and readable forward pass:
__all__ = ["Vit2023Ibp338"]
import torch
import torch.nn as nn
def dynamic_slice(data, starts, ends, axes=None, steps=None):
"""Dynamic slice helper for ONNX Slice operation."""
# Ensure tensor
starts = torch.as_tensor(starts, device=data.device)
ends = torch.as_tensor(ends, device=data.device)
if axes is None:
axes = torch.arange(starts.numel(), device=data.device)
else:
axes = torch.as_tensor(axes, device=data.device)
if steps is None:
steps = torch.ones_like(starts)
else:
steps = torch.as_tensor(steps, device=data.device)
# Normalize negative starts/ends
dims = torch.as_tensor(data.shape, device=data.device)
# axes tells where to read dim size
dim_sizes = dims[axes]
starts = torch.where(starts < 0, dim_sizes + starts, starts)
ends = torch.where(ends < 0, dim_sizes + ends, ends)
# Clip to bounds (ONNX semantics)
# Use tensors for both min and max to avoid type mismatch
zero = torch.zeros_like(dim_sizes)
starts = torch.clamp(starts, min=zero, max=dim_sizes)
ends = torch.clamp(ends, min=zero, max=dim_sizes)
# Build index tuple dynamically
index = [slice(None)] * data.ndim
for i in range(axes.shape[0]):
ax = axes[i].item()
idx = torch.arange(starts[i], ends[i], steps[i], device=data.device)
index[ax] = idx
return data[tuple(index)]
class Vit2023Ibp338(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("c4", torch.empty([1], dtype=torch.int64))
self.register_buffer("c6", torch.empty([1], dtype=torch.int64))
self.register_buffer("c7", torch.empty([1], dtype=torch.int64))
self.register_buffer("c8", torch.empty([48], dtype=torch.float32))
self.register_buffer("c9", torch.empty([17, 48], dtype=torch.float32))
self.register_buffer("c11", torch.empty([48, 48], dtype=torch.float32))
self.register_buffer("c12", torch.empty([48], dtype=torch.float32))
self.register_buffer("c13", torch.empty([48, 48], dtype=torch.float32))
self.register_buffer("c14", torch.empty([48], dtype=torch.float32))
self.register_buffer("c15", torch.empty([48, 48], dtype=torch.float32))
self.register_buffer("c16", torch.empty([48], dtype=torch.float32))
self.register_buffer("c18", torch.empty([1], dtype=torch.int64))
self.register_buffer("c19", torch.empty([1], dtype=torch.int64))
self.register_buffer("c20", torch.empty([1], dtype=torch.int64))
self.register_buffer("c22", torch.empty([1], dtype=torch.int64))
self.register_buffer("c23", torch.empty([1], dtype=torch.int64))
self.register_buffer("c24", torch.empty([1], dtype=torch.int64))
self.register_buffer("c26", torch.empty([1], dtype=torch.int64))
self.register_buffer("c27", torch.empty([1], dtype=torch.int64))
self.register_buffer("c28", torch.empty([1], dtype=torch.int64))
self.register_buffer("c31", torch.empty([1], dtype=torch.int64))
self.register_buffer("c32", torch.empty([1], dtype=torch.int64))
self.register_buffer("c33", torch.empty([48, 48], dtype=torch.float32))
self.register_buffer("c34", torch.empty([48], dtype=torch.float32))
self.register_buffer("c35", torch.empty([48, 96], dtype=torch.float32))
self.register_buffer("c36", torch.empty([96], dtype=torch.float32))
self.register_buffer("c37", torch.empty([96, 48], dtype=torch.float32))
self.register_buffer("c38", torch.empty([48], dtype=torch.float32))
self.register_buffer("c40", torch.empty([48, 48], dtype=torch.float32))
self.register_buffer("c41", torch.empty([48], dtype=torch.float32))
self.register_buffer("c42", torch.empty([48, 48], dtype=torch.float32))
self.register_buffer("c43", torch.empty([48], dtype=torch.float32))
self.register_buffer("c44", torch.empty([48, 48], dtype=torch.float32))
self.register_buffer("c45", torch.empty([48], dtype=torch.float32))
self.register_buffer("c47", torch.empty([1], dtype=torch.int64))
self.register_buffer("c48", torch.empty([1], dtype=torch.int64))
self.register_buffer("c49", torch.empty([1], dtype=torch.int64))
self.register_buffer("c51", torch.empty([1], dtype=torch.int64))
self.register_buffer("c52", torch.empty([1], dtype=torch.int64))
self.register_buffer("c53", torch.empty([1], dtype=torch.int64))
self.register_buffer("c55", torch.empty([1], dtype=torch.int64))
self.register_buffer("c56", torch.empty([1], dtype=torch.int64))
self.register_buffer("c57", torch.empty([1], dtype=torch.int64))
self.register_buffer("c60", torch.empty([1], dtype=torch.int64))
self.register_buffer("c61", torch.empty([1], dtype=torch.int64))
self.register_buffer("c62", torch.empty([48, 48], dtype=torch.float32))
self.register_buffer("c63", torch.empty([48], dtype=torch.float32))
self.register_buffer("c64", torch.empty([48, 96], dtype=torch.float32))
self.register_buffer("c65", torch.empty([96], dtype=torch.float32))
self.register_buffer("c66", torch.empty([96, 48], dtype=torch.float32))
self.register_buffer("c67", torch.empty([48], dtype=torch.float32))
self.register_buffer("c69", torch.empty([48, 48], dtype=torch.float32))
self.register_buffer("c70", torch.empty([48], dtype=torch.float32))
self.register_buffer("c71", torch.empty([48, 48], dtype=torch.float32))
self.register_buffer("c72", torch.empty([48], dtype=torch.float32))
self.register_buffer("c73", torch.empty([48, 48], dtype=torch.float32))
self.register_buffer("c74", torch.empty([48], dtype=torch.float32))
self.register_buffer("c76", torch.empty([1], dtype=torch.int64))
self.register_buffer("c77", torch.empty([1], dtype=torch.int64))
self.register_buffer("c78", torch.empty([1], dtype=torch.int64))
self.register_buffer("c80", torch.empty([1], dtype=torch.int64))
self.register_buffer("c81", torch.empty([1], dtype=torch.int64))
self.register_buffer("c82", torch.empty([1], dtype=torch.int64))
self.register_buffer("c84", torch.empty([1], dtype=torch.int64))
self.register_buffer("c85", torch.empty([1], dtype=torch.int64))
self.register_buffer("c86", torch.empty([1], dtype=torch.int64))
self.register_buffer("c89", torch.empty([1], dtype=torch.int64))
self.register_buffer("c90", torch.empty([1], dtype=torch.int64))
self.register_buffer("c91", torch.empty([48, 48], dtype=torch.float32))
self.register_buffer("c92", torch.empty([48], dtype=torch.float32))
self.register_buffer("c93", torch.empty([48, 96], dtype=torch.float32))
self.register_buffer("c94", torch.empty([96], dtype=torch.float32))
self.register_buffer("c95", torch.empty([96, 48], dtype=torch.float32))
self.register_buffer("c96", torch.empty([48], dtype=torch.float32))
self.register_buffer("c97", torch.empty([1], dtype=torch.int64))
self.conv2d1 = nn.Conv2d(3, 48, 8, stride=8)
self.batchnorm2d1 = nn.BatchNorm2d(
48, eps=9.999999747378752e-06, momentum=0.10000002384185791
)
self.flatten1 = nn.Flatten(3)
self.softmax1 = nn.Softmax(-1)
self.batchnorm2d2 = nn.BatchNorm2d(
48, eps=9.999999747378752e-06, momentum=0.10000002384185791
)
self.relu1 = nn.ReLU()
self.batchnorm2d3 = nn.BatchNorm2d(
48, eps=9.999999747378752e-06, momentum=0.10000002384185791
)
self.flatten2 = nn.Flatten(3)
self.softmax2 = nn.Softmax(-1)
self.batchnorm2d4 = nn.BatchNorm2d(
48, eps=9.999999747378752e-06, momentum=0.10000002384185791
)
self.relu2 = nn.ReLU()
self.batchnorm2d5 = nn.BatchNorm2d(
48, eps=9.999999747378752e-06, momentum=0.10000002384185791
)
self.flatten3 = nn.Flatten(3)
self.softmax3 = nn.Softmax(-1)
self.batchnorm2d6 = nn.BatchNorm2d(
48, eps=9.999999747378752e-06, momentum=0.10000002384185791
)
self.relu3 = nn.ReLU()
self.batchnorm2d7 = nn.BatchNorm2d(
48, eps=9.999999747378752e-06, momentum=0.10000002384185791
)
self.linear1 = nn.Linear(48, 10)
def forward(self, x0):
x1 = torch.tensor(x0.shape, dtype=torch.int64)
x2 = x1[0]
x3 = self.conv2d1(x0)
x4 = torch.tensor(x3.shape, dtype=torch.int64)
x5 = x4[0:2]
x6 = torch.cat([x5, self.c4])
x7 = x3.reshape([int(x) for x in x6.tolist()])
x8 = x7.permute((0, 2, 1))
x9 = x2.unsqueeze(0)
x10 = torch.cat([x9, self.c6, self.c7])
x11 = torch.full(x10.tolist(), 0.0, dtype=torch.float32)
x12 = x11 + self.c8
x13 = torch.cat([x12, x8], dim=1)
x14 = x13 + self.c9
x15 = x14.permute((0, 2, 1))
x16 = self.batchnorm2d1(x15.unsqueeze(2)).squeeze(2)
x17 = x16.permute((0, 2, 1))
x18 = torch.tensor(x17.shape, dtype=torch.int64)
x19 = x18[0]
x20 = x17 @ self.c11
x21 = self.c12 + x20
x22 = x17 @ self.c13
x23 = self.c14 + x22
x24 = x17 @ self.c15
x25 = self.c16 + x24
x26 = x19.unsqueeze(0)
x27 = torch.cat([x26, self.c18, self.c19, self.c20])
x28 = x19.unsqueeze(0)
x29 = torch.cat([x28, self.c22, self.c23, self.c24])
x30 = x19.unsqueeze(0)
x31 = torch.cat([x30, self.c26, self.c27, self.c28])
x32 = x21.reshape([int(x) for x in x27.tolist()])
x33 = x32.permute((0, 2, 1, 3))
x34 = x23.reshape([int(x) for x in x29.tolist()])
x35 = x25.reshape([int(x) for x in x31.tolist()])
x36 = x35.permute((0, 2, 1, 3))
x37 = x34.permute((0, 2, 3, 1))
x38 = x33 @ x37
x39 = x38 * 0.25
x40 = torch.tensor(x39.shape, dtype=torch.int64)
x41 = self.flatten1(x39)
x42 = self.softmax1(x41)
x43 = x42.reshape([int(x) for x in x40.tolist()])
x44 = x43 @ x36
x45 = x44.permute((0, 2, 1, 3))
x46 = x19.unsqueeze(0)
x47 = torch.cat([x46, self.c31, self.c32])
x48 = x45.reshape([int(x) for x in x47.tolist()])
x49 = x48 @ self.c33
x50 = self.c34 + x49
x51 = x50 + x14
x52 = x51.permute((0, 2, 1))
x53 = self.batchnorm2d2(x52.unsqueeze(2)).squeeze(2)
x54 = x53.permute((0, 2, 1))
x55 = x54 @ self.c35
x56 = self.c36 + x55
x57 = self.relu1(x56)
x58 = x57 @ self.c37
x59 = self.c38 + x58
x60 = x59 + x51
x61 = x60.permute((0, 2, 1))
x62 = self.batchnorm2d3(x61.unsqueeze(2)).squeeze(2)
x63 = x62.permute((0, 2, 1))
x64 = torch.tensor(x63.shape, dtype=torch.int64)
x65 = x64[0]
x66 = x63 @ self.c40
x67 = self.c41 + x66
x68 = x63 @ self.c42
x69 = self.c43 + x68
x70 = x63 @ self.c44
x71 = self.c45 + x70
x72 = x65.unsqueeze(0)
x73 = torch.cat([x72, self.c47, self.c48, self.c49])
x74 = x65.unsqueeze(0)
x75 = torch.cat([x74, self.c51, self.c52, self.c53])
x76 = x65.unsqueeze(0)
x77 = torch.cat([x76, self.c55, self.c56, self.c57])
x78 = x67.reshape([int(x) for x in x73.tolist()])
x79 = x78.permute((0, 2, 1, 3))
x80 = x69.reshape([int(x) for x in x75.tolist()])
x81 = x71.reshape([int(x) for x in x77.tolist()])
x82 = x81.permute((0, 2, 1, 3))
x83 = x80.permute((0, 2, 3, 1))
x84 = x79 @ x83
x85 = x84 * 0.25
x86 = torch.tensor(x85.shape, dtype=torch.int64)
x87 = self.flatten2(x85)
x88 = self.softmax2(x87)
x89 = x88.reshape([int(x) for x in x86.tolist()])
x90 = x89 @ x82
x91 = x90.permute((0, 2, 1, 3))
x92 = x65.unsqueeze(0)
x93 = torch.cat([x92, self.c60, self.c61])
x94 = x91.reshape([int(x) for x in x93.tolist()])
x95 = x94 @ self.c62
x96 = self.c63 + x95
x97 = x96 + x60
x98 = x97.permute((0, 2, 1))
x99 = self.batchnorm2d4(x98.unsqueeze(2)).squeeze(2)
x100 = x99.permute((0, 2, 1))
x101 = x100 @ self.c64
x102 = self.c65 + x101
x103 = self.relu2(x102)
x104 = x103 @ self.c66
x105 = self.c67 + x104
x106 = x105 + x97
x107 = x106.permute((0, 2, 1))
x108 = self.batchnorm2d5(x107.unsqueeze(2)).squeeze(2)
x109 = x108.permute((0, 2, 1))
x110 = torch.tensor(x109.shape, dtype=torch.int64)
x111 = x110[0]
x112 = x109 @ self.c69
x113 = self.c70 + x112
x114 = x109 @ self.c71
x115 = self.c72 + x114
x116 = x109 @ self.c73
x117 = self.c74 + x116
x118 = x111.unsqueeze(0)
x119 = torch.cat([x118, self.c76, self.c77, self.c78])
x120 = x111.unsqueeze(0)
x121 = torch.cat([x120, self.c80, self.c81, self.c82])
x122 = x111.unsqueeze(0)
x123 = torch.cat([x122, self.c84, self.c85, self.c86])
x124 = x113.reshape([int(x) for x in x119.tolist()])
x125 = x124.permute((0, 2, 1, 3))
x126 = x115.reshape([int(x) for x in x121.tolist()])
x127 = x117.reshape([int(x) for x in x123.tolist()])
x128 = x127.permute((0, 2, 1, 3))
x129 = x126.permute((0, 2, 3, 1))
x130 = x125 @ x129
x131 = x130 * 0.25
x132 = torch.tensor(x131.shape, dtype=torch.int64)
x133 = self.flatten3(x131)
x134 = self.softmax3(x133)
x135 = x134.reshape([int(x) for x in x132.tolist()])
x136 = x135 @ x128
x137 = x136.permute((0, 2, 1, 3))
x138 = x111.unsqueeze(0)
x139 = torch.cat([x138, self.c89, self.c90])
x140 = x137.reshape([int(x) for x in x139.tolist()])
x141 = x140 @ self.c91
x142 = self.c92 + x141
x143 = x142 + x106
x144 = x143.permute((0, 2, 1))
x145 = self.batchnorm2d6(x144.unsqueeze(2)).squeeze(2)
x146 = x145.permute((0, 2, 1))
x147 = x146 @ self.c93
x148 = self.c94 + x147
x149 = self.relu3(x148)
x150 = x149 @ self.c95
x151 = self.c96 + x150
x152 = x151 + x143
x153 = torch.mean(x152, self.c97.tolist(), keepdim=False)
x154 = self.batchnorm2d7(x153.unsqueeze(2).unsqueeze(3)).squeeze(2).squeeze(2)
x155 = self.linear1(x154)
return x155
Main API for ONNX to PyTorch conversion.
Constructor:
TorchONNX(verbose: bool = False)Parameters:
verbose: Enable detailed logging during conversion (default: False)
Methods:
convert(
onnx_path: str,
benchmark_name: str | None = None,
target_py_path: str = "model.py",
target_pth_path: str = "model.pth",
vmap_mode: bool = True
) -> NoneConverts ONNX model to PyTorch.
Parameters:
onnx_path: Path to input ONNX modelbenchmark_name: Optional name for module (defaults to filename)target_py_path: Output .py file path for model structure (default: "model.py")target_pth_path: Output .pth file path for model parameters (default: "model.pth")vmap_mode: Generate torch.vmap-compatible code for vectorized evaluation (default: True)
@staticmethod
preprocess(onnx_path: str, target_opset: int = 20) -> onnx.ModelProtoPreprocesses ONNX model (normalization, shape inference, validation).
Parameters:
onnx_path: Path to input ONNX modeltarget_opset: Target ONNX opset version (default: 20)
Returns:
- Normalized ONNX model
TorchONNX is extensively tested on the VNNCOMP 2024 benchmarks, the official benchmark suite for neural network verification competitions. The test suite includes:
- Vision Transformers (ViT): Complex transformer architectures with attention mechanisms
- Convolutional Neural Networks: Various CNN architectures from traffic sign detection to autonomous control
- Feedforward Networks: MLPs with various activation functions and normalizations
- Hybrid Architectures: Models combining multiple architectural patterns
All converted models are validated to produce numerically identical outputs to their original ONNX counterparts, ensuring correctness across diverse model types and operations.
To test with VNNCOMP 2024 benchmarks, clone the vnncomp2024 repository and ensure the following structure:
torchonnx/
│ ├── torchonnx/
│ ├── README.md
│ └── tests/
└── ...
vnncomp2024/
│ ├── benchmarks/
└── ...
Then run the test suite:
cd torchonnx/tests
python test_benchmarks.pyThe tool implements most commonly used operations in feedforward neural networks and transformers:
- Convolution: Conv1d, Conv2d, ConvTranspose1d, ConvTranspose2d
- Pooling: MaxPool2d, AvgPool2d, AdaptiveAvgPool2d
- Normalization: BatchNorm2d (with automatic dimension handling)
- Activation: ReLU, LeakyReLU, Sigmoid, Tanh, Softmax, ELU, GELU
- Linear: Linear
- Dropout: Dropout
- Upsampling: Upsample
- Shape Operations: Flatten
- Convolution: F.conv, F.conv_transpose
- Linear: F.linear
- Pooling: F.interpolate
- Padding: F.pad
- Concatenation: torch.cat
- Indexing: torch.gather, scatter_nd
- Reduction: torch.mean, torch.sum, torch.min, torch.max, torch.argmax
- Clipping: torch.clamp
- Conditional: torch.where
- Generation: torch.full, torch.arange
- Arithmetic: add (+), sub (-), mul (*), div (/), matmul (@), pow (pow), neg (neg)
- Comparison: equal (==)
- Shape: reshape, permute, squeeze, unsqueeze, shape, expand, cast
- Slicing: slice, split
- Math: sign, cos, sin, floor
Transformer-based architectures are decomposed into basic operations and handled correctly.
TorchONNX is extensively tested to ensure correctness and reliability. The test suite includes unit tests, integration tests, performance benchmarks, and VNNCOMP 2024 benchmark validation.
Latest Test Run (2026-01-03):
| Test Suite | Passed | Skipped | Warnings | Time |
|---|---|---|---|---|
| Unit Tests | 853 | 8 | 5 | 2.53s |
| Benchmark Tests | 5 | 54 | 4 | 1.00s |
| Total | 858 | 62 | 9 | 3.53s |
Coverage:
- Python 3.11 and 3.12
- All 6 pipeline stages (normalize, build, analyze, optimize, generate, simplify)
- End-to-end conversion tests with numerical accuracy validation
- Error handling for edge cases and invalid inputs
- Performance benchmarks on complex models
tests/
├── test_units/test_torchonnx/ # Unit tests (853 tests)
│ ├── test_normalize.py # ONNX normalization and shape inference
│ ├── test_build.py # Structural IR construction
│ ├── test_analyze.py # Semantic IR and type mapping
│ ├── test_generate.py # Code generation and forward methods
│ ├── test_simplify.py # Code optimization and formatting
│ ├── test_pipeline.py # End-to-end conversion pipeline
│ ├── test_integration.py # Integration with PyTorch models
│ ├── test_error_handling.py # Error handling and edge cases
│ ├── test_attr_validation.py # Attribute validation rules
│ ├── test_conv_operations.py # Convolution operation details
│ ├── test_operation_handlers.py # Operation handler validation
│ ├── test_optimize.py # IR optimization passes
│ ├── test_remaining_gaps.py # Code formatting edge cases
│ └── fixtures/ # Test fixtures and utilities
└── test_benchmarks/ # Benchmark tests (5 tests, 54 skipped)
├── test_performance.py # Performance benchmarks
├── test_vmap_mode.py # vmap-compatible code generation
├── test_torchonnx.py # Model conversion benchmarks
├── test_torchonnx_regression.py # Regression baseline tests
├── test_vnncomp2024_benchmarks.py # VNNCOMP 2024 benchmark validation
├── build_benchmarks.py # Benchmark data generation
└── baselines/ # Expected outputs for benchmarks
Unit Tests (853 tests):
- Normalize Stage: 67 tests - ONNX model loading, preprocessing, shape inference
- Build Stage: 11 tests - Structural IR construction and validation
- Analyze Stage: 73 tests - Semantic analysis, type mapping, attribute extraction
- Generate Stage: 162 tests - Code generation, forward methods, state dict handling
- Simplify Stage: 47 tests - Code formatting, optimization, file headers
- Pipeline Tests: 54 tests - End-to-end conversion, numerical accuracy
- Integration Tests: 46 tests - Real PyTorch models, error handling
- Operation Handlers: 52 tests - Individual operation validation
- Error Handling: 58 tests - Invalid inputs, edge cases, type errors
Benchmark Tests (5 tests):
- Performance: 4 tests - normalize, build, format_code, model_creation
- vmap Compatibility: 1 test - vmap-compatible code generation validation
Skipped Tests (62 tests):
- Benchmark data tests (54 skipped): Requires external benchmark data via
build_benchmarks.py - Memory-intensive tests (1 skipped): Excluded from default test runs
- Complex model tests (3 skipped): Intentionally skipped for stability
- vmap mode tests (4 skipped): Requires full benchmark environment
Run all tests:
pytest tests/Run unit tests only:
pytest tests/test_units/test_torchonnx/ -vRun benchmark tests only:
pytest tests/test_benchmarks/ -vRun with verbose output and short tracebacks:
pytest tests/ -v --tb=shortRun with coverage report:
pytest tests/ --cov=src/torchonnx --cov-report=term-missing --cov-report=htmlRun specific test file:
pytest tests/test_units/test_torchonnx/test_pipeline.py -vRun specific test class:
pytest tests/test_units/test_torchonnx/test_pipeline.py::TestConvertAPI -vRun specific test:
pytest tests/test_units/test_torchonnx/test_pipeline.py::TestConvertAPI::test_convert_numerical_accuracy -vRun performance benchmarks:
pytest tests/test_benchmarks/test_performance.py -vTests run automatically via GitHub Actions:
- Schedule: Daily at 8 AM UTC
- Manual: Workflow dispatch trigger available
- Matrix: Python 3.11 and 3.12 on ubuntu-latest
- Coverage: Codecov integration for coverage reporting
- Linting: Ruff (check and format)
- Type Checking: MyPy static analysis
See .github/workflows/unit-tests.yml for full CI/CD configuration.
- Pass Rate: 99.3% (858 passed / 920 total tests)
- Test Execution: ~3.5 seconds for full suite
- Code Coverage: Comprehensive coverage of all pipeline stages
- Error Handling: 58 dedicated error handling tests
- Numerical Accuracy: Validated against ONNX Runtime and PyTorch
- Type Safety: Full type hint coverage with MyPy validation
- ShapeONNX: Advanced shape inference for ONNX models. SlimONNX uses ShapeONNX for shape-dependent optimizations.
- TorchVNNLIB: PyTorch library for neural network verification. Often used in conjunction with SlimONNX for model verification tasks. This convert the VNNLIB data files to
.pthformat for PyTorch or.npzformat for NumPy. - SlimONNX: ONNX model simplification tool that removes redundant operations and optimizes the graph before conversion.
- VNN-COMP: International Verification of Neural Networks Competition. SlimONNX is tested on all VNN-COMP 2024 benchmarks.
- ONNX Simplifier: Alternative ONNX optimization tool with different optimization strategies.
Unsupported Operations:
- Use SlimONNX to simplify model first
- Check the "Supported Operations" list above
- Open an issue for missing operations
Shape Inference Errors:
- Ensure input ONNX model is valid and complete
- Try using ShapeONNX externally to infer shapes
- Check ONNX opset version compatibility
Invalid ONNX Model:
- Validate model using ONNX checker:
onnx.checker.check_model(model) - Ensure all inputs have defined shapes in the ONNX graph
Import Errors:
- Verify PyTorch 2.x is installed and compatible with your system
- Check that all dependencies listed in Installation section are installed
- Ensure correct Python path and virtual environment
Numerical Differences:
- Verify original ONNX model produces expected outputs in onnxruntime
- Check for floating-point precision differences (use allclose with tolerance)
- Inspect generated code for potential precision issues
vmap Compatibility Issues:
- Disable vmap mode if conversion fails:
converter.convert(..., vmap_mode=False) - Check for in-place operations or
.item()calls in generated code - Some operations may not be vmap-compatible; see vmap Mode section
Slow Conversion:
- Large models may take time to normalize and analyze
- Enable verbose mode to monitor progress:
TorchONNX(verbose=True) - Simplify model using SlimONNX first to reduce complexity
Large Generated Files:
- Many constants embedded in code; this is normal
- Use parameter quantization in ONNX before conversion if size is critical
- Generated .pth files contain only parameters, not code
See CONTRIBUTING.md for development setup, testing procedures, code quality standards, and pull request guidelines.
Initial public release with clean commit history
- Complete ONNX to PyTorch conversion pipeline
- Support for 50+ ONNX operations
- vmap-compatible code generation mode
- Extensive testing on VNNCOMP 2024 benchmarks
- 853 passing tests with comprehensive coverage
- Clean public API with proper visibility controls
- Full CI/CD pipeline with GitHub Actions
This project is licensed under the MIT License. See the LICENSE file for details.