From 65ee397f341688282291b0ef529a7c6aa4c2f9f8 Mon Sep 17 00:00:00 2001 From: Eli Schwartz Date: Sun, 15 Oct 2023 21:26:58 -0400 Subject: [PATCH] cuda module: fully type annotate Special notes: - _nvcc_arch_flags is always called with exact arguments, no need for default values - min_driver_version has its args annotation loosened because it has to fit the constraints of the module interface? --- mesonbuild/modules/cuda.py | 48 ++++++++++++++++++++------------------ run_mypy.py | 1 + 2 files changed, 26 insertions(+), 23 deletions(-) diff --git a/mesonbuild/modules/cuda.py b/mesonbuild/modules/cuda.py index 690053868..eb73a5770 100644 --- a/mesonbuild/modules/cuda.py +++ b/mesonbuild/modules/cuda.py @@ -3,27 +3,31 @@ from __future__ import annotations -import typing as T import re +import typing as T -from ..mesonlib import version_compare +from ..mesonlib import listify, version_compare from ..compilers.cuda import CudaCompiler from ..interpreter.type_checking import NoneType from . import NewExtensionModule, ModuleInfo from ..interpreterbase import ( - ContainerTypeInfo, InvalidArguments, KwargInfo, flatten, noKwargs, typed_kwargs, typed_pos_args, + ContainerTypeInfo, InvalidArguments, KwargInfo, noKwargs, typed_kwargs, typed_pos_args, ) if T.TYPE_CHECKING: from typing_extensions import TypedDict from . import ModuleState + from ..interpreter import Interpreter + from ..interpreterbase import TYPE_var class ArchFlagsKwargs(TypedDict): detected: T.Optional[T.List[str]] + AutoArch = T.Union[str, T.List[str]] + DETECTED_KW: KwargInfo[T.Union[None, T.List[str]]] = KwargInfo('detected', (ContainerTypeInfo(list, str), NoneType), listify=True) @@ -31,7 +35,7 @@ class CudaModule(NewExtensionModule): INFO = ModuleInfo('CUDA', '0.50.0', unstable=True) - def __init__(self, *args, **kwargs): + def __init__(self, interp: Interpreter): super().__init__() self.methods.update({ "min_driver_version": self.min_driver_version, @@ -41,7 +45,7 @@ class CudaModule(NewExtensionModule): @noKwargs def min_driver_version(self, state: 'ModuleState', - args: T.Tuple[str], + args: T.List[TYPE_var], kwargs: T.Dict[str, T.Any]) -> str: argerror = InvalidArguments('min_driver_version must have exactly one positional argument: ' + 'a CUDA Toolkit version string. Beware that, since CUDA 11.0, ' + @@ -113,18 +117,18 @@ class CudaModule(NewExtensionModule): return ret @staticmethod - def _break_arch_string(s): + def _break_arch_string(s: str) -> T.List[str]: s = re.sub('[ \t\r\n,;]+', ';', s) - s = s.strip(';').split(';') - return s + return s.strip(';').split(';') @staticmethod - def _detected_cc_from_compiler(c) -> T.List[str]: + def _detected_cc_from_compiler(c: T.Union[str, CudaCompiler]) -> T.List[str]: if isinstance(c, CudaCompiler): return [c.detected_cc] return [] - def _validate_nvcc_arch_args(self, args: T.Tuple[T.Union[str, CudaCompiler], T.List[str]], kwargs: ArchFlagsKwargs): + def _validate_nvcc_arch_args(self, args: T.Tuple[T.Union[str, CudaCompiler], T.List[str]], + kwargs: ArchFlagsKwargs) -> T.Tuple[str, AutoArch, T.List[str]]: compiler = args[0] if isinstance(compiler, CudaCompiler): @@ -132,22 +136,20 @@ class CudaModule(NewExtensionModule): else: cuda_version = compiler - arch_list = args[1] - arch_list = [self._break_arch_string(a) for a in arch_list] - arch_list = flatten(arch_list) + arch_list: AutoArch = args[1] + arch_list = listify([self._break_arch_string(a) for a in arch_list]) if len(arch_list) > 1 and not set(arch_list).isdisjoint({'All', 'Common', 'Auto'}): raise InvalidArguments('''The special architectures 'All', 'Common' and 'Auto' must appear alone, as a positional argument!''') arch_list = arch_list[0] if len(arch_list) == 1 else arch_list detected = kwargs['detected'] if kwargs['detected'] is not None else self._detected_cc_from_compiler(compiler) - detected = [self._break_arch_string(a) for a in detected] - detected = flatten(detected) + detected = [x for a in detected for x in self._break_arch_string(a)] if not set(detected).isdisjoint({'All', 'Common', 'Auto'}): raise InvalidArguments('''The special architectures 'All', 'Common' and 'Auto' must appear alone, as a positional argument!''') return cuda_version, arch_list, detected - def _filter_cuda_arch_list(self, cuda_arch_list, lo: str, hi: T.Optional[str], saturate: str) -> T.List[str]: + def _filter_cuda_arch_list(self, cuda_arch_list: T.List[str], lo: str, hi: T.Optional[str], saturate: str) -> T.List[str]: """ Filter CUDA arch list (no codenames) for >= low and < hi architecture bounds, and deduplicate. @@ -165,7 +167,7 @@ class CudaModule(NewExtensionModule): filtered_cuda_arch_list.append(arch) return filtered_cuda_arch_list - def _nvcc_arch_flags(self, cuda_version, cuda_arch_list='Auto', detected=''): + def _nvcc_arch_flags(self, cuda_version: str, cuda_arch_list: AutoArch, detected: T.List[str]) -> T.Tuple[T.List[str], T.List[str]]: """ Using the CUDA Toolkit version and the target architectures, compute the NVCC architecture flags. @@ -288,11 +290,11 @@ class CudaModule(NewExtensionModule): cuda_arch_list = sorted(x for x in set(cuda_arch_list) if x) - cuda_arch_bin = [] - cuda_arch_ptx = [] + cuda_arch_bin: T.List[str] = [] + cuda_arch_ptx: T.List[str] = [] for arch_name in cuda_arch_list: - arch_bin = [] - arch_ptx = [] + arch_bin: T.Optional[T.List[str]] + arch_ptx: T.Optional[T.List[str]] add_ptx = arch_name.endswith('+PTX') if add_ptx: arch_name = arch_name[:-len('+PTX')] @@ -371,5 +373,5 @@ class CudaModule(NewExtensionModule): return nvcc_flags, nvcc_archs_readable -def initialize(*args, **kwargs): - return CudaModule(*args, **kwargs) +def initialize(interp: Interpreter) -> CudaModule: + return CudaModule(interp) diff --git a/run_mypy.py b/run_mypy.py index a9b52d9ac..c57a75c12 100755 --- a/run_mypy.py +++ b/run_mypy.py @@ -51,6 +51,7 @@ modules = [ 'mesonbuild/mlog.py', 'mesonbuild/msubprojects.py', 'mesonbuild/modules/__init__.py', + 'mesonbuild/modules/cuda.py', 'mesonbuild/modules/external_project.py', 'mesonbuild/modules/fs.py', 'mesonbuild/modules/gnome.py',