diff --git a/paddlers/custom_models/gan/generators/rcan.py b/paddlers/custom_models/gan/generators/rcan.py index 6724c10..83c44e3 100644 --- a/paddlers/custom_models/gan/generators/rcan.py +++ b/paddlers/custom_models/gan/generators/rcan.py @@ -7,12 +7,14 @@ from .builder import GENERATORS def default_conv(in_channels, out_channels, kernel_size, bias=True): - return nn.Conv2D( - in_channels, - out_channels, - kernel_size, - padding=(kernel_size // 2), - bias_attr=bias) + weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.XavierUniform(), + need_clip =True) + return nn.Conv2D(in_channels, + out_channels, + kernel_size, + padding=(kernel_size // 2), + weight_attr=weight_attr, + bias_attr=bias) class MeanShift(nn.Conv2D):