cuda module: fully type annotate

Special notes:
- _nvcc_arch_flags is always called with exact arguments, no need for
  default values
- min_driver_version has its args annotation loosened because it has to
  fit the constraints of the module interface?
pull/12848/head
Eli Schwartz 1 year ago
parent 5899daf25b
commit 65ee397f34
No known key found for this signature in database
GPG Key ID: CEB167EFB5722BD6
  1. 48
      mesonbuild/modules/cuda.py
  2. 1
      run_mypy.py

@ -3,27 +3,31 @@
from __future__ import annotations from __future__ import annotations
import typing as T
import re import re
import typing as T
from ..mesonlib import version_compare from ..mesonlib import listify, version_compare
from ..compilers.cuda import CudaCompiler from ..compilers.cuda import CudaCompiler
from ..interpreter.type_checking import NoneType from ..interpreter.type_checking import NoneType
from . import NewExtensionModule, ModuleInfo from . import NewExtensionModule, ModuleInfo
from ..interpreterbase import ( from ..interpreterbase import (
ContainerTypeInfo, InvalidArguments, KwargInfo, flatten, noKwargs, typed_kwargs, typed_pos_args, ContainerTypeInfo, InvalidArguments, KwargInfo, noKwargs, typed_kwargs, typed_pos_args,
) )
if T.TYPE_CHECKING: if T.TYPE_CHECKING:
from typing_extensions import TypedDict from typing_extensions import TypedDict
from . import ModuleState from . import ModuleState
from ..interpreter import Interpreter
from ..interpreterbase import TYPE_var
class ArchFlagsKwargs(TypedDict): class ArchFlagsKwargs(TypedDict):
detected: T.Optional[T.List[str]] detected: T.Optional[T.List[str]]
AutoArch = T.Union[str, T.List[str]]
DETECTED_KW: KwargInfo[T.Union[None, T.List[str]]] = KwargInfo('detected', (ContainerTypeInfo(list, str), NoneType), listify=True) DETECTED_KW: KwargInfo[T.Union[None, T.List[str]]] = KwargInfo('detected', (ContainerTypeInfo(list, str), NoneType), listify=True)
@ -31,7 +35,7 @@ class CudaModule(NewExtensionModule):
INFO = ModuleInfo('CUDA', '0.50.0', unstable=True) INFO = ModuleInfo('CUDA', '0.50.0', unstable=True)
def __init__(self, *args, **kwargs): def __init__(self, interp: Interpreter):
super().__init__() super().__init__()
self.methods.update({ self.methods.update({
"min_driver_version": self.min_driver_version, "min_driver_version": self.min_driver_version,
@ -41,7 +45,7 @@ class CudaModule(NewExtensionModule):
@noKwargs @noKwargs
def min_driver_version(self, state: 'ModuleState', def min_driver_version(self, state: 'ModuleState',
args: T.Tuple[str], args: T.List[TYPE_var],
kwargs: T.Dict[str, T.Any]) -> str: kwargs: T.Dict[str, T.Any]) -> str:
argerror = InvalidArguments('min_driver_version must have exactly one positional argument: ' + argerror = InvalidArguments('min_driver_version must have exactly one positional argument: ' +
'a CUDA Toolkit version string. Beware that, since CUDA 11.0, ' + 'a CUDA Toolkit version string. Beware that, since CUDA 11.0, ' +
@ -113,18 +117,18 @@ class CudaModule(NewExtensionModule):
return ret return ret
@staticmethod @staticmethod
def _break_arch_string(s): def _break_arch_string(s: str) -> T.List[str]:
s = re.sub('[ \t\r\n,;]+', ';', s) s = re.sub('[ \t\r\n,;]+', ';', s)
s = s.strip(';').split(';') return s.strip(';').split(';')
return s
@staticmethod @staticmethod
def _detected_cc_from_compiler(c) -> T.List[str]: def _detected_cc_from_compiler(c: T.Union[str, CudaCompiler]) -> T.List[str]:
if isinstance(c, CudaCompiler): if isinstance(c, CudaCompiler):
return [c.detected_cc] return [c.detected_cc]
return [] return []
def _validate_nvcc_arch_args(self, args: T.Tuple[T.Union[str, CudaCompiler], T.List[str]], kwargs: ArchFlagsKwargs): def _validate_nvcc_arch_args(self, args: T.Tuple[T.Union[str, CudaCompiler], T.List[str]],
kwargs: ArchFlagsKwargs) -> T.Tuple[str, AutoArch, T.List[str]]:
compiler = args[0] compiler = args[0]
if isinstance(compiler, CudaCompiler): if isinstance(compiler, CudaCompiler):
@ -132,22 +136,20 @@ class CudaModule(NewExtensionModule):
else: else:
cuda_version = compiler cuda_version = compiler
arch_list = args[1] arch_list: AutoArch = args[1]
arch_list = [self._break_arch_string(a) for a in arch_list] arch_list = listify([self._break_arch_string(a) for a in arch_list])
arch_list = flatten(arch_list)
if len(arch_list) > 1 and not set(arch_list).isdisjoint({'All', 'Common', 'Auto'}): if len(arch_list) > 1 and not set(arch_list).isdisjoint({'All', 'Common', 'Auto'}):
raise InvalidArguments('''The special architectures 'All', 'Common' and 'Auto' must appear alone, as a positional argument!''') raise InvalidArguments('''The special architectures 'All', 'Common' and 'Auto' must appear alone, as a positional argument!''')
arch_list = arch_list[0] if len(arch_list) == 1 else arch_list arch_list = arch_list[0] if len(arch_list) == 1 else arch_list
detected = kwargs['detected'] if kwargs['detected'] is not None else self._detected_cc_from_compiler(compiler) detected = kwargs['detected'] if kwargs['detected'] is not None else self._detected_cc_from_compiler(compiler)
detected = [self._break_arch_string(a) for a in detected] detected = [x for a in detected for x in self._break_arch_string(a)]
detected = flatten(detected)
if not set(detected).isdisjoint({'All', 'Common', 'Auto'}): if not set(detected).isdisjoint({'All', 'Common', 'Auto'}):
raise InvalidArguments('''The special architectures 'All', 'Common' and 'Auto' must appear alone, as a positional argument!''') raise InvalidArguments('''The special architectures 'All', 'Common' and 'Auto' must appear alone, as a positional argument!''')
return cuda_version, arch_list, detected return cuda_version, arch_list, detected
def _filter_cuda_arch_list(self, cuda_arch_list, lo: str, hi: T.Optional[str], saturate: str) -> T.List[str]: def _filter_cuda_arch_list(self, cuda_arch_list: T.List[str], lo: str, hi: T.Optional[str], saturate: str) -> T.List[str]:
""" """
Filter CUDA arch list (no codenames) for >= low and < hi architecture Filter CUDA arch list (no codenames) for >= low and < hi architecture
bounds, and deduplicate. bounds, and deduplicate.
@ -165,7 +167,7 @@ class CudaModule(NewExtensionModule):
filtered_cuda_arch_list.append(arch) filtered_cuda_arch_list.append(arch)
return filtered_cuda_arch_list return filtered_cuda_arch_list
def _nvcc_arch_flags(self, cuda_version, cuda_arch_list='Auto', detected=''): def _nvcc_arch_flags(self, cuda_version: str, cuda_arch_list: AutoArch, detected: T.List[str]) -> T.Tuple[T.List[str], T.List[str]]:
""" """
Using the CUDA Toolkit version and the target architectures, compute Using the CUDA Toolkit version and the target architectures, compute
the NVCC architecture flags. the NVCC architecture flags.
@ -288,11 +290,11 @@ class CudaModule(NewExtensionModule):
cuda_arch_list = sorted(x for x in set(cuda_arch_list) if x) cuda_arch_list = sorted(x for x in set(cuda_arch_list) if x)
cuda_arch_bin = [] cuda_arch_bin: T.List[str] = []
cuda_arch_ptx = [] cuda_arch_ptx: T.List[str] = []
for arch_name in cuda_arch_list: for arch_name in cuda_arch_list:
arch_bin = [] arch_bin: T.Optional[T.List[str]]
arch_ptx = [] arch_ptx: T.Optional[T.List[str]]
add_ptx = arch_name.endswith('+PTX') add_ptx = arch_name.endswith('+PTX')
if add_ptx: if add_ptx:
arch_name = arch_name[:-len('+PTX')] arch_name = arch_name[:-len('+PTX')]
@ -371,5 +373,5 @@ class CudaModule(NewExtensionModule):
return nvcc_flags, nvcc_archs_readable return nvcc_flags, nvcc_archs_readable
def initialize(*args, **kwargs): def initialize(interp: Interpreter) -> CudaModule:
return CudaModule(*args, **kwargs) return CudaModule(interp)

@ -51,6 +51,7 @@ modules = [
'mesonbuild/mlog.py', 'mesonbuild/mlog.py',
'mesonbuild/msubprojects.py', 'mesonbuild/msubprojects.py',
'mesonbuild/modules/__init__.py', 'mesonbuild/modules/__init__.py',
'mesonbuild/modules/cuda.py',
'mesonbuild/modules/external_project.py', 'mesonbuild/modules/external_project.py',
'mesonbuild/modules/fs.py', 'mesonbuild/modules/fs.py',
'mesonbuild/modules/gnome.py', 'mesonbuild/modules/gnome.py',

Loading…
Cancel
Save