cuda module: use typed_kwargs

This officially only ever accepted string or array of strings.
pull/12848/head
Eli Schwartz 1 year ago
parent 6f7e745052
commit cf35d9b4ce
No known key found for this signature in database
GPG Key ID: CEB167EFB5722BD6
  1. 31
      mesonbuild/modules/cuda.py

@ -8,18 +8,26 @@ import re
from ..mesonlib import version_compare
from ..compilers.cuda import CudaCompiler
from ..interpreter.type_checking import NoneType
from . import NewExtensionModule, ModuleInfo
from ..interpreterbase import (
flatten, permittedKwargs, noKwargs,
InvalidArguments
ContainerTypeInfo, InvalidArguments, KwargInfo, flatten, noKwargs, typed_kwargs,
)
if T.TYPE_CHECKING:
from typing_extensions import TypedDict
from . import ModuleState
from ..compilers import Compiler
class ArchFlagsKwargs(TypedDict):
detected: T.Optional[T.List[str]]
DETECTED_KW: KwargInfo[T.Union[None, T.List[str]]] = KwargInfo('detected', (ContainerTypeInfo(list, str), NoneType), listify=True)
class CudaModule(NewExtensionModule):
INFO = ModuleInfo('CUDA', '0.50.0', unstable=True)
@ -87,18 +95,18 @@ class CudaModule(NewExtensionModule):
return driver_version
@permittedKwargs(['detected'])
@typed_kwargs('cuda.nvcc_arch_flags', DETECTED_KW)
def nvcc_arch_flags(self, state: 'ModuleState',
args: T.Tuple[T.Union[Compiler, CudaCompiler, str]],
kwargs: T.Dict[str, T.Any]) -> T.List[str]:
kwargs: ArchFlagsKwargs) -> T.List[str]:
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
ret = self._nvcc_arch_flags(*nvcc_arch_args)[0]
return ret
@permittedKwargs(['detected'])
@typed_kwargs('cuda.nvcc_arch_readable', DETECTED_KW)
def nvcc_arch_readable(self, state: 'ModuleState',
args: T.Tuple[T.Union[Compiler, CudaCompiler, str]],
kwargs: T.Dict[str, T.Any]) -> T.List[str]:
kwargs: ArchFlagsKwargs) -> T.List[str]:
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
ret = self._nvcc_arch_flags(*nvcc_arch_args)[1]
return ret
@ -110,10 +118,10 @@ class CudaModule(NewExtensionModule):
return s
@staticmethod
def _detected_cc_from_compiler(c):
def _detected_cc_from_compiler(c) -> T.List[str]:
if isinstance(c, CudaCompiler):
return c.detected_cc
return ''
return [c.detected_cc]
return []
@staticmethod
def _version_from_compiler(c):
@ -123,7 +131,7 @@ class CudaModule(NewExtensionModule):
return c
return 'unknown'
def _validate_nvcc_arch_args(self, args, kwargs):
def _validate_nvcc_arch_args(self, args, kwargs: ArchFlagsKwargs):
argerror = InvalidArguments('The first argument must be an NVCC compiler object, or its version string!')
if len(args) < 1:
@ -141,8 +149,7 @@ class CudaModule(NewExtensionModule):
raise InvalidArguments('''The special architectures 'All', 'Common' and 'Auto' must appear alone, as a positional argument!''')
arch_list = arch_list[0] if len(arch_list) == 1 else arch_list
detected = kwargs.get('detected', self._detected_cc_from_compiler(compiler))
detected = flatten([detected])
detected = kwargs['detected'] if kwargs['detected'] is not None else self._detected_cc_from_compiler(compiler)
detected = [self._break_arch_string(a) for a in detected]
detected = flatten(detected)
if not set(detected).isdisjoint({'All', 'Common', 'Auto'}):

Loading…
Cancel
Save