Shi Xian
2024-06-18 6c467e6f0abfc6d20d0621fbbf67b4dbd81776cc
runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py
@@ -8,74 +8,78 @@
import librosa
import numpy as np
from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
                          OrtInferSession, TokenIDConverter, get_logger,
                          read_yaml)
from .utils.utils import (
    CharTokenizer,
    Hypothesis,
    ONNXRuntimeError,
    OrtInferSession,
    TokenIDConverter,
    get_logger,
    read_yaml,
)
from .utils.postprocess_utils import sentence_postprocess
from .utils.frontend import WavFrontendOnline, SinusoidalPositionEncoderOnline
logging = get_logger()
class Paraformer():
    def __init__(self, model_dir: Union[str, Path] = None,
                 batch_size: int = 1,
                 chunk_size: List = [5, 10, 5],
                 device_id: Union[str, int] = "-1",
                 quantize: bool = False,
                 intra_op_num_threads: int = 4,
                 cache_dir: str = None
                 ):
class Paraformer:
    def __init__(
        self,
        model_dir: Union[str, Path] = None,
        batch_size: int = 1,
        chunk_size: List = [5, 10, 5],
        device_id: Union[str, int] = "-1",
        quantize: bool = False,
        intra_op_num_threads: int = 4,
        cache_dir: str = None,
        **kwargs,
    ):
        if not Path(model_dir).exists():
            try:
                from modelscope.hub.snapshot_download import snapshot_download
            except:
                raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
                      "\npip3 install -U modelscope\n" \
                      "For the users in China, you could install with the command:\n" \
                      "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
                raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" "\npip3 install -U modelscope\n" "For the users in China, you could install with the command:\n" "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
            try:
                model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
            except:
                raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir)
        encoder_model_file = os.path.join(model_dir, 'model.onnx')
        decoder_model_file = os.path.join(model_dir, 'decoder.onnx')
                raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
                    model_dir
                )
        encoder_model_file = os.path.join(model_dir, "model.onnx")
        decoder_model_file = os.path.join(model_dir, "decoder.onnx")
        if quantize:
            encoder_model_file = os.path.join(model_dir, 'model_quant.onnx')
            decoder_model_file = os.path.join(model_dir, 'decoder_quant.onnx')
            encoder_model_file = os.path.join(model_dir, "model_quant.onnx")
            decoder_model_file = os.path.join(model_dir, "decoder_quant.onnx")
        if not os.path.exists(encoder_model_file) or not os.path.exists(decoder_model_file):
            print(".onnx is not exist, begin to export onnx")
            print(".onnx does not exist, begin to export onnx")
            try:
                from funasr import AutoModel
            except:
                raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
                      "\npip3 install -U funasr\n" \
                      "For the users in China, you could install with the command:\n" \
                      "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
                raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
            model = AutoModel(model=model_dir)
            model_dir = model.export(type="onnx", quantize=quantize)
            model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
        config_file = os.path.join(model_dir, 'config.yaml')
        cmvn_file = os.path.join(model_dir, 'am.mvn')
        config_file = os.path.join(model_dir, "config.yaml")
        cmvn_file = os.path.join(model_dir, "am.mvn")
        config = read_yaml(config_file)
        token_list = os.path.join(model_dir, 'tokens.json')
        with open(token_list, 'r', encoding='utf-8') as f:
        token_list = os.path.join(model_dir, "tokens.json")
        with open(token_list, "r", encoding="utf-8") as f:
            token_list = json.load(f)
        self.converter = TokenIDConverter(token_list)
        self.tokenizer = CharTokenizer()
        self.frontend = WavFrontendOnline(
            cmvn_file=cmvn_file,
            **config['frontend_conf']
        )
        self.frontend = WavFrontendOnline(cmvn_file=cmvn_file, **config["frontend_conf"])
        self.pe = SinusoidalPositionEncoderOnline()
        self.ort_encoder_infer = OrtInferSession(encoder_model_file, device_id,
                                                 intra_op_num_threads=intra_op_num_threads)
        self.ort_decoder_infer = OrtInferSession(decoder_model_file, device_id,
                                                 intra_op_num_threads=intra_op_num_threads)
        self.ort_encoder_infer = OrtInferSession(
            encoder_model_file, device_id, intra_op_num_threads=intra_op_num_threads
        )
        self.ort_decoder_infer = OrtInferSession(
            decoder_model_file, device_id, intra_op_num_threads=intra_op_num_threads
        )
        self.batch_size = batch_size
        self.chunk_size = chunk_size
        self.encoder_output_size = config["encoder_conf"]["output_size"]
