cuda module: use typed_pos_args for most methods

The min_driver_version function has an extensive, informative custom
error message, so leave that in place.

The other two functions didn't have much information there, and it's
fairly evident that the cuda compiler itself is the best thing to have
here. Moreover, there was some fairly gnarly code to validate the
allowed values, which we can greatly simplify by uplifting the
typechecking parts to the dedicated decorators that are both really good
at it, and have nicely formatted error messages complete with reference
to the problematic functions.
pull/12848/head
Eli Schwartz 11 months ago
parent 1b15176168
commit 5899daf25b
No known key found for this signature in database
GPG Key ID: CEB167EFB5722BD6
  1. 27
      mesonbuild/modules/cuda.py

@ -13,14 +13,13 @@ from ..interpreter.type_checking import NoneType
from . import NewExtensionModule, ModuleInfo
from ..interpreterbase import (
ContainerTypeInfo, InvalidArguments, KwargInfo, flatten, noKwargs, typed_kwargs,
ContainerTypeInfo, InvalidArguments, KwargInfo, flatten, noKwargs, typed_kwargs, typed_pos_args,
)
if T.TYPE_CHECKING:
from typing_extensions import TypedDict
from . import ModuleState
from ..compilers import Compiler
class ArchFlagsKwargs(TypedDict):
detected: T.Optional[T.List[str]]
@ -95,17 +94,19 @@ class CudaModule(NewExtensionModule):
return driver_version
@typed_pos_args('cuda.nvcc_arch_flags', (str, CudaCompiler), varargs=str)
@typed_kwargs('cuda.nvcc_arch_flags', DETECTED_KW)
def nvcc_arch_flags(self, state: 'ModuleState',
args: T.Tuple[T.Union[Compiler, CudaCompiler, str]],
args: T.Tuple[T.Union[CudaCompiler, str], T.List[str]],
kwargs: ArchFlagsKwargs) -> T.List[str]:
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
ret = self._nvcc_arch_flags(*nvcc_arch_args)[0]
return ret
@typed_pos_args('cuda.nvcc_arch_readable', (str, CudaCompiler), varargs=str)
@typed_kwargs('cuda.nvcc_arch_readable', DETECTED_KW)
def nvcc_arch_readable(self, state: 'ModuleState',
args: T.Tuple[T.Union[Compiler, CudaCompiler, str]],
args: T.Tuple[T.Union[CudaCompiler, str], T.List[str]],
kwargs: ArchFlagsKwargs) -> T.List[str]:
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
ret = self._nvcc_arch_flags(*nvcc_arch_args)[1]
@ -123,21 +124,15 @@ class CudaModule(NewExtensionModule):
return [c.detected_cc]
return []
def _validate_nvcc_arch_args(self, args, kwargs: ArchFlagsKwargs):
argerror = InvalidArguments('The first argument must be an NVCC compiler object, or its version string!')
def _validate_nvcc_arch_args(self, args: T.Tuple[T.Union[str, CudaCompiler], T.List[str]], kwargs: ArchFlagsKwargs):
if len(args) < 1:
raise argerror
compiler = args[0]
if isinstance(compiler, CudaCompiler):
cuda_version = compiler.version
else:
compiler = args[0]
if isinstance(compiler, CudaCompiler):
cuda_version = compiler.version
elif isinstance(compiler, str):
cuda_version = compiler
else:
raise argerror
cuda_version = compiler
arch_list = [] if len(args) <= 1 else flatten(args[1:])
arch_list = args[1]
arch_list = [self._break_arch_string(a) for a in arch_list]
arch_list = flatten(arch_list)
if len(arch_list) > 1 and not set(arch_list).isdisjoint({'All', 'Common', 'Auto'}):

Loading…
Cancel
Save