Port CUDA module to new API.

pull/8478/head
Olexa Bilaniuk 4 years ago committed by Xavier Claessens
parent 504ae2dee8
commit c4e4363483
  1. 34
      mesonbuild/modules/unstable_cuda.py

@ -12,27 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import typing as T
import re
from ..mesonlib import version_compare
from ..interpreter import CompilerHolder
from ..compilers import CudaCompiler
from . import ExtensionModule, ModuleReturnValue
from . import ModuleObject
from ..interpreterbase import (
flatten, permittedKwargs, noKwargs,
InvalidArguments, FeatureNew
)
class CudaModule(ExtensionModule):
class CudaModule(ModuleObject):
@FeatureNew('CUDA module', '0.50.0')
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.methods.update({
"min_driver_version": self.min_driver_version,
"nvcc_arch_flags": self.nvcc_arch_flags,
"nvcc_arch_readable": self.nvcc_arch_readable,
})
@noKwargs
def min_driver_version(self, state, args, kwargs):
def min_driver_version(self, state: 'ModuleState',
args: T.Tuple[str],
kwargs: T.Dict[str, T.Any]) -> str:
argerror = InvalidArguments('min_driver_version must have exactly one positional argument: ' +
'a CUDA Toolkit version string. Beware that, since CUDA 11.0, ' +
'the CUDA Toolkit\'s components (including NVCC) are versioned ' +
@ -69,19 +77,23 @@ class CudaModule(ExtensionModule):
driver_version = d.get(state.host_machine.system, d['linux'])
break
return ModuleReturnValue(driver_version, [driver_version])
return driver_version
@permittedKwargs(['detected'])
def nvcc_arch_flags(self, state, args, kwargs):
nvcc_arch_args = self._validate_nvcc_arch_args(state, args, kwargs)
def nvcc_arch_flags(self, state: 'ModuleState',
args: T.Tuple[T.Union[CompilerHolder, CudaCompiler, str]],
kwargs: T.Dict[str, T.Any]) -> T.List[str]:
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
ret = self._nvcc_arch_flags(*nvcc_arch_args)[0]
return ModuleReturnValue(ret, [ret])
return ret
@permittedKwargs(['detected'])
def nvcc_arch_readable(self, state, args, kwargs):
nvcc_arch_args = self._validate_nvcc_arch_args(state, args, kwargs)
def nvcc_arch_readable(self, state: 'ModuleState',
args: T.Tuple[T.Union[CompilerHolder, CudaCompiler, str]],
kwargs: T.Dict[str, T.Any]) -> T.List[str]:
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
ret = self._nvcc_arch_flags(*nvcc_arch_args)[1]
return ModuleReturnValue(ret, [ret])
return ret
@staticmethod
def _break_arch_string(s):
@ -107,7 +119,7 @@ class CudaModule(ExtensionModule):
return c
return 'unknown'
def _validate_nvcc_arch_args(self, state, args, kwargs):
def _validate_nvcc_arch_args(self, args, kwargs):
argerror = InvalidArguments('The first argument must be an NVCC compiler object, or its version string!')
if len(args) < 1:

Loading…
Cancel
Save