|
|
@ -980,7 +980,7 @@ class PSA(nn.Module): |
|
|
|
self.cv1 = Conv(c1, 2 * self.c, 1, 1) |
|
|
|
self.cv1 = Conv(c1, 2 * self.c, 1, 1) |
|
|
|
self.cv2 = Conv(2 * self.c, c1, 1) |
|
|
|
self.cv2 = Conv(2 * self.c, c1, 1) |
|
|
|
|
|
|
|
|
|
|
|
self.attn = Attention(self.c, attn_ratio=0.5, num_heads=self.c // 128) |
|
|
|
self.attn = Attention(self.c, attn_ratio=0.5, num_heads=self.c // 32) |
|
|
|
self.ffn = nn.Sequential(Conv(self.c, self.c * 2, 1), Conv(self.c * 2, self.c, 1, act=False)) |
|
|
|
self.ffn = nn.Sequential(Conv(self.c, self.c * 2, 1), Conv(self.c * 2, self.c, 1, act=False)) |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
def forward(self, x): |
|
|
|