You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

58 lines
1.7 KiB

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddlers
from paddlers.rs_models.cd import BIT
from attach_tools import Attach
attach = Attach.to(paddlers.rs_models.cd)
@attach
class IterativeBIT(nn.Layer):
def __init__(self, num_iters=1, gamma=0.1, num_classes=2, bit_kwargs=None):
super().__init__()
if num_iters <= 0:
raise ValueError(
f"`num_iters` should have positive value, but got {num_iters}.")
self.num_iters = num_iters
self.gamma = gamma
if bit_kwargs is None:
bit_kwargs = dict()
if 'num_classes' in bit_kwargs:
raise KeyError("'num_classes' should not be set in `bit_kwargs`.")
bit_kwargs['num_classes'] = num_classes
self.bit = BIT(**bit_kwargs)
def forward(self, t1, t2):
rate_map = self._init_rate_map(t1.shape)
for it in range(self.num_iters):
# Construct inputs
x1 = self._constr_iter_input(t1, rate_map)
x2 = self._constr_iter_input(t2, rate_map)
# Get logits
logits_list = self.bit(x1, x2)
# Construct rate map
prob_map = F.softmax(logits_list[0], axis=1)
rate_map = self._constr_rate_map(prob_map)
return logits_list
def _constr_iter_input(self, im, rate_map):
return paddle.concat([im.rate_map], axis=1)
def _init_rate_map(self, im_shape):
b, _, h, w = im_shape
return paddle.zeros((b, 1, h, w))
def _constr_rate_map(self, prob_map):
if prob_map.shape[1] != 2:
raise ValueError(
f"`prob_map.shape[1]` must be 2, but got {prob_map.shape[1]}.")
return (prob_map[:, 1:2] * self.gamma)