@ -13,14 +13,13 @@ from ..interpreter.type_checking import NoneType
from . import NewExtensionModule , ModuleInfo
from . . interpreterbase import (
ContainerTypeInfo , InvalidArguments , KwargInfo , flatten , noKwargs , typed_kwargs ,
ContainerTypeInfo , InvalidArguments , KwargInfo , flatten , noKwargs , typed_kwargs , typed_pos_args ,
)
if T . TYPE_CHECKING :
from typing_extensions import TypedDict
from . import ModuleState
from . . compilers import Compiler
class ArchFlagsKwargs ( TypedDict ) :
detected : T . Optional [ T . List [ str ] ]
@ -95,17 +94,19 @@ class CudaModule(NewExtensionModule):
return driver_version
@typed_pos_args ( ' cuda.nvcc_arch_flags ' , ( str , CudaCompiler ) , varargs = str )
@typed_kwargs ( ' cuda.nvcc_arch_flags ' , DETECTED_KW )
def nvcc_arch_flags ( self , state : ' ModuleState ' ,
args : T . Tuple [ T . Union [ Compiler , CudaCompiler , str ] ] ,
args : T . Tuple [ T . Union [ CudaC ompiler , str ] , T . List [ str ] ] ,
kwargs : ArchFlagsKwargs ) - > T . List [ str ] :
nvcc_arch_args = self . _validate_nvcc_arch_args ( args , kwargs )
ret = self . _nvcc_arch_flags ( * nvcc_arch_args ) [ 0 ]
return ret
@typed_pos_args ( ' cuda.nvcc_arch_readable ' , ( str , CudaCompiler ) , varargs = str )
@typed_kwargs ( ' cuda.nvcc_arch_readable ' , DETECTED_KW )
def nvcc_arch_readable ( self , state : ' ModuleState ' ,
args : T . Tuple [ T . Union [ Compiler , CudaCompiler , str ] ] ,
args : T . Tuple [ T . Union [ CudaC ompiler , str ] , T . List [ str ] ] ,
kwargs : ArchFlagsKwargs ) - > T . List [ str ] :
nvcc_arch_args = self . _validate_nvcc_arch_args ( args , kwargs )
ret = self . _nvcc_arch_flags ( * nvcc_arch_args ) [ 1 ]
@ -123,21 +124,15 @@ class CudaModule(NewExtensionModule):
return [ c . detected_cc ]
return [ ]
def _validate_nvcc_arch_args ( self , args , kwargs : ArchFlagsKwargs ) :
argerror = InvalidArguments ( ' The first argument must be an NVCC compiler object, or its version string! ' )
def _validate_nvcc_arch_args ( self , args : T . Tuple [ T . Union [ str , CudaCompiler ] , T . List [ str ] ] , kwargs : ArchFlagsKwargs ) :
if len ( args ) < 1 :
raise argerror
compiler = args [ 0 ]
if isinstance ( compiler , CudaCompiler ) :
cuda_version = compiler . version
else :
compiler = args [ 0 ]
if isinstance ( compiler , CudaCompiler ) :
cuda_version = compiler . version
elif isinstance ( compiler , str ) :
cuda_version = compiler
else :
raise argerror
cuda_version = compiler
arch_list = [ ] if len ( args ) < = 1 else flatten ( args [ 1 : ] )
arch_list = args [ 1 ]
arch_list = [ self . _break_arch_string ( a ) for a in arch_list ]
arch_list = flatten ( arch_list )
if len ( arch_list ) > 1 and not set ( arch_list ) . isdisjoint ( { ' All ' , ' Common ' , ' Auto ' } ) :