|
|
@ -18,7 +18,7 @@ class PositionEncodingSine(nn.Module): |
|
|
|
pe = torch.zeros((d_model, *max_shape)) |
|
|
|
pe = torch.zeros((d_model, *max_shape)) |
|
|
|
y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0) |
|
|
|
y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0) |
|
|
|
x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0) |
|
|
|
x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0) |
|
|
|
div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2)) |
|
|
|
div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2))) |
|
|
|
div_term = div_term[:, None, None] # [C//4, 1, 1] |
|
|
|
div_term = div_term[:, None, None] # [C//4, 1, 1] |
|
|
|
pe[0::4, :, :] = torch.sin(x_position * div_term) |
|
|
|
pe[0::4, :, :] = torch.sin(x_position * div_term) |
|
|
|
pe[1::4, :, :] = torch.cos(x_position * div_term) |
|
|
|
pe[1::4, :, :] = torch.cos(x_position * div_term) |
|
|
|