| | |
| | | from funasr.models.transformer.scorers.ctc_prefix_score import CTCPrefixScoreTH |
| | | from funasr.models.transformer.scorers.scorer_interface import BatchPartialScorerInterface |
| | | |
| | | |
| | | class CTCPrefixScorer(BatchPartialScorerInterface): |
| | | """Decoder interface wrapper for CTCPrefixScore.""" |
| | | |
| | |
| | | """ |
| | | prev_score, state = state |
| | | presub_score, new_st = self.impl(y.cpu(), ids.cpu(), state) |
| | | tscore = torch.as_tensor( |
| | | presub_score - prev_score, device=x.device, dtype=x.dtype |
| | | ) |
| | | tscore = torch.as_tensor(presub_score - prev_score, device=x.device, dtype=x.dtype) |
| | | return tscore, (presub_score, new_st) |
| | | |
| | | def batch_init_state(self, x: torch.Tensor): |