From f591f33111453c674bb80b8a8fa9c0bff29477e1 Mon Sep 17 00:00:00 2001
From: 维石 <shixian.shi@alibaba-inc.com>
Date: 星期一, 03 六月 2024 15:15:52 +0800
Subject: [PATCH] update libtorch infer

---
 runtime/python/libtorch/funasr_torch/paraformer_bin.py |  202 +++++++++++++++++++++++++++++++++++++++++++++++--
 1 files changed, 191 insertions(+), 11 deletions(-)

diff --git a/runtime/python/libtorch/funasr_torch/paraformer_bin.py b/runtime/python/libtorch/funasr_torch/paraformer_bin.py
index b7fb14b..e9642c7 100644
--- a/runtime/python/libtorch/funasr_torch/paraformer_bin.py
+++ b/runtime/python/libtorch/funasr_torch/paraformer_bin.py
@@ -1,21 +1,20 @@
 # -*- encoding: utf-8 -*-
+import json
+import copy
+import torch
 import os.path
+import librosa
+import numpy as np
 from pathlib import Path
 from typing import List, Union, Tuple
 
-import copy
-import librosa
-import numpy as np
-
-from .utils.utils import CharTokenizer, Hypothesis, TokenIDConverter, get_logger, read_yaml
-from .utils.postprocess_utils import sentence_postprocess
+from .utils.utils import pad_list
 from .utils.frontend import WavFrontend
 from .utils.timestamp_utils import time_stamp_lfr6_onnx
+from .utils.postprocess_utils import sentence_postprocess
+from .utils.utils import CharTokenizer, Hypothesis, TokenIDConverter, get_logger, read_yaml
 
 logging = get_logger()
-
-import torch
-import json
 
 
 class Paraformer:
@@ -32,7 +31,6 @@
         device_id: Union[str, int] = "-1",
         plot_timestamp_to: str = "",
         quantize: bool = False,
-        intra_op_num_threads: int = 4,
         cache_dir: str = None,
         **kwargs,
     ):
@@ -236,4 +234,186 @@
         token = self.converter.ids2tokens(token_int)
         token = token[: valid_token_num - self.pred_bias]
         # texts = sentence_postprocess(token)
