diff --git a/mesonbuild/modules/cuda.py b/mesonbuild/modules/cuda.py index 525010839..7cfd3a04f 100644 --- a/mesonbuild/modules/cuda.py +++ b/mesonbuild/modules/cuda.py @@ -8,18 +8,26 @@ import re from ..mesonlib import version_compare from ..compilers.cuda import CudaCompiler +from ..interpreter.type_checking import NoneType from . import NewExtensionModule, ModuleInfo from ..interpreterbase import ( - flatten, permittedKwargs, noKwargs, - InvalidArguments + ContainerTypeInfo, InvalidArguments, KwargInfo, flatten, noKwargs, typed_kwargs, ) if T.TYPE_CHECKING: + from typing_extensions import TypedDict + from . import ModuleState from ..compilers import Compiler + class ArchFlagsKwargs(TypedDict): + detected: T.Optional[T.List[str]] + + +DETECTED_KW: KwargInfo[T.Union[None, T.List[str]]] = KwargInfo('detected', (ContainerTypeInfo(list, str), NoneType), listify=True) + class CudaModule(NewExtensionModule): INFO = ModuleInfo('CUDA', '0.50.0', unstable=True) @@ -87,18 +95,18 @@ class CudaModule(NewExtensionModule): return driver_version - @permittedKwargs(['detected']) + @typed_kwargs('cuda.nvcc_arch_flags', DETECTED_KW) def nvcc_arch_flags(self, state: 'ModuleState', args: T.Tuple[T.Union[Compiler, CudaCompiler, str]], - kwargs: T.Dict[str, T.Any]) -> T.List[str]: + kwargs: ArchFlagsKwargs) -> T.List[str]: nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs) ret = self._nvcc_arch_flags(*nvcc_arch_args)[0] return ret - @permittedKwargs(['detected']) + @typed_kwargs('cuda.nvcc_arch_readable', DETECTED_KW) def nvcc_arch_readable(self, state: 'ModuleState', args: T.Tuple[T.Union[Compiler, CudaCompiler, str]], - kwargs: T.Dict[str, T.Any]) -> T.List[str]: + kwargs: ArchFlagsKwargs) -> T.List[str]: nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs) ret = self._nvcc_arch_flags(*nvcc_arch_args)[1] return ret @@ -110,10 +118,10 @@ class CudaModule(NewExtensionModule): return s @staticmethod - def _detected_cc_from_compiler(c): + def _detected_cc_from_compiler(c) -> T.List[str]: if isinstance(c, CudaCompiler): - return c.detected_cc - return '' + return [c.detected_cc] + return [] @staticmethod def _version_from_compiler(c): @@ -123,7 +131,7 @@ class CudaModule(NewExtensionModule): return c return 'unknown' - def _validate_nvcc_arch_args(self, args, kwargs): + def _validate_nvcc_arch_args(self, args, kwargs: ArchFlagsKwargs): argerror = InvalidArguments('The first argument must be an NVCC compiler object, or its version string!') if len(args) < 1: @@ -141,8 +149,7 @@ class CudaModule(NewExtensionModule): raise InvalidArguments('''The special architectures 'All', 'Common' and 'Auto' must appear alone, as a positional argument!''') arch_list = arch_list[0] if len(arch_list) == 1 else arch_list - detected = kwargs.get('detected', self._detected_cc_from_compiler(compiler)) - detected = flatten([detected]) + detected = kwargs['detected'] if kwargs['detected'] is not None else self._detected_cc_from_compiler(compiler) detected = [self._break_arch_string(a) for a in detected] detected = flatten(detected) if not set(detected).isdisjoint({'All', 'Common', 'Auto'}):