|
|
|
@ -12,27 +12,35 @@ |
|
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
|
# limitations under the License. |
|
|
|
|
|
|
|
|
|
import typing as T |
|
|
|
|
import re |
|
|
|
|
|
|
|
|
|
from ..mesonlib import version_compare |
|
|
|
|
from ..interpreter import CompilerHolder |
|
|
|
|
from ..compilers import CudaCompiler |
|
|
|
|
|
|
|
|
|
from . import ExtensionModule, ModuleReturnValue |
|
|
|
|
from . import ModuleObject |
|
|
|
|
|
|
|
|
|
from ..interpreterbase import ( |
|
|
|
|
flatten, permittedKwargs, noKwargs, |
|
|
|
|
InvalidArguments, FeatureNew |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
class CudaModule(ExtensionModule): |
|
|
|
|
class CudaModule(ModuleObject): |
|
|
|
|
|
|
|
|
|
@FeatureNew('CUDA module', '0.50.0') |
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
self.methods.update({ |
|
|
|
|
"min_driver_version": self.min_driver_version, |
|
|
|
|
"nvcc_arch_flags": self.nvcc_arch_flags, |
|
|
|
|
"nvcc_arch_readable": self.nvcc_arch_readable, |
|
|
|
|
}) |
|
|
|
|
|
|
|
|
|
@noKwargs |
|
|
|
|
def min_driver_version(self, state, args, kwargs): |
|
|
|
|
def min_driver_version(self, state: 'ModuleState', |
|
|
|
|
args: T.Tuple[str], |
|
|
|
|
kwargs: T.Dict[str, T.Any]) -> str: |
|
|
|
|
argerror = InvalidArguments('min_driver_version must have exactly one positional argument: ' + |
|
|
|
|
'a CUDA Toolkit version string. Beware that, since CUDA 11.0, ' + |
|
|
|
|
'the CUDA Toolkit\'s components (including NVCC) are versioned ' + |
|
|
|
@ -69,19 +77,23 @@ class CudaModule(ExtensionModule): |
|
|
|
|
driver_version = d.get(state.host_machine.system, d['linux']) |
|
|
|
|
break |
|
|
|
|
|
|
|
|
|
return ModuleReturnValue(driver_version, [driver_version]) |
|
|
|
|
return driver_version |
|
|
|
|
|
|
|
|
|
@permittedKwargs(['detected']) |
|
|
|
|
def nvcc_arch_flags(self, state, args, kwargs): |
|
|
|
|
nvcc_arch_args = self._validate_nvcc_arch_args(state, args, kwargs) |
|
|
|
|
def nvcc_arch_flags(self, state: 'ModuleState', |
|
|
|
|
args: T.Tuple[T.Union[CompilerHolder, CudaCompiler, str]], |
|
|
|
|
kwargs: T.Dict[str, T.Any]) -> T.List[str]: |
|
|
|
|
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs) |
|
|
|
|
ret = self._nvcc_arch_flags(*nvcc_arch_args)[0] |
|
|
|
|
return ModuleReturnValue(ret, [ret]) |
|
|
|
|
return ret |
|
|
|
|
|
|
|
|
|
@permittedKwargs(['detected']) |
|
|
|
|
def nvcc_arch_readable(self, state, args, kwargs): |
|
|
|
|
nvcc_arch_args = self._validate_nvcc_arch_args(state, args, kwargs) |
|
|
|
|
def nvcc_arch_readable(self, state: 'ModuleState', |
|
|
|
|
args: T.Tuple[T.Union[CompilerHolder, CudaCompiler, str]], |
|
|
|
|
kwargs: T.Dict[str, T.Any]) -> T.List[str]: |
|
|
|
|
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs) |
|
|
|
|
ret = self._nvcc_arch_flags(*nvcc_arch_args)[1] |
|
|
|
|
return ModuleReturnValue(ret, [ret]) |
|
|
|
|
return ret |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def _break_arch_string(s): |
|
|
|
@ -107,7 +119,7 @@ class CudaModule(ExtensionModule): |
|
|
|
|
return c |
|
|
|
|
return 'unknown' |
|
|
|
|
|
|
|
|
|
def _validate_nvcc_arch_args(self, state, args, kwargs): |
|
|
|
|
def _validate_nvcc_arch_args(self, args, kwargs): |
|
|
|
|
argerror = InvalidArguments('The first argument must be an NVCC compiler object, or its version string!') |
|
|
|
|
|
|
|
|
|
if len(args) < 1: |
|
|
|
|