|
@@ -329,6 +329,8 @@ class Trainer:
|
|
|
{'params': self.model.backbone.parameters(), 'lr': self.args.learning_rate_backbone},
|
|
|
{'params': self.model.aspp.parameters(), 'lr': self.args.learning_rate_aspp},
|
|
|
{'params': self.model.decoder.parameters(), 'lr': self.args.learning_rate_decoder},
|
|
|
+ {'params': self.model.project_mat.parameters(), 'lr': self.args.learning_rate_decoder},
|
|
|
+ {'params': self.model.project_seg.parameters(), 'lr': self.args.learning_rate_decoder},
|
|
|
{'params': self.model.refiner.parameters(), 'lr': self.args.learning_rate_refiner},
|
|
|
])
|
|
|
self.scaler = GradScaler()
|