You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
58 lines
1.7 KiB
58 lines
1.7 KiB
import paddle |
|
import paddle.nn as nn |
|
import paddle.nn.functional as F |
|
import paddlers |
|
from paddlers.rs_models.cd import BIT |
|
from attach_tools import Attach |
|
|
|
attach = Attach.to(paddlers.rs_models.cd) |
|
|
|
|
|
@attach |
|
class IterativeBIT(nn.Layer): |
|
def __init__(self, num_iters=1, gamma=0.1, num_classes=2, bit_kwargs=None): |
|
super().__init__() |
|
|
|
if num_iters <= 0: |
|
raise ValueError( |
|
f"`num_iters` should have positive value, but got {num_iters}.") |
|
|
|
self.num_iters = num_iters |
|
self.gamma = gamma |
|
|
|
if bit_kwargs is None: |
|
bit_kwargs = dict() |
|
|
|
if 'num_classes' in bit_kwargs: |
|
raise KeyError("'num_classes' should not be set in `bit_kwargs`.") |
|
bit_kwargs['num_classes'] = num_classes |
|
|
|
self.bit = BIT(**bit_kwargs) |
|
|
|
def forward(self, t1, t2): |
|
rate_map = self._init_rate_map(t1.shape) |
|
|
|
for it in range(self.num_iters): |
|
# Construct inputs |
|
x1 = self._constr_iter_input(t1, rate_map) |
|
x2 = self._constr_iter_input(t2, rate_map) |
|
# Get logits |
|
logits_list = self.bit(x1, x2) |
|
# Construct rate map |
|
prob_map = F.softmax(logits_list[0], axis=1) |
|
rate_map = self._constr_rate_map(prob_map) |
|
|
|
return logits_list |
|
|
|
def _constr_iter_input(self, im, rate_map): |
|
return paddle.concat([im.rate_map], axis=1) |
|
|
|
def _init_rate_map(self, im_shape): |
|
b, _, h, w = im_shape |
|
return paddle.zeros((b, 1, h, w)) |
|
|
|
def _constr_rate_map(self, prob_map): |
|
if prob_map.shape[1] != 2: |
|
raise ValueError( |
|
f"`prob_map.shape[1]` must be 2, but got {prob_map.shape[1]}.") |
|
return (prob_map[:, 1:2] * self.gamma)
|
|
|