fix paraformer online last chunk decoding strategy
| | |
| | | feats_len = speech_lengths |
| | | |
| | | if feats.shape[1] != 0: |
| | | if cache_en["is_final"]: |
| | | if feats.shape[1] + cache_en["chunk_size"][2] < cache_en["chunk_size"][1]: |
| | | cache_en["last_chunk"] = True |
| | | else: |
| | | # first chunk |
| | | feats_chunk1 = feats[:, :cache_en["chunk_size"][1], :] |
| | | feats_len = torch.tensor([feats_chunk1.shape[1]]) |
| | | results_chunk1 = self.infer(feats_chunk1, feats_len, cache) |
| | | |
| | | # last chunk |
| | | cache_en["last_chunk"] = True |
| | | feats_chunk2 = feats[:, -(feats.shape[1] + cache_en["chunk_size"][2] - cache_en["chunk_size"][1]):, :] |
| | | feats_len = torch.tensor([feats_chunk2.shape[1]]) |
| | | results_chunk2 = self.infer(feats_chunk2, feats_len, cache) |
| | | |
| | | return [" ".join(results_chunk1 + results_chunk2)] |
| | | |
| | | results = self.infer(feats, feats_len, cache) |
| | | |
| | | return results |
| | |
| | | def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}): |
| | | if len(cache) == 0: |
| | | return feats |
| | | # process last chunk |
| | | cache["feats"] = to_device(cache["feats"], device=feats.device) |
| | | overlap_feats = torch.cat((cache["feats"], feats), dim=1) |
| | | if cache["is_final"]: |
| | | cache["feats"] = overlap_feats[:, -cache["chunk_size"][0]:, :] |
| | | if not cache["last_chunk"]: |
| | | padding_length = sum(cache["chunk_size"]) - overlap_feats.shape[1] |
| | | overlap_feats = overlap_feats.transpose(1, 2) |
| | | overlap_feats = F.pad(overlap_feats, (0, padding_length)) |
| | | overlap_feats = overlap_feats.transpose(1, 2) |
| | | else: |
| | | cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :] |
| | | cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :] |
| | | return overlap_feats |
| | | |
| | | def forward_chunk(self, |
| | |
| | |
|
| | | if cache is not None and "chunk_size" in cache:
|
| | | alphas[:, :cache["chunk_size"][0]] = 0.0
|
| | | alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
|
| | | if "is_final" in cache and not cache["is_final"]:
|
| | | alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
|
| | | if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
|
| | | cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device)
|
| | | cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device)
|
| | | hidden = torch.cat((cache["cif_hidden"], hidden), dim=1)
|
| | | alphas = torch.cat((cache["cif_alphas"], alphas), dim=1)
|
| | | if cache is not None and "last_chunk" in cache and cache["last_chunk"]:
|
| | | if cache is not None and "is_final" in cache and cache["is_final"]:
|
| | | tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device)
|
| | | tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device)
|
| | | tail_alphas = torch.tile(tail_alphas, (batch_size, 1))
|