From c6ea822bd0d64a89d9177e27172369603d90cba2 Mon Sep 17 00:00:00 2001 From: Anant Sakhare Date: Sat, 14 Dec 2024 17:28:12 +0530 Subject: [PATCH 1/3] added basic engine export setup --- .gitignore | 4 +- trolo/export/exporter.py | 96 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 97 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 4dffa4b..fdf451c 100644 --- a/.gitignore +++ b/.gitignore @@ -26,4 +26,6 @@ test.py wandb/ env_exp/ *.jpeg -*.png \ No newline at end of file +*.png +*.bin +*.xml \ No newline at end of file diff --git a/trolo/export/exporter.py b/trolo/export/exporter.py index 0c7deb5..e8e1a8e 100644 --- a/trolo/export/exporter.py +++ b/trolo/export/exporter.py @@ -1,6 +1,7 @@ from typing import Dict, Union, Optional, List, Tuple import os +import sys from pathlib import Path import numpy as np @@ -56,6 +57,7 @@ def __init__( self.config = self.load_config(config) self.device = torch.device(infer_device(device)) + LOGGER.info(f"{self.device}") self.model = self.load_model(self.model_path) self.model.to(self.device) self.model.eval() @@ -126,8 +128,14 @@ def export( fp16=fp16, ) + elif export_format.lower().strip() == "engine" or export_format.lower().strip() =="tensorrt" : + exported_path = self.export_engine( + input_size=input_size, + dtype="fp32", + ) + if not os.path.exists(exported_path): - LOGGER.error(f"Failed to export model to ONNX: {exported_path}") + LOGGER.error(f"Failed to export model: {exported_path}") LOGGER.info(f"Model exported to {exported_path}") @@ -185,7 +193,7 @@ def export2onnx( def export_openvino( self, input_size : Union[List, Tuple] = None, - dynamic : Optional [bool] = False, + verbose : Optional [bool] = False, batch_size : Optional[int] = 1, fp16 : Optional[bool] = False ) -> str: @@ -203,3 +211,87 @@ def export_openvino( ov.runtime.save_model(ov_model, output_path, compress_to_fp16=fp16) return output_path + + def export_engine( + input_size : Union[List, Tuple] = None, + dtpye : Optional [str] = "fp32", + batch_size : Optional[int] = 1, + ): + # chec device + if self.device is None or self.device == "cpu": + raise ValueError( + "TensorRT requires GPU export, but no device was specified. Please explicitly specify a GPU device (e.g., device=cuda:0) to proceed." + ) + import tensorrt as trt + # check file + if not self.model_path.endswith("onnx"): + exported_path = self.export2onnx(input_size, batch_size=batch_size ) + exported_path = self.model_path + + filename, file_ext = os.path.splitext(self.model_path) + + # check dtpype + if dtpye.lower() == "int8": + trt_dtype = trt.DataType.INT8 + raise ValueError("Currently we do not supprot the int8 conversion & calibration") + elif dtpye.lower() == "fp16": + trt_dtype = trt.DataType.HALF + + elif dtpye.lower() == "fp32": + trt_dtype = trt.DataType.FLOAT + else: + raise ValueError(f"Unsupported data type {dtype}") + + if int(trt.__version__[0]) < 8: + raise RuntimeError( + f"Incompatible TensorRT version detected! The required version is 8 or higher, " + f"but your current version is {trt.__version__}. Please upgrade TensorRT to proceed." + ) + + net_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + # TODO :: INT8 Support needed + + if verbose: + trt_logger = trt.Logger(trt.LOGGER.Verbose) + else: + trt_logger = trt.Logger() + + builder = trt.Builder(trt_logger) + network = builder.create_network(net_flags) + parser = trt.OnnxParser(network, trt_logger) + + if not parser.parse_from_file(exported_path): + raise RuntimeError(f"Failed to load ONNX file {exported_path}") + + inputs = [network.get_input(i) for i in range(network.num_inputs)] + outputs = [network.get_output(i) for i in range(network.num_outputs)] + + config = builder.create_builder_config() + config.max_workspace_size = 2 << 30 + if trt_dtype == trt.DataType.HALF: + config.flags |= 1 << int(trt.BuilderFlag.FP16) + + # TODO :- Implement INT8 + + engine = builder.build_engine(network, config) + + if not engine: + _, _, tb = sys.exc_info() + traceback.print_tb(tb) + + tb_info = traceback.extract_tb(tb) + if tb_info: + _, line, _, text = tb_info[-1] + raise AssertionError( + f"Parsing failed on line {line} in statement: {text}" + ) + else: + raise AssertionError("Engine creation failed, no traceback available.") + + engine_f = f"{filename}-{str(dtpye)}.engine" + with open(engine_f, "wb") as f: + f.write(engine.serialize()) + + LOGGER.info(f"TRT Engine saved to file :{engine_f}") + + \ No newline at end of file From 799f523a8d8e33e23e537dac0812fa6bacc2e289 Mon Sep 17 00:00:00 2001 From: Anant Sakhare Date: Sat, 14 Dec 2024 18:21:36 +0530 Subject: [PATCH 2/3] minor fixes --- trolo/export/exporter.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/trolo/export/exporter.py b/trolo/export/exporter.py index e8e1a8e..3753b06 100644 --- a/trolo/export/exporter.py +++ b/trolo/export/exporter.py @@ -3,6 +3,7 @@ import os import sys from pathlib import Path +import traceback import numpy as np import torch @@ -131,7 +132,7 @@ def export( elif export_format.lower().strip() == "engine" or export_format.lower().strip() =="tensorrt" : exported_path = self.export_engine( input_size=input_size, - dtype="fp32", + dtype="fp32" ) if not os.path.exists(exported_path): @@ -213,9 +214,11 @@ def export_openvino( return output_path def export_engine( + self, input_size : Union[List, Tuple] = None, - dtpye : Optional [str] = "fp32", + dtype : Optional [str] = "fp32", batch_size : Optional[int] = 1, + verbose : Optional[bool] = False ): # chec device if self.device is None or self.device == "cpu": @@ -226,23 +229,24 @@ def export_engine( # check file if not self.model_path.endswith("onnx"): exported_path = self.export2onnx(input_size, batch_size=batch_size ) - exported_path = self.model_path + else: + exported_path = self.model_path filename, file_ext = os.path.splitext(self.model_path) # check dtpype - if dtpye.lower() == "int8": + if dtype.lower() == "int8": trt_dtype = trt.DataType.INT8 raise ValueError("Currently we do not supprot the int8 conversion & calibration") - elif dtpye.lower() == "fp16": + elif dtype.lower() == "fp16": trt_dtype = trt.DataType.HALF - elif dtpye.lower() == "fp32": + elif dtype.lower() == "fp32": trt_dtype = trt.DataType.FLOAT else: raise ValueError(f"Unsupported data type {dtype}") - if int(trt.__version__[0]) < 8: + if int(trt.__version__[0]) > 8: raise RuntimeError( f"Incompatible TensorRT version detected! The required version is 8 or higher, " f"but your current version is {trt.__version__}. Please upgrade TensorRT to proceed." @@ -267,15 +271,14 @@ def export_engine( outputs = [network.get_output(i) for i in range(network.num_outputs)] config = builder.create_builder_config() - config.max_workspace_size = 2 << 30 if trt_dtype == trt.DataType.HALF: config.flags |= 1 << int(trt.BuilderFlag.FP16) # TODO :- Implement INT8 - engine = builder.build_engine(network, config) + engine = builder.build_serialized_network(network, config) - if not engine: + if not engine: _, _, tb = sys.exc_info() traceback.print_tb(tb) @@ -288,10 +291,11 @@ def export_engine( else: raise AssertionError("Engine creation failed, no traceback available.") - engine_f = f"{filename}-{str(dtpye)}.engine" + engine_f = f"{filename}_{str(dtype)}.engine" with open(engine_f, "wb") as f: - f.write(engine.serialize()) + f.write(engine) LOGGER.info(f"TRT Engine saved to file :{engine_f}") + return engine_f \ No newline at end of file From 6cf5ce6cc6716c9b001bff560581f9947cbcf3e1 Mon Sep 17 00:00:00 2001 From: Anant Sakhare Date: Wed, 8 Jan 2025 00:07:43 +0530 Subject: [PATCH 3/3] updated --- trolo/export/exporter.py | 99 ++++++++++++++-------------------------- 1 file changed, 33 insertions(+), 66 deletions(-) diff --git a/trolo/export/exporter.py b/trolo/export/exporter.py index 3753b06..7f90b9f 100644 --- a/trolo/export/exporter.py +++ b/trolo/export/exporter.py @@ -213,89 +213,56 @@ def export_openvino( ov.runtime.save_model(ov_model, output_path, compress_to_fp16=fp16) return output_path - def export_engine( + def export_engine( self, - input_size : Union[List, Tuple] = None, - dtype : Optional [str] = "fp32", - batch_size : Optional[int] = 1, - verbose : Optional[bool] = False - ): - # chec device + input_size: Union[List, Tuple] = None, + dtype: Optional[str] = "fp32", + batch_size: Optional[int] = 1, + verbose: Optional[bool] = False, + ): + # Check device if self.device is None or self.device == "cpu": raise ValueError( - "TensorRT requires GPU export, but no device was specified. Please explicitly specify a GPU device (e.g., device=cuda:0) to proceed." + "TensorRT requires GPU export, but no device was specified. Please explicitly specify a GPU device (e.g., device=cuda:0) to proceed." ) - import tensorrt as trt - # check file - if not self.model_path.endswith("onnx"): - exported_path = self.export2onnx(input_size, batch_size=batch_size ) - else: - exported_path = self.model_path - - filename, file_ext = os.path.splitext(self.model_path) - # check dtpype - if dtype.lower() == "int8": - trt_dtype = trt.DataType.INT8 - raise ValueError("Currently we do not supprot the int8 conversion & calibration") - elif dtype.lower() == "fp16": - trt_dtype = trt.DataType.HALF + import tensorrt as trt - elif dtype.lower() == "fp32": - trt_dtype = trt.DataType.FLOAT + if not self.model_path.endswith("onnx"): + exported_path = self.export2onnx(input_size, batch_size=batch_size) else: - raise ValueError(f"Unsupported data type {dtype}") - - if int(trt.__version__[0]) > 8: - raise RuntimeError( - f"Incompatible TensorRT version detected! The required version is 8 or higher, " - f"but your current version is {trt.__version__}. Please upgrade TensorRT to proceed." - ) - - net_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) - # TODO :: INT8 Support needed + exported_path = self.model_path if verbose: - trt_logger = trt.Logger(trt.LOGGER.Verbose) + trt_logger = trt.Logger(trt.Logger.Severity.VERBOSE) else: - trt_logger = trt.Logger() - - builder = trt.Builder(trt_logger) - network = builder.create_network(net_flags) + trt_logger = trt.Logger(trt.Logger.Severity.WARNING) + + builder = trt.Builder(trt_logger) + network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) parser = trt.OnnxParser(network, trt_logger) if not parser.parse_from_file(exported_path): raise RuntimeError(f"Failed to load ONNX file {exported_path}") - - inputs = [network.get_input(i) for i in range(network.num_inputs)] - outputs = [network.get_output(i) for i in range(network.num_outputs)] config = builder.create_builder_config() - if trt_dtype == trt.DataType.HALF: + if dtype.lower() == "fp16": config.flags |= 1 << int(trt.BuilderFlag.FP16) - - # TODO :- Implement INT8 - - engine = builder.build_serialized_network(network, config) - - if not engine: - _, _, tb = sys.exc_info() - traceback.print_tb(tb) - - tb_info = traceback.extract_tb(tb) - if tb_info: - _, line, _, text = tb_info[-1] - raise AssertionError( - f"Parsing failed on line {line} in statement: {text}" - ) - else: - raise AssertionError("Engine creation failed, no traceback available.") - + elif dtype.lower() == "int8": + config.flags |= 1 << int(trt.BuilderFlag.INT8) + raise NotImplementedError("INT8 calibration is not yet implemented.") + + try: + engine = builder.build_serialized_network(network, config) + if not engine: + raise RuntimeError("Failed to build TensorRT engine.") + except Exception as e: + raise RuntimeError(f"Engine serialization failed: {str(e)}") + + filename = Path(self.model_path).stem engine_f = f"{filename}_{str(dtype)}.engine" with open(engine_f, "wb") as f: f.write(engine) - - LOGGER.info(f"TRT Engine saved to file :{engine_f}") - return engine_f - - \ No newline at end of file + + LOGGER.info(f"TRT Engine saved to file: {engine_f}") + return engine_f \ No newline at end of file