Yabin Li
2024-06-25 b7060884fa4b8b85f79462644a5c99062d223da0
runtime/python/libtorch/funasr_torch/paraformer_bin.py
@@ -46,11 +46,11 @@
                    model_dir
                )
        model_file = os.path.join(model_dir, "model.torchscripts")
        model_file = os.path.join(model_dir, "model.torchscript")
        if quantize:
            model_file = os.path.join(model_dir, "model_quant.torchscripts")
            model_file = os.path.join(model_dir, "model_quant.torchscript")
        if not os.path.exists(model_file):
            print(".onnx is not exist, begin to export onnx")
            print(".torchscripts does not exist, begin to export torchscript")
            try:
                from funasr import AutoModel
            except:
@@ -268,21 +268,21 @@
                )
        if quantize:
            model_bb_file = os.path.join(model_dir, "model_bb_quant.torchscripts")
            model_eb_file = os.path.join(model_dir, "model_eb_quant.torchscripts")
            model_bb_file = os.path.join(model_dir, "model_bb_quant.torchscript")
            model_eb_file = os.path.join(model_dir, "model_eb_quant.torchscript")
        else:
            model_bb_file = os.path.join(model_dir, "model_bb.torchscripts")
            model_eb_file = os.path.join(model_dir, "model_eb.torchscripts")
            model_bb_file = os.path.join(model_dir, "model_bb.torchscript")
            model_eb_file = os.path.join(model_dir, "model_eb.torchscript")
        if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)):
            print(".onnx is not exist, begin to export onnx")
            print(".onnx does not exist, begin to export onnx")
            try:
                from funasr import AutoModel
            except:
                raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
            model = AutoModel(model=model_dir)
            model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
            model_dir = model.export(type="torchscript", quantize=quantize, **kwargs)
        config_file = os.path.join(model_dir, "config.yaml")
        cmvn_file = os.path.join(model_dir, "am.mvn")
@@ -316,10 +316,12 @@
    ) -> List:
        # make hotword list
        hotwords, hotwords_length = self.proc_hotword(hotwords)
        # import pdb; pdb.set_trace()
        [bias_embed] = self.eb_infer(hotwords, hotwords_length)
        if int(self.device_id) != -1:
            bias_embed = self.eb_infer(hotwords.cuda())
        else:
            bias_embed = self.eb_infer(hotwords)
        # index from bias_embed
        bias_embed = bias_embed.transpose(1, 0, 2)
        bias_embed = torch.transpose(bias_embed, 0, 1)
        _ind = np.arange(0, len(hotwords)).tolist()
        bias_embed = bias_embed[_ind, hotwords_length.tolist()]
        waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
@@ -328,15 +330,14 @@
        for beg_idx in range(0, waveform_nums, self.batch_size):
            end_idx = min(waveform_nums, beg_idx + self.batch_size)
            feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
            bias_embed = np.expand_dims(bias_embed, axis=0)
            bias_embed = np.repeat(bias_embed, feats.shape[0], axis=0)
            bias_embed = torch.unsqueeze(bias_embed, 0).repeat(feats.shape[0], 1, 1)
            try:
                with torch.no_grad():
                    if int(self.device_id) == -1:
                        outputs = self.ort_infer(feats, feats_len)
                        outputs = self.bb_infer(feats, feats_len, bias_embed)
                        am_scores, valid_token_lens = outputs[0], outputs[1]
                    else:
                        outputs = self.ort_infer(feats.cuda(), feats_len.cuda())
                        outputs = self.bb_infer(feats.cuda(), feats_len.cuda(), bias_embed.cuda())
                        am_scores, valid_token_lens = outputs[0].cpu(), outputs[1].cpu()
            except:
                # logging.warning(traceback.format_exc())
@@ -371,16 +372,16 @@
        hotword_int = [word_map(i) for i in hotwords]
        hotword_int.append(np.array([1]))
        hotwords = pad_list(hotword_int, pad_value=0, max_len=10)
        return hotwords, hotwords_length
        return torch.tensor(hotwords), hotwords_length
    def bb_infer(
        self, feats: np.ndarray, feats_len: np.ndarray, bias_embed
    ) -> Tuple[np.ndarray, np.ndarray]:
        outputs = self.ort_infer_bb([feats, feats_len, bias_embed])
        self, feats, feats_len, bias_embed
    ):
        outputs = self.ort_infer_bb(feats, feats_len, bias_embed)
        return outputs
    def eb_infer(self, hotwords, hotwords_length):
        outputs = self.ort_infer_eb([hotwords.astype(np.int32), hotwords_length.astype(np.int32)])
    def eb_infer(self, hotwords):
        outputs = self.ort_infer_eb(hotwords.long())
        return outputs
    def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]: