| | |
| | | |
| | | from funasr.fileio.datadir_writer import DatadirWriter |
| | | from funasr.datasets.preprocessor import LMPreprocessor |
| | | from funasr.tasks.asr import ASRTaskAligner_temp as ASRTask |
| | | from funasr.tasks.asr import ASRTaskAligner as ASRTask |
| | | from funasr.torch_utils.device_funcs import to_device |
| | | from funasr.torch_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.utils import config_argparse |
| | |
| | | encoder: AbsEncoder, |
| | | predictor: CifPredictorV3, |
| | | predictor_bias: int = 0, |
| | | token_list=None, |
| | | ): |
| | | assert check_argument_types() |
| | | |
| | |
| | | self.predictor = predictor |
| | | self.predictor_bias = predictor_bias |
| | | self.criterion_pre = mae_loss() |
| | | self.token_list = token_list |
| | | |
| | | def forward( |
| | | self, |
| | |
| | | encoder_out_mask, |
| | | token_num) |
| | | return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak |
| | | |
| | | def collect_feats( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | ) -> Dict[str, torch.Tensor]: |
| | | if self.extract_feats_in_collect_stats: |
| | | feats, feats_lengths = self._extract_feats(speech, speech_lengths) |
| | | else: |
| | | # Generate dummy stats if extract_feats_in_collect_stats is False |
| | | logging.warning( |
| | | "Generating dummy stats for feats and feats_lengths, " |
| | | "because encoder_conf.extract_feats_in_collect_stats is " |
| | | f"{self.extract_feats_in_collect_stats}" |
| | | ) |
| | | feats, feats_lengths = speech, speech_lengths |
| | | return {"feats": feats, "feats_lengths": feats_lengths} |
| | |
| | | bicif_paraformer=BiCifParaformer, |
| | | contextual_paraformer=ContextualParaformer, |
| | | mfcca=MFCCA, |
| | | timestamp_predictor=TimestampPredictor, |
| | | timestamp_prediction=TimestampPredictor, |
| | | ), |
| | | type_check=AbsESPnetModel, |
| | | default="asr", |
| | |
| | | token_list = list(args.token_list) |
| | | else: |
| | | raise RuntimeError("token_list must be str or list") |
| | | vocab_size = len(token_list) |
| | | logging.info(f"Vocabulary size: {vocab_size}") |
| | | |
| | | # 1. frontend |
| | | if args.input_size is None: |
| | |
| | | frontend=frontend, |
| | | encoder=encoder, |
| | | predictor=predictor, |
| | | token_list=token_list, |
| | | **args.model_conf, |
| | | ) |
| | | |
| | |
| | | ) -> Tuple[str, ...]: |
| | | retval = ("speech", "text") |
| | | return retval |
| | | |
| | | |
| | | class ASRTaskAligner_temp(ASRTaskParaformer): |
| | | @classmethod |
| | | def required_data_names( |
| | | cls, train: bool = True, inference: bool = False |
| | | ) -> Tuple[str, ...]: |
| | | retval = ("speech", "text") |
| | | return retval |