维石
2024-06-03 487189b949a2edad084e8275b35b72f324ba5218
bug fix
3个文件已修改
13 ■■■■■ 已修改文件
runtime/python/libtorch/demo_contextual_paraformer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/libtorch/demo_seaco_paraformer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/libtorch/funasr_torch/paraformer_bin.py 9 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/libtorch/demo_contextual_paraformer.py
@@ -7,7 +7,7 @@
model = ContextualParaformer(model_dir, batch_size=1, device_id=device_id)  # gpu
wav_path = "{}/.cache/modelscope/hub/{}/example/asr_example.wav".format(Path.home(), model_dir)
hotwords = "你的热词 魔哒"
hotwords = "你的热词 魔搭"
result = model(wav_path, hotwords)
print(result)
runtime/python/libtorch/demo_seaco_paraformer.py
@@ -7,7 +7,7 @@
model = SeacoParaformer(model_dir, batch_size=1, device_id=device_id)  # gpu
wav_path = "{}/.cache/modelscope/hub/{}/example/asr_example.wav".format(Path.home(), model_dir)
hotwords = "你的热词 魔哒"
hotwords = "你的热词 魔搭"
result = model(wav_path, hotwords)
print(result)
runtime/python/libtorch/funasr_torch/paraformer_bin.py
@@ -316,7 +316,10 @@
    ) -> List:
        # make hotword list
        hotwords, hotwords_length = self.proc_hotword(hotwords)
        bias_embed = self.eb_infer(torch.Tensor(hotwords))
        if int(self.device_id) != -1:
            bias_embed = self.eb_infer(hotwords.cuda())
        else:
            bias_embed = self.eb_infer(hotwords)
        # index from bias_embed
        bias_embed = torch.transpose(bias_embed, 0, 1)
        _ind = np.arange(0, len(hotwords)).tolist()
@@ -334,7 +337,7 @@
                        outputs = self.bb_infer(feats, feats_len, bias_embed)
                        am_scores, valid_token_lens = outputs[0], outputs[1]
                    else:
                        outputs = self.bb_infer_infer(feats.cuda(), feats_len.cuda(), bias_embed)
                        outputs = self.bb_infer_infer(feats.cuda(), feats_len.cuda(), bias_embed.cuda())
                        am_scores, valid_token_lens = outputs[0].cpu(), outputs[1].cpu()
            except:
                # logging.warning(traceback.format_exc())
@@ -369,7 +372,7 @@
        hotword_int = [word_map(i) for i in hotwords]
        hotword_int.append(np.array([1]))
        hotwords = pad_list(hotword_int, pad_value=0, max_len=10)
        return hotwords, hotwords_length
        return torch.tensor(hotwords), hotwords_length
    def bb_infer(
        self, feats, feats_len, bias_embed