cuda module: fully type annotate

Special notes:
- _nvcc_arch_flags is always called with exact arguments, no need for
  default values
- min_driver_version has its args annotation loosened because it has to
  fit the constraints of the module interface?
pull/12848/head
Eli Schwartz 1 year ago
parent 5899daf25b
commit 65ee397f34
No known key found for this signature in database
GPG Key ID: CEB167EFB5722BD6
  1. 48
      mesonbuild/modules/cuda.py
  2. 1
      run_mypy.py

@ -3,27 +3,31 @@
from __future__ import annotations
import typing as T
import re
import typing as T
from ..mesonlib import version_compare
from ..mesonlib import listify, version_compare
from ..compilers.cuda import CudaCompiler
from ..interpreter.type_checking import NoneType
from . import NewExtensionModule, ModuleInfo
from ..interpreterbase import (
ContainerTypeInfo, InvalidArguments, KwargInfo, flatten, noKwargs, typed_kwargs, typed_pos_args,
ContainerTypeInfo, InvalidArguments, KwargInfo, noKwargs, typed_kwargs, typed_pos_args,
)
if T.TYPE_CHECKING:
from typing_extensions import TypedDict
from . import ModuleState
from ..interpreter import Interpreter
from ..interpreterbase import TYPE_var
class ArchFlagsKwargs(TypedDict):
detected: T.Optional[T.List[str]]
AutoArch = T.Union[str, T.List[str]]
DETECTED_KW: KwargInfo[T.Union[None, T.List[str]]] = KwargInfo('detected', (ContainerTypeInfo(list, str), NoneType), listify=True)
@ -31,7 +35,7 @@ class CudaModule(NewExtensionModule):
INFO = ModuleInfo('CUDA', '0.50.0', unstable=True)
def __init__(self, *args, **kwargs):
def __init__(self, interp: Interpreter):
super().__init__()
self.methods.update({
"min_driver_version": self.min_driver_version,
@ -41,7 +45,7 @@ class CudaModule(NewExtensionModule):
@noKwargs
def min_driver_version(self, state: 'ModuleState',
args: T.Tuple[str],
args: T.List[TYPE_var],
kwargs: T.Dict[str, T.Any]) -> str:
argerror = InvalidArguments('min_driver_version must have exactly one positional argument: ' +
'a CUDA Toolkit version string. Beware that, since CUDA 11.0, ' +
@ -113,18 +117,18 @@ class CudaModule(NewExtensionModule):
return ret
@staticmethod
def _break_arch_string(s):
def _break_arch_string(s: str) -> T.List[str]:
s = re.sub('[ \t\r\n,;]+', ';', s)
s = s.strip(';').split(';')
return s
return s.strip(';').split(';')
@staticmethod
def _detected_cc_from_compiler(c) -> T.List[str]:
def _detected_cc_from_compiler(c: T.Union[str, CudaCompiler]) -> T.List[str]:
if isinstance(c, CudaCompiler):
return [c.detected_cc]
return []
def _validate_nvcc_arch_args(self, args: T.Tuple[T.Union[str, CudaCompiler], T.List[str]], kwargs: ArchFlagsKwargs):
def _validate_nvcc_arch_args(self, args: T.Tuple[T.Union[str, CudaCompiler], T.List[str]],
kwargs: ArchFlagsKwargs) -> T.Tuple[str, AutoArch, T.List[str]]:
compiler = args[0]
if isinstance(compiler, CudaCompiler):
@ -132,22 +136,20 @@ class CudaModule(NewExtensionModule):
else:
cuda_version = compiler
arch_list = args[1]
arch_list = [self._break_arch_string(a) for a in arch_list]
arch_list = flatten(arch_list)
arch_list: AutoArch = args[1]
arch_list = listify([self._break_arch_string(a) for a in arch_list])
if len(arch_list) > 1 and not set(arch_list).isdisjoint({'All', 'Common', 'Auto'}):
raise InvalidArguments('''The special architectures 'All', 'Common' and 'Auto' must appear alone, as a positional argument!''')
arch_list = arch_list[0] if len(arch_list) == 1 else arch_list
detected = kwargs['detected'] if kwargs['detected'] is not None else self._detected_cc_from_compiler(compiler)
detected = [self._break_arch_string(a) for a in detected]
detected = flatten(detected)
detected = [x for a in detected for x in self._break_arch_string(a)]
if not set(detected).isdisjoint({'All', 'Common', 'Auto'}):
raise InvalidArguments('''The special architectures 'All', 'Common' and 'Auto' must appear alone, as a positional argument!''')
return cuda_version, arch_list, detected
def _filter_cuda_arch_list(self, cuda_arch_list, lo: str, hi: T.Optional[str], saturate: str) -> T.List[str]:
def _filter_cuda_arch_list(self, cuda_arch_list: T.List[str], lo: str, hi: T.Optional[str], saturate: str) -> T.List[str]:
"""
Filter CUDA arch list (no codenames) for >= low and < hi architecture
bounds, and deduplicate.
@ -165,7 +167,7 @@ class CudaModule(NewExtensionModule):
filtered_cuda_arch_list.append(arch)
return filtered_cuda_arch_list
def _nvcc_arch_flags(self, cuda_version, cuda_arch_list='Auto', detected=''):
def _nvcc_arch_flags(self, cuda_version: str, cuda_arch_list: AutoArch, detected: T.List[str]) -> T.Tuple[T.List[str], T.List[str]]:
"""
Using the CUDA Toolkit version and the target architectures, compute
the NVCC architecture flags.
@ -288,11 +290,11 @@ class CudaModule(NewExtensionModule):
cuda_arch_list = sorted(x for x in set(cuda_arch_list) if x)
cuda_arch_bin = []
cuda_arch_ptx = []
cuda_arch_bin: T.List[str] = []
cuda_arch_ptx: T.List[str] = []
for arch_name in cuda_arch_list:
arch_bin = []
arch_ptx = []
arch_bin: T.Optional[T.List[str]]
arch_ptx: T.Optional[T.List[str]]
add_ptx = arch_name.endswith('+PTX')
if add_ptx:
arch_name = arch_name[:-len('+PTX')]
@ -371,5 +373,5 @@ class CudaModule(NewExtensionModule):
return nvcc_flags, nvcc_archs_readable
def initialize(*args, **kwargs):
return CudaModule(*args, **kwargs)
def initialize(interp: Interpreter) -> CudaModule:
return CudaModule(interp)

@ -51,6 +51,7 @@ modules = [
'mesonbuild/mlog.py',
'mesonbuild/msubprojects.py',
'mesonbuild/modules/__init__.py',
'mesonbuild/modules/cuda.py',
'mesonbuild/modules/external_project.py',
'mesonbuild/modules/fs.py',
'mesonbuild/modules/gnome.py',

Loading…
Cancel
Save