own
Bobholamovic 2 years ago
parent 7cccfe93c3
commit a373e11835
  1. 2
      docs/intro/transforms.md
  2. 15
      examples/rs_research/custom_trainer.py
  3. 5
      examples/rs_research/train_cd.py
  4. 4
      paddlers/tasks/change_detector.py

@ -12,7 +12,7 @@ PaddleRS对不同遥感任务需要的数据预处理/数据增强(合称为
| RandomResizeByShort | 随机调整输入影像大小,保持纵横比不变(根据短边计算缩放系数)。 | 所有任务 | ... |
| ResizeByLong | 调整输入影像大小,保持纵横比不变(根据长边计算缩放系数)。 | 所有任务 | ... |
| RandomHorizontalFlip | 随机水平翻转输入影像。 | 所有任务 | ... |
| RandomVerticalFlip | 随机直翻转输入影像。 | 所有任务 | ... |
| RandomVerticalFlip | 随机直翻转输入影像。 | 所有任务 | ... |
| Normalize | 对输入影像应用标准化。 | 所有任务 | ... |
| CenterCrop | 对输入影像进行中心裁剪。 | 所有任务 | ... |
| RandomCrop | 对输入影像进行随机中心裁剪。 | 所有任务 | ... |

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import paddle
import paddlers
from paddlers.tasks.change_detector import BaseChangeDetector
@ -27,12 +29,21 @@ def make_trainer(net_type, *args, **kwargs):
use_mixed_loss=False,
losses=None,
**params):
super().__init__(
sig = inspect.signature(net_type.__init__)
net_params = {
k: p.default
for k, p in sig.parameters.items() if not p.default is p.empty
}
net_params.pop('self', None)
net_params.pop('num_classes', None)
net_params.update(params)
super(trainer_type, self).__init__(
model_name=net_type.__name__,
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
losses=losses,
**params)
**net_params)
if not issubclass(net_type, paddle.nn.Layer):
raise TypeError("net must be a subclass of paddle.nn.Layer")

@ -76,8 +76,7 @@ test_dataset = pdrs.datasets.CDDataset(
# 构建自定义模型CustomModel并为其自动生成训练器
# make_trainer()的首个参数为模型类型,剩余参数为模型构造所需参数
# 这里使用默认参数构造
model = make_trainer(CustomModel)
model = make_trainer(CustomModel, in_channels=3)
# 构建学习率调度器
# 使用定步长学习率衰减策略
@ -86,7 +85,7 @@ lr_scheduler = paddle.optimizer.lr.StepDecay(
# 构建优化器
optimizer = paddle.optimizer.Adam(
model.net.parameters(), learning_rate=lr_scheduler)
parameters=model.net.parameters(), learning_rate=lr_scheduler)
# 执行模型训练
model.train(

@ -52,9 +52,7 @@ class BaseChangeDetector(BaseModel):
if 'with_net' in self.init_params:
del self.init_params['with_net']
super(BaseChangeDetector, self).__init__('change_detector')
if model_name not in __all__:
raise ValueError("ERROR: There is no model named {}.".format(
model_name))
self.model_name = model_name
self.num_classes = num_classes
self.use_mixed_loss = use_mixed_loss

Loading…
Cancel
Save