-        return token
\ No newline at end of file
+        return token
+
+    
+class ContextualParaformer(Paraformer):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+    https://arxiv.org/abs/2206.08317
+    """
+
+    def __init__(
+        self,
+        model_dir: Union[str, Path] = None,
+        batch_size: int = 1,
+        device_id: Union[str, int] = "-1",
+        plot_timestamp_to: str = "",
+        quantize: bool = False,
+        cache_dir: str = None,
+        **kwargs,
+    ):
+
+        if not Path(model_dir).exists():
+            try:
+                from modelscope.hub.snapshot_download import snapshot_download
+            except:
+                raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" "\npip3 install -U modelscope\n" "For the users in China, you could install with the command:\n" "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
+            try:
+                model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
+            except:
+                raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
+                    model_dir
+                )
+
+        if quantize:
+            model_bb_file = os.path.join(model_dir, "model_bb_quant.torchscripts")
+            model_eb_file = os.path.join(model_dir, "model_eb_quant.torchscripts")
+        else:
+            model_bb_file = os.path.join(model_dir, "model_bb.torchscripts")
+            model_eb_file = os.path.join(model_dir, "model_eb.torchscripts")
+
+        if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)):
+            print(".onnx is not exist, begin to export onnx")
+            try:
+                from funasr import AutoModel
+            except:
+                raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
+
+            model = AutoModel(model=model_dir)
+            model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
+
+        config_file = os.path.join(model_dir, "config.yaml")
+        cmvn_file = os.path.join(model_dir, "am.mvn")
+        config = read_yaml(config_file)
+        token_list = os.path.join(model_dir, "tokens.json")
+        with open(token_list, "r", encoding="utf-8") as f:
+            token_list = json.load(f)
+
+        # revert token_list into vocab dict
+        self.vocab = {}
+        for i, token in enumerate(token_list):
+            self.vocab[token] = i
+
+        self.converter = TokenIDConverter(token_list)
+        self.tokenizer = CharTokenizer()
+        self.frontend = WavFrontend(cmvn_file=cmvn_file, **config["frontend_conf"])
+        
+        self.ort_infer_bb = torch.jit.load(model_bb_file)
+        self.ort_infer_eb = torch.jit.load(model_eb_file)
+        self.device_id = device_id
+
+        self.batch_size = batch_size
+        self.plot_timestamp_to = plot_timestamp_to
+        if "predictor_bias" in config["model_conf"].keys():
+            self.pred_bias = config["model_conf"]["predictor_bias"]
+        else:
+            self.pred_bias = 0
+
+    def __call__(
+        self, wav_content: Union[str, np.ndarray, List[str]], hotwords: str, **kwargs
+    ) -> List:
+        # make hotword list
+        hotwords, hotwords_length = self.proc_hotword(hotwords)
+        # import pdb; pdb.set_trace()
+        [bias_embed] = self.eb_infer(hotwords, hotwords_length)
+        # index from bias_embed
+        bias_embed = bias_embed.transpose(1, 0, 2)
+        _ind = np.arange(0, len(hotwords)).tolist()
+        bias_embed = bias_embed[_ind, hotwords_length.tolist()]
+        waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
+        waveform_nums = len(waveform_list)
+        asr_res = []
+        for beg_idx in range(0, waveform_nums, self.batch_size):
+            end_idx = min(waveform_nums, beg_idx + self.batch_size)
+            feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
+            bias_embed = np.expand_dims(bias_embed, axis=0)
+            bias_embed = np.repeat(bias_embed, feats.shape[0], axis=0)
+            try:
+                with torch.no_grad():
+                    if int(self.device_id) == -1:
+                        outputs = self.ort_infer(feats, feats_len)
+                        am_scores, valid_token_lens = outputs[0], outputs[1]
+                    else:
+                        outputs = self.ort_infer(feats.cuda(), feats_len.cuda())
+                        am_scores, valid_token_lens = outputs[0].cpu(), outputs[1].cpu()
+            except:
+                # logging.warning(traceback.format_exc())
+                logging.warning("input wav is silence or noise")
+                preds = [""]
+            else:
+                preds = self.decode(am_scores, valid_token_lens)
+                for pred in preds:
+                    pred = sentence_postprocess(pred)
+                    asr_res.append({"preds": pred})
+        return asr_res
+
+    def proc_hotword(self, hotwords):
+        hotwords = hotwords.split(" ")
+        hotwords_length = [len(i) - 1 for i in hotwords]
+        hotwords_length.append(0)
+        hotwords_length = np.array(hotwords_length)
+
+        # hotwords.append('<s>')
+        def word_map(word):
+            hotwords = []
+            for c in word:
+                if c not in self.vocab.keys():
+                    hotwords.append(8403)
+                    logging.warning(
+                        "oov character {} found in hotword {}, replaced by <unk>".format(c, word)
+                    )
+                else:
+                    hotwords.append(self.vocab[c])
+            return np.array(hotwords)
+
+        hotword_int = [word_map(i) for i in hotwords]
+        hotword_int.append(np.array([1]))
+        hotwords = pad_list(hotword_int, pad_value=0, max_len=10)
+        return hotwords, hotwords_length
+
+    def bb_infer(
+        self, feats: np.ndarray, feats_len: np.ndarray, bias_embed
+    ) -> Tuple[np.ndarray, np.ndarray]:
+        outputs = self.ort_infer_bb([feats, feats_len, bias_embed])
+        return outputs
+
+    def eb_infer(self, hotwords, hotwords_length):
+        outputs = self.ort_infer_eb([hotwords.astype(np.int32), hotwords_length.astype(np.int32)])
+        return outputs
+
+    def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
+        return [
+            self.decode_one(am_score, token_num)
+            for am_score, token_num in zip(am_scores, token_nums)
+        ]
+
+    def decode_one(self, am_score: np.ndarray, valid_token_num: int) -> List[str]:
+        yseq = am_score.argmax(axis=-1)
+        score = am_score.max(axis=-1)
+        score = np.sum(score, axis=-1)
+
+        # pad with mask tokens to ensure compatibility with sos/eos tokens
+        # asr_model.sos:1  asr_model.eos:2
+        yseq = np.array([1] + yseq.tolist() + [2])
+        hyp = Hypothesis(yseq=yseq, score=score)
+
+        # remove sos/eos and get results
+        last_pos = -1
+        token_int = hyp.yseq[1:last_pos].tolist()
+
+        # remove blank symbol id, which is assumed to be 0
+        token_int = list(filter(lambda x: x not in (0, 2), token_int))
+
+        # Change integer-ids to tokens
+        token = self.converter.ids2tokens(token_int)
+        token = token[: valid_token_num - self.pred_bias]
+        # texts = sentence_postprocess(token)
+        return token
+
+
+class SeacoParaformer(ContextualParaformer):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        # no difference with contextual_paraformer in method of calling onnx models

--
Gitblit v1.9.1