diff --git a/mesonbuild/modules/unstable_cuda.py b/mesonbuild/modules/unstable_cuda.py index 0d693c3a0..919918ce5 100644 --- a/mesonbuild/modules/unstable_cuda.py +++ b/mesonbuild/modules/unstable_cuda.py @@ -34,16 +34,19 @@ class CudaModule(ExtensionModule): @noKwargs def min_driver_version(self, state, args, kwargs): argerror = InvalidArguments('min_driver_version must have exactly one positional argument: ' + - 'an NVCC compiler object, or its version string.') + 'a CUDA Toolkit version string. Beware that, since CUDA 11.0, ' + + 'the CUDA Toolkit\'s components (including NVCC) are versioned ' + + 'independently from each other (and the CUDA Toolkit as a whole).') - if len(args) != 1: + if len(args) != 1 or not isinstance(args[0], str): raise argerror - else: - cuda_version = self._version_from_compiler(args[0]) - if cuda_version == 'unknown': - raise argerror + cuda_version = args[0] driver_version_table = [ + {'cuda_version': '>=11.1.0', 'windows': '456.38', 'linux': '455.23'}, + {'cuda_version': '>=11.0.3', 'windows': '451.82', 'linux': '450.51.06'}, + {'cuda_version': '>=11.0.2', 'windows': '451.48', 'linux': '450.51.05'}, + {'cuda_version': '>=11.0.1', 'windows': '451.22', 'linux': '450.36.06'}, {'cuda_version': '>=10.2.89', 'windows': '441.22', 'linux': '440.33'}, {'cuda_version': '>=10.1.105', 'windows': '418.96', 'linux': '418.39'}, {'cuda_version': '>=10.0.130', 'windows': '411.31', 'linux': '410.48'}, diff --git a/test cases/cuda/3 cudamodule/meson.build b/test cases/cuda/3 cudamodule/meson.build index f55632863..8410535d8 100644 --- a/test cases/cuda/3 cudamodule/meson.build +++ b/test cases/cuda/3 cudamodule/meson.build @@ -3,9 +3,9 @@ project('cudamodule', 'cuda', version : '1.0.0') nvcc = meson.get_compiler('cuda') cuda = import('unstable-cuda') -arch_flags = cuda.nvcc_arch_flags(nvcc, 'Auto', detected: ['6.0']) -arch_readable = cuda.nvcc_arch_readable(nvcc, 'Auto', detected: ['6.0']) -driver_version = cuda.min_driver_version(nvcc) +arch_flags = cuda.nvcc_arch_flags(nvcc.version(), 'Auto', detected: ['6.0']) +arch_readable = cuda.nvcc_arch_readable(nvcc.version(), 'Auto', detected: ['6.0']) +driver_version = cuda.min_driver_version(nvcc.version()) message('NVCC version: ' + nvcc.version()) message('NVCC flags: ' + ' '.join(arch_flags))