shixian.shi
2023-04-27 3c0a9fb7c1bd642f3370d406fca81ff50ae9bc82
fix name
2个文件已修改
18 ■■■■ 已修改文件
funasr/models/e2e_asr_contextual_paraformer.py 14 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/asr.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_asr_contextual_paraformer.py
@@ -1,4 +1,3 @@
from json import decoder
import logging
from contextlib import contextmanager
from distutils.version import LooseVersion
@@ -7,35 +6,24 @@
from typing import Optional
from typing import Tuple
from typing import Union
import random
from unicodedata import bidirectional
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
    LabelSmoothingLoss,  # noqa: H301
)
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.e2e_asr_common import ErrorCalculator
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.predictor.cif import mae_loss
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.modules.add_sos_eos import add_sos_eos
from funasr.modules.nets_utils import make_pad_mask, pad_list
from funasr.modules.nets_utils import th_accuracy
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.models.predictor.cif import CifPredictorV3
from funasr.modules.streaming_utils import utils as myutils
from funasr.models.e2e_asr_paraformer import Paraformer
from funasr.modules.layer_norm import LayerNorm
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
@@ -47,7 +35,7 @@
        yield
class AdvancedContextualParaformer(Paraformer):
class NeatContextualParaformer(Paraformer):
    def __init__(
        self,
        vocab_size: int,
funasr/tasks/asr.py
@@ -42,7 +42,7 @@
from funasr.models.joint_net.joint_network import JointNetwork
from funasr.models.e2e_asr import ESPnetASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_asr_contextual_paraformer import AdvancedContextualParaformer
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_asr_mfcca import MFCCA
from funasr.models.e2e_uni_asr import UniASR
@@ -129,7 +129,7 @@
        paraformer_bert=ParaformerBert,
        bicif_paraformer=BiCifParaformer,
        contextual_paraformer=ContextualParaformer,
        acontextual_paraformer=AdvancedContextualParaformer,
        neatcontextual_paraformer=NeatContextualParaformer,
        mfcca=MFCCA,
        timestamp_prediction=TimestampPredictor,
    ),