From 27f31cd42bb4e20dc19de0034fc0d80b449f1db1 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 06 十二月 2023 17:01:12 +0800
Subject: [PATCH] funasr2
---
funasr/modules/nets_utils.py | 2
funasr/datasets/small_datasets/preprocessor.py | 1
funasr/tokenizer/char_tokenizer.py | 6
funasr/utils/dynamic_import.py | 13
funasr/cli/model_class_factory.py | 298 ++++++++++
funasr/datasets/data_sampler.py | 12
funasr/cli/models/paraformer.py | 652 +++++++++++++++++++++++
funasr/utils/load_fr_py.py | 13
funasr/cli/trainer.py | 236 ++++++++
funasr/tokenizer/sentencepiece_tokenizer.py | 2
funasr/cli/__init__.py | 0
funasr/tokenizer/abs_tokenizer.py | 73 ++
funasr/schedulers/__init__.py | 23
funasr/tokenizer/phoneme_tokenizer.py | 1
funasr/cli/train_cli.py | 170 ++++++
funasr/tokenizer/funtoken.py | 75 ++
funasr/optimizers/__init__.py | 17
funasr/tokenizer/build_tokenizer.py | 17
funasr/cli/models/__init__.py | 0
funasr/datasets/dataset_jsonl.py | 11
funasr/tokenizer/word_tokenizer.py | 1
21 files changed, 1,604 insertions(+), 19 deletions(-)
diff --git a/funasr/bin/asr_trainer.py b/funasr/cli/__init__.py
similarity index 100%
rename from funasr/bin/asr_trainer.py
rename to funasr/cli/__init__.py
diff --git a/funasr/cli/model_class_factory.py b/funasr/cli/model_class_factory.py
new file mode 100644
index 0000000..b329492
--- /dev/null
+++ b/funasr/cli/model_class_factory.py
@@ -0,0 +1,298 @@
+import argparse
+import logging
+import os
+from pathlib import Path
+from typing import Callable
+from typing import Collection
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import numpy as np
+import torch
+import yaml
+
+from funasr.datasets.collate_fn import CommonCollateFn
+from funasr.datasets.preprocessor import CommonPreprocessor
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.layers.global_mvn import GlobalMVN
+from funasr.layers.utterance_mvn import UtteranceMVN
+from funasr.models.ctc import CTC
+from funasr.models.decoder.abs_decoder import AbsDecoder
+from funasr.models.decoder.rnn_decoder import RNNDecoder
+from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
+from funasr.models.decoder.transformer_decoder import (
+ DynamicConvolution2DTransformerDecoder, # noqa: H301
+)
+from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
+from funasr.models.decoder.transformer_decoder import (
+ LightweightConvolution2DTransformerDecoder, # noqa: H301
+)
+from funasr.models.decoder.transformer_decoder import (
+ LightweightConvolutionTransformerDecoder, # noqa: H301
+)
+from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
+from funasr.models.decoder.transformer_decoder import TransformerDecoder
+from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
+from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
+from funasr.models.e2e_asr import ASRModel
+from funasr.models.decoder.rnnt_decoder import RNNTDecoder
+from funasr.models.joint_net.joint_network import JointNetwork
+from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
+from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
+from funasr.models.e2e_tp import TimestampPredictor
+from funasr.models.e2e_asr_mfcca import MFCCA
+from funasr.models.e2e_sa_asr import SAASRModel
+from funasr.models.e2e_uni_asr import UniASR
+from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
+from funasr.models.e2e_asr_bat import BATModel
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
+from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
+from funasr.models.encoder.rnn_encoder import RNNEncoder
+from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
+from funasr.models.encoder.transformer_encoder import TransformerEncoder
+from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
+from funasr.models.encoder.resnet34_encoder import ResNet34Diar
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.default import MultiChannelFrontend
+from funasr.models.frontend.fused import FusedFrontends
+from funasr.models.frontend.s3prl import S3prlFrontend
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
+from funasr.models.postencoder.hugging_face_transformers_postencoder import (
+ HuggingFaceTransformersPostEncoder, # noqa: H301
+)
+from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3, BATPredictor
+from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.preencoder.linear import LinearProjection
+from funasr.models.preencoder.sinc import LightweightSincConvs
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models.specaug.specaug import SpecAug
+from funasr.models.specaug.specaug import SpecAugLFR
+from funasr.modules.subsampling import Conv1dSubsampling
+from funasr.tasks.abs_task import AbsTask
+from funasr.tokenizer.phoneme_tokenizer import g2p_choices
+from funasr.torch_utils.initialize import initialize
+from funasr.models.base_model import FunASRModel
+from funasr.train.class_choices import ClassChoices
+from funasr.train.trainer import Trainer
+from funasr.utils.get_default_kwargs import get_default_kwargs
+from funasr.utils.nested_dict_action import NestedDictAction
+from funasr.utils.types import float_or_none
+from funasr.utils.types import int_or_none
+from funasr.utils.types import str2bool
+from funasr.utils.types import str_or_none
+
+# from funasr.models.paraformer import Paraformer
+frontend_choices = ClassChoices(
+ name="frontend",
+ classes=dict(
+ default=DefaultFrontend,
+ sliding_window=SlidingWindow,
+ s3prl=S3prlFrontend,
+ fused=FusedFrontends,
+ wav_frontend=WavFrontend,
+ multichannelfrontend=MultiChannelFrontend,
+ ),
+ type_check=AbsFrontend,
+ default="default",
+)
+specaug_choices = ClassChoices(
+ name="specaug",
+ classes=dict(
+ specaug=SpecAug,
+ specaug_lfr=SpecAugLFR,
+ ),
+ type_check=AbsSpecAug,
+ default=None,
+ optional=True,
+)
+# specaug_choices = {"specaug":SpecAug}
+normalize_choices = ClassChoices(
+ "normalize",
+ classes=dict(
+ global_mvn=GlobalMVN,
+ utterance_mvn=UtteranceMVN,
+ ),
+ type_check=AbsNormalize,
+ default=None,
+ optional=True,
+)
+# model_choices = ClassChoices(
+# "model",
+# classes=dict(
+# asr=ASRModel,
+# uniasr=UniASR,
+# paraformer=Paraformer,
+# paraformer_online=ParaformerOnline,
+# paraformer_bert=ParaformerBert,
+# bicif_paraformer=BiCifParaformer,
+# contextual_paraformer=ContextualParaformer,
+# neatcontextual_paraformer=NeatContextualParaformer,
+# mfcca=MFCCA,
+# timestamp_prediction=TimestampPredictor,
+# rnnt=TransducerModel,
+# rnnt_unified=UnifiedTransducerModel,
+# bat=BATModel,
+# sa_asr=SAASRModel,
+# ),
+# type_check=None,
+# default="asr",
+# )
+preencoder_choices = ClassChoices(
+ name="preencoder",
+ classes=dict(
+ sinc=LightweightSincConvs,
+ linear=LinearProjection,
+ ),
+ type_check=AbsPreEncoder,
+ default=None,
+ optional=True,
+)
+encoder_choices = ClassChoices(
+ "encoder",
+ classes=dict(
+ conformer=ConformerEncoder,
+ transformer=TransformerEncoder,
+ rnn=RNNEncoder,
+ sanm=SANMEncoder,
+ sanm_chunk_opt=SANMEncoderChunkOpt,
+ data2vec_encoder=Data2VecEncoder,
+ mfcca_enc=MFCCAEncoder,
+ chunk_conformer=ConformerChunkEncoder,
+ ),
+ type_check=AbsEncoder,
+ default="rnn",
+)
+encoder_choices2 = ClassChoices(
+ "encoder2",
+ classes=dict(
+ conformer=ConformerEncoder,
+ transformer=TransformerEncoder,
+ rnn=RNNEncoder,
+ sanm=SANMEncoder,
+ sanm_chunk_opt=SANMEncoderChunkOpt,
+ ),
+ type_check=AbsEncoder,
+ default="rnn",
+)
+asr_encoder_choices = ClassChoices(
+ "asr_encoder",
+ classes=dict(
+ conformer=ConformerEncoder,
+ transformer=TransformerEncoder,
+ rnn=RNNEncoder,
+ sanm=SANMEncoder,
+ sanm_chunk_opt=SANMEncoderChunkOpt,
+ data2vec_encoder=Data2VecEncoder,
+ mfcca_enc=MFCCAEncoder,
+ ),
+ type_check=AbsEncoder,
+ default="rnn",
+)
+spk_encoder_choices = ClassChoices(
+ "spk_encoder",
+ classes=dict(
+ resnet34_diar=ResNet34Diar,
+ ),
+ default="resnet34_diar",
+)
+postencoder_choices = ClassChoices(
+ name="postencoder",
+ classes=dict(
+ hugging_face_transformers=HuggingFaceTransformersPostEncoder,
+ ),
+ type_check=AbsPostEncoder,
+ default=None,
+ optional=True,
+)
+decoder_choices = ClassChoices(
+ "decoder",
+ classes=dict(
+ transformer=TransformerDecoder,
+ lightweight_conv=LightweightConvolutionTransformerDecoder,
+ lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
+ dynamic_conv=DynamicConvolutionTransformerDecoder,
+ dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
+ rnn=RNNDecoder,
+ fsmn_scama_opt=FsmnDecoderSCAMAOpt,
+ paraformer_decoder_sanm=ParaformerSANMDecoder,
+ paraformer_decoder_san=ParaformerDecoderSAN,
+ contextual_paraformer_decoder=ContextualParaformerDecoder,
+ sa_decoder=SAAsrTransformerDecoder,
+ ),
+ type_check=AbsDecoder,
+ default="rnn",
+)
+decoder_choices2 = ClassChoices(
+ "decoder2",
+ classes=dict(
+ transformer=TransformerDecoder,
+ lightweight_conv=LightweightConvolutionTransformerDecoder,
+ lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
+ dynamic_conv=DynamicConvolutionTransformerDecoder,
+ dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
+ rnn=RNNDecoder,
+ fsmn_scama_opt=FsmnDecoderSCAMAOpt,
+ paraformer_decoder_sanm=ParaformerSANMDecoder,
+ ),
+ type_check=AbsDecoder,
+ default="rnn",
+)
+
+rnnt_decoder_choices = ClassChoices(
+ "rnnt_decoder",
+ classes=dict(
+ rnnt=RNNTDecoder,
+ ),
+ type_check=RNNTDecoder,
+ default="rnnt",
+)
+
+joint_network_choices = ClassChoices(
+ name="joint_network",
+ classes=dict(
+ joint_network=JointNetwork,
+ ),
+ default="joint_network",
+ optional=True,
+)
+
+predictor_choices = ClassChoices(
+ name="predictor",
+ classes=dict(
+ cif_predictor=CifPredictor,
+ ctc_predictor=None,
+ cif_predictor_v2=CifPredictorV2,
+ cif_predictor_v3=CifPredictorV3,
+ bat_predictor=BATPredictor,
+ ),
+ type_check=None,
+ default="cif_predictor",
+ optional=True,
+)
+predictor_choices2 = ClassChoices(
+ name="predictor2",
+ classes=dict(
+ cif_predictor=CifPredictor,
+ ctc_predictor=None,
+ cif_predictor_v2=CifPredictorV2,
+ ),
+ type_check=None,
+ default="cif_predictor",
+ optional=True,
+)
+stride_conv_choices = ClassChoices(
+ name="stride_conv",
+ classes=dict(
+ stride_conv1d=Conv1dSubsampling
+ ),
+ type_check=None,
+ default="stride_conv1d",
+ optional=True,
+)
\ No newline at end of file
diff --git a/funasr/bin/asr_trainer.py b/funasr/cli/models/__init__.py
similarity index 100%
copy from funasr/bin/asr_trainer.py
copy to funasr/cli/models/__init__.py
diff --git a/funasr/cli/models/paraformer.py b/funasr/cli/models/paraformer.py
new file mode 100644
index 0000000..ced1c23
--- /dev/null
+++ b/funasr/cli/models/paraformer.py
@@ -0,0 +1,652 @@
+import logging
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import torch
+import torch.nn as nn
+import random
+import numpy as np
+
+# from funasr.layers.abs_normalize import AbsNormalize
+from funasr.losses.label_smoothing_loss import (
+ LabelSmoothingLoss, # noqa: H301
+)
+# from funasr.models.ctc import CTC
+# from funasr.models.decoder.abs_decoder import AbsDecoder
+# from funasr.models.e2e_asr_common import ErrorCalculator
+# from funasr.models.encoder.abs_encoder import AbsEncoder
+# from funasr.models.frontend.abs_frontend import AbsFrontend
+# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
+from funasr.models.predictor.cif import mae_loss
+# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+# from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.modules.add_sos_eos import add_sos_eos
+from funasr.modules.nets_utils import make_pad_mask, pad_list
+from funasr.modules.nets_utils import th_accuracy
+from funasr.torch_utils.device_funcs import force_gatherable
+# from funasr.models.base_model import FunASRModel
+# from funasr.models.predictor.cif import CifPredictorV3
+
+from funasr.cli.model_class_factory import *
+
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+
+
+class Paraformer(nn.Module):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2206.08317
+ """
+
+ def __init__(
+ self,
+ # token_list: Union[Tuple[str, ...], List[str]],
+ frontend: Optional[str] = None,
+ frontend_conf: Optional[Dict] = None,
+ specaug: Optional[str] = None,
+ specaug_conf: Optional[Dict] = None,
+ normalize: str = None,
+ normalize_conf: Optional[Dict] = None,
+ encoder: str = None,
+ encoder_conf: Optional[Dict] = None,
+ decoder: str = None,
+ decoder_conf: Optional[Dict] = None,
+ ctc: str = None,
+ ctc_conf: Optional[Dict] = None,
+ predictor: str = None,
+ predictor_conf: Optional[Dict] = None,
+ ctc_weight: float = 0.5,
+ interctc_weight: float = 0.0,
+ input_size: int = 80,
+ vocab_size: int = -1,
+ ignore_id: int = -1,
+ blank_id: int = 0,
+ sos: int = 1,
+ eos: int = 2,
+ lsm_weight: float = 0.0,
+ length_normalized_loss: bool = False,
+ # report_cer: bool = True,
+ # report_wer: bool = True,
+ # sym_space: str = "<space>",
+ # sym_blank: str = "<blank>",
+ # extract_feats_in_collect_stats: bool = True,
+ # predictor=None,
+ predictor_weight: float = 0.0,
+ predictor_bias: int = 0,
+ sampling_ratio: float = 0.2,
+ share_embedding: bool = False,
+ # preencoder: Optional[AbsPreEncoder] = None,
+ # postencoder: Optional[AbsPostEncoder] = None,
+ use_1st_decoder_loss: bool = False,
+ **kwargs,
+ ):
+ assert 0.0 <= ctc_weight <= 1.0, ctc_weight
+ assert 0.0 <= interctc_weight < 1.0, interctc_weight
+
+ super().__init__()
+
+ # import pdb;
+ # pdb.set_trace()
+
+ if frontend is not None:
+ frontend_class = frontend_choices.get_class(frontend)
+ frontend = frontend_class(**frontend_conf)
+ if specaug is not None:
+ specaug_class = specaug_choices.get_class(specaug)
+ specaug = specaug_class(**specaug_conf)
+ if normalize is not None:
+ normalize_class = normalize_choices.get_class(normalize)
+ normalize = normalize_class(**normalize_conf)
+ encoder_class = encoder_choices.get_class(encoder)
+ encoder = encoder_class(input_size=input_size, **encoder_conf)
+ encoder_output_size = encoder.output_size()
+ if decoder is not None:
+ decoder_class = decoder_choices.get_class(decoder)
+ decoder = decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ **decoder_conf,
+ )
+ if ctc_weight > 0.0:
+
+ if ctc_conf is None:
+ ctc_conf = {}
+
+ ctc = CTC(
+ odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
+ )
+ if predictor is not None:
+ predictor_class = predictor_choices.get_class(predictor)
+ predictor = predictor_class(**predictor_conf)
+
+ # note that eos is the same as sos (equivalent ID)
+ self.blank_id = blank_id
+ self.sos = sos if sos is not None else vocab_size - 1
+ self.eos = eos if eos is not None else vocab_size - 1
+ self.vocab_size = vocab_size
+ self.ignore_id = ignore_id
+ self.ctc_weight = ctc_weight
+ self.interctc_weight = interctc_weight
+ # self.token_list = token_list.copy()
+ #
+ self.frontend = frontend
+ self.specaug = specaug
+ self.normalize = normalize
+ # self.preencoder = preencoder
+ # self.postencoder = postencoder
+ self.encoder = encoder
+ #
+ # if not hasattr(self.encoder, "interctc_use_conditioning"):
+ # self.encoder.interctc_use_conditioning = False
+ # if self.encoder.interctc_use_conditioning:
+ # self.encoder.conditioning_layer = torch.nn.Linear(
+ # vocab_size, self.encoder.output_size()
+ # )
+ #
+ # self.error_calculator = None
+ #
+ if ctc_weight == 1.0:
+ self.decoder = None
+ else:
+ self.decoder = decoder
+
+ self.criterion_att = LabelSmoothingLoss(
+ size=vocab_size,
+ padding_idx=ignore_id,
+ smoothing=lsm_weight,
+ normalize_length=length_normalized_loss,
+ )
+ #
+ # if report_cer or report_wer:
+ # self.error_calculator = ErrorCalculator(
+ # token_list, sym_space, sym_blank, report_cer, report_wer
+ # )
+ #
+ if ctc_weight == 0.0:
+ self.ctc = None
+ else:
+ self.ctc = ctc
+ #
+ # self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
+ self.predictor = predictor
+ self.predictor_weight = predictor_weight
+ self.predictor_bias = predictor_bias
+ self.sampling_ratio = sampling_ratio
+ self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
+ # self.step_cur = 0
+ #
+ self.share_embedding = share_embedding
+ if self.share_embedding:
+ self.decoder.embed = None
+
+ self.use_1st_decoder_loss = use_1st_decoder_loss
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Frontend + Encoder + Decoder + Calc loss
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ decoding_ind: int
+ """
+ decoding_ind = kwargs.get("kwargs", None)
+ # import pdb;
+ # pdb.set_trace()
+ if len(text_lengths.size()) > 1:
+ text_lengths = text_lengths[:, 0]
+ if len(speech_lengths.size()) > 1:
+ speech_lengths = speech_lengths[:, 0]
+
+ batch_size = speech.shape[0]
+
+ # # for data-parallel
+ # text = text[:, : text_lengths.max()]
+ # speech = speech[:, :speech_lengths.max()]
+
+ # 1. Encoder
+ if hasattr(self.encoder, "overlap_chunk_cls"):
+ ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
+ else:
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ intermediate_outs = None
+ if isinstance(encoder_out, tuple):
+ intermediate_outs = encoder_out[1]
+ encoder_out = encoder_out[0]
+
+ loss_att, pre_loss_att, acc_att, cer_att, wer_att = None, None, None, None, None
+ loss_ctc, cer_ctc = None, None
+ loss_pre = None
+ stats = dict()
+
+ # 1. CTC branch
+ if self.ctc_weight != 0.0:
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # Collect CTC branch stats
+ stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+ stats["cer_ctc"] = cer_ctc
+
+ # Intermediate CTC (optional)
+ loss_interctc = 0.0
+ if self.interctc_weight != 0.0 and intermediate_outs is not None:
+ for layer_idx, intermediate_out in intermediate_outs:
+ # we assume intermediate_out has the same length & padding
+ # as those of encoder_out
+ loss_ic, cer_ic = self._calc_ctc_loss(
+ intermediate_out, encoder_out_lens, text, text_lengths
+ )
+ loss_interctc = loss_interctc + loss_ic
+
+ # Collect Intermedaite CTC stats
+ stats["loss_interctc_layer{}".format(layer_idx)] = (
+ loss_ic.detach() if loss_ic is not None else None
+ )
+ stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
+
+ loss_interctc = loss_interctc / len(intermediate_outs)
+
+ # calculate whole encoder loss
+ loss_ctc = (
+ 1 - self.interctc_weight
+ ) * loss_ctc + self.interctc_weight * loss_interctc
+
+ # 2b. Attention decoder branch
+ if self.ctc_weight != 1.0:
+ loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # 3. CTC-Att loss definition
+ if self.ctc_weight == 0.0:
+ loss = loss_att + loss_pre * self.predictor_weight
+ elif self.ctc_weight == 1.0:
+ loss = loss_ctc
+ else:
+ loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
+
+ if self.use_1st_decoder_loss and pre_loss_att is not None:
+ loss = loss + (1 - self.ctc_weight) * pre_loss_att
+
+ # Collect Attn branch stats
+ stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+ stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
+ stats["acc"] = acc_att
+ stats["cer"] = cer_att
+ stats["wer"] = wer_att
+ stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+
+ stats["loss"] = torch.clone(loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def collect_feats(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ ) -> Dict[str, torch.Tensor]:
+ if self.extract_feats_in_collect_stats:
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+ else:
+ # Generate dummy stats if extract_feats_in_collect_stats is False
+ logging.warning(
+ "Generating dummy stats for feats and feats_lengths, "
+ "because encoder_conf.extract_feats_in_collect_stats is "
+ f"{self.extract_feats_in_collect_stats}"
+ )
+ feats, feats_lengths = speech, speech_lengths
+ return {"feats": feats, "feats_lengths": feats_lengths}
+
+ def encode(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Frontend + Encoder. Note that this method is used by asr_inference.py
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ ind: int
+ """
+ with autocast(False):
+ # # 1. Extract feats
+ # feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+ # 2. Data augmentation
+ if self.specaug is not None and self.training:
+ feats, feats_lengths = self.specaug(speech, speech_lengths)
+
+ # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+ # # Pre-encoder, e.g. used for raw input data
+ # if self.preencoder is not None:
+ # feats, feats_lengths = self.preencoder(feats, feats_lengths)
+
+ # 4. Forward encoder
+ # feats: (Batch, Length, Dim)
+ # -> encoder_out: (Batch, Length2, Dim2)
+ if self.encoder.interctc_use_conditioning:
+ if hasattr(self.encoder, "overlap_chunk_cls"):
+ encoder_out, encoder_out_lens, _ = self.encoder(
+ feats, feats_lengths, ctc=self.ctc, ind=ind
+ )
+ encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
+ encoder_out_lens,
+ chunk_outs=None)
+ else:
+ encoder_out, encoder_out_lens, _ = self.encoder(
+ feats, feats_lengths, ctc=self.ctc
+ )
+ else:
+ if hasattr(self.encoder, "overlap_chunk_cls"):
+ encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, ind=ind)
+ encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
+ encoder_out_lens,
+ chunk_outs=None)
+ else:
+ encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
+ intermediate_outs = None
+ if isinstance(encoder_out, tuple):
+ intermediate_outs = encoder_out[1]
+ encoder_out = encoder_out[0]
+
+ # # Post-encoder, e.g. NLU
+ # if self.postencoder is not None:
+ # encoder_out, encoder_out_lens = self.postencoder(
+ # encoder_out, encoder_out_lens
+ # )
+
+ assert encoder_out.size(0) == speech.size(0), (
+ encoder_out.size(),
+ speech.size(0),
+ )
+ assert encoder_out.size(1) <= encoder_out_lens.max(), (
+ encoder_out.size(),
+ encoder_out_lens.max(),
+ )
+
+ if intermediate_outs is not None:
+ return (encoder_out, intermediate_outs), encoder_out_lens
+
+ return encoder_out, encoder_out_lens
+
+ def calc_predictor(self, encoder_out, encoder_out_lens):
+
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask,
+ ignore_id=self.ignore_id)
+ return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
+
+ def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
+
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
+ )
+ decoder_out = decoder_outs[0]
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+ return decoder_out, ys_pad_lens
+
+ def _extract_feats(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ assert speech_lengths.dim() == 1, speech_lengths.shape
+
+ # for data-parallel
+ speech = speech[:, : speech_lengths.max()]
+ if self.frontend is not None:
+ # Frontend
+ # e.g. STFT and Feature extract
+ # data_loader may send time-domain signal in this case
+ # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
+ feats, feats_lengths = self.frontend(speech, speech_lengths)
+ else:
+ # No frontend and no feature extract
+ feats, feats_lengths = speech, speech_lengths
+ return feats, feats_lengths
+
+ def nll(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ) -> torch.Tensor:
+ """Compute negative log likelihood(nll) from transformer-decoder
+ Normally, this function is called in batchify_nll.
+ Args:
+ encoder_out: (Batch, Length, Dim)
+ encoder_out_lens: (Batch,)
+ ys_pad: (Batch, Length)
+ ys_pad_lens: (Batch,)
+ """
+ ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_in_lens = ys_pad_lens + 1
+
+ # 1. Forward decoder
+ decoder_out, _ = self.decoder(
+ encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
+ ) # [batch, seqlen, dim]
+ batch_size = decoder_out.size(0)
+ decoder_num_class = decoder_out.size(2)
+ # nll: negative log-likelihood
+ nll = torch.nn.functional.cross_entropy(
+ decoder_out.view(-1, decoder_num_class),
+ ys_out_pad.view(-1),
+ ignore_index=self.ignore_id,
+ reduction="none",
+ )
+ nll = nll.view(batch_size, -1)
+ nll = nll.sum(dim=1)
+ assert nll.size(0) == batch_size
+ return nll
+
+ def batchify_nll(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ batch_size: int = 100,
+ ):
+ """Compute negative log likelihood(nll) from transformer-decoder
+ To avoid OOM, this fuction seperate the input into batches.
+ Then call nll for each batch and combine and return results.
+ Args:
+ encoder_out: (Batch, Length, Dim)
+ encoder_out_lens: (Batch,)
+ ys_pad: (Batch, Length)
+ ys_pad_lens: (Batch,)
+ batch_size: int, samples each batch contain when computing nll,
+ you may change this to avoid OOM or increase
+ GPU memory usage
+ """
+ total_num = encoder_out.size(0)
+ if total_num <= batch_size:
+ nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+ else:
+ nll = []
+ start_idx = 0
+ while True:
+ end_idx = min(start_idx + batch_size, total_num)
+ batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
+ batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
+ batch_ys_pad = ys_pad[start_idx:end_idx, :]
+ batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
+ batch_nll = self.nll(
+ batch_encoder_out,
+ batch_encoder_out_lens,
+ batch_ys_pad,
+ batch_ys_pad_lens,
+ )
+ nll.append(batch_nll)
+ start_idx = end_idx
+ if start_idx == total_num:
+ break
+ nll = torch.cat(nll)
+ assert nll.size(0) == total_num
+ return nll
+
+ def _calc_att_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ if self.predictor_bias == 1:
+ _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_pad_lens = ys_pad_lens + self.predictor_bias
+ pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad, encoder_out_mask,
+ ignore_id=self.ignore_id)
+
+ # 0. sampler
+ decoder_out_1st = None
+ pre_loss_att = None
+ if self.sampling_ratio > 0.0:
+
+
+ if self.use_1st_decoder_loss:
+ sematic_embeds, decoder_out_1st, pre_loss_att = self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
+ pre_acoustic_embeds)
+ else:
+ sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
+ pre_acoustic_embeds)
+ else:
+ if self.step_cur < 2:
+ logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+ sematic_embeds = pre_acoustic_embeds
+
+ # 1. Forward decoder
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+
+ if decoder_out_1st is None:
+ decoder_out_1st = decoder_out
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_pad)
+ acc_att = th_accuracy(
+ decoder_out_1st.view(-1, self.vocab_size),
+ ys_pad,
+ ignore_label=self.ignore_id,
+ )
+ loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+
+ # Compute cer/wer using attention-decoder
+ if self.training or self.error_calculator is None:
+ cer_att, wer_att = None, None
+ else:
+ ys_hat = decoder_out_1st.argmax(dim=-1)
+ cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+ return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
+
+ def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
+
+ tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
+ ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
+ if self.share_embedding:
+ ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
+ else:
+ ys_pad_embed = self.decoder.embed(ys_pad_masked)
+ with torch.no_grad():
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+ pred_tokens = decoder_out.argmax(-1)
+ nonpad_positions = ys_pad.ne(self.ignore_id)
+ seq_lens = (nonpad_positions).sum(1)
+ same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
+ input_mask = torch.ones_like(nonpad_positions)
+ bsz, seq_len = ys_pad.size()
+ for li in range(bsz):
+ target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
+ if target_num > 0:
+ input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
+ input_mask = input_mask.eq(1)
+ input_mask = input_mask.masked_fill(~nonpad_positions, False)
+ input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
+
+ sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
+ input_mask_expand_dim, 0)
+ return sematic_embeds * tgt_mask, decoder_out * tgt_mask
+
+ def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
+ tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
+ ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
+ if self.share_embedding:
+ ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
+ else:
+ ys_pad_embed = self.decoder.embed(ys_pad_masked)
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
+ )
+ pre_loss_att = self.criterion_att(decoder_outs[0], ys_pad)
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+ pred_tokens = decoder_out.argmax(-1)
+ nonpad_positions = ys_pad.ne(self.ignore_id)
+ seq_lens = (nonpad_positions).sum(1)
+ same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
+ input_mask = torch.ones_like(nonpad_positions)
+ bsz, seq_len = ys_pad.size()
+ for li in range(bsz):
+ target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
+ if target_num > 0:
+ input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num], value=0)
+ input_mask = input_mask.eq(1)
+ input_mask = input_mask.masked_fill(~nonpad_positions, False)
+ input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
+
+ sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
+ input_mask_expand_dim, 0)
+
+ return sematic_embeds * tgt_mask, decoder_out * tgt_mask, pre_loss_att
+
+ def _calc_ctc_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ # Calc CTC loss
+ loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+
+ # Calc CER using CTC
+ cer_ctc = None
+ if not self.training and self.error_calculator is not None:
+ ys_hat = self.ctc.argmax(encoder_out).data
+ cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+ return loss_ctc, cer_ctc
diff --git a/funasr/cli/train_cli.py b/funasr/cli/train_cli.py
new file mode 100644
index 0000000..28e0e28
--- /dev/null
+++ b/funasr/cli/train_cli.py
@@ -0,0 +1,170 @@
+import argparse
+import logging
+import os
+import sys
+from io import BytesIO
+from collections.abc import Sequence
+import torch
+import hydra
+from omegaconf import DictConfig, OmegaConf
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+# from funasr.model_class_factory1 import model_choices
+from funasr.modules.lora.utils import mark_only_lora_as_trainable
+from funasr.optimizers import optim_choices
+from funasr.schedulers import scheduler_choices
+from funasr.torch_utils.load_pretrained_model import load_pretrained_model
+from funasr.torch_utils.initialize import initialize
+from funasr.datasets.data_sampler import BatchSampler
+# from funasr.tokenizer.build_tokenizer import build_tokenizer
+# from funasr.tokenizer.token_id_converter import TokenIDConverter
+from funasr.tokenizer.funtoken import build_tokenizer
+from funasr.datasets.dataset_jsonl import AudioDataset
+from funasr.cli.trainer import Trainer
+# from funasr.utils.load_fr_py import load_class_from_path
+from funasr.utils.dynamic_import import dynamic_import
+import torch.distributed as dist
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+
+
+def preprocess_config(cfg: DictConfig):
+ for key, value in cfg.items():
+ if value == 'None':
+ cfg[key] = None
+
+
+
+@hydra.main()
+def main(kwargs: DictConfig):
+ # preprocess_config(kwargs)
+ import pdb; pdb.set_trace()
+ # set random seed
+ set_all_random_seed(kwargs.get("seed", 0))
+ torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
+ torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
+ torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
+
+ local_rank = int(os.environ.get('LOCAL_RANK', 0))
+ # Check if we are using DDP or FSDP
+ use_ddp = 'WORLD_SIZE' in os.environ
+ use_fsdp = kwargs.get("use_fsdp", None)
+ if use_ddp or use_fsdp:
+ dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
+ device= torch.cuda.set_device(local_rank)
+
+
+ # build_tokenizer
+ tokenizer = build_tokenizer(
+ token_type=kwargs.get("token_type", "char"),
+ bpemodel=kwargs.get("bpemodel", None),
+ delimiter=kwargs.get("delimiter", None),
+ space_symbol=kwargs.get("space_symbol", "<space>"),
+ non_linguistic_symbols=kwargs.get("non_linguistic_symbols", None),
+ g2p_type=kwargs.get("g2p_type", None),
+ token_list=kwargs.get("token_list", None),
+ unk_symbol=kwargs.get("unk_symbol", "<unk>"),
+ )
+
+ # import pdb;
+ # pdb.set_trace()
+ # build model
+ # model_class = model_choices.get_class(kwargs.get("model", "asr"))
+ # model_class = load_class_from_path(kwargs.get("model").split(":"))
+ model_class = dynamic_import(kwargs.get("model"))
+ model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
+ # model = model.to(device=kwargs.get("device", "cpu"))
+
+
+
+ # import pdb;
+ # pdb.set_trace()
+ # freeze_param
+ freeze_param = kwargs.get("freeze_param", None)
+ if freeze_param is not None:
+ freeze_param = eval(freeze_param)
+ if isinstance(freeze_param, Sequence):
+ freeze_param = (freeze_param,)
+ logging.info("freeze_param is not None: ", freeze_param)
+ for t in freeze_param:
+ for k, p in model.named_parameters():
+ if k.startswith(t + ".") or k == t:
+ logging.info(f"Setting {k}.requires_grad = False")
+ p.requires_grad = False
+
+
+ if use_ddp:
+ model = model.cuda(local_rank)
+ model = DDP(model, device_ids=[local_rank])
+ elif use_fsdp:
+ model = FSDP(model).cuda(local_rank)
+
+
+ # optim
+ optim = kwargs.get("optim", "adam")
+ assert optim in optim_choices
+ optim_class = optim_choices.get(optim)
+ optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
+
+ # scheduler
+ scheduler = kwargs.get("scheduler", "warmuplr")
+ assert scheduler in scheduler_choices
+ scheduler_class = scheduler_choices.get(scheduler)
+ scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
+
+ # init_param
+ init_param = kwargs.get("init_param", None)
+ if init_param is not None:
+ init_param = eval(init_param)
+ if isinstance(init_param, Sequence):
+ init_param = (init_param,)
+ logging.info("init_param is not None: ", freeze_param)
+ for p in init_param:
+ logging.info(f"Loading pretrained params from {p}")
+ load_pretrained_model(
+ model=model,
+ init_param=p,
+ ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
+ oss_bucket=kwargs.get("oss_bucket", None),
+ )
+ else:
+ initialize(model, kwargs.get("init", "kaiming_normal"))
+
+
+ # dataset
+ dataset_tr = AudioDataset(kwargs.get("train_data_set_list"), frontend=model.frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
+
+ # dataloader
+ batch_sampler = BatchSampler(dataset_tr, **kwargs.get("dataset_conf"), **kwargs.get("dataset_conf").get("batch_conf"))
+ dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
+ collate_fn=dataset_tr.collator,
+ batch_sampler=batch_sampler,
+ num_workers=kwargs.get("num_workers", 0),
+ pin_memory=True)
+
+ trainer = Trainer(
+ model=model,
+ optim=optim,
+ scheduler=scheduler,
+ dataloader_train=dataloader_tr,
+ dataloader_val=None,
+ local_rank=local_rank,
+ use_ddp=use_ddp,
+ use_fsdp=use_fsdp,
+ **kwargs.get("train_conf"),
+ )
+ trainer.run()
+
+ if use_ddp or use_fsdp:
+ torch.distributed.destroy_process_group()
+
+
+
+def train(epoch, model, op):
+ pass
+
+def val():
+ pass
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/funasr/cli/trainer.py b/funasr/cli/trainer.py
new file mode 100644
index 0000000..30e0419
--- /dev/null
+++ b/funasr/cli/trainer.py
@@ -0,0 +1,236 @@
+import torch
+import os
+from funasr.torch_utils.device_funcs import to_device
+import logging
+from tqdm import tqdm
+from contextlib import nullcontext
+
+class Trainer:
+ """
+ A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
+ and optionally resuming from a saved checkpoint.
+
+ Attributes:
+ max_epoch (int): Maximum number of epochs for training.
+ model (torch.nn.Module): The model to be trained.
+ optim (torch.optim.Optimizer): The optimizer to use for training.
+ scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
+ dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
+ dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
+ output_dir (str): Directory where model checkpoints will be saved.
+ resume (str, optional): Path to a checkpoint to resume training from.
+ """
+
+ def __init__(self, model,
+ optim,
+ scheduler,
+ dataloader_train,
+ dataloader_val,
+ local_rank,
+ use_ddp=False,
+ use_fsdp=False,
+ **kwargs):
+ """
+ Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
+
+ Args:
+ model (torch.nn.Module): The model to be trained.
+ optim (torch.optim.Optimizer): The optimizer to use for training.
+ scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
+ dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
+ dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
+ **kwargs: Additional keyword arguments:
+ max_epoch (int): The maximum number of epochs for training.
+ output_dir (str): The directory where model checkpoints will be saved. Default is './'.
+ resume (str, optional): The file path to a checkpoint to resume training from.
+ """
+
+ self.model = model
+ self.optim = optim
+ self.scheduler = scheduler
+ self.dataloader_train = dataloader_train
+ self.dataloader_val = dataloader_val
+ self.output_dir = kwargs.get('output_dir', './')
+ self.resume = kwargs.get('resume', None)
+ self.start_epoch = 1
+ self.max_epoch = kwargs.get('max_epoch', 100)
+ self.local_rank = local_rank
+ self.use_ddp = use_ddp
+ self.use_fsdp = use_fsdp
+ self.device = torch.device("cuda", local_rank)
+ self.kwargs = kwargs
+
+ if self.resume:
+ self._resume_checkpoint(self.resume)
+
+ def _save_checkpoint(self, epoch):
+ """
+ Saves a checkpoint containing the model's state, the optimizer's state,
+ and the scheduler's state at the end of the given epoch. This method is
+ intended to be called at the end of each epoch to save the training progress.
+
+ Args:
+ epoch (int): The epoch number at which the checkpoint is being saved.
+ """
+ state = {
+ 'epoch': epoch,
+ 'state_dict': self.model.state_dict(),
+ 'optimizer': self.optim.state_dict(),
+ 'scheduler': self.scheduler.state_dict(),
+ }
+ # Create output directory if it does not exist
+ os.makedirs(self.output_dir, exist_ok=True)
+ filename = os.path.join(self.output_dir, f'model.{epoch}.pb')
+ torch.save(state, filename)
+ print(f'Checkpoint saved to {filename}')
+
+ def _resume_checkpoint(self, resume_path):
+ """
+ Resumes training from a checkpoint at the given file path.
+ Loads the model's state, the optimizer's state, and the scheduler's state.
+
+ Args:
+ resume_path (str): The file path to the checkpoint to resume from.
+ """
+ if os.path.isfile(resume_path):
+ checkpoint = torch.load(resume_path)
+ self.start_epoch = checkpoint['epoch'] + 1
+ self.model.load_state_dict(checkpoint['state_dict'])
+ self.optim.load_state_dict(checkpoint['optimizer'])
+ self.scheduler.load_state_dict(checkpoint['scheduler'])
+ print(f"Checkpoint loaded successfully from '{resume_path}' at (epoch {checkpoint['epoch']})")
+ else:
+ print(f"No checkpoint found at '{resume_path}', starting from scratch")
+
+ def run(self):
+ """
+ Starts the training process, iterating over epochs, training the model,
+ and saving checkpoints at the end of each epoch.
+ """
+ for epoch in range(self.start_epoch, self.max_epoch + 1):
+ self._train_epoch(epoch)
+ # self._validate_epoch(epoch)
+ self._save_checkpoint(epoch)
+ self.scheduler.step()
+
+ def _train_epoch(self, epoch):
+ """
+ Defines the training process for a single epoch with gradient accumulation.
+ Args:
+ epoch (int): The current epoch number.
+ """
+ self.model.train()
+ pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_train),
+ dynamic_ncols=True)
+
+ # Set the number of steps for gradient accumulation
+ accumulation_steps = self.kwargs.get("accumulation_steps", 1)
+ # Initialize the gradient accumulation
+ self.optim.zero_grad()
+
+ for batch_idx, batch in enumerate(self.dataloader_train):
+ batch = to_device(batch, self.device)
+
+ my_context = model.no_sync if batch_idx % accumulation_steps != 0 else nullcontext
+ with my_context():
+ retval = self.model(**batch)
+ loss, stats, weight = retval
+
+ # Scale the loss since we're not updating for every mini-batch
+ loss = loss / accumulation_steps
+ loss.backward()
+
+ # Perform an optimizer step only after accumulating enough gradients
+ if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(self.dataloader_train):
+ # Perform gradient clipping if it is set
+ if self.kwargs.get("grad_clip", None) is not None:
+ grad_norm = torch.nn.utils.clip_grad_norm_(
+ self.model.parameters(),
+ max_norm=self.kwargs.get("grad_clip", 10.0),
+ norm_type=self.kwargs.get("grad_clip_type", 2.0),
+ )
+ if not torch.isfinite(grad_norm):
+ logging.warning(
+ f"The grad norm is {grad_norm}. Skipping updating the model."
+ )
+ self.optim.zero_grad() # Reset gradients
+ continue
+
+ # Execute an optimization step (update model parameters)
+ self.optim.step()
+ self.scheduler.step()
+ # Clear gradients for the next accumulation stage
+ self.optim.zero_grad()
+
+ pbar.update(1)
+ pbar.set_description(
+ f"Training Epoch: {epoch + 1}/{self.max_epoch}, step {batch_idx}/{len(self.dataloader_train)} (loss: {loss.detach().float()})")
+
+ pbar.close()
+
+ # def _train_epoch(self, epoch):
+ # """
+ # Defines the training process for a single epoch.
+ # Should be implemented with the actual model training steps.
+ #
+ # Args:
+ # epoch (int): The current epoch number.
+ # """
+ # self.model.train()
+ # pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_train), dynamic_ncols=True)
+ # for batch_idx, batch in enumerate(self.dataloader_train):
+ # batch = to_device(batch, "cpu")
+ # retval = self.model(**batch)
+ # loss, stats, weight = retval
+ # self.optim.zero_grad()
+ # loss.backward()
+ #
+ # # compute the gradient norm to check if it is normal or not
+ # grad_norm = torch.nn.utils.clip_grad_norm_(
+ # self.model.parameters(),
+ # max_norm=self.kwargs.get("grad_clip", 10.0),
+ # norm_type=self.kwargs.get("grad_clip_type", 2.0),
+ # )
+ # if not torch.isfinite(grad_norm):
+ # logging.warning(
+ # f"The grad norm is {grad_norm}. Skipping updating the model."
+ # )
+ # continue
+ # self.optim.step()
+ # self.scheduler.step()
+ # pbar.update(1)
+ # pbar.set_description(
+ # f"Training Epoch: {epoch + 1}/{self.max_epoch}, step {batch_idx}/{len(self.dataloader_train)} (loss: {loss.detach().float()})")
+ #
+ # pbar.close()
+ #
+
+ def _validate_epoch(self, epoch):
+ """
+ Defines the validation process for a single epoch.
+ Should be implemented with the actual model validation steps.
+
+ Args:
+ epoch (int): The current epoch number.
+ """
+ self.model.eval()
+ with torch.no_grad():
+ for data, target in self.dataloader_val:
+ # Implement the model validation steps here
+ pass
+
+# # Example usage
+# if __name__ == "__main__":
+# # Assuming the following objects have already been correctly created and initialized:
+# # model, optim, scheduler, dataloader_train, and dataloader_val.
+# trainer = Trainer(
+# max_epoch=10,
+# model=model,
+# optim=optim,
+# scheduler=scheduler,
+# dataloader_train=dataloader_train,
+# dataloader_val=dataloader_val,
+# output_dir='path_to_save_model',
+# resume='path_to_checkpoint_if_any'
+# )
+# trainer.run()
\ No newline at end of file
diff --git a/funasr/datasets/data_sampler.py b/funasr/datasets/data_sampler.py
index 60c7c84..3a19a17 100644
--- a/funasr/datasets/data_sampler.py
+++ b/funasr/datasets/data_sampler.py
@@ -4,17 +4,17 @@
class BatchSampler(torch.utils.data.BatchSampler):
- def __init__(self, dataset, batch_size_type: str="example", batch_size: int=100, sort_size: int=30, drop_last: bool=False, shuffle: bool=True, **kwargs):
+ def __init__(self, dataset, batch_type: str="example", batch_size: int=100, sort_size: int=30, drop_last: bool=False, shuffle: bool=True, **kwargs):
self.drop_last = drop_last
self.pre_idx = -1
self.dataset = dataset
self.total_samples = len(dataset)
- # self.batch_size_type = args.batch_size_type
+ # self.batch_type = args.batch_type
# self.batch_size = args.batch_size
# self.sort_size = args.sort_size
# self.max_length_token = args.max_length_token
- self.batch_size_type = batch_size_type
+ self.batch_type = batch_type
self.batch_size = batch_size
self.sort_size = sort_size
self.max_length_token = kwargs.get("max_length_token", 5000)
@@ -26,7 +26,7 @@
return self.total_samples
def __iter__(self):
- print("in sampler")
+ # print("in sampler")
if self.shuffle:
np.random.shuffle(self.shuffle_idx)
@@ -36,7 +36,7 @@
num_sample = 0
iter_num = (self.total_samples-1) // self.sort_size + 1
- print("iter_num: ", iter_num)
+ # print("iter_num: ", iter_num)
for iter in range(self.pre_idx + 1, iter_num):
datalen_with_index = []
for i in range(self.sort_size):
@@ -59,7 +59,7 @@
max_token_cur = max(max_token, sample_len_cur_raw)
max_token_padding = 1 + num_sample
- if self.batch_size_type == 'token':
+ if self.batch_type == 'token':
max_token_padding *= max_token_cur
if max_token_padding <= self.batch_size:
batch.append(idx)
diff --git a/funasr/datasets/dataset_jsonl.py b/funasr/datasets/dataset_jsonl.py
index 3a548c8..eef67c5 100644
--- a/funasr/datasets/dataset_jsonl.py
+++ b/funasr/datasets/dataset_jsonl.py
@@ -88,18 +88,16 @@
class AudioDataset(torch.utils.data.Dataset):
- def __init__(self, path, frontend=None, tokenizer=None, token_id_converter=None):
-
+ def __init__(self, path, frontend=None, tokenizer=None, int_pad_value: int = -1, float_pad_value: float = 0.0, **kwargs):
super().__init__()
self.indexed_dataset = IndexedDatasetJsonl(path)
self.frontend = frontend.forward
self.fs = 16000 if frontend is None else frontend.fs
self.data_type = "sound"
self.tokenizer = tokenizer
- self.token_id_converter = token_id_converter
- self.int_pad_value = -1
- self.float_pad_value = 0.0
+ self.int_pad_value = int_pad_value
+ self.float_pad_value = float_pad_value
@@ -115,8 +113,7 @@
data_src = load_audio(source, fs=self.fs)
speech, speech_lengths = extract_features(data_src, self.data_type, self.frontend)
target = item["target"]
- text = self.tokenizer.text2tokens(target)
- ids = self.token_id_converter.tokens2ids(text)
+ ids = self.tokenizer.encode(target)
ids_lengths = len(ids)
text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)
diff --git a/funasr/datasets/small_datasets/preprocessor.py b/funasr/datasets/small_datasets/preprocessor.py
index 01a8c6f..62beaab 100644
--- a/funasr/datasets/small_datasets/preprocessor.py
+++ b/funasr/datasets/small_datasets/preprocessor.py
@@ -361,6 +361,7 @@
tokens = seg_tokenize(tokens, self.seg_dict)
else:
tokens = self.tokenizer.text2tokens(text)
+
text_ints = self.token_id_converter.tokens2ids(tokens)
data[self.text_name] = np.array(text_ints, dtype=np.int64)
return data
diff --git a/funasr/modules/nets_utils.py b/funasr/modules/nets_utils.py
index b1879fa..0beb083 100644
--- a/funasr/modules/nets_utils.py
+++ b/funasr/modules/nets_utils.py
@@ -347,7 +347,7 @@
Args:
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
- pad_targets (LongTensor): Target label tensors (B, Lmax, D).
+ pad_targets (LongTensor): Target label tensors (B, Lmax).
ignore_label (int): Ignore label id.
Returns:
diff --git a/funasr/optimizers/__init__.py b/funasr/optimizers/__init__.py
index e69de29..b4dfe5d 100644
--- a/funasr/optimizers/__init__.py
+++ b/funasr/optimizers/__init__.py
@@ -0,0 +1,17 @@
+import torch
+from funasr.optimizers.fairseq_adam import FairseqAdam
+from funasr.optimizers.sgd import SGD
+
+optim_choices = dict(
+ adam=torch.optim.Adam,
+ fairseq_adam=FairseqAdam,
+ adamw=torch.optim.AdamW,
+ sgd=SGD,
+ adadelta=torch.optim.Adadelta,
+ adagrad=torch.optim.Adagrad,
+ adamax=torch.optim.Adamax,
+ asgd=torch.optim.ASGD,
+ lbfgs=torch.optim.LBFGS,
+ rmsprop=torch.optim.RMSprop,
+ rprop=torch.optim.Rprop,
+)
\ No newline at end of file
diff --git a/funasr/schedulers/__init__.py b/funasr/schedulers/__init__.py
index e69de29..7bb8118 100644
--- a/funasr/schedulers/__init__.py
+++ b/funasr/schedulers/__init__.py
@@ -0,0 +1,23 @@
+import torch
+import torch.multiprocessing
+import torch.nn
+import torch.optim
+
+from funasr.schedulers.noam_lr import NoamLR
+from funasr.schedulers.tri_stage_scheduler import TriStageLR
+from funasr.schedulers.warmup_lr import WarmupLR
+
+scheduler_choices = dict(
+ ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
+ lambdalr=torch.optim.lr_scheduler.LambdaLR,
+ steplr=torch.optim.lr_scheduler.StepLR,
+ multisteplr=torch.optim.lr_scheduler.MultiStepLR,
+ exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
+ CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
+ noamlr=NoamLR,
+ warmuplr=WarmupLR,
+ tri_stage=TriStageLR,
+ cycliclr=torch.optim.lr_scheduler.CyclicLR,
+ onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
+ CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
+)
\ No newline at end of file
diff --git a/funasr/tokenizer/abs_tokenizer.py b/funasr/tokenizer/abs_tokenizer.py
index fc2ccb3..ffb6b76 100644
--- a/funasr/tokenizer/abs_tokenizer.py
+++ b/funasr/tokenizer/abs_tokenizer.py
@@ -2,7 +2,13 @@
from abc import abstractmethod
from typing import Iterable
from typing import List
+from pathlib import Path
+from typing import Dict
+from typing import Iterable
+from typing import List
+from typing import Union
+import numpy as np
class AbsTokenizer(ABC):
@abstractmethod
@@ -12,3 +18,70 @@
@abstractmethod
def tokens2text(self, tokens: Iterable[str]) -> str:
raise NotImplementedError
+
+
+class BaseTokenizer(ABC):
+ def __init__(self, token_list: Union[Path, str, Iterable[str]],
+ unk_symbol: str = "<unk>",
+ **kwargs,
+ ):
+
+ if isinstance(token_list, (Path, str)):
+ token_list = Path(token_list)
+ self.token_list_repr = str(token_list)
+ self.token_list: List[str] = []
+
+ with token_list.open("r", encoding="utf-8") as f:
+ for idx, line in enumerate(f):
+ line = line.rstrip()
+ self.token_list.append(line)
+
+ else:
+ self.token_list: List[str] = list(token_list)
+ self.token_list_repr = ""
+ for i, t in enumerate(self.token_list):
+ if i == 3:
+ break
+ self.token_list_repr += f"{t}, "
+ self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
+
+ self.token2id: Dict[str, int] = {}
+ for i, t in enumerate(self.token_list):
+ if t in self.token2id:
+ raise RuntimeError(f'Symbol "{t}" is duplicated')
+ self.token2id[t] = i
+
+ self.unk_symbol = unk_symbol
+ if self.unk_symbol not in self.token2id:
+ raise RuntimeError(
+ f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
+ )
+ self.unk_id = self.token2id[self.unk_symbol]
+
+ def encode(self, text):
+ tokens = self.text2tokens(text)
+ text_ints = self.tokens2ids(tokens)
+
+ return text_ints
+
+ def decode(self, text_ints):
+ return self.ids2tokens(text_ints)
+
+ def get_num_vocabulary_size(self) -> int:
+ return len(self.token_list)
+
+ def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
+ if isinstance(integers, np.ndarray) and integers.ndim != 1:
+ raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
+ return [self.token_list[i] for i in integers]
+
+ def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
+ return [self.token2id.get(i, self.unk_id) for i in tokens]
+
+ @abstractmethod
+ def text2tokens(self, line: str) -> List[str]:
+ raise NotImplementedError
+
+ @abstractmethod
+ def tokens2text(self, tokens: Iterable[str]) -> str:
+ raise NotImplementedError
diff --git a/funasr/tokenizer/build_tokenizer.py b/funasr/tokenizer/build_tokenizer.py
index 9d1cdc3..1dc17da 100644
--- a/funasr/tokenizer/build_tokenizer.py
+++ b/funasr/tokenizer/build_tokenizer.py
@@ -1,7 +1,17 @@
from pathlib import Path
from typing import Iterable
from typing import Union
+from abc import ABC
+from abc import abstractmethod
+from typing import Iterable
+from typing import List
+from pathlib import Path
+from typing import Dict
+from typing import Iterable
+from typing import List
+from typing import Union
+import numpy as np
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
from funasr.tokenizer.char_tokenizer import CharTokenizer
@@ -18,6 +28,7 @@
space_symbol: str = "<space>",
delimiter: str = None,
g2p_type: str = None,
+ **kwargs,
) -> AbsTokenizer:
"""A helper function to instantiate Tokenizer"""
if token_type == "bpe":
@@ -28,7 +39,7 @@
raise RuntimeError(
"remove_non_linguistic_symbols is not implemented for token_type=bpe"
)
- return SentencepiecesTokenizer(bpemodel)
+ return SentencepiecesTokenizer(bpemodel, **kwargs)
elif token_type == "word":
if remove_non_linguistic_symbols and non_linguistic_symbols is not None:
@@ -38,13 +49,14 @@
remove_non_linguistic_symbols=True,
)
else:
- return WordTokenizer(delimiter=delimiter)
+ return WordTokenizer(delimiter=delimiter, **kwargs)
elif token_type == "char":
return CharTokenizer(
non_linguistic_symbols=non_linguistic_symbols,
space_symbol=space_symbol,
remove_non_linguistic_symbols=remove_non_linguistic_symbols,
+ **kwargs
)
elif token_type == "phn":
@@ -53,6 +65,7 @@
non_linguistic_symbols=non_linguistic_symbols,
space_symbol=space_symbol,
remove_non_linguistic_symbols=remove_non_linguistic_symbols,
+ **kwargs
)
else:
diff --git a/funasr/tokenizer/char_tokenizer.py b/funasr/tokenizer/char_tokenizer.py
index 6c9a5a5..80528a2 100644
--- a/funasr/tokenizer/char_tokenizer.py
+++ b/funasr/tokenizer/char_tokenizer.py
@@ -6,15 +6,17 @@
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
+from funasr.tokenizer.abs_tokenizer import BaseTokenizer
-
-class CharTokenizer(AbsTokenizer):
+class CharTokenizer(BaseTokenizer):
def __init__(
self,
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
space_symbol: str = "<space>",
remove_non_linguistic_symbols: bool = False,
+ **kwargs,
):
+ super().__init__(**kwargs)
self.space_symbol = space_symbol
if non_linguistic_symbols is None:
self.non_linguistic_symbols = set()
diff --git a/funasr/tokenizer/funtoken.py b/funasr/tokenizer/funtoken.py
new file mode 100644
index 0000000..7187d85
--- /dev/null
+++ b/funasr/tokenizer/funtoken.py
@@ -0,0 +1,75 @@
+from pathlib import Path
+from typing import Iterable
+from typing import Union
+from abc import ABC
+from abc import abstractmethod
+from typing import Iterable
+from typing import List
+from pathlib import Path
+from typing import Dict
+from typing import Iterable
+from typing import List
+from typing import Union
+
+import numpy as np
+
+from funasr.tokenizer.abs_tokenizer import AbsTokenizer
+from funasr.tokenizer.char_tokenizer import CharTokenizer
+from funasr.tokenizer.phoneme_tokenizer import PhonemeTokenizer
+from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer
+from funasr.tokenizer.word_tokenizer import WordTokenizer
+
+def build_tokenizer(
+ token_type: str,
+ bpemodel: Union[Path, str, Iterable[str]] = None,
+ non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+ remove_non_linguistic_symbols: bool = False,
+ space_symbol: str = "<space>",
+ delimiter: str = None,
+ g2p_type: str = None,
+ **kwargs,
+):
+ """A helper function to instantiate Tokenizer"""
+ # import pdb;
+ # pdb.set_trace()
+ if token_type == "bpe":
+ if bpemodel is None:
+ raise ValueError('bpemodel is required if token_type = "bpe"')
+
+ if remove_non_linguistic_symbols:
+ raise RuntimeError(
+ "remove_non_linguistic_symbols is not implemented for token_type=bpe"
+ )
+ return SentencepiecesTokenizer(bpemodel, **kwargs)
+
+ elif token_type == "word":
+ if remove_non_linguistic_symbols and non_linguistic_symbols is not None:
+ return WordTokenizer(
+ delimiter=delimiter,
+ non_linguistic_symbols=non_linguistic_symbols,
+ remove_non_linguistic_symbols=True,
+ )
+ else:
+ return WordTokenizer(delimiter=delimiter, **kwargs)
+
+ elif token_type == "char":
+ return CharTokenizer(
+ non_linguistic_symbols=non_linguistic_symbols,
+ space_symbol=space_symbol,
+ remove_non_linguistic_symbols=remove_non_linguistic_symbols,
+ **kwargs
+ )
+
+ elif token_type == "phn":
+ return PhonemeTokenizer(
+ g2p_type=g2p_type,
+ non_linguistic_symbols=non_linguistic_symbols,
+ space_symbol=space_symbol,
+ remove_non_linguistic_symbols=remove_non_linguistic_symbols,
+ **kwargs
+ )
+
+ else:
+ raise ValueError(
+ f"token_mode must be one of bpe, word, char or phn: " f"{token_type}"
+ )
diff --git a/funasr/tokenizer/phoneme_tokenizer.py b/funasr/tokenizer/phoneme_tokenizer.py
index 0117c6a..04b423b 100644
--- a/funasr/tokenizer/phoneme_tokenizer.py
+++ b/funasr/tokenizer/phoneme_tokenizer.py
@@ -363,6 +363,7 @@
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
space_symbol: str = "<space>",
remove_non_linguistic_symbols: bool = False,
+ **kwargs,
):
if g2p_type is None:
self.g2p = split_by_space
diff --git a/funasr/tokenizer/sentencepiece_tokenizer.py b/funasr/tokenizer/sentencepiece_tokenizer.py
index 9a65920..df98c2c 100644
--- a/funasr/tokenizer/sentencepiece_tokenizer.py
+++ b/funasr/tokenizer/sentencepiece_tokenizer.py
@@ -9,7 +9,7 @@
class SentencepiecesTokenizer(AbsTokenizer):
- def __init__(self, model: Union[Path, str]):
+ def __init__(self, model: Union[Path, str], **kwargs):
self.model = str(model)
# NOTE(kamo):
# Don't build SentencePieceProcessor in __init__()
diff --git a/funasr/tokenizer/word_tokenizer.py b/funasr/tokenizer/word_tokenizer.py
index cbd0673..d7bbaf9 100644
--- a/funasr/tokenizer/word_tokenizer.py
+++ b/funasr/tokenizer/word_tokenizer.py
@@ -14,6 +14,7 @@
delimiter: str = None,
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
remove_non_linguistic_symbols: bool = False,
+ **kwargs,
):
self.delimiter = delimiter
diff --git a/funasr/utils/dynamic_import.py b/funasr/utils/dynamic_import.py
new file mode 100644
index 0000000..2830cb2
--- /dev/null
+++ b/funasr/utils/dynamic_import.py
@@ -0,0 +1,13 @@
+import importlib
+
+
+def dynamic_import(import_path):
+ """dynamic import module and class
+
+ :param str import_path: syntax 'module_name:class_name'
+ :return: imported class
+ """
+
+ module_name, objname = import_path.split(":")
+ m = importlib.import_module(module_name)
+ return getattr(m, objname)
diff --git a/funasr/utils/load_fr_py.py b/funasr/utils/load_fr_py.py
new file mode 100644
index 0000000..6697e04
--- /dev/null
+++ b/funasr/utils/load_fr_py.py
@@ -0,0 +1,13 @@
+import importlib.util
+import sys
+
+def load_class_from_path(model_path):
+ path, class_name = model_path
+ # import pdb;
+ # pdb.set_trace()
+ spec = importlib.util.spec_from_file_location("module.name", path)
+ module = importlib.util.module_from_spec(spec)
+ sys.modules[spec.name] = module
+ spec.loader.exec_module(module)
+ return getattr(module, class_name)
+
--
Gitblit v1.9.1