diff --git a/models/cudart_api.py b/models/cudart_api.py index a21a36d..65e0f67 100644 --- a/models/cudart_api.py +++ b/models/cudart_api.py @@ -1,6 +1,6 @@ import os import warnings -from collections import namedtuple +from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Tuple, Union @@ -13,6 +13,15 @@ os.environ['CUDA_MODULE_LOADING'] = 'LAZY' warnings.filterwarnings(action='ignore', category=DeprecationWarning) +@dataclass +class Tensor: + name: str + dtype: np.dtype + shape: Tuple + cpu: ndarray + gpu: int + + class TRTEngine: def __init__(self, weight: Union[str, Path]) -> None: @@ -51,7 +60,6 @@ class TRTEngine: def __init_bindings(self) -> None: dynamic = False - Tensor = namedtuple('Tensor', ('name', 'dtype', 'shape', 'cpu', 'gpu')) inp_info = [] out_info = [] out_ptrs = [] diff --git a/models/pycuda_api.py b/models/pycuda_api.py index e340da3..086a8d1 100644 --- a/models/pycuda_api.py +++ b/models/pycuda_api.py @@ -1,6 +1,6 @@ import os import warnings -from collections import namedtuple +from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Tuple, Union @@ -14,6 +14,15 @@ os.environ['CUDA_MODULE_LOADING'] = 'LAZY' warnings.filterwarnings(action='ignore', category=DeprecationWarning) +@dataclass +class Tensor: + name: str + dtype: np.dtype + shape: Tuple + cpu: ndarray + gpu: int + + class TRTEngine: def __init__(self, weight: Union[str, Path]) -> None: @@ -51,7 +60,6 @@ class TRTEngine: def __init_bindings(self) -> None: dynamic = False - Tensor = namedtuple('Tensor', ('name', 'dtype', 'shape', 'cpu', 'gpu')) inp_info = [] out_info = [] out_ptrs = [] @@ -129,7 +137,7 @@ class TRTEngine: shape = tuple(self.context.get_binding_shape(j)) dtype = self.out_info[i].dtype cpu = np.empty(shape, dtype=dtype) - gpu = cuda.mem_alloc(contiguous_inputs[i].nbytes) + gpu = cuda.mem_alloc(cpu.nbytes) cuda.memcpy_htod_async(gpu, cpu, self.stream) else: cpu = self.out_info[i].cpu