| | |
| | | is_pad_mask = kwargs.get("is_pad_mask", False) |
| | | is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False) |
| | | |
| | | fsmn_cache = cache[layer]["fsmn_cache"] if len(cache) > 0 or cache is None else None |
| | | fsmn_cache = cache[layer]["fsmn_cache"] if cache is not None and len(cache) > 0 else None |
| | | # if fsmn_cache is not None: |
| | | # x = x[:, -1:] |
| | | att_res, fsmn_cache = self.attn(self.attn_ln(x), mask=None, cache=fsmn_cache) |
| | |
| | | if len(kwargs.get("data_type", [])) > 1: |
| | | audio_sample_list, text_token_int_list = audio_sample_list |
| | | text_token_int = text_token_int_list[0] |
| | | text_token_int = tokenizer.encode(text_token_int) |
| | | else: |
| | | text_token_int = None |
| | | |
| | |
| | | ) |
| | | |
| | | if text_token_int is not None: |
| | | i = 1 |
| | | i = 0 |
| | | results = [] |
| | | ibest_writer = None |
| | | if kwargs.get("output_dir") is not None: |
| | |
| | | ibest_writer = self.writer[f"1best_recog"] |
| | | |
| | | # 1. Forward decoder |
| | | ys_pad = torch.tensor(text_token_int, dtype=torch.int64).to(kwargs["device"])[None, :] |
| | | ys_pad = torch.tensor(sos_int + text_token_int, dtype=torch.int64).to(kwargs["device"])[ |
| | | None, : |
| | | ] |
| | | ys_pad_lens = torch.tensor([len(text_token_int)], dtype=torch.int64).to( |
| | | kwargs["device"] |
| | | )[None, :] |
| | |
| | | |
| | | else: |
| | | x = x.to(mel.device) |
| | | # FIX(funasr): sense vocie |
| | | # logits = model.logits(x[:, :-1], mel)[:, -1] |
| | | logits = model.logits(x[:, :], mel)[:, -1] |
| | | |
| | | logits = model.logits(x[:, :-1], mel)[:, -1] |
| | | # collect detected languages; suppress all non-language tokens |
| | | mask = torch.ones(logits.shape[-1], dtype=torch.bool) |
| | | mask[list(tokenizer.all_language_tokens)] = False |