shixian.shi
2023-03-13 a1fe3c635f47e941c2bb2a545ce0aface87fe041
update tp inference
3个文件已修改
37 ■■■■■ 已修改文件
funasr/bin/tp_inference.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_tp.py 21 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/asr.py 14 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/tp_inference.py
@@ -18,7 +18,7 @@
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.datasets.preprocessor import LMPreprocessor
from funasr.tasks.asr import ASRTaskAligner_temp as ASRTask
from funasr.tasks.asr import ASRTaskAligner as ASRTask
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
funasr/models/e2e_tp.py
@@ -41,6 +41,7 @@
            encoder: AbsEncoder,
            predictor: CifPredictorV3,
            predictor_bias: int = 0,
            token_list=None,
    ):
        assert check_argument_types()
@@ -54,6 +55,7 @@
        self.predictor = predictor
        self.predictor_bias = predictor_bias
        self.criterion_pre = mae_loss()
        self.token_list = token_list
    
    def forward(
            self,
@@ -152,3 +154,22 @@
                                                                                               encoder_out_mask,
                                                                                               token_num)
        return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
    def collect_feats(
            self,
            speech: torch.Tensor,
            speech_lengths: torch.Tensor,
            text: torch.Tensor,
            text_lengths: torch.Tensor,
    ) -> Dict[str, torch.Tensor]:
        if self.extract_feats_in_collect_stats:
            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
        else:
            # Generate dummy stats if extract_feats_in_collect_stats is False
            logging.warning(
                "Generating dummy stats for feats and feats_lengths, "
                "because encoder_conf.extract_feats_in_collect_stats is "
                f"{self.extract_feats_in_collect_stats}"
            )
            feats, feats_lengths = speech, speech_lengths
        return {"feats": feats, "feats_lengths": feats_lengths}
funasr/tasks/asr.py
@@ -125,7 +125,7 @@
        bicif_paraformer=BiCifParaformer,
        contextual_paraformer=ContextualParaformer,
        mfcca=MFCCA,
        timestamp_predictor=TimestampPredictor,
        timestamp_prediction=TimestampPredictor,
    ),
    type_check=AbsESPnetModel,
    default="asr",
@@ -1278,8 +1278,6 @@
            token_list = list(args.token_list)
        else:
            raise RuntimeError("token_list must be str or list")
        vocab_size = len(token_list)
        logging.info(f"Vocabulary size: {vocab_size}")
        # 1. frontend
        if args.input_size is None:
@@ -1316,6 +1314,7 @@
            frontend=frontend,
            encoder=encoder,
            predictor=predictor,
            token_list=token_list,
            **args.model_conf,
        )
@@ -1326,15 +1325,6 @@
        assert check_return_type(model)
        return model
    @classmethod
    def required_data_names(
            cls, train: bool = True, inference: bool = False
    ) -> Tuple[str, ...]:
        retval = ("speech", "text")
        return retval
class ASRTaskAligner_temp(ASRTaskParaformer):
    @classmethod
    def required_data_names(
            cls, train: bool = True, inference: bool = False