| funasr/models/base_model.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/models/e2e_asr.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/models/e2e_asr_mfcca.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/models/e2e_asr_paraformer.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/models/e2e_diar_eend_ola.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/models/e2e_diar_sond.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/models/e2e_sv.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/models/e2e_tp.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/models/e2e_uni_asr.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 |
funasr/models/base_model.py
New file @@ -0,0 +1,17 @@ import torch class FunASRModel(torch.nn.Module): """The common model class """ def __init__(self): super().__init__() self.num_updates = 0 def set_num_updates(self, num_updates): self.num_updates = num_updates def get_num_updates(self): return self.num_updates funasr/models/e2e_asr.py
@@ -24,11 +24,11 @@ from funasr.models.postencoder.abs_postencoder import AbsPostEncoder from funasr.models.preencoder.abs_preencoder import AbsPreEncoder from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.models.base_model import FunASRModel from funasr.modules.add_sos_eos import add_sos_eos from funasr.modules.e2e_asr_common import ErrorCalculator from funasr.modules.nets_utils import th_accuracy from funasr.torch_utils.device_funcs import force_gatherable from funasr.train.abs_espnet_model import AbsESPnetModel if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): from torch.cuda.amp import autocast @@ -39,7 +39,7 @@ yield class ESPnetASRModel(AbsESPnetModel): class ESPnetASRModel(FunASRModel): """CTC-attention hybrid Encoder-Decoder model""" def __init__( funasr/models/e2e_asr_mfcca.py
@@ -21,9 +21,10 @@ from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.preencoder.abs_preencoder import AbsPreEncoder from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.models.base_model import FunASRModel from funasr.layers.abs_normalize import AbsNormalize from funasr.torch_utils.device_funcs import force_gatherable from funasr.train.abs_espnet_model import AbsESPnetModel if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): from torch.cuda.amp import autocast @@ -35,7 +36,7 @@ import pdb import random import math class MFCCA(AbsESPnetModel): class MFCCA(FunASRModel): """CTC-attention hybrid Encoder-Decoder model""" def __init__( funasr/models/e2e_asr_paraformer.py
@@ -25,11 +25,11 @@ from funasr.models.predictor.cif import mae_loss from funasr.models.preencoder.abs_preencoder import AbsPreEncoder from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.models.base_model import FunASRModel from funasr.modules.add_sos_eos import add_sos_eos from funasr.modules.nets_utils import make_pad_mask, pad_list from funasr.modules.nets_utils import th_accuracy from funasr.torch_utils.device_funcs import force_gatherable from funasr.train.abs_espnet_model import AbsESPnetModel from funasr.models.predictor.cif import CifPredictorV3 @@ -42,7 +42,7 @@ yield class Paraformer(AbsESPnetModel): class Paraformer(FunASRModel): """ Author: Speech Lab, Alibaba Group, China Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition funasr/models/e2e_diar_eend_ola.py
@@ -15,8 +15,8 @@ from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor from funasr.modules.eend_ola.utils.power import generate_mapping_dict from funasr.models.base_model import FunASRModel from funasr.torch_utils.device_funcs import force_gatherable from funasr.train.abs_espnet_model import AbsESPnetModel if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): pass @@ -34,7 +34,7 @@ return att class DiarEENDOLAModel(AbsESPnetModel): class DiarEENDOLAModel(FunASRModel): """EEND-OLA diarization model""" def __init__( funasr/models/e2e_diar_sond.py
@@ -20,9 +20,9 @@ from funasr.models.encoder.abs_encoder import AbsEncoder from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.models.base_model import FunASRModel from funasr.layers.abs_normalize import AbsNormalize from funasr.torch_utils.device_funcs import force_gatherable from funasr.train.abs_espnet_model import AbsESPnetModel from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy from funasr.utils.misc import int2vec @@ -35,7 +35,7 @@ yield class DiarSondModel(AbsESPnetModel): class DiarSondModel(FunASRModel): """Speaker overlap-aware neural diarization model reference: https://arxiv.org/abs/2211.10243 """ funasr/models/e2e_sv.py
@@ -21,11 +21,11 @@ from funasr.models.postencoder.abs_postencoder import AbsPostEncoder from funasr.models.preencoder.abs_preencoder import AbsPreEncoder from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.models.base_model import FunASRModel from funasr.modules.add_sos_eos import add_sos_eos from funasr.modules.e2e_asr_common import ErrorCalculator from funasr.modules.nets_utils import th_accuracy from funasr.torch_utils.device_funcs import force_gatherable from funasr.train.abs_espnet_model import AbsESPnetModel if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): from torch.cuda.amp import autocast @@ -36,7 +36,7 @@ yield class ESPnetSVModel(AbsESPnetModel): class ESPnetSVModel(FunASRModel): """CTC-attention hybrid Encoder-Decoder model""" def __init__( funasr/models/e2e_tp.py
@@ -14,10 +14,10 @@ from funasr.models.encoder.abs_encoder import AbsEncoder from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.predictor.cif import mae_loss from funasr.models.base_model import FunASRModel from funasr.modules.add_sos_eos import add_sos_eos from funasr.modules.nets_utils import make_pad_mask, pad_list from funasr.torch_utils.device_funcs import force_gatherable from funasr.train.abs_espnet_model import AbsESPnetModel from funasr.models.predictor.cif import CifPredictorV3 @@ -30,7 +30,7 @@ yield class TimestampPredictor(AbsESPnetModel): class TimestampPredictor(FunASRModel): """ Author: Speech Lab, Alibaba Group, China """ funasr/models/e2e_uni_asr.py
@@ -23,6 +23,7 @@ from funasr.models.postencoder.abs_postencoder import AbsPostEncoder from funasr.models.preencoder.abs_preencoder import AbsPreEncoder from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.models.base_model import FunASRModel from funasr.layers.abs_normalize import AbsNormalize from funasr.torch_utils.device_funcs import force_gatherable from funasr.train.abs_espnet_model import AbsESPnetModel @@ -38,7 +39,7 @@ yield class UniASR(AbsESPnetModel): class UniASR(FunASRModel): """ Author: Speech Lab, Alibaba Group, China """