@ -2,88 +2,206 @@ import paddle
import paddle . nn as nn
import paddle . nn . functional as F
import paddlers
from paddlers . rs_models . cd import BIT
from paddlers . rs_models . cd . layers import Conv3x3 , MaxPool2x2 , ConvTransposed3x3 , Identity
from paddlers . rs_models . cd . layers import ChannelAttention , SpatialAttention
from attach_tools import Attach
attach = Attach . to ( paddlers . rs_models . cd )
@attach
class IterativeBIT ( BIT ) :
class CustomModel ( nn . Layer ) :
def __init__ ( self ,
num_iters = 1 ,
feat_channels = 32 ,
num_classes = 2 ,
bit_kwargs = None ) :
if num_iters < = 0 :
raise ValueError (
f " `num_iters` should have positive value, but got { num_iters } . " )
self . num_iters = num_iters
if bit_kwargs is None :
bit_kwargs = dict ( )
if ' num_classes ' in bit_kwargs :
raise KeyError ( " ' num_classes ' should not be set in `bit_kwargs`. " )
bit_kwargs [ ' num_classes ' ] = num_classes
super ( ) . __init__ ( * * bit_kwargs )
in_channels ,
num_classes ,
att_types = ' cst ' ,
use_dropout = False ) :
super ( CustomModel , self ) . __init__ ( )
C1 , C2 , C3 , C4 , C5 = 16 , 32 , 64 , 128 , 256
self . use_dropout = use_dropout
self . conv11 = Conv3x3 ( in_channels , C1 , norm = True , act = True )
self . do11 = self . _make_dropout ( )
self . conv12 = Conv3x3 ( C1 , C1 , norm = True , act = True )
self . do12 = self . _make_dropout ( )
self . pool1 = MaxPool2x2 ( )
self . conv21 = Conv3x3 ( C1 , C2 , norm = True , act = True )
self . do21 = self . _make_dropout ( )
self . conv22 = Conv3x3 ( C2 , C2 , norm = True , act = True )
self . do22 = self . _make_dropout ( )
self . pool2 = MaxPool2x2 ( )
self . conv31 = Conv3x3 ( C2 , C3 , norm = True , act = True )
self . do31 = self . _make_dropout ( )
self . conv32 = Conv3x3 ( C3 , C3 , norm = True , act = True )
self . do32 = self . _make_dropout ( )
self . conv33 = Conv3x3 ( C3 , C3 , norm = True , act = True )
self . do33 = self . _make_dropout ( )
self . pool3 = MaxPool2x2 ( )
self . conv41 = Conv3x3 ( C3 , C4 , norm = True , act = True )
self . do41 = self . _make_dropout ( )
self . conv42 = Conv3x3 ( C4 , C4 , norm = True , act = True )
self . do42 = self . _make_dropout ( )
self . conv43 = Conv3x3 ( C4 , C4 , norm = True , act = True )
self . do43 = self . _make_dropout ( )
self . pool4 = MaxPool2x2 ( )
self . upconv4 = ConvTransposed3x3 ( C4 , C4 , output_padding = 1 )
self . conv43d = Conv3x3 ( C5 , C4 , norm = True , act = True )
self . do43d = self . _make_dropout ( )
self . conv42d = Conv3x3 ( C4 , C4 , norm = True , act = True )
self . do42d = self . _make_dropout ( )
self . conv41d = Conv3x3 ( C4 , C3 , norm = True , act = True )
self . do41d = self . _make_dropout ( )
self . upconv3 = ConvTransposed3x3 ( C3 , C3 , output_padding = 1 )
self . conv33d = Conv3x3 ( C4 , C3 , norm = True , act = True )
self . do33d = self . _make_dropout ( )
self . conv32d = Conv3x3 ( C3 , C3 , norm = True , act = True )
self . do32d = self . _make_dropout ( )
self . conv31d = Conv3x3 ( C3 , C2 , norm = True , act = True )
self . do31d = self . _make_dropout ( )
self . upconv2 = ConvTransposed3x3 ( C2 , C2 , output_padding = 1 )
self . conv22d = Conv3x3 ( C3 , C2 , norm = True , act = True )
self . do22d = self . _make_dropout ( )
self . conv21d = Conv3x3 ( C2 , C1 , norm = True , act = True )
self . do21d = self . _make_dropout ( )
self . upconv1 = ConvTransposed3x3 ( C1 , C1 , output_padding = 1 )
self . conv12d = Conv3x3 ( C2 , C1 , norm = True , act = True )
self . do12d = self . _make_dropout ( )
self . conv11d = Conv3x3 ( C1 , num_classes )
if ' c ' in att_types :
self . att_c = ChannelAttention ( C4 )
else :
self . att_c = Identity ( )
if ' s ' in att_types :
self . att_s = SpatialAttention ( )
else :
self . att_s = Identity ( )
if ' t ' in att_types :
self . att_t = ChannelAttention ( 2 , ratio = 1 )
else :
self . att_t = Identity ( )
self . conv_fuse = nn . Sequential (
nn . Conv2D ( feat_channels + 1 , feat_channels , 1 ) , nn . Sigmoid ( ) )
self . init_weight ( )
def forward ( self , t1 , t2 ) :
# Extract features via shared backbone.
x1 = self . backbone ( t1 )
x2 = self . backbone ( t2 )
# Tokenization
if self . use_tokenizer :
token1 = self . _get_semantic_tokens ( x1 )
token2 = self . _get_semantic_tokens ( x2 )
# Encode t1
# Stage 1
x11 = self . do11 ( self . conv11 ( t1 ) )
x12_1 = self . do12 ( self . conv12 ( x11 ) )
x1p = self . pool1 ( x12_1 )
# Stage 2
x21 = self . do21 ( self . conv21 ( x1p ) )
x22_1 = self . do22 ( self . conv22 ( x21 ) )
x2p = self . pool2 ( x22_1 )
# Stage 3
x31 = self . do31 ( self . conv31 ( x2p ) )
x32 = self . do32 ( self . conv32 ( x31 ) )
x33_1 = self . do33 ( self . conv33 ( x32 ) )
x3p = self . pool3 ( x33_1 )
# Stage 4
x41 = self . do41 ( self . conv41 ( x3p ) )
x42 = self . do42 ( self . conv42 ( x41 ) )
x43_1 = self . do43 ( self . conv43 ( x42 ) )
x4p = self . pool4 ( x43_1 )
# Encode t2
# Stage 1
x11 = self . do11 ( self . conv11 ( t2 ) )
x12_2 = self . do12 ( self . conv12 ( x11 ) )
x1p = self . pool1 ( x12_2 )
# Stage 2
x21 = self . do21 ( self . conv21 ( x1p ) )
x22_2 = self . do22 ( self . conv22 ( x21 ) )
x2p = self . pool2 ( x22_2 )
# Stage 3
x31 = self . do31 ( self . conv31 ( x2p ) )
x32 = self . do32 ( self . conv32 ( x31 ) )
x33_2 = self . do33 ( self . conv33 ( x32 ) )
x3p = self . pool3 ( x33_2 )
# Stage 4
x41 = self . do41 ( self . conv41 ( x3p ) )
x42 = self . do42 ( self . conv42 ( x41 ) )
x43_2 = self . do43 ( self . conv43 ( x42 ) )
x4p = self . pool4 ( x43_2 )
# Attend
x43_1 = self . att_c ( x43_1 ) * x43_1
x43_1 = self . att_s ( x43_1 ) * x43_1
x43_2 = self . att_c ( x43_2 ) * x43_2
x43_2 = self . att_s ( x43_2 ) * x43_2
x43 = paddle . stack ( [ x43_1 , x43_2 ] , axis = 1 )
x43 = paddle . transpose ( x43 , [ 0 , 2 , 1 , 3 , 4 ] )
x43 = paddle . flatten ( x43 , stop_axis = 1 )
x43 = self . att_t ( x43 ) * x43
x43 = x43 . reshape ( ( x43_1 . shape [ 0 ] , - 1 , 2 , * x43 . shape [ 2 : ] ) )
x43_1 , x43_2 = x43 [ : , : , 0 ] , x43 [ : , : , 1 ]
# Decode
# Stage 4d
x4d = self . upconv4 ( x4p )
pad4 = ( 0 , x43_1 . shape [ 3 ] - x4d . shape [ 3 ] , 0 ,
x43_1 . shape [ 2 ] - x4d . shape [ 2 ] )
x4d = F . pad ( x4d , pad = pad4 , mode = ' replicate ' )
x4d = paddle . concat ( [ x4d , paddle . abs ( x43_1 - x43_2 ) ] , 1 )
x43d = self . do43d ( self . conv43d ( x4d ) )
x42d = self . do42d ( self . conv42d ( x43d ) )
x41d = self . do41d ( self . conv41d ( x42d ) )
# Stage 3d
x3d = self . upconv3 ( x41d )
pad3 = ( 0 , x33_1 . shape [ 3 ] - x3d . shape [ 3 ] , 0 ,
x33_1 . shape [ 2 ] - x3d . shape [ 2 ] )
x3d = F . pad ( x3d , pad = pad3 , mode = ' replicate ' )
x3d = paddle . concat ( [ x3d , paddle . abs ( x33_1 - x33_2 ) ] , 1 )
x33d = self . do33d ( self . conv33d ( x3d ) )
x32d = self . do32d ( self . conv32d ( x33d ) )
x31d = self . do31d ( self . conv31d ( x32d ) )
# Stage 2d
x2d = self . upconv2 ( x31d )
pad2 = ( 0 , x22_1 . shape [ 3 ] - x2d . shape [ 3 ] , 0 ,
x22_1 . shape [ 2 ] - x2d . shape [ 2 ] )
x2d = F . pad ( x2d , pad = pad2 , mode = ' replicate ' )
x2d = paddle . concat ( [ x2d , paddle . abs ( x22_1 - x22_2 ) ] , 1 )
x22d = self . do22d ( self . conv22d ( x2d ) )
x21d = self . do21d ( self . conv21d ( x22d ) )
# Stage 1d
x1d = self . upconv1 ( x21d )
pad1 = ( 0 , x12_1 . shape [ 3 ] - x1d . shape [ 3 ] , 0 ,
x12_1 . shape [ 2 ] - x1d . shape [ 2 ] )
x1d = F . pad ( x1d , pad = pad1 , mode = ' replicate ' )
x1d = paddle . concat ( [ x1d , paddle . abs ( x12_1 - x12_2 ) ] , 1 )
x12d = self . do12d ( self . conv12d ( x1d ) )
x11d = self . conv11d ( x12d )
return [ x11d ]
def init_weight ( self ) :
pass
def _make_dropout ( self ) :
if self . use_dropout :
return nn . Dropout2D ( p = 0.2 )
else :
token1 = self . _get_reshaped_tokens ( x1 )
token2 = self . _get_reshaped_tokens ( x2 )
# Transformer encoder forward
token = paddle . concat ( [ token1 , token2 ] , axis = 1 )
token = self . encode ( token )
token1 , token2 = paddle . chunk ( token , 2 , axis = 1 )
# Get initial rate map
rate_map = self . _init_rate_map ( x1 . shape )
for it in range ( self . num_iters ) :
# Construct inputs
x1_iter = self . _constr_iter_input ( x1 , rate_map )
x2_iter = self . _constr_iter_input ( x2 , rate_map )
# Transformer decoder forward
y1 = self . decode ( x1_iter , token1 )
y2 = self . decode ( x2_iter , token2 )
# Feature differencing
y = paddle . abs ( y1 - y2 )
# Construct rate map
rate_map = self . _constr_rate_map ( y )
y = self . upsample ( y )
pred = self . conv_out ( y )
return [ pred ]
def _init_rate_map ( self , im_shape ) :
b , _ , h , w = im_shape
return paddle . full ( ( b , 1 , h , w ) , 0.5 )
def _constr_iter_input ( self , x , rate_map ) :
return self . conv_fuse ( paddle . concat ( [ x , rate_map ] , axis = 1 ) )
def _constr_rate_map ( self , x ) :
rate_map = x . mean ( 1 , keepdim = True ) . detach ( ) # Cut off gradient workflow
# min-max normalization
rate_map - = rate_map . min ( )
rate_map / = rate_map . max ( )
return rate_map
return Identity ( )