gaochangfeng
2024-04-10 851e3e3ef83d0769d9bde172d8841f6b20e3e377
Gcf (#1605)

* 修复无法预测nospeech标签的问题

* 修复prompt存储的设备的问题

---------

Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com>
Co-authored-by: zhifu gao <zhifu.gzf@alibaba-inc.com>
1个文件已修改
6 ■■■■■ 已修改文件
funasr/models/sense_voice/whisper_lib/decoding.py 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/whisper_lib/decoding.py
@@ -10,6 +10,8 @@
from .audio import CHUNK_LENGTH
from .tokenizer import Tokenizer, get_tokenizer
from .utils import compression_ratio
from funasr.models.transformer.utils.nets_utils import to_device
if TYPE_CHECKING:
    from .model import Whisper
@@ -58,6 +60,10 @@
    # x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device)  # [n_audio, 1]
    if x is None:
        x = torch.tensor([tokenizer.encode(initial_prompt, allowed_special="all")] * n_audio).to(mel.device)  # [n_audio, 1]
    else:
        x = x.to(mel.device)
    logits = model.logits(x[:,:-1], mel)[:, -1]
    # collect detected languages; suppress all non-language tokens
    mask = torch.ones(logits.shape[-1], dtype=torch.bool)