@@ -94,7 +98,9 @@
        cache["cif_alphas"] = np.zeros((batch_size, 1)).astype(np.float32)
        cache["chunk_size"] = self.chunk_size
        cache["last_chunk"] = False
        cache["feats"] = np.zeros((batch_size, self.chunk_size[0] + self.chunk_size[2], self.feats_dims)).astype(np.float32)
        cache["feats"] = np.zeros(
            (batch_size, self.chunk_size[0] + self.chunk_size[2], self.feats_dims)
        ).astype(np.float32)
        cache["decoder_fsmn"] = []
        for i in range(self.fsmn_layer):
            fsmn_cache = np.zeros((batch_size, self.fsmn_dims, self.fsmn_lorder)).astype(np.float32)
@@ -107,31 +113,31 @@
        # process last chunk
        overlap_feats = np.concatenate((cache["feats"], feats), axis=1)
        if cache["is_final"]:
            cache["feats"] = overlap_feats[:, -self.chunk_size[0]:, :]
            cache["feats"] = overlap_feats[:, -self.chunk_size[0] :, :]
            if not cache["last_chunk"]:
               padding_length = sum(self.chunk_size) - overlap_feats.shape[1]
               overlap_feats = np.pad(overlap_feats, ((0, 0), (0, padding_length), (0, 0)))
                padding_length = sum(self.chunk_size) - overlap_feats.shape[1]
                overlap_feats = np.pad(overlap_feats, ((0, 0), (0, padding_length), (0, 0)))
        else:
            cache["feats"] = overlap_feats[:, -(self.chunk_size[0] + self.chunk_size[2]):, :]
            cache["feats"] = overlap_feats[:, -(self.chunk_size[0] + self.chunk_size[2]) :, :]
        return overlap_feats
    def __call__(self, audio_in: np.ndarray, **kwargs):
        waveforms = np.expand_dims(audio_in, axis=0)
        param_dict = kwargs.get('param_dict', dict())
        is_final = param_dict.get('is_final', False)
        cache = param_dict.get('cache', dict())
        param_dict = kwargs.get("param_dict", dict())
        is_final = param_dict.get("is_final", False)
        cache = param_dict.get("cache", dict())
        asr_res = []
        if waveforms.shape[1] < 16 * 60 and is_final and len(cache) > 0:
            cache["last_chunk"] = True
            feats = cache["feats"]
            feats_len = np.array([feats.shape[1]]).astype(np.int32)
            asr_res = self.infer(feats, feats_len, cache)
            return asr_res
        feats, feats_len = self.extract_feat(waveforms, is_final)
        if feats.shape[1] != 0:
            feats *= self.encoder_output_size ** 0.5
            feats *= self.encoder_output_size**0.5
            cache = self.prepare_cache(cache)
            cache["is_final"] = is_final
@@ -144,16 +150,19 @@
                    feats = self.add_overlap_chunk(feats, cache)
                else:
                    # first chunk
                    feats_chunk1 = self.add_overlap_chunk(feats[:, :self.chunk_size[1], :], cache)
                    feats_chunk1 = self.add_overlap_chunk(feats[:, : self.chunk_size[1], :], cache)
                    feats_len = np.array([feats_chunk1.shape[1]]).astype(np.int32)
                    asr_res_chunk1 = self.infer(feats_chunk1, feats_len, cache)
                    # last chunk
                    cache["last_chunk"] = True
                    feats_chunk2 = self.add_overlap_chunk(feats[:, -(feats.shape[1] + self.chunk_size[2] - self.chunk_size[1]):, :], cache)
                    feats_chunk2 = self.add_overlap_chunk(
                        feats[:, -(feats.shape[1] + self.chunk_size[2] - self.chunk_size[1]) :, :],
                        cache,
                    )
                    feats_len = np.array([feats_chunk2.shape[1]]).astype(np.int32)
                    asr_res_chunk2 = self.infer(feats_chunk2, feats_len, cache)
                    asr_res_chunk = asr_res_chunk1 + asr_res_chunk2
                    res = {}
                    for pred in asr_res_chunk:
