Merge pull request #59 from triple-Mu/triplemu/fix

Fix cuda-python and pycuda
pull/62/head
triple Mu 2 years ago committed by GitHub
commit 2b2ec8667d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 12
      models/cudart_api.py
  2. 14
      models/pycuda_api.py

@ -1,6 +1,6 @@
import os import os
import warnings import warnings
from collections import namedtuple from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@ -13,6 +13,15 @@ os.environ['CUDA_MODULE_LOADING'] = 'LAZY'
warnings.filterwarnings(action='ignore', category=DeprecationWarning) warnings.filterwarnings(action='ignore', category=DeprecationWarning)
@dataclass
class Tensor:
name: str
dtype: np.dtype
shape: Tuple
cpu: ndarray
gpu: int
class TRTEngine: class TRTEngine:
def __init__(self, weight: Union[str, Path]) -> None: def __init__(self, weight: Union[str, Path]) -> None:
@ -51,7 +60,6 @@ class TRTEngine:
def __init_bindings(self) -> None: def __init_bindings(self) -> None:
dynamic = False dynamic = False
Tensor = namedtuple('Tensor', ('name', 'dtype', 'shape', 'cpu', 'gpu'))
inp_info = [] inp_info = []
out_info = [] out_info = []
out_ptrs = [] out_ptrs = []

@ -1,6 +1,6 @@
import os import os
import warnings import warnings
from collections import namedtuple from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@ -14,6 +14,15 @@ os.environ['CUDA_MODULE_LOADING'] = 'LAZY'
warnings.filterwarnings(action='ignore', category=DeprecationWarning) warnings.filterwarnings(action='ignore', category=DeprecationWarning)
@dataclass
class Tensor:
name: str
dtype: np.dtype
shape: Tuple
cpu: ndarray
gpu: int
class TRTEngine: class TRTEngine:
def __init__(self, weight: Union[str, Path]) -> None: def __init__(self, weight: Union[str, Path]) -> None:
@ -51,7 +60,6 @@ class TRTEngine:
def __init_bindings(self) -> None: def __init_bindings(self) -> None:
dynamic = False dynamic = False
Tensor = namedtuple('Tensor', ('name', 'dtype', 'shape', 'cpu', 'gpu'))
inp_info = [] inp_info = []
out_info = [] out_info = []
out_ptrs = [] out_ptrs = []
@ -129,7 +137,7 @@ class TRTEngine:
shape = tuple(self.context.get_binding_shape(j)) shape = tuple(self.context.get_binding_shape(j))
dtype = self.out_info[i].dtype dtype = self.out_info[i].dtype
cpu = np.empty(shape, dtype=dtype) cpu = np.empty(shape, dtype=dtype)
gpu = cuda.mem_alloc(contiguous_inputs[i].nbytes) gpu = cuda.mem_alloc(cpu.nbytes)
cuda.memcpy_htod_async(gpu, cpu, self.stream) cuda.memcpy_htod_async(gpu, cpu, self.stream)
else: else:
cpu = self.out_info[i].cpu cpu = self.out_info[i].cpu

Loading…
Cancel
Save