@ -3,27 +3,31 @@
from __future__ import annotations
import typing as T
import re
import typing as T
from . . mesonlib import version_compare
from . . mesonlib import listify , version_compare
from . . compilers . cuda import CudaCompiler
from . . interpreter . type_checking import NoneType
from . import NewExtensionModule , ModuleInfo
from . . interpreterbase import (
ContainerTypeInfo , InvalidArguments , KwargInfo , flatten , noKwargs , typed_kwargs , typed_pos_args ,
ContainerTypeInfo , InvalidArguments , KwargInfo , noKwargs , typed_kwargs , typed_pos_args ,
)
if T . TYPE_CHECKING :
from typing_extensions import TypedDict
from . import ModuleState
from . . interpreter import Interpreter
from . . interpreterbase import TYPE_var
class ArchFlagsKwargs ( TypedDict ) :
detected : T . Optional [ T . List [ str ] ]
AutoArch = T . Union [ str , T . List [ str ] ]
DETECTED_KW : KwargInfo [ T . Union [ None , T . List [ str ] ] ] = KwargInfo ( ' detected ' , ( ContainerTypeInfo ( list , str ) , NoneType ) , listify = True )
@ -31,7 +35,7 @@ class CudaModule(NewExtensionModule):
INFO = ModuleInfo ( ' CUDA ' , ' 0.50.0 ' , unstable = True )
def __init__ ( self , * args , * * kwargs ) :
def __init__ ( self , interp : Interpreter ) :
super ( ) . __init__ ( )
self . methods . update ( {
" min_driver_version " : self . min_driver_version ,
@ -41,7 +45,7 @@ class CudaModule(NewExtensionModule):
@noKwargs
def min_driver_version ( self , state : ' ModuleState ' ,
args : T . Tuple [ st r] ,
args : T . List [ TYPE_va r] ,
kwargs : T . Dict [ str , T . Any ] ) - > str :
argerror = InvalidArguments ( ' min_driver_version must have exactly one positional argument: ' +
' a CUDA Toolkit version string. Beware that, since CUDA 11.0, ' +
@ -113,18 +117,18 @@ class CudaModule(NewExtensionModule):
return ret
@staticmethod
def _break_arch_string ( s ) :
def _break_arch_string ( s : str ) - > T . List [ str ] :
s = re . sub ( ' [ \t \r \n ,;]+ ' , ' ; ' , s )
s = s . strip ( ' ; ' ) . split ( ' ; ' )
return s
return s . strip ( ' ; ' ) . split ( ' ; ' )
@staticmethod
def _detected_cc_from_compiler ( c ) - > T . List [ str ] :
def _detected_cc_from_compiler ( c : T . Union [ str , CudaCompiler ] ) - > T . List [ str ] :
if isinstance ( c , CudaCompiler ) :
return [ c . detected_cc ]
return [ ]
def _validate_nvcc_arch_args ( self , args : T . Tuple [ T . Union [ str , CudaCompiler ] , T . List [ str ] ] , kwargs : ArchFlagsKwargs ) :
def _validate_nvcc_arch_args ( self , args : T . Tuple [ T . Union [ str , CudaCompiler ] , T . List [ str ] ] ,
kwargs : ArchFlagsKwargs ) - > T . Tuple [ str , AutoArch , T . List [ str ] ] :
compiler = args [ 0 ]
if isinstance ( compiler , CudaCompiler ) :
@ -132,22 +136,20 @@ class CudaModule(NewExtensionModule):
else :
cuda_version = compiler
arch_list = args [ 1 ]
arch_list = [ self . _break_arch_string ( a ) for a in arch_list ]
arch_list = flatten ( arch_list )
arch_list : AutoArch = args [ 1 ]
arch_list = listify ( [ self . _break_arch_string ( a ) for a in arch_list ] )
if len ( arch_list ) > 1 and not set ( arch_list ) . isdisjoint ( { ' All ' , ' Common ' , ' Auto ' } ) :
raise InvalidArguments ( ''' The special architectures ' All ' , ' Common ' and ' Auto ' must appear alone, as a positional argument! ''' )
arch_list = arch_list [ 0 ] if len ( arch_list ) == 1 else arch_list
detected = kwargs [ ' detected ' ] if kwargs [ ' detected ' ] is not None else self . _detected_cc_from_compiler ( compiler )
detected = [ self . _break_arch_string ( a ) for a in detected ]
detected = flatten ( detected )
detected = [ x for a in detected for x in self . _break_arch_string ( a ) ]
if not set ( detected ) . isdisjoint ( { ' All ' , ' Common ' , ' Auto ' } ) :
raise InvalidArguments ( ''' The special architectures ' All ' , ' Common ' and ' Auto ' must appear alone, as a positional argument! ''' )
return cuda_version , arch_list , detected
def _filter_cuda_arch_list ( self , cuda_arch_list , lo : str , hi : T . Optional [ str ] , saturate : str ) - > T . List [ str ] :
def _filter_cuda_arch_list ( self , cuda_arch_list : T . List [ str ] , lo : str , hi : T . Optional [ str ] , saturate : str ) - > T . List [ str ] :
"""
Filter CUDA arch list ( no codenames ) for > = low and < hi architecture
bounds , and deduplicate .
@ -165,7 +167,7 @@ class CudaModule(NewExtensionModule):
filtered_cuda_arch_list . append ( arch )
return filtered_cuda_arch_list
def _nvcc_arch_flags ( self , cuda_version , cuda_arch_list = ' Auto ' , detected = ' ' ) :
def _nvcc_arch_flags ( self , cuda_version : str , cuda_arch_list : AutoArch , detected : T . List [ str ] ) - > T . Tuple [ T . List [ str ] , T . List [ str ] ] :
"""
Using the CUDA Toolkit version and the target architectures , compute
the NVCC architecture flags .
@ -288,11 +290,11 @@ class CudaModule(NewExtensionModule):
cuda_arch_list = sorted ( x for x in set ( cuda_arch_list ) if x )
cuda_arch_bin = [ ]
cuda_arch_ptx = [ ]
cuda_arch_bin : T . List [ str ] = [ ]
cuda_arch_ptx : T . List [ str ] = [ ]
for arch_name in cuda_arch_list :
arch_bin = [ ]
arch_ptx = [ ]
arch_bin : T . Optional [ T . List [ str ] ]
arch_ptx : T . Optional [ T . List [ str ] ]
add_ptx = arch_name . endswith ( ' +PTX ' )
if add_ptx :
arch_name = arch_name [ : - len ( ' +PTX ' ) ]
@ -371,5 +373,5 @@ class CudaModule(NewExtensionModule):
return nvcc_flags , nvcc_archs_readable
def initialize ( * args , * * kwargs ) :
return CudaModule ( * args , * * kwargs )
def initialize ( interp : Interpreter ) - > CudaModule :
return CudaModule ( interp )