shixian.shi
2023-09-12 9c622feb645ee8ab166cd6d5fc9d0b2130a0f5fd
update proc for oov in hotword onnx inference
2个文件已修改
11 ■■■■ 已修改文件
funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py 9 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py
@@ -5,7 +5,7 @@
model = ContextualParaformer(model_dir, batch_size=1)
wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/example/asr_example.wav'.format(Path.home())]
hotwords = '随机热词 各种热词 魔搭 阿里巴巴'
hotwords = '随机热词 各种热词 魔搭 阿里巴巴 仏'
result = model(wav_path, hotwords)
print(result)
funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -314,7 +314,14 @@
        hotwords_length = torch.Tensor(hotwords_length).to(torch.int32)
        # hotwords.append('<s>')
        def word_map(word):
            return torch.tensor([self.vocab[i] for i in word])
            hotwords = []
            for c in word:
                if c not in self.vocab.keys():
                    hotwords.append(8403)
                    logging.warning("oov character {} found in hotword {}, replaced by <unk>".format(c, word))
                else:
                    hotwords.append(self.vocab[c])
            return torch.tensor(hotwords)
        hotword_int = [word_map(i) for i in hotwords]
        # import pdb; pdb.set_trace()
        hotword_int.append(torch.tensor([1]))