diff --git a/.gitignore b/.gitignore index 4dffa4b..fdf451c 100644 --- a/.gitignore +++ b/.gitignore @@ -26,4 +26,6 @@ test.py wandb/ env_exp/ *.jpeg -*.png \ No newline at end of file +*.png +*.bin +*.xml \ No newline at end of file diff --git a/trolo/export/exporter.py b/trolo/export/exporter.py index 0c7deb5..7f90b9f 100644 --- a/trolo/export/exporter.py +++ b/trolo/export/exporter.py @@ -1,7 +1,9 @@ from typing import Dict, Union, Optional, List, Tuple import os +import sys from pathlib import Path +import traceback import numpy as np import torch @@ -56,6 +58,7 @@ def __init__( self.config = self.load_config(config) self.device = torch.device(infer_device(device)) + LOGGER.info(f"{self.device}") self.model = self.load_model(self.model_path) self.model.to(self.device) self.model.eval() @@ -126,8 +129,14 @@ def export( fp16=fp16, ) + elif export_format.lower().strip() == "engine" or export_format.lower().strip() =="tensorrt" : + exported_path = self.export_engine( + input_size=input_size, + dtype="fp32" + ) + if not os.path.exists(exported_path): - LOGGER.error(f"Failed to export model to ONNX: {exported_path}") + LOGGER.error(f"Failed to export model: {exported_path}") LOGGER.info(f"Model exported to {exported_path}") @@ -185,7 +194,7 @@ def export2onnx( def export_openvino( self, input_size : Union[List, Tuple] = None, - dynamic : Optional [bool] = False, + verbose : Optional [bool] = False, batch_size : Optional[int] = 1, fp16 : Optional[bool] = False ) -> str: @@ -203,3 +212,57 @@ def export_openvino( ov.runtime.save_model(ov_model, output_path, compress_to_fp16=fp16) return output_path + + def export_engine( + self, + input_size: Union[List, Tuple] = None, + dtype: Optional[str] = "fp32", + batch_size: Optional[int] = 1, + verbose: Optional[bool] = False, + ): + # Check device + if self.device is None or self.device == "cpu": + raise ValueError( + "TensorRT requires GPU export, but no device was specified. Please explicitly specify a GPU device (e.g., device=cuda:0) to proceed." + ) + + import tensorrt as trt + + if not self.model_path.endswith("onnx"): + exported_path = self.export2onnx(input_size, batch_size=batch_size) + else: + exported_path = self.model_path + + if verbose: + trt_logger = trt.Logger(trt.Logger.Severity.VERBOSE) + else: + trt_logger = trt.Logger(trt.Logger.Severity.WARNING) + + builder = trt.Builder(trt_logger) + network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) + parser = trt.OnnxParser(network, trt_logger) + + if not parser.parse_from_file(exported_path): + raise RuntimeError(f"Failed to load ONNX file {exported_path}") + + config = builder.create_builder_config() + if dtype.lower() == "fp16": + config.flags |= 1 << int(trt.BuilderFlag.FP16) + elif dtype.lower() == "int8": + config.flags |= 1 << int(trt.BuilderFlag.INT8) + raise NotImplementedError("INT8 calibration is not yet implemented.") + + try: + engine = builder.build_serialized_network(network, config) + if not engine: + raise RuntimeError("Failed to build TensorRT engine.") + except Exception as e: + raise RuntimeError(f"Engine serialization failed: {str(e)}") + + filename = Path(self.model_path).stem + engine_f = f"{filename}_{str(dtype)}.engine" + with open(engine_f, "wb") as f: + f.write(engine) + + LOGGER.info(f"TRT Engine saved to file: {engine_f}") + return engine_f \ No newline at end of file