游雁
2023-02-16 6023271be65ddf888e26cffc646631809bf500a8
amp pipeline
2个文件已修改
175 ■■■■■ 已修改文件
funasr/bin/asr_inference_paraformer.py 174 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/export_model.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer.py
@@ -41,16 +41,7 @@
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
header_colors = '\033[95m'
end_colors = '\033[0m'
global_asr_language: str = 'zh-cn'
global_sample_rate: Union[int, Dict[Any, int]] = {
    'audio_fs': 16000,
    'model_fs': 16000
}
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
class Speech2Text:
@@ -346,6 +337,160 @@
        # assert check_return_type(results)
        return results
class Speech2TextExport(torch.nn.Module):
    """Speech2TextExport class
    """
    def __init__(
            self,
            asr_train_config: Union[Path, str] = None,
            asr_model_file: Union[Path, str] = None,
            cmvn_file: Union[Path, str] = None,
            lm_train_config: Union[Path, str] = None,
            lm_file: Union[Path, str] = None,
            token_type: str = None,
            bpemodel: str = None,
            device: str = "cpu",
            maxlenratio: float = 0.0,
            minlenratio: float = 0.0,
            dtype: str = "float32",
            beam_size: int = 20,
            ctc_weight: float = 0.5,
            lm_weight: float = 1.0,
            ngram_weight: float = 0.9,
            penalty: float = 0.0,
            nbest: int = 1,
            frontend_conf: dict = None,
            hotword_list_or_file: str = None,
            **kwargs,
    ):
        # 1. Build ASR model
        asr_model, asr_train_args = ASRTask.build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        )
        frontend = None
        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
            frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
        logging.info("asr_model: {}".format(asr_model))
        logging.info("asr_train_args: {}".format(asr_train_args))
        asr_model.to(dtype=getattr(torch, dtype)).eval()
        token_list = asr_model.token_list
        logging.info(f"Decoding device={device}, dtype={dtype}")
        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
        if token_type is None:
            token_type = asr_train_args.token_type
        if bpemodel is None:
            bpemodel = asr_train_args.bpemodel
        if token_type is None:
            tokenizer = None
        elif token_type == "bpe":
            if bpemodel is not None:
                tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
            else:
                tokenizer = None
        else:
            tokenizer = build_tokenizer(token_type=token_type)
        converter = TokenIDConverter(token_list=token_list)
        logging.info(f"Text tokenizer: {tokenizer}")
        # self.asr_model = asr_model
        self.asr_train_args = asr_train_args
        self.converter = converter
        self.tokenizer = tokenizer
        self.device = device
        self.dtype = dtype
        self.nbest = nbest
        self.frontend = frontend
        model = Paraformer_export(asr_model, onnx=False)
        self.asr_model = model
    @torch.no_grad()
    def forward(
            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
    ):
        """Inference
        Args:
                speech: Input speech data
        Returns:
                text, token, token_int, hyp
        """
        assert check_argument_types()
        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        if self.frontend is not None:
            feats, feats_len = self.frontend.forward(speech, speech_lengths)
            feats = to_device(feats, device=self.device)
            feats_len = feats_len.int()
            self.asr_model.frontend = None
        else:
            feats = speech
            feats_len = speech_lengths
        enc_len_batch_total = feats_len.sum()
        lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
        batch = {"speech": feats, "speech_lengths": feats_len}
        # a. To device
        batch = to_device(batch, device=self.device)
        decoder_outs = self.asr_model(**batch)
        decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        results = []
        b, n, d = decoder_out.size()
        for i in range(b):
            am_scores = decoder_out[i, :ys_pad_lens[i], :]
            yseq = am_scores.argmax(dim=-1)
            score = am_scores.max(dim=-1)[0]
            score = torch.sum(score, dim=-1)
            # pad with mask tokens to ensure compatibility with sos/eos tokens
            yseq = torch.tensor(
                yseq.tolist(), device=yseq.device
            )
            nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
            for hyp in nbest_hyps:
                assert isinstance(hyp, (Hypothesis)), type(hyp)
                # remove sos/eos and get results
                last_pos = -1
                if isinstance(hyp.yseq, list):
                    token_int = hyp.yseq[1:last_pos]
                else:
                    token_int = hyp.yseq[1:last_pos].tolist()
                # remove blank symbol id, which is assumed to be 0
                token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
                # Change integer-ids to tokens
                token = self.converter.ids2tokens(token_int)
                if self.tokenizer is not None:
                    text = self.tokenizer.tokens2text(token)
                else:
                    text = None
                results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor))
        return results
def inference(
        maxlenratio: float,
@@ -454,9 +599,11 @@
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    export_mode = False
    if param_dict is not None:
        hotword_list_or_file = param_dict.get('hotword')
        export_mode = param_dict.get("export_mode", False)
    else:
        hotword_list_or_file = None
@@ -490,7 +637,10 @@
        nbest=nbest,
        hotword_list_or_file=hotword_list_or_file,
    )
    speech2text = Speech2Text(**speech2text_kwargs)
    if export_mode:
        speech2text = Speech2TextExport(**speech2text_kwargs)
    else:
        speech2text = Speech2Text(**speech2text_kwargs)
    def _forward(
            data_path_and_name_and_type,
funasr/export/export_model.py
@@ -44,6 +44,7 @@
            model,
            self.export_config,
        )
        model.eval()
        # self._export_onnx(model, verbose, export_dir)
        if self.onnx:
            self._export_onnx(model, verbose, export_dir)