From 172e7ac986f299ad545cbd91a8cecc3ef967af36 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 11 十二月 2023 10:17:22 +0800
Subject: [PATCH] Revert "Dev gzf funasr2" (#1164)
---
funasr/modules/nets_utils.py | 2
funasr/models/e2e_asr.py | 1
funasr/datasets/small_datasets/preprocessor.py | 1
funasr/models/e2e_asr_paraformer.py | 7 -
funasr/tokenizer/char_tokenizer.py | 6 -
funasr/bin/asr_trainer.py | 0
funasr/datasets/data_sampler.py | 16 ++--
/dev/null | 13 ---
funasr/tokenizer/sentencepiece_tokenizer.py | 2
funasr/models/e2e_asr_contextual_paraformer.py | 1
funasr/tokenizer/abs_tokenizer.py | 74 ------------------
funasr/schedulers/__init__.py | 23 -----
funasr/tokenizer/phoneme_tokenizer.py | 1
funasr/models/e2e_uni_asr.py | 1
funasr/optimizers/__init__.py | 17 ----
funasr/tokenizer/build_tokenizer.py | 19 ----
funasr/datasets/dataloader_fn.py | 5 +
funasr/datasets/dataset_jsonl.py | 18 +--
funasr/tokenizer/word_tokenizer.py | 1
19 files changed, 27 insertions(+), 181 deletions(-)
diff --git a/funasr/cli/__init__.py b/funasr/bin/asr_trainer.py
similarity index 100%
rename from funasr/cli/__init__.py
rename to funasr/bin/asr_trainer.py
diff --git a/funasr/cli/model_class_factory.py b/funasr/cli/model_class_factory.py
deleted file mode 100644
index b329492..0000000
--- a/funasr/cli/model_class_factory.py
+++ /dev/null
@@ -1,298 +0,0 @@
-import argparse
-import logging
-import os
-from pathlib import Path
-from typing import Callable
-from typing import Collection
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-
-import numpy as np
-import torch
-import yaml
-
-from funasr.datasets.collate_fn import CommonCollateFn
-from funasr.datasets.preprocessor import CommonPreprocessor
-from funasr.layers.abs_normalize import AbsNormalize
-from funasr.layers.global_mvn import GlobalMVN
-from funasr.layers.utterance_mvn import UtteranceMVN
-from funasr.models.ctc import CTC
-from funasr.models.decoder.abs_decoder import AbsDecoder
-from funasr.models.decoder.rnn_decoder import RNNDecoder
-from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
-from funasr.models.decoder.transformer_decoder import (
- DynamicConvolution2DTransformerDecoder, # noqa: H301
-)
-from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
-from funasr.models.decoder.transformer_decoder import (
- LightweightConvolution2DTransformerDecoder, # noqa: H301
-)
-from funasr.models.decoder.transformer_decoder import (
- LightweightConvolutionTransformerDecoder, # noqa: H301
-)
-from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
-from funasr.models.decoder.transformer_decoder import TransformerDecoder
-from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
-from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
-from funasr.models.e2e_asr import ASRModel
-from funasr.models.decoder.rnnt_decoder import RNNTDecoder
-from funasr.models.joint_net.joint_network import JointNetwork
-from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
-from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
-from funasr.models.e2e_tp import TimestampPredictor
-from funasr.models.e2e_asr_mfcca import MFCCA
-from funasr.models.e2e_sa_asr import SAASRModel
-from funasr.models.e2e_uni_asr import UniASR
-from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
-from funasr.models.e2e_asr_bat import BATModel
-from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
-from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
-from funasr.models.encoder.rnn_encoder import RNNEncoder
-from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
-from funasr.models.encoder.transformer_encoder import TransformerEncoder
-from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
-from funasr.models.encoder.resnet34_encoder import ResNet34Diar
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.frontend.default import DefaultFrontend
-from funasr.models.frontend.default import MultiChannelFrontend
-from funasr.models.frontend.fused import FusedFrontends
-from funasr.models.frontend.s3prl import S3prlFrontend
-from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.models.frontend.windowing import SlidingWindow
-from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.postencoder.hugging_face_transformers_postencoder import (
- HuggingFaceTransformersPostEncoder, # noqa: H301
-)
-from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3, BATPredictor
-from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-from funasr.models.preencoder.linear import LinearProjection
-from funasr.models.preencoder.sinc import LightweightSincConvs
-from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.specaug.specaug import SpecAug
-from funasr.models.specaug.specaug import SpecAugLFR
-from funasr.modules.subsampling import Conv1dSubsampling
-from funasr.tasks.abs_task import AbsTask
-from funasr.tokenizer.phoneme_tokenizer import g2p_choices
-from funasr.torch_utils.initialize import initialize
-from funasr.models.base_model import FunASRModel
-from funasr.train.class_choices import ClassChoices
-from funasr.train.trainer import Trainer
-from funasr.utils.get_default_kwargs import get_default_kwargs
-from funasr.utils.nested_dict_action import NestedDictAction
-from funasr.utils.types import float_or_none
-from funasr.utils.types import int_or_none
-from funasr.utils.types import str2bool
-from funasr.utils.types import str_or_none
-
-# from funasr.models.paraformer import Paraformer
-frontend_choices = ClassChoices(
- name="frontend",
- classes=dict(
- default=DefaultFrontend,
- sliding_window=SlidingWindow,
- s3prl=S3prlFrontend,
- fused=FusedFrontends,
- wav_frontend=WavFrontend,
- multichannelfrontend=MultiChannelFrontend,
- ),
- type_check=AbsFrontend,
- default="default",
-)
-specaug_choices = ClassChoices(
- name="specaug",
- classes=dict(
- specaug=SpecAug,
- specaug_lfr=SpecAugLFR,
- ),
- type_check=AbsSpecAug,
- default=None,
- optional=True,
-)
-# specaug_choices = {"specaug":SpecAug}
-normalize_choices = ClassChoices(
- "normalize",
- classes=dict(
- global_mvn=GlobalMVN,
- utterance_mvn=UtteranceMVN,
- ),
- type_check=AbsNormalize,
- default=None,
- optional=True,
-)
-# model_choices = ClassChoices(
-# "model",
-# classes=dict(
-# asr=ASRModel,
-# uniasr=UniASR,
-# paraformer=Paraformer,
-# paraformer_online=ParaformerOnline,
-# paraformer_bert=ParaformerBert,
-# bicif_paraformer=BiCifParaformer,
-# contextual_paraformer=ContextualParaformer,
-# neatcontextual_paraformer=NeatContextualParaformer,
-# mfcca=MFCCA,
-# timestamp_prediction=TimestampPredictor,
-# rnnt=TransducerModel,
-# rnnt_unified=UnifiedTransducerModel,
-# bat=BATModel,
-# sa_asr=SAASRModel,
-# ),
-# type_check=None,
-# default="asr",
-# )
-preencoder_choices = ClassChoices(
- name="preencoder",
- classes=dict(
- sinc=LightweightSincConvs,
- linear=LinearProjection,
- ),
- type_check=AbsPreEncoder,
- default=None,
- optional=True,
-)
-encoder_choices = ClassChoices(
- "encoder",
- classes=dict(
- conformer=ConformerEncoder,
- transformer=TransformerEncoder,
- rnn=RNNEncoder,
- sanm=SANMEncoder,
- sanm_chunk_opt=SANMEncoderChunkOpt,
- data2vec_encoder=Data2VecEncoder,
- mfcca_enc=MFCCAEncoder,
- chunk_conformer=ConformerChunkEncoder,
- ),
- type_check=AbsEncoder,
- default="rnn",
-)
-encoder_choices2 = ClassChoices(
- "encoder2",
- classes=dict(
- conformer=ConformerEncoder,
- transformer=TransformerEncoder,
- rnn=RNNEncoder,
- sanm=SANMEncoder,
- sanm_chunk_opt=SANMEncoderChunkOpt,
- ),
- type_check=AbsEncoder,
- default="rnn",
-)
-asr_encoder_choices = ClassChoices(
- "asr_encoder",
- classes=dict(
- conformer=ConformerEncoder,
- transformer=TransformerEncoder,
- rnn=RNNEncoder,
- sanm=SANMEncoder,
- sanm_chunk_opt=SANMEncoderChunkOpt,
- data2vec_encoder=Data2VecEncoder,
- mfcca_enc=MFCCAEncoder,
- ),
- type_check=AbsEncoder,
- default="rnn",
-)
-spk_encoder_choices = ClassChoices(
- "spk_encoder",
- classes=dict(
- resnet34_diar=ResNet34Diar,
- ),
- default="resnet34_diar",
-)
-postencoder_choices = ClassChoices(
- name="postencoder",
- classes=dict(
- hugging_face_transformers=HuggingFaceTransformersPostEncoder,
- ),
- type_check=AbsPostEncoder,
- default=None,
- optional=True,
-)
-decoder_choices = ClassChoices(
- "decoder",
- classes=dict(
- transformer=TransformerDecoder,
- lightweight_conv=LightweightConvolutionTransformerDecoder,
- lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
- dynamic_conv=DynamicConvolutionTransformerDecoder,
- dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
- rnn=RNNDecoder,
- fsmn_scama_opt=FsmnDecoderSCAMAOpt,
- paraformer_decoder_sanm=ParaformerSANMDecoder,
- paraformer_decoder_san=ParaformerDecoderSAN,
- contextual_paraformer_decoder=ContextualParaformerDecoder,
- sa_decoder=SAAsrTransformerDecoder,
- ),
- type_check=AbsDecoder,
- default="rnn",
-)
-decoder_choices2 = ClassChoices(
- "decoder2",
- classes=dict(
- transformer=TransformerDecoder,
- lightweight_conv=LightweightConvolutionTransformerDecoder,
- lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
- dynamic_conv=DynamicConvolutionTransformerDecoder,
- dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
- rnn=RNNDecoder,
- fsmn_scama_opt=FsmnDecoderSCAMAOpt,
- paraformer_decoder_sanm=ParaformerSANMDecoder,
- ),
- type_check=AbsDecoder,
- default="rnn",
-)
-
-rnnt_decoder_choices = ClassChoices(
- "rnnt_decoder",
- classes=dict(
- rnnt=RNNTDecoder,
- ),
- type_check=RNNTDecoder,
- default="rnnt",
-)
-
-joint_network_choices = ClassChoices(
- name="joint_network",
- classes=dict(
- joint_network=JointNetwork,
- ),
- default="joint_network",
- optional=True,
-)
-
-predictor_choices = ClassChoices(
- name="predictor",
- classes=dict(
- cif_predictor=CifPredictor,
- ctc_predictor=None,
- cif_predictor_v2=CifPredictorV2,
- cif_predictor_v3=CifPredictorV3,
- bat_predictor=BATPredictor,
- ),
- type_check=None,
- default="cif_predictor",
- optional=True,
-)
-predictor_choices2 = ClassChoices(
- name="predictor2",
- classes=dict(
- cif_predictor=CifPredictor,
- ctc_predictor=None,
- cif_predictor_v2=CifPredictorV2,
- ),
- type_check=None,
- default="cif_predictor",
- optional=True,
-)
-stride_conv_choices = ClassChoices(
- name="stride_conv",
- classes=dict(
- stride_conv1d=Conv1dSubsampling
- ),
- type_check=None,
- default="stride_conv1d",
- optional=True,
-)
\ No newline at end of file
diff --git a/funasr/cli/models/__init__.py b/funasr/cli/models/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/cli/models/__init__.py
+++ /dev/null
diff --git a/funasr/cli/models/paraformer.py b/funasr/cli/models/paraformer.py
deleted file mode 100644
index ee8c0b4..0000000
--- a/funasr/cli/models/paraformer.py
+++ /dev/null
@@ -1,652 +0,0 @@
-import logging
-from contextlib import contextmanager
-from distutils.version import LooseVersion
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-
-import torch
-import torch.nn as nn
-import random
-import numpy as np
-
-# from funasr.layers.abs_normalize import AbsNormalize
-from funasr.losses.label_smoothing_loss import (
- LabelSmoothingLoss, # noqa: H301
-)
-# from funasr.models.ctc import CTC
-# from funasr.models.decoder.abs_decoder import AbsDecoder
-# from funasr.models.e2e_asr_common import ErrorCalculator
-# from funasr.models.encoder.abs_encoder import AbsEncoder
-# from funasr.models.frontend.abs_frontend import AbsFrontend
-# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.predictor.cif import mae_loss
-# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-# from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.modules.add_sos_eos import add_sos_eos
-from funasr.modules.nets_utils import make_pad_mask, pad_list
-from funasr.modules.nets_utils import th_accuracy
-from funasr.torch_utils.device_funcs import force_gatherable
-# from funasr.models.base_model import FunASRModel
-# from funasr.models.predictor.cif import CifPredictorV3
-
-from funasr.cli.model_class_factory import *
-
-
-if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
- from torch.cuda.amp import autocast
-else:
- # Nothing to do if torch<1.6.0
- @contextmanager
- def autocast(enabled=True):
- yield
-
-
-class Paraformer(nn.Module):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
- https://arxiv.org/abs/2206.08317
- """
-
- def __init__(
- self,
- # token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[str] = None,
- frontend_conf: Optional[Dict] = None,
- specaug: Optional[str] = None,
- specaug_conf: Optional[Dict] = None,
- normalize: str = None,
- normalize_conf: Optional[Dict] = None,
- encoder: str = None,
- encoder_conf: Optional[Dict] = None,
- decoder: str = None,
- decoder_conf: Optional[Dict] = None,
- ctc: str = None,
- ctc_conf: Optional[Dict] = None,
- predictor: str = None,
- predictor_conf: Optional[Dict] = None,
- ctc_weight: float = 0.5,
- interctc_weight: float = 0.0,
- input_size: int = 80,
- vocab_size: int = -1,
- ignore_id: int = -1,
- blank_id: int = 0,
- sos: int = 1,
- eos: int = 2,
- lsm_weight: float = 0.0,
- length_normalized_loss: bool = False,
- # report_cer: bool = True,
- # report_wer: bool = True,
- # sym_space: str = "<space>",
- # sym_blank: str = "<blank>",
- # extract_feats_in_collect_stats: bool = True,
- # predictor=None,
- predictor_weight: float = 0.0,
- predictor_bias: int = 0,
- sampling_ratio: float = 0.2,
- share_embedding: bool = False,
- # preencoder: Optional[AbsPreEncoder] = None,
- # postencoder: Optional[AbsPostEncoder] = None,
- use_1st_decoder_loss: bool = False,
- **kwargs,
- ):
- assert 0.0 <= ctc_weight <= 1.0, ctc_weight
- assert 0.0 <= interctc_weight < 1.0, interctc_weight
-
- super().__init__()
-
- # import pdb;
- # pdb.set_trace()
-
- if frontend is not None:
- frontend_class = frontend_choices.get_class(frontend)
- frontend = frontend_class(**frontend_conf)
- if specaug is not None:
- specaug_class = specaug_choices.get_class(specaug)
- specaug = specaug_class(**specaug_conf)
- if normalize is not None:
- normalize_class = normalize_choices.get_class(normalize)
- normalize = normalize_class(**normalize_conf)
- encoder_class = encoder_choices.get_class(encoder)
- encoder = encoder_class(input_size=input_size, **encoder_conf)
- encoder_output_size = encoder.output_size()
- if decoder is not None:
- decoder_class = decoder_choices.get_class(decoder)
- decoder = decoder_class(
- vocab_size=vocab_size,
- encoder_output_size=encoder_output_size,
- **decoder_conf,
- )
- if ctc_weight > 0.0:
-
- if ctc_conf is None:
- ctc_conf = {}
-
- ctc = CTC(
- odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
- )
- if predictor is not None:
- predictor_class = predictor_choices.get_class(predictor)
- predictor = predictor_class(**predictor_conf)
-
- # note that eos is the same as sos (equivalent ID)
- self.blank_id = blank_id
- self.sos = sos if sos is not None else vocab_size - 1
- self.eos = eos if eos is not None else vocab_size - 1
- self.vocab_size = vocab_size
- self.ignore_id = ignore_id
- self.ctc_weight = ctc_weight
- self.interctc_weight = interctc_weight
- # self.token_list = token_list.copy()
- #
- self.frontend = frontend
- self.specaug = specaug
- self.normalize = normalize
- # self.preencoder = preencoder
- # self.postencoder = postencoder
- self.encoder = encoder
- #
- # if not hasattr(self.encoder, "interctc_use_conditioning"):
- # self.encoder.interctc_use_conditioning = False
- # if self.encoder.interctc_use_conditioning:
- # self.encoder.conditioning_layer = torch.nn.Linear(
- # vocab_size, self.encoder.output_size()
- # )
- #
- # self.error_calculator = None
- #
- if ctc_weight == 1.0:
- self.decoder = None
- else:
- self.decoder = decoder
-
- self.criterion_att = LabelSmoothingLoss(
- size=vocab_size,
- padding_idx=ignore_id,
- smoothing=lsm_weight,
- normalize_length=length_normalized_loss,
- )
- #
- # if report_cer or report_wer:
- # self.error_calculator = ErrorCalculator(
- # token_list, sym_space, sym_blank, report_cer, report_wer
- # )
- #
- if ctc_weight == 0.0:
- self.ctc = None
- else:
- self.ctc = ctc
- #
- # self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
- self.predictor = predictor
- self.predictor_weight = predictor_weight
- self.predictor_bias = predictor_bias
- self.sampling_ratio = sampling_ratio
- self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
- # self.step_cur = 0
- #
- self.share_embedding = share_embedding
- if self.share_embedding:
- self.decoder.embed = None
-
- self.use_1st_decoder_loss = use_1st_decoder_loss
-
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- **kwargs,
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Frontend + Encoder + Decoder + Calc loss
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- text: (Batch, Length)
- text_lengths: (Batch,)
- decoding_ind: int
- """
- decoding_ind = kwargs.get("kwargs", None)
- # import pdb;
- # pdb.set_trace()
- if len(text_lengths.size()) > 1:
- text_lengths = text_lengths[:, 0]
- if len(speech_lengths.size()) > 1:
- speech_lengths = speech_lengths[:, 0]
-
- batch_size = speech.shape[0]
-
- # # for data-parallel
- # text = text[:, : text_lengths.max()]
- # speech = speech[:, :speech_lengths.max()]
-
- # 1. Encoder
- if hasattr(self.encoder, "overlap_chunk_cls"):
- ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
- else:
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
-
- loss_att, pre_loss_att, acc_att, cer_att, wer_att = None, None, None, None, None
- loss_ctc, cer_ctc = None, None
- loss_pre = None
- stats = dict()
-
- # 1. CTC branch
- if self.ctc_weight != 0.0:
- loss_ctc, cer_ctc = self._calc_ctc_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
-
- # Collect CTC branch stats
- stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
- stats["cer_ctc"] = cer_ctc
-
- # Intermediate CTC (optional)
- loss_interctc = 0.0
- if self.interctc_weight != 0.0 and intermediate_outs is not None:
- for layer_idx, intermediate_out in intermediate_outs:
- # we assume intermediate_out has the same length & padding
- # as those of encoder_out
- loss_ic, cer_ic = self._calc_ctc_loss(
- intermediate_out, encoder_out_lens, text, text_lengths
- )
- loss_interctc = loss_interctc + loss_ic
-
- # Collect Intermedaite CTC stats
- stats["loss_interctc_layer{}".format(layer_idx)] = (
- loss_ic.detach() if loss_ic is not None else None
- )
- stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
-
- loss_interctc = loss_interctc / len(intermediate_outs)
-
- # calculate whole encoder loss
- loss_ctc = (
- 1 - self.interctc_weight
- ) * loss_ctc + self.interctc_weight * loss_interctc
-
- # 2b. Attention decoder branch
- if self.ctc_weight != 1.0:
- loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
-
- # 3. CTC-Att loss definition
- if self.ctc_weight == 0.0:
- loss = loss_att + loss_pre * self.predictor_weight
- elif self.ctc_weight == 1.0:
- loss = loss_ctc
- else:
- loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
-
- if self.use_1st_decoder_loss and pre_loss_att is not None:
- loss = loss + (1 - self.ctc_weight) * pre_loss_att
-
- # Collect Attn branch stats
- stats["loss_att"] = loss_att.detach() if loss_att is not None else None
- stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
- stats["acc"] = acc_att
- stats["cer"] = cer_att
- stats["wer"] = wer_att
- stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
-
- stats["loss"] = torch.clone(loss.detach())
-
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
-
- def collect_feats(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- ) -> Dict[str, torch.Tensor]:
- if self.extract_feats_in_collect_stats:
- feats, feats_lengths = self._extract_feats(speech, speech_lengths)
- else:
- # Generate dummy stats if extract_feats_in_collect_stats is False
- logging.warning(
- "Generating dummy stats for feats and feats_lengths, "
- "because encoder_conf.extract_feats_in_collect_stats is "
- f"{self.extract_feats_in_collect_stats}"
- )
- feats, feats_lengths = speech, speech_lengths
- return {"feats": feats, "feats_lengths": feats_lengths}
-
- def encode(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Frontend + Encoder. Note that this method is used by asr_inference.py
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- ind: int
- """
- with autocast(False):
- # # 1. Extract feats
- # feats, feats_lengths = self._extract_feats(speech, speech_lengths)
-
- # 2. Data augmentation
- if self.specaug is not None and self.training:
- feats, feats_lengths = self.specaug(speech, speech_lengths)
-
- # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
- if self.normalize is not None:
- feats, feats_lengths = self.normalize(feats, feats_lengths)
-
- # # Pre-encoder, e.g. used for raw input data
- # if self.preencoder is not None:
- # feats, feats_lengths = self.preencoder(feats, feats_lengths)
-
- # 4. Forward encoder
- # feats: (Batch, Length, Dim)
- # -> encoder_out: (Batch, Length2, Dim2)
- if self.encoder.interctc_use_conditioning:
- if hasattr(self.encoder, "overlap_chunk_cls"):
- encoder_out, encoder_out_lens, _ = self.encoder(
- feats, feats_lengths, ctc=self.ctc, ind=ind
- )
- encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
- encoder_out_lens,
- chunk_outs=None)
- else:
- encoder_out, encoder_out_lens, _ = self.encoder(
- feats, feats_lengths, ctc=self.ctc
- )
- else:
- if hasattr(self.encoder, "overlap_chunk_cls"):
- encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, ind=ind)
- encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
- encoder_out_lens,
- chunk_outs=None)
- else:
- encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
-
- # # Post-encoder, e.g. NLU
- # if self.postencoder is not None:
- # encoder_out, encoder_out_lens = self.postencoder(
- # encoder_out, encoder_out_lens
- # )
-
- assert encoder_out.size(0) == speech.size(0), (
- encoder_out.size(),
- speech.size(0),
- )
- assert encoder_out.size(1) <= encoder_out_lens.max(), (
- encoder_out.size(),
- encoder_out_lens.max(),
- )
-
- if intermediate_outs is not None:
- return (encoder_out, intermediate_outs), encoder_out_lens
-
- return encoder_out, encoder_out_lens
-
- def calc_predictor(self, encoder_out, encoder_out_lens):
-
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask,
- ignore_id=self.ignore_id)
- return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
-
- def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
-
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
- )
- decoder_out = decoder_outs[0]
- decoder_out = torch.log_softmax(decoder_out, dim=-1)
- return decoder_out, ys_pad_lens
-
- def _extract_feats(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- assert speech_lengths.dim() == 1, speech_lengths.shape
-
- # for data-parallel
- speech = speech[:, : speech_lengths.max()]
- if self.frontend is not None:
- # Frontend
- # e.g. STFT and Feature extract
- # data_loader may send time-domain signal in this case
- # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
- feats, feats_lengths = self.frontend(speech, speech_lengths)
- else:
- # No frontend and no feature extract
- feats, feats_lengths = speech, speech_lengths
- return feats, feats_lengths
-
- def nll(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ) -> torch.Tensor:
- """Compute negative log likelihood(nll) from transformer-decoder
- Normally, this function is called in batchify_nll.
- Args:
- encoder_out: (Batch, Length, Dim)
- encoder_out_lens: (Batch,)
- ys_pad: (Batch, Length)
- ys_pad_lens: (Batch,)
- """
- ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_in_lens = ys_pad_lens + 1
-
- # 1. Forward decoder
- decoder_out, _ = self.decoder(
- encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
- ) # [batch, seqlen, dim]
- batch_size = decoder_out.size(0)
- decoder_num_class = decoder_out.size(2)
- # nll: negative log-likelihood
- nll = torch.nn.functional.cross_entropy(
- decoder_out.view(-1, decoder_num_class),
- ys_out_pad.view(-1),
- ignore_index=self.ignore_id,
- reduction="none",
- )
- nll = nll.view(batch_size, -1)
- nll = nll.sum(dim=1)
- assert nll.size(0) == batch_size
- return nll
-
- def batchify_nll(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- batch_size: int = 100,
- ):
- """Compute negative log likelihood(nll) from transformer-decoder
- To avoid OOM, this fuction seperate the input into batches.
- Then call nll for each batch and combine and return results.
- Args:
- encoder_out: (Batch, Length, Dim)
- encoder_out_lens: (Batch,)
- ys_pad: (Batch, Length)
- ys_pad_lens: (Batch,)
- batch_size: int, samples each batch contain when computing nll,
- you may change this to avoid OOM or increase
- GPU memory usage
- """
- total_num = encoder_out.size(0)
- if total_num <= batch_size:
- nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
- else:
- nll = []
- start_idx = 0
- while True:
- end_idx = min(start_idx + batch_size, total_num)
- batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
- batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
- batch_ys_pad = ys_pad[start_idx:end_idx, :]
- batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
- batch_nll = self.nll(
- batch_encoder_out,
- batch_encoder_out_lens,
- batch_ys_pad,
- batch_ys_pad_lens,
- )
- nll.append(batch_nll)
- start_idx = end_idx
- if start_idx == total_num:
- break
- nll = torch.cat(nll)
- assert nll.size(0) == total_num
- return nll
-
- def _calc_att_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- if self.predictor_bias == 1:
- _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_pad_lens = ys_pad_lens + self.predictor_bias
- pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad, encoder_out_mask,
- ignore_id=self.ignore_id)
-
- # 0. sampler
- decoder_out_1st = None
- pre_loss_att = None
- if self.sampling_ratio > 0.0:
-
-
- if self.use_1st_decoder_loss:
- sematic_embeds, decoder_out_1st, pre_loss_att = self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
- pre_acoustic_embeds)
- else:
- sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
- pre_acoustic_embeds)
- else:
- if self.step_cur < 2:
- logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds = pre_acoustic_embeds
-
- # 1. Forward decoder
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
-
- if decoder_out_1st is None:
- decoder_out_1st = decoder_out
- # 2. Compute attention loss
- loss_att = self.criterion_att(decoder_out, ys_pad)
- acc_att = th_accuracy(
- decoder_out_1st.view(-1, self.vocab_size),
- ys_pad,
- ignore_label=self.ignore_id,
- )
- loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
-
- # Compute cer/wer using attention-decoder
- if self.training or self.error_calculator is None:
- cer_att, wer_att = None, None
- else:
- ys_hat = decoder_out_1st.argmax(dim=-1)
- cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
-
- return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
-
- def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
-
- tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
- ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
- if self.share_embedding:
- ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
- else:
- ys_pad_embed = self.decoder.embed(ys_pad_masked)
- with torch.no_grad():
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
- pred_tokens = decoder_out.argmax(-1)
- nonpad_positions = ys_pad.ne(self.ignore_id)
- seq_lens = (nonpad_positions).sum(1)
- same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
- input_mask = torch.ones_like(nonpad_positions)
- bsz, seq_len = ys_pad.size()
- for li in range(bsz):
- target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
- if target_num > 0:
- input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device), value=0)
- input_mask = input_mask.eq(1)
- input_mask = input_mask.masked_fill(~nonpad_positions, False)
- input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
-
- sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
- input_mask_expand_dim, 0)
- return sematic_embeds * tgt_mask, decoder_out * tgt_mask
-
- def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
- tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
- ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
- if self.share_embedding:
- ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
- else:
- ys_pad_embed = self.decoder.embed(ys_pad_masked)
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
- )
- pre_loss_att = self.criterion_att(decoder_outs[0], ys_pad)
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
- pred_tokens = decoder_out.argmax(-1)
- nonpad_positions = ys_pad.ne(self.ignore_id)
- seq_lens = (nonpad_positions).sum(1)
- same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
- input_mask = torch.ones_like(nonpad_positions)
- bsz, seq_len = ys_pad.size()
- for li in range(bsz):
- target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
- if target_num > 0:
- input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device), value=0)
- input_mask = input_mask.eq(1)
- input_mask = input_mask.masked_fill(~nonpad_positions, False)
- input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
-
- sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
- input_mask_expand_dim, 0)
-
- return sematic_embeds * tgt_mask, decoder_out * tgt_mask, pre_loss_att
-
- def _calc_ctc_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- # Calc CTC loss
- loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
-
- # Calc CER using CTC
- cer_ctc = None
- if not self.training and self.error_calculator is not None:
- ys_hat = self.ctc.argmax(encoder_out).data
- cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
- return loss_ctc, cer_ctc
diff --git a/funasr/cli/train_cli.py b/funasr/cli/train_cli.py
deleted file mode 100644
index 54cd2e8..0000000
--- a/funasr/cli/train_cli.py
+++ /dev/null
@@ -1,163 +0,0 @@
-import argparse
-import logging
-import os
-import sys
-from io import BytesIO
-from collections.abc import Sequence
-import torch
-import hydra
-from omegaconf import DictConfig, OmegaConf
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-# from funasr.model_class_factory1 import model_choices
-from funasr.modules.lora.utils import mark_only_lora_as_trainable
-from funasr.optimizers import optim_choices
-from funasr.schedulers import scheduler_choices
-from funasr.torch_utils.load_pretrained_model import load_pretrained_model
-from funasr.torch_utils.initialize import initialize
-from funasr.datasets.data_sampler import BatchSampler
-# from funasr.tokenizer.build_tokenizer import build_tokenizer
-# from funasr.tokenizer.token_id_converter import TokenIDConverter
-from funasr.tokenizer.funtoken import build_tokenizer
-from funasr.datasets.dataset_jsonl import AudioDataset
-from funasr.cli.trainer import Trainer
-# from funasr.utils.load_fr_py import load_class_from_path
-from funasr.utils.dynamic_import import dynamic_import
-import torch.distributed as dist
-from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
-
-
-def preprocess_config(cfg: DictConfig):
- for key, value in cfg.items():
- if value == 'None':
- cfg[key] = None
-
-
-
-@hydra.main()
-def main(kwargs: DictConfig):
- # preprocess_config(kwargs)
- # import pdb; pdb.set_trace()
- # set random seed
- set_all_random_seed(kwargs.get("seed", 0))
- torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
- torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
- torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
-
- local_rank = int(os.environ.get('LOCAL_RANK', 0))
- # Check if we are using DDP or FSDP
- use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
- use_fsdp = kwargs.get("use_fsdp", None)
- if use_ddp or use_fsdp:
- dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
- torch.cuda.set_device(local_rank)
-
-
- # build_tokenizer
- tokenizer = build_tokenizer(
- token_type=kwargs.get("token_type", "char"),
- bpemodel=kwargs.get("bpemodel", None),
- delimiter=kwargs.get("delimiter", None),
- space_symbol=kwargs.get("space_symbol", "<space>"),
- non_linguistic_symbols=kwargs.get("non_linguistic_symbols", None),
- g2p_type=kwargs.get("g2p_type", None),
- token_list=kwargs.get("token_list", None),
- unk_symbol=kwargs.get("unk_symbol", "<unk>"),
- )
-
- # import pdb;
- # pdb.set_trace()
- # build model
- # model_class = model_choices.get_class(kwargs.get("model", "asr"))
- # model_class = load_class_from_path(kwargs.get("model").split(":"))
- model_class = dynamic_import(kwargs.get("model"))
- model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
- frontend = model.frontend
- # init_param
- init_param = kwargs.get("init_param", None)
- if init_param is not None:
- init_param = eval(init_param)
- if isinstance(init_param, Sequence):
- init_param = (init_param,)
- logging.info("init_param is not None: ", init_param)
- for p in init_param:
- logging.info(f"Loading pretrained params from {p}")
- load_pretrained_model(
- model=model,
- init_param=p,
- ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
- oss_bucket=kwargs.get("oss_bucket", None),
- )
- else:
- initialize(model, kwargs.get("init", "kaiming_normal"))
-
- # import pdb;
- # pdb.set_trace()
- # freeze_param
- freeze_param = kwargs.get("freeze_param", None)
- if freeze_param is not None:
- freeze_param = eval(freeze_param)
- if isinstance(freeze_param, Sequence):
- freeze_param = (freeze_param,)
- logging.info("freeze_param is not None: ", freeze_param)
- for t in freeze_param:
- for k, p in model.named_parameters():
- if k.startswith(t + ".") or k == t:
- logging.info(f"Setting {k}.requires_grad = False")
- p.requires_grad = False
-
-
- if use_ddp:
- model = model.cuda(local_rank)
- model = DDP(model, device_ids=[local_rank],
- find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
- elif use_fsdp:
- model = FSDP(model).cuda(local_rank)
- else:
- model = model.to(device=kwargs.get("device", "cuda"))
-
-
- # optim
- optim = kwargs.get("optim", "adam")
- assert optim in optim_choices
- optim_class = optim_choices.get(optim)
- optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
-
- # scheduler
- scheduler = kwargs.get("scheduler", "warmuplr")
- assert scheduler in scheduler_choices
- scheduler_class = scheduler_choices.get(scheduler)
- scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
-
-
- # dataset
- dataset_tr = AudioDataset(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
-
- # dataloader
- batch_sampler = BatchSampler(dataset_tr, **kwargs.get("dataset_conf"), **kwargs.get("dataset_conf").get("batch_conf"))
- dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
- collate_fn=dataset_tr.collator,
- batch_sampler=batch_sampler,
- num_workers=kwargs.get("num_workers", 0),
- pin_memory=True)
-
- trainer = Trainer(
- model=model,
- optim=optim,
- scheduler=scheduler,
- dataloader_train=dataloader_tr,
- dataloader_val=None,
- local_rank=local_rank,
- use_ddp=use_ddp,
- use_fsdp=use_fsdp,
- **kwargs.get("train_conf"),
- )
- trainer.run()
-
- if use_ddp or use_fsdp:
- torch.distributed.destroy_process_group()
-
-
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/funasr/cli/trainer.py b/funasr/cli/trainer.py
deleted file mode 100644
index 28a843b..0000000
--- a/funasr/cli/trainer.py
+++ /dev/null
@@ -1,199 +0,0 @@
-import torch
-import os
-from funasr.torch_utils.device_funcs import to_device
-import logging
-from tqdm import tqdm
-from contextlib import nullcontext
-import torch.distributed as dist
-from funasr.torch_utils.recursive_op import recursive_average
-
-class Trainer:
- """
- A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
- and optionally resuming from a saved checkpoint.
-
- Attributes:
- max_epoch (int): Maximum number of epochs for training.
- model (torch.nn.Module): The model to be trained.
- optim (torch.optim.Optimizer): The optimizer to use for training.
- scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
- dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
- dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
- output_dir (str): Directory where model checkpoints will be saved.
- resume (str, optional): Path to a checkpoint to resume training from.
- """
-
- def __init__(self, model,
- optim,
- scheduler,
- dataloader_train,
- dataloader_val,
- local_rank,
- use_ddp=False,
- use_fsdp=False,
- **kwargs):
- """
- Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
-
- Args:
- model (torch.nn.Module): The model to be trained.
- optim (torch.optim.Optimizer): The optimizer to use for training.
- scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
- dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
- dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
- **kwargs: Additional keyword arguments:
- max_epoch (int): The maximum number of epochs for training.
- output_dir (str): The directory where model checkpoints will be saved. Default is './'.
- resume (str, optional): The file path to a checkpoint to resume training from.
- """
-
- self.model = model
- self.optim = optim
- self.scheduler = scheduler
- self.dataloader_train = dataloader_train
- self.dataloader_val = dataloader_val
- self.output_dir = kwargs.get('output_dir', './')
- self.resume = kwargs.get('resume', None)
- self.start_epoch = 1
- self.max_epoch = kwargs.get('max_epoch', 100)
- self.local_rank = local_rank
- self.rank = dist.get_rank()
- self.world_size = dist.get_world_size()
- self.use_ddp = use_ddp
- self.use_fsdp = use_fsdp
- self.device = torch.device("cuda", local_rank)
- self.kwargs = kwargs
-
- if self.resume:
- self._resume_checkpoint(self.resume)
-
- def _save_checkpoint(self, epoch):
- """
- Saves a checkpoint containing the model's state, the optimizer's state,
- and the scheduler's state at the end of the given epoch. This method is
- intended to be called at the end of each epoch to save the training progress.
-
- Args:
- epoch (int): The epoch number at which the checkpoint is being saved.
- """
- state = {
- 'epoch': epoch,
- 'state_dict': self.model.state_dict(),
- 'optimizer': self.optim.state_dict(),
- 'scheduler': self.scheduler.state_dict(),
- }
- # Create output directory if it does not exist
- os.makedirs(self.output_dir, exist_ok=True)
- filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
- torch.save(state, filename)
- print(f'Checkpoint saved to {filename}')
-
- def _resume_checkpoint(self, resume_path):
- """
- Resumes training from a checkpoint at the given file path.
- Loads the model's state, the optimizer's state, and the scheduler's state.
-
- Args:
- resume_path (str): The file path to the checkpoint to resume from.
- """
- if os.path.isfile(resume_path):
- checkpoint = torch.load(resume_path)
- self.start_epoch = checkpoint['epoch'] + 1
- self.model.load_state_dict(checkpoint['state_dict'])
- self.optim.load_state_dict(checkpoint['optimizer'])
- self.scheduler.load_state_dict(checkpoint['scheduler'])
- print(f"Checkpoint loaded successfully from '{resume_path}' at (epoch {checkpoint['epoch']})")
- else:
- print(f"No checkpoint found at '{resume_path}', starting from scratch")
-
- def run(self):
- """
- Starts the training process, iterating over epochs, training the model,
- and saving checkpoints at the end of each epoch.
- """
- for epoch in range(self.start_epoch, self.max_epoch + 1):
- self._train_epoch(epoch)
- # self._validate_epoch(epoch)
- if dist.get_rank() == 0:
- self._save_checkpoint(epoch)
- self.scheduler.step()
-
- def _train_epoch(self, epoch):
- """
- Defines the training process for a single epoch with gradient accumulation.
- Args:
- epoch (int): The current epoch number.
- """
- self.model.train()
- pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_train),
- dynamic_ncols=True)
-
- # Set the number of steps for gradient accumulation
- accum_grad = self.kwargs.get("accum_grad", 1)
- # Initialize the gradient accumulation
- self.optim.zero_grad()
-
- for batch_idx, batch in enumerate(self.dataloader_train):
- batch = to_device(batch, self.device)
-
- my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
- with my_context():
- retval = self.model(**batch)
- loss, stats, weight = retval
- stats = {k: v for k, v in stats.items() if v is not None}
- if self.use_ddp or self.use_fsdp:
- # Apply weighted averaging for loss and stats
- loss = (loss * weight.type(loss.dtype)).sum()
- # if distributed, this method can also apply all_reduce()
- stats, weight = recursive_average(stats, weight, distributed=True)
- # Now weight is summation over all workers
- loss /= weight
- # Multiply world_size because DistributedDataParallel
- # automatically normalizes the gradient by world_size.
- loss *= self.world_size
- # Scale the loss since we're not updating for every mini-batch
- loss = loss / accum_grad
- loss.backward()
-
- # Perform an optimizer step only after accumulating enough gradients
- if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train):
- # Perform gradient clipping if it is set
- if self.kwargs.get("grad_clip", None) is not None:
- grad_norm = torch.nn.utils.clip_grad_norm_(
- self.model.parameters(),
- max_norm=self.kwargs.get("grad_clip", 10.0),
- norm_type=self.kwargs.get("grad_clip_type", 2.0),
- )
- if not torch.isfinite(grad_norm):
- logging.warning(
- f"The grad norm is {grad_norm}. Skipping updating the model."
- )
- self.optim.zero_grad() # Reset gradients
- continue
-
- # Execute an optimization step (update model parameters)
- self.optim.step()
- self.scheduler.step()
- # Clear gradients for the next accumulation stage
- self.optim.zero_grad()
-
- pbar.update(1)
- if self.local_rank == 0:
- pbar.set_description(
- f"Training Epoch: {epoch + 1}/{self.max_epoch}, step {batch_idx}/{len(self.dataloader_train)} (loss: {loss.detach().float():.3f}, {[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]})")
-
- pbar.close()
-
- def _validate_epoch(self, epoch):
- """
- Defines the validation process for a single epoch.
- Should be implemented with the actual model validation steps.
-
- Args:
- epoch (int): The current epoch number.
- """
- self.model.eval()
- with torch.no_grad():
- for data, target in self.dataloader_val:
- # Implement the model validation steps here
- pass
diff --git a/funasr/datasets/data_sampler.py b/funasr/datasets/data_sampler.py
index 3a19a17..c8e7b0d 100644
--- a/funasr/datasets/data_sampler.py
+++ b/funasr/datasets/data_sampler.py
@@ -4,17 +4,17 @@
class BatchSampler(torch.utils.data.BatchSampler):
- def __init__(self, dataset, batch_type: str="example", batch_size: int=100, sort_size: int=30, drop_last: bool=False, shuffle: bool=True, **kwargs):
+ def __init__(self, dataset, batch_size_type: str="example", batch_size: int=100, sort_size: int=30, drop_last: bool=False, shuffle: bool=True, **kwargs):
self.drop_last = drop_last
self.pre_idx = -1
self.dataset = dataset
self.total_samples = len(dataset)
- # self.batch_type = args.batch_type
+ # self.batch_size_type = args.batch_size_type
# self.batch_size = args.batch_size
# self.sort_size = args.sort_size
# self.max_length_token = args.max_length_token
- self.batch_type = batch_type
+ self.batch_size_type = batch_size_type
self.batch_size = batch_size
self.sort_size = sort_size
self.max_length_token = kwargs.get("max_length_token", 5000)
@@ -26,7 +26,7 @@
return self.total_samples
def __iter__(self):
- # print("in sampler")
+ print("in sampler")
if self.shuffle:
np.random.shuffle(self.shuffle_idx)
@@ -36,7 +36,7 @@
num_sample = 0
iter_num = (self.total_samples-1) // self.sort_size + 1
- # print("iter_num: ", iter_num)
+ print("iter_num: ", iter_num)
for iter in range(self.pre_idx + 1, iter_num):
datalen_with_index = []
for i in range(self.sort_size):
@@ -46,8 +46,8 @@
idx_map = self.shuffle_idx[idx]
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
- sample_len_cur = self.dataset.indexed_dataset.get_source_len(self.dataset.indexed_dataset[idx_map]) + \
- self.dataset.indexed_dataset.get_target_len(self.dataset.indexed_dataset[idx_map])
+ sample_len_cur = self.dataset.indexed_dataset[idx_map]["source_len"] + \
+ self.dataset.indexed_dataset[idx_map]["target_len"]
datalen_with_index.append([idx, sample_len_cur])
@@ -59,7 +59,7 @@
max_token_cur = max(max_token, sample_len_cur_raw)
max_token_padding = 1 + num_sample
- if self.batch_type == 'token':
+ if self.batch_size_type == 'token':
max_token_padding *= max_token_cur
if max_token_padding <= self.batch_size:
batch.append(idx)
diff --git a/funasr/datasets/dataloader_fn.py b/funasr/datasets/dataloader_fn.py
index b0ecf4f..a43c947 100644
--- a/funasr/datasets/dataloader_fn.py
+++ b/funasr/datasets/dataloader_fn.py
@@ -38,13 +38,16 @@
batch_sampler = BatchSampler(dataset)
+def collator(samples: list = None):
+ return samples
+
if __name__ == "__main__":
dataloader_tr = torch.utils.data.DataLoader(dataset,
collate_fn=dataset.collator,
batch_sampler=batch_sampler,
shuffle=False,
- num_workers=0,
+ num_workers=8,
pin_memory=True)
print(len(dataset))
diff --git a/funasr/datasets/dataset_jsonl.py b/funasr/datasets/dataset_jsonl.py
index eef67c5..543b60e 100644
--- a/funasr/datasets/dataset_jsonl.py
+++ b/funasr/datasets/dataset_jsonl.py
@@ -78,26 +78,21 @@
def __getitem__(self, index):
return self.contents[index]
-
- def get_source_len(self, data_dict):
- return data_dict["source_len"]
-
- def get_target_len(self, data_dict):
-
- return data_dict["target_len"] if "target_len" in data_dict else 0
class AudioDataset(torch.utils.data.Dataset):
- def __init__(self, path, frontend=None, tokenizer=None, int_pad_value: int = -1, float_pad_value: float = 0.0, **kwargs):
+ def __init__(self, path, frontend=None, tokenizer=None, token_id_converter=None):
+
super().__init__()
self.indexed_dataset = IndexedDatasetJsonl(path)
self.frontend = frontend.forward
self.fs = 16000 if frontend is None else frontend.fs
self.data_type = "sound"
self.tokenizer = tokenizer
+ self.token_id_converter = token_id_converter
- self.int_pad_value = int_pad_value
- self.float_pad_value = float_pad_value
+ self.int_pad_value = -1
+ self.float_pad_value = 0.0
@@ -113,7 +108,8 @@
data_src = load_audio(source, fs=self.fs)
speech, speech_lengths = extract_features(data_src, self.data_type, self.frontend)
target = item["target"]
- ids = self.tokenizer.encode(target)
+ text = self.tokenizer.text2tokens(target)
+ ids = self.token_id_converter.tokens2ids(text)
ids_lengths = len(ids)
text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)
diff --git a/funasr/datasets/small_datasets/preprocessor.py b/funasr/datasets/small_datasets/preprocessor.py
index 62beaab..01a8c6f 100644
--- a/funasr/datasets/small_datasets/preprocessor.py
+++ b/funasr/datasets/small_datasets/preprocessor.py
@@ -361,7 +361,6 @@
tokens = seg_tokenize(tokens, self.seg_dict)
else:
tokens = self.tokenizer.text2tokens(text)
-
text_ints = self.token_id_converter.tokens2ids(tokens)
data[self.text_name] = np.array(text_ints, dtype=np.int64)
return data
diff --git a/funasr/models/e2e_asr.py b/funasr/models/e2e_asr.py
index c1eb003..050847e 100644
--- a/funasr/models/e2e_asr.py
+++ b/funasr/models/e2e_asr.py
@@ -223,7 +223,6 @@
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((text_lengths + 1).sum())
-
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
diff --git a/funasr/models/e2e_asr_contextual_paraformer.py b/funasr/models/e2e_asr_contextual_paraformer.py
index 598d5ac..b474dbc 100644
--- a/funasr/models/e2e_asr_contextual_paraformer.py
+++ b/funasr/models/e2e_asr_contextual_paraformer.py
@@ -234,7 +234,6 @@
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((text_lengths + self.predictor_bias).sum())
-
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index 6b1d824..0e0b95b 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -256,7 +256,6 @@
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((text_lengths + self.predictor_bias).sum())
-
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
@@ -869,7 +868,6 @@
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((text_lengths + self.predictor_bias).sum())
-
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
@@ -1497,7 +1495,6 @@
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((text_lengths + self.predictor_bias).sum())
-
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
@@ -1769,7 +1766,6 @@
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((text_lengths + self.predictor_bias).sum())
-
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
@@ -1972,7 +1968,6 @@
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((text_lengths + self.predictor_bias).sum())
-
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
@@ -2267,4 +2262,4 @@
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
var_dict_tf[name_tf].shape))
- return var_dict_torch_update
\ No newline at end of file
+ return var_dict_torch_update
diff --git a/funasr/models/e2e_uni_asr.py b/funasr/models/e2e_uni_asr.py
index 07ebd81..14fb7f3 100644
--- a/funasr/models/e2e_uni_asr.py
+++ b/funasr/models/e2e_uni_asr.py
@@ -443,7 +443,6 @@
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((text_lengths + 1).sum())
-
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
diff --git a/funasr/modules/nets_utils.py b/funasr/modules/nets_utils.py
index 0beb083..b1879fa 100644
--- a/funasr/modules/nets_utils.py
+++ b/funasr/modules/nets_utils.py
@@ -347,7 +347,7 @@
Args:
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
- pad_targets (LongTensor): Target label tensors (B, Lmax).
+ pad_targets (LongTensor): Target label tensors (B, Lmax, D).
ignore_label (int): Ignore label id.
Returns:
diff --git a/funasr/optimizers/__init__.py b/funasr/optimizers/__init__.py
index b4dfe5d..e69de29 100644
--- a/funasr/optimizers/__init__.py
+++ b/funasr/optimizers/__init__.py
@@ -1,17 +0,0 @@
-import torch
-from funasr.optimizers.fairseq_adam import FairseqAdam
-from funasr.optimizers.sgd import SGD
-
-optim_choices = dict(
- adam=torch.optim.Adam,
- fairseq_adam=FairseqAdam,
- adamw=torch.optim.AdamW,
- sgd=SGD,
- adadelta=torch.optim.Adadelta,
- adagrad=torch.optim.Adagrad,
- adamax=torch.optim.Adamax,
- asgd=torch.optim.ASGD,
- lbfgs=torch.optim.LBFGS,
- rmsprop=torch.optim.RMSprop,
- rprop=torch.optim.Rprop,
-)
\ No newline at end of file
diff --git a/funasr/schedulers/__init__.py b/funasr/schedulers/__init__.py
index 7bb8118..e69de29 100644
--- a/funasr/schedulers/__init__.py
+++ b/funasr/schedulers/__init__.py
@@ -1,23 +0,0 @@
-import torch
-import torch.multiprocessing
-import torch.nn
-import torch.optim
-
-from funasr.schedulers.noam_lr import NoamLR
-from funasr.schedulers.tri_stage_scheduler import TriStageLR
-from funasr.schedulers.warmup_lr import WarmupLR
-
-scheduler_choices = dict(
- ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
- lambdalr=torch.optim.lr_scheduler.LambdaLR,
- steplr=torch.optim.lr_scheduler.StepLR,
- multisteplr=torch.optim.lr_scheduler.MultiStepLR,
- exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
- CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
- noamlr=NoamLR,
- warmuplr=WarmupLR,
- tri_stage=TriStageLR,
- cycliclr=torch.optim.lr_scheduler.CyclicLR,
- onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
- CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
-)
\ No newline at end of file
diff --git a/funasr/tokenizer/abs_tokenizer.py b/funasr/tokenizer/abs_tokenizer.py
index d2fc3f0..fc2ccb3 100644
--- a/funasr/tokenizer/abs_tokenizer.py
+++ b/funasr/tokenizer/abs_tokenizer.py
@@ -2,87 +2,13 @@
from abc import abstractmethod
from typing import Iterable
from typing import List
-from pathlib import Path
-from typing import Dict
-from typing import Iterable
-from typing import List
-from typing import Union
-import numpy as np
class AbsTokenizer(ABC):
@abstractmethod
def text2tokens(self, line: str) -> List[str]:
raise NotImplementedError
- @abstractmethod
- def tokens2text(self, tokens: Iterable[str]) -> str:
- raise NotImplementedError
-
-
-class BaseTokenizer(ABC):
- def __init__(self, token_list: Union[Path, str, Iterable[str]]=None,
- unk_symbol: str = "<unk>",
- **kwargs,
- ):
-
- if token_list is not None:
- if isinstance(token_list, (Path, str)):
- token_list = Path(token_list)
- self.token_list_repr = str(token_list)
- self.token_list: List[str] = []
-
- with token_list.open("r", encoding="utf-8") as f:
- for idx, line in enumerate(f):
- line = line.rstrip()
- self.token_list.append(line)
-
- else:
- self.token_list: List[str] = list(token_list)
- self.token_list_repr = ""
- for i, t in enumerate(self.token_list):
- if i == 3:
- break
- self.token_list_repr += f"{t}, "
- self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
-
- self.token2id: Dict[str, int] = {}
- for i, t in enumerate(self.token_list):
- if t in self.token2id:
- raise RuntimeError(f'Symbol "{t}" is duplicated')
- self.token2id[t] = i
-
- self.unk_symbol = unk_symbol
- if self.unk_symbol not in self.token2id:
- raise RuntimeError(
- f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
- )
- self.unk_id = self.token2id[self.unk_symbol]
-
- def encode(self, text):
- tokens = self.text2tokens(text)
- text_ints = self.tokens2ids(tokens)
-
- return text_ints
-
- def decode(self, text_ints):
- return self.ids2tokens(text_ints)
-
- def get_num_vocabulary_size(self) -> int:
- return len(self.token_list)
-
- def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
- if isinstance(integers, np.ndarray) and integers.ndim != 1:
- raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
- return [self.token_list[i] for i in integers]
-
- def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
- return [self.token2id.get(i, self.unk_id) for i in tokens]
-
- @abstractmethod
- def text2tokens(self, line: str) -> List[str]:
- raise NotImplementedError
-
@abstractmethod
def tokens2text(self, tokens: Iterable[str]) -> str:
raise NotImplementedError
diff --git a/funasr/tokenizer/build_tokenizer.py b/funasr/tokenizer/build_tokenizer.py
index 05db6a6..9d1cdc3 100644
--- a/funasr/tokenizer/build_tokenizer.py
+++ b/funasr/tokenizer/build_tokenizer.py
@@ -1,17 +1,7 @@
from pathlib import Path
from typing import Iterable
from typing import Union
-from abc import ABC
-from abc import abstractmethod
-from typing import Iterable
-from typing import List
-from pathlib import Path
-from typing import Dict
-from typing import Iterable
-from typing import List
-from typing import Union
-import numpy as np
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
from funasr.tokenizer.char_tokenizer import CharTokenizer
@@ -28,8 +18,7 @@
space_symbol: str = "<space>",
delimiter: str = None,
g2p_type: str = None,
- **kwargs,
-):
+) -> AbsTokenizer:
"""A helper function to instantiate Tokenizer"""
if token_type == "bpe":
if bpemodel is None:
@@ -39,7 +28,7 @@
raise RuntimeError(
"remove_non_linguistic_symbols is not implemented for token_type=bpe"
)
- return SentencepiecesTokenizer(bpemodel, **kwargs)
+ return SentencepiecesTokenizer(bpemodel)
elif token_type == "word":
if remove_non_linguistic_symbols and non_linguistic_symbols is not None:
@@ -49,14 +38,13 @@
remove_non_linguistic_symbols=True,
)
else:
- return WordTokenizer(delimiter=delimiter, **kwargs)
+ return WordTokenizer(delimiter=delimiter)
elif token_type == "char":
return CharTokenizer(
non_linguistic_symbols=non_linguistic_symbols,
space_symbol=space_symbol,
remove_non_linguistic_symbols=remove_non_linguistic_symbols,
- **kwargs
)
elif token_type == "phn":
@@ -65,7 +53,6 @@
non_linguistic_symbols=non_linguistic_symbols,
space_symbol=space_symbol,
remove_non_linguistic_symbols=remove_non_linguistic_symbols,
- **kwargs
)
else:
diff --git a/funasr/tokenizer/char_tokenizer.py b/funasr/tokenizer/char_tokenizer.py
index 80528a2..6c9a5a5 100644
--- a/funasr/tokenizer/char_tokenizer.py
+++ b/funasr/tokenizer/char_tokenizer.py
@@ -6,17 +6,15 @@
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
-from funasr.tokenizer.abs_tokenizer import BaseTokenizer
-class CharTokenizer(BaseTokenizer):
+
+class CharTokenizer(AbsTokenizer):
def __init__(
self,
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
space_symbol: str = "<space>",
remove_non_linguistic_symbols: bool = False,
- **kwargs,
):
- super().__init__(**kwargs)
self.space_symbol = space_symbol
if non_linguistic_symbols is None:
self.non_linguistic_symbols = set()
diff --git a/funasr/tokenizer/funtoken.py b/funasr/tokenizer/funtoken.py
deleted file mode 100644
index 7187d85..0000000
--- a/funasr/tokenizer/funtoken.py
+++ /dev/null
@@ -1,75 +0,0 @@
-from pathlib import Path
-from typing import Iterable
-from typing import Union
-from abc import ABC
-from abc import abstractmethod
-from typing import Iterable
-from typing import List
-from pathlib import Path
-from typing import Dict
-from typing import Iterable
-from typing import List
-from typing import Union
-
-import numpy as np
-
-from funasr.tokenizer.abs_tokenizer import AbsTokenizer
-from funasr.tokenizer.char_tokenizer import CharTokenizer
-from funasr.tokenizer.phoneme_tokenizer import PhonemeTokenizer
-from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer
-from funasr.tokenizer.word_tokenizer import WordTokenizer
-
-def build_tokenizer(
- token_type: str,
- bpemodel: Union[Path, str, Iterable[str]] = None,
- non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
- remove_non_linguistic_symbols: bool = False,
- space_symbol: str = "<space>",
- delimiter: str = None,
- g2p_type: str = None,
- **kwargs,
-):
- """A helper function to instantiate Tokenizer"""
- # import pdb;
- # pdb.set_trace()
- if token_type == "bpe":
- if bpemodel is None:
- raise ValueError('bpemodel is required if token_type = "bpe"')
-
- if remove_non_linguistic_symbols:
- raise RuntimeError(
- "remove_non_linguistic_symbols is not implemented for token_type=bpe"
- )
- return SentencepiecesTokenizer(bpemodel, **kwargs)
-
- elif token_type == "word":
- if remove_non_linguistic_symbols and non_linguistic_symbols is not None:
- return WordTokenizer(
- delimiter=delimiter,
- non_linguistic_symbols=non_linguistic_symbols,
- remove_non_linguistic_symbols=True,
- )
- else:
- return WordTokenizer(delimiter=delimiter, **kwargs)
-
- elif token_type == "char":
- return CharTokenizer(
- non_linguistic_symbols=non_linguistic_symbols,
- space_symbol=space_symbol,
- remove_non_linguistic_symbols=remove_non_linguistic_symbols,
- **kwargs
- )
-
- elif token_type == "phn":
- return PhonemeTokenizer(
- g2p_type=g2p_type,
- non_linguistic_symbols=non_linguistic_symbols,
- space_symbol=space_symbol,
- remove_non_linguistic_symbols=remove_non_linguistic_symbols,
- **kwargs
- )
-
- else:
- raise ValueError(
- f"token_mode must be one of bpe, word, char or phn: " f"{token_type}"
- )
diff --git a/funasr/tokenizer/phoneme_tokenizer.py b/funasr/tokenizer/phoneme_tokenizer.py
index 04b423b..0117c6a 100644
--- a/funasr/tokenizer/phoneme_tokenizer.py
+++ b/funasr/tokenizer/phoneme_tokenizer.py
@@ -363,7 +363,6 @@
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
space_symbol: str = "<space>",
remove_non_linguistic_symbols: bool = False,
- **kwargs,
):
if g2p_type is None:
self.g2p = split_by_space
diff --git a/funasr/tokenizer/sentencepiece_tokenizer.py b/funasr/tokenizer/sentencepiece_tokenizer.py
index df98c2c..9a65920 100644
--- a/funasr/tokenizer/sentencepiece_tokenizer.py
+++ b/funasr/tokenizer/sentencepiece_tokenizer.py
@@ -9,7 +9,7 @@
class SentencepiecesTokenizer(AbsTokenizer):
- def __init__(self, model: Union[Path, str], **kwargs):
+ def __init__(self, model: Union[Path, str]):
self.model = str(model)
# NOTE(kamo):
# Don't build SentencePieceProcessor in __init__()
diff --git a/funasr/tokenizer/word_tokenizer.py b/funasr/tokenizer/word_tokenizer.py
index d7bbaf9..cbd0673 100644
--- a/funasr/tokenizer/word_tokenizer.py
+++ b/funasr/tokenizer/word_tokenizer.py
@@ -14,7 +14,6 @@
delimiter: str = None,
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
remove_non_linguistic_symbols: bool = False,
- **kwargs,
):
self.delimiter = delimiter
diff --git a/funasr/utils/dynamic_import.py b/funasr/utils/dynamic_import.py
deleted file mode 100644
index 2830cb2..0000000
--- a/funasr/utils/dynamic_import.py
+++ /dev/null
@@ -1,13 +0,0 @@
-import importlib
-
-
-def dynamic_import(import_path):
- """dynamic import module and class
-
- :param str import_path: syntax 'module_name:class_name'
- :return: imported class
- """
-
- module_name, objname = import_path.split(":")
- m = importlib.import_module(module_name)
- return getattr(m, objname)
diff --git a/funasr/utils/load_fr_py.py b/funasr/utils/load_fr_py.py
deleted file mode 100644
index 6697e04..0000000
--- a/funasr/utils/load_fr_py.py
+++ /dev/null
@@ -1,13 +0,0 @@
-import importlib.util
-import sys
-
-def load_class_from_path(model_path):
- path, class_name = model_path
- # import pdb;
- # pdb.set_trace()
- spec = importlib.util.spec_from_file_location("module.name", path)
- module = importlib.util.module_from_spec(spec)
- sys.modules[spec.name] = module
- spec.loader.exec_module(module)
- return getattr(module, class_name)
-
--
Gitblit v1.9.1