维石
2024-06-03 fd0992af3d1a2d2d098b1fab24f67f8c4cece39d
runtime/python/libtorch/funasr_torch/paraformer_bin.py
@@ -275,7 +275,7 @@
            model_eb_file = os.path.join(model_dir, "model_eb.torchscripts")
        if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)):
            print(".onnx is not exist, begin to export onnx")
            print(".onnx does not exist, begin to export onnx")
            try:
                from funasr import AutoModel
            except:
@@ -316,8 +316,7 @@
    ) -> List:
        # make hotword list
        hotwords, hotwords_length = self.proc_hotword(hotwords)
        # import pdb; pdb.set_trace()
        [bias_embed] = self.eb_infer(hotwords, hotwords_length)
        [bias_embed] = self.eb_infer(torch.Tensor(hotwords), torch.Tensor(hotwords_length))
        # index from bias_embed
        bias_embed = bias_embed.transpose(1, 0, 2)
        _ind = np.arange(0, len(hotwords)).tolist()
@@ -333,10 +332,10 @@
            try:
                with torch.no_grad():
                    if int(self.device_id) == -1:
                        outputs = self.ort_infer(feats, feats_len)
                        outputs = self.bb_infer(feats, feats_len)
                        am_scores, valid_token_lens = outputs[0], outputs[1]
                    else:
                        outputs = self.ort_infer(feats.cuda(), feats_len.cuda())
                        outputs = self.bb_infer_infer(feats.cuda(), feats_len.cuda())
                        am_scores, valid_token_lens = outputs[0].cpu(), outputs[1].cpu()
            except:
                # logging.warning(traceback.format_exc())
@@ -374,13 +373,13 @@
        return hotwords, hotwords_length
    def bb_infer(
        self, feats: np.ndarray, feats_len: np.ndarray, bias_embed
        self, feats, feats_len, bias_embed
    ) -> Tuple[np.ndarray, np.ndarray]:
        outputs = self.ort_infer_bb([feats, feats_len, bias_embed])
        return outputs
    def eb_infer(self, hotwords, hotwords_length):
        outputs = self.ort_infer_eb([hotwords.astype(np.int32), hotwords_length.astype(np.int32)])
        outputs = self.ort_infer_eb([hotwords, hotwords_length])
        return outputs
    def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]: