| | |
| | | self.batch_total = 0 |
| | | self.use_fp16 = use_fp16 |
| | | self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000) |
| | | self.validate_interval = kwargs.get("validate_interval", 5000) |
| | | self.validate_interval = kwargs.get("validate_interval", -1) |
| | | if self.validate_interval < 0: |
| | | self.validate_interval = self.save_checkpoint_interval |
| | | assert ( |
| | | self.save_checkpoint_interval == self.validate_interval |
| | | ), f"save_checkpoint_interval must equal to validate_interval" |
| | | self.keep_nbest_models = kwargs.get("keep_nbest_models", 500) |
| | | self.avg_keep_nbest_models_type = kwargs.get("avg_keep_nbest_models_type", "acc") |
| | | self.avg_nbest_model = kwargs.get("avg_nbest_model", 10) |