@@ -187,17 +196,18 @@
            dec_input.extend(cache["decoder_fsmn"])
            dec_output = self.ort_decoder_infer(dec_input)
            logits, sample_ids, cache["decoder_fsmn"] = dec_output[0], dec_output[1], dec_output[2:]
            cache["decoder_fsmn"] = [item[:, :, -self.fsmn_lorder:] for item in cache["decoder_fsmn"]]
            cache["decoder_fsmn"] = [
                item[:, :, -self.fsmn_lorder :] for item in cache["decoder_fsmn"]
            ]
            preds = self.decode(logits, acoustic_embeds_len)
            for pred in preds:
                pred = sentence_postprocess(pred)
                asr_res.append({'preds': pred})
                asr_res.append({"preds": pred})
        return asr_res
    def load_data(self,
                  wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
    def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
        def load_wav(path: str) -> np.ndarray:
            waveform, _ = librosa.load(path, sr=fs)
            return waveform
@@ -211,12 +221,11 @@
        if isinstance(wav_content, list):
            return [load_wav(path) for path in wav_content]
        raise TypeError(
            f'The type of {wav_content} is not in [str, np.ndarray, list]')
        raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")
    def extract_feat(self,
                     waveforms: np.ndarray, is_final: bool = False
                     ) -> Tuple[np.ndarray, np.ndarray]:
    def extract_feat(
        self, waveforms: np.ndarray, is_final: bool = False
    ) -> Tuple[np.ndarray, np.ndarray]:
        waveforms_lens = np.zeros(waveforms.shape[0]).astype(np.int32)
        for idx, waveform in enumerate(waveforms):
            waveforms_lens[idx] = waveform.shape[-1]
@@ -225,12 +234,12 @@
        return feats.astype(np.float32), feats_len.astype(np.int32)
    def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
        return [self.decode_one(am_score, token_num)
                for am_score, token_num in zip(am_scores, token_nums)]
        return [
            self.decode_one(am_score, token_num)
            for am_score, token_num in zip(am_scores, token_nums)
        ]
    def decode_one(self,
                   am_score: np.ndarray,
                   valid_token_num: int) -> List[str]:
    def decode_one(self, am_score: np.ndarray, valid_token_num: int) -> List[str]:
        yseq = am_score.argmax(axis=-1)
        score = am_score.max(axis=-1)
        score = np.sum(score, axis=-1)
@@ -260,15 +269,15 @@
        list_frames = []
        cache_alphas = []
        cache_hiddens = []
        alphas[:, :self.chunk_size[0]] = 0.0
        alphas[:, sum(self.chunk_size[:2]):] = 0.0
        alphas[:, : self.chunk_size[0]] = 0.0
        alphas[:, sum(self.chunk_size[:2]) :] = 0.0
        if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
            hidden = np.concatenate((cache["cif_hidden"], hidden), axis=1)
            alphas = np.concatenate((cache["cif_alphas"], alphas), axis=1)
        if cache is not None and "last_chunk" in cache and cache["last_chunk"]:
            tail_hidden = np.zeros((batch_size, 1, hidden_size)).astype(np.float32)
            tail_alphas = np.array([[self.tail_threshold]]).astype(np.float32)
            tail_alphas =np.tile(tail_alphas, (batch_size, 1))
            tail_alphas = np.tile(tail_alphas, (batch_size, 1))
            hidden = np.concatenate((hidden, tail_hidden), axis=1)
            alphas = np.concatenate((alphas, tail_alphas), axis=1)
@@ -316,5 +325,6 @@
        cache["cif_hidden"] = np.stack(cache_hiddens, axis=0)
        cache["cif_hidden"] = np.expand_dims(cache["cif_hidden"], axis=0)
        return np.stack(list_ls, axis=0).astype(np.float32), np.stack(token_length, axis=0).astype(np.int32)
        return np.stack(list_ls, axis=0).astype(np.float32), np.stack(token_length, axis=0).astype(
            np.int32
        )