|
|
|
@ -8,18 +8,26 @@ import re |
|
|
|
|
|
|
|
|
|
from ..mesonlib import version_compare |
|
|
|
|
from ..compilers.cuda import CudaCompiler |
|
|
|
|
from ..interpreter.type_checking import NoneType |
|
|
|
|
|
|
|
|
|
from . import NewExtensionModule, ModuleInfo |
|
|
|
|
|
|
|
|
|
from ..interpreterbase import ( |
|
|
|
|
flatten, permittedKwargs, noKwargs, |
|
|
|
|
InvalidArguments |
|
|
|
|
ContainerTypeInfo, InvalidArguments, KwargInfo, flatten, noKwargs, typed_kwargs, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
if T.TYPE_CHECKING: |
|
|
|
|
from typing_extensions import TypedDict |
|
|
|
|
|
|
|
|
|
from . import ModuleState |
|
|
|
|
from ..compilers import Compiler |
|
|
|
|
|
|
|
|
|
class ArchFlagsKwargs(TypedDict): |
|
|
|
|
detected: T.Optional[T.List[str]] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DETECTED_KW: KwargInfo[T.Union[None, T.List[str]]] = KwargInfo('detected', (ContainerTypeInfo(list, str), NoneType), listify=True) |
|
|
|
|
|
|
|
|
|
class CudaModule(NewExtensionModule): |
|
|
|
|
|
|
|
|
|
INFO = ModuleInfo('CUDA', '0.50.0', unstable=True) |
|
|
|
@ -87,18 +95,18 @@ class CudaModule(NewExtensionModule): |
|
|
|
|
|
|
|
|
|
return driver_version |
|
|
|
|
|
|
|
|
|
@permittedKwargs(['detected']) |
|
|
|
|
@typed_kwargs('cuda.nvcc_arch_flags', DETECTED_KW) |
|
|
|
|
def nvcc_arch_flags(self, state: 'ModuleState', |
|
|
|
|
args: T.Tuple[T.Union[Compiler, CudaCompiler, str]], |
|
|
|
|
kwargs: T.Dict[str, T.Any]) -> T.List[str]: |
|
|
|
|
kwargs: ArchFlagsKwargs) -> T.List[str]: |
|
|
|
|
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs) |
|
|
|
|
ret = self._nvcc_arch_flags(*nvcc_arch_args)[0] |
|
|
|
|
return ret |
|
|
|
|
|
|
|
|
|
@permittedKwargs(['detected']) |
|
|
|
|
@typed_kwargs('cuda.nvcc_arch_readable', DETECTED_KW) |
|
|
|
|
def nvcc_arch_readable(self, state: 'ModuleState', |
|
|
|
|
args: T.Tuple[T.Union[Compiler, CudaCompiler, str]], |
|
|
|
|
kwargs: T.Dict[str, T.Any]) -> T.List[str]: |
|
|
|
|
kwargs: ArchFlagsKwargs) -> T.List[str]: |
|
|
|
|
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs) |
|
|
|
|
ret = self._nvcc_arch_flags(*nvcc_arch_args)[1] |
|
|
|
|
return ret |
|
|
|
@ -110,10 +118,10 @@ class CudaModule(NewExtensionModule): |
|
|
|
|
return s |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def _detected_cc_from_compiler(c): |
|
|
|
|
def _detected_cc_from_compiler(c) -> T.List[str]: |
|
|
|
|
if isinstance(c, CudaCompiler): |
|
|
|
|
return c.detected_cc |
|
|
|
|
return '' |
|
|
|
|
return [c.detected_cc] |
|
|
|
|
return [] |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def _version_from_compiler(c): |
|
|
|
@ -123,7 +131,7 @@ class CudaModule(NewExtensionModule): |
|
|
|
|
return c |
|
|
|
|
return 'unknown' |
|
|
|
|
|
|
|
|
|
def _validate_nvcc_arch_args(self, args, kwargs): |
|
|
|
|
def _validate_nvcc_arch_args(self, args, kwargs: ArchFlagsKwargs): |
|
|
|
|
argerror = InvalidArguments('The first argument must be an NVCC compiler object, or its version string!') |
|
|
|
|
|
|
|
|
|
if len(args) < 1: |
|
|
|
@ -141,8 +149,7 @@ class CudaModule(NewExtensionModule): |
|
|
|
|
raise InvalidArguments('''The special architectures 'All', 'Common' and 'Auto' must appear alone, as a positional argument!''') |
|
|
|
|
arch_list = arch_list[0] if len(arch_list) == 1 else arch_list |
|
|
|
|
|
|
|
|
|
detected = kwargs.get('detected', self._detected_cc_from_compiler(compiler)) |
|
|
|
|
detected = flatten([detected]) |
|
|
|
|
detected = kwargs['detected'] if kwargs['detected'] is not None else self._detected_cc_from_compiler(compiler) |
|
|
|
|
detected = [self._break_arch_string(a) for a in detected] |
|
|
|
|
detected = flatten(detected) |
|
|
|
|
if not set(detected).isdisjoint({'All', 'Common', 'Auto'}): |
|
|
|
|