Gcf (#1605)
* 修复无法预测nospeech标签的问题
* 修复prompt存储的设备的问题
---------
Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com>
Co-authored-by: zhifu gao <zhifu.gzf@alibaba-inc.com>
| | |
| | | from .audio import CHUNK_LENGTH |
| | | from .tokenizer import Tokenizer, get_tokenizer |
| | | from .utils import compression_ratio |
| | | from funasr.models.transformer.utils.nets_utils import to_device |
| | | |
| | | |
| | | if TYPE_CHECKING: |
| | | from .model import Whisper |
| | |
| | | # x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1] |
| | | if x is None: |
| | | x = torch.tensor([tokenizer.encode(initial_prompt, allowed_special="all")] * n_audio).to(mel.device) # [n_audio, 1] |
| | | |
| | | else: |
| | | x = x.to(mel.device) |
| | | |
| | | logits = model.logits(x[:,:-1], mel)[:, -1] |
| | | # collect detected languages; suppress all non-language tokens |
| | | mask = torch.ones(logits.shape[-1], dtype=torch.bool) |