From abb33d6b2097e5b0643326bc1b376a63cdc2f967 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 17:06:21 +0800
Subject: [PATCH] Dev gzf deepspeed (#1844)

---
 funasr/datasets/sense_voice_datasets/datasets.py             |  199 +++++++++++++++
 funasr/models/sense_voice/model.py                           |  282 ---------------------
 funasr/tokenizer/sentencepiece_tokenizer.py                  |    6 
 funasr/train_utils/trainer.py                                |   28 +-
 examples/industrial_data_pretraining/sense_voice/demo_ctc.py |    4 
 funasr/datasets/audio_datasets/index_ds.py                   |    6 
 funasr/models/sanm/encoder.py                                |  220 ----------------
 7 files changed, 233 insertions(+), 512 deletions(-)

diff --git a/examples/industrial_data_pretraining/sense_voice/demo_ctc.py b/examples/industrial_data_pretraining/sense_voice/demo_ctc.py
index 064d1e9..a8ba7f9 100644
--- a/examples/industrial_data_pretraining/sense_voice/demo_ctc.py
+++ b/examples/industrial_data_pretraining/sense_voice/demo_ctc.py
@@ -18,8 +18,8 @@
 res = model.generate(
     input=input_file,
     cache={},
-    language="zh",
-    text_norm="wotextnorm",
+    language="auto",
+    text_norm="woitn",
 )
 
 print(res)
diff --git a/funasr/datasets/audio_datasets/index_ds.py b/funasr/datasets/audio_datasets/index_ds.py
index 385218a..39ef409 100644
--- a/funasr/datasets/audio_datasets/index_ds.py
+++ b/funasr/datasets/audio_datasets/index_ds.py
@@ -118,6 +118,12 @@
                         text_language = data.get("text_language", None)
                         if text_language is not None:
                             contents_i["text_language"] = text_language
+                        if "emo_target" in data:
+                            contents_i["emo_target"] = data["emo_target"]
+                        if "event_target" in data:
+                            contents_i["event_target"] = data["event_target"]
+                        if "with_or_wo_itn" in data:
+                            contents_i["with_or_wo_itn"] = data["with_or_wo_itn"]
                         # audio_language = data.get("audio_language", None)
                         # if audio_language is not None:
                         #     contents_i["audio_language"] = audio_language
diff --git a/funasr/datasets/sense_voice_datasets/datasets.py b/funasr/datasets/sense_voice_datasets/datasets.py
index d4e14f2..6b57a9f 100644
--- a/funasr/datasets/sense_voice_datasets/datasets.py
+++ b/funasr/datasets/sense_voice_datasets/datasets.py
@@ -229,3 +229,202 @@
             outputs["target_mask"] = outputs["target_mask"][:, :target_mask_lengths_max]
 
         return outputs
+
+
+@tables.register("dataset_classes", "SenseVoiceCTCDataset")
+class SenseVoiceCTCDataset(torch.utils.data.Dataset):
+    """
+    SenseVoiceCTCDataset
+    """
+
+    def __init__(
+        self,
+        path,
+        index_ds: str = None,
+        frontend=None,
+        tokenizer=None,
+        int_pad_value: int = -1,
+        float_pad_value: float = 0.0,
+        **kwargs,
+    ):
+        super().__init__()
+        index_ds_class = tables.index_ds_classes.get(index_ds)
+        self.index_ds = index_ds_class(path, **kwargs)
+        preprocessor_speech = kwargs.get("preprocessor_speech", None)
+        if preprocessor_speech:
+            preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
+            preprocessor_speech = preprocessor_speech_class(
+                **kwargs.get("preprocessor_speech_conf")
+            )
+        self.preprocessor_speech = preprocessor_speech
+        preprocessor_text = kwargs.get("preprocessor_text", None)
+        if preprocessor_text:
+            preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
+            preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
+        self.preprocessor_text = preprocessor_text
+
+        self.frontend = frontend
+        self.fs = 16000 if frontend is None else frontend.fs
+        self.data_type = "sound"
+        self.tokenizer = tokenizer
+
+        self.int_pad_value = int_pad_value
+        self.float_pad_value = float_pad_value
+        self.sos = kwargs.get("sos", "<|startoftranscript|>")
+        self.eos = kwargs.get("eos", "<|endoftext|>")
+        self.batch_size = kwargs.get("batch_size")
+        self.batch_type = kwargs.get("batch_type")
+        self.prompt_ids_len = 0
+        self.retry = kwargs.get("retry", 5)
+
+        self.permute = False
+        from funasr.frontends.whisper_frontend import WhisperFrontend
+
+        if isinstance(self.frontend, WhisperFrontend):
+            self.permute = True
+
+    def get_source_len(self, index):
+        item = self.index_ds[index]
+        return self.index_ds.get_source_len(item)
+
+    def get_target_len(self, index):
+        item = self.index_ds[index]
+        return self.index_ds.get_target_len(item)
+
+    def __len__(self):
+        return len(self.index_ds)
+
+    def __getitem__(self, index):
+
+        output = None
+        for idx in range(self.retry):
+            if idx == 0:
+                index_cur = index
+            else:
+                index_cur = torch.randint(0, len(self.index_ds), ()).item()
+
+            item = self.index_ds[index_cur]
+
+            source = item["source"]
+            try:
+                data_src = load_audio_text_image_video(source, fs=self.fs)
+            except Exception as e:
+                logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}")
+                continue
+
+            if self.preprocessor_speech:
+                data_src = self.preprocessor_speech(data_src, fs=self.fs)
+            speech, speech_lengths = extract_fbank(
+                data_src, data_type=self.data_type, frontend=self.frontend, is_final=True
+            )  # speech: [b, T, d]
+
+            if speech_lengths > self.batch_size:
+                continue
+            if self.permute:
+                speech = speech.permute(0, 2, 1)
+            asr_target = item["target"]
+            if self.preprocessor_text:
+                asr_target = self.preprocessor_text(asr_target)
+            emo_target = item["emo_target"]
+            event_target = item["event_target"]
+            text_language = item.get("text_language", "<|zh|>")
+            punc_itn_bottom = item.get("with_or_wo_itn", "<|SPECIAL_TOKEN_13|>")
+
+            target_ids = self.tokenizer.encode(asr_target, allowed_special="all")
+            target_ids_len = len(target_ids)  # [text]
+            if target_ids_len > 200:
+                continue
+
+            lid_ids = self.tokenizer.encode(text_language, allowed_special="all")
+            emo_ids = self.tokenizer.encode(emo_target, allowed_special="all")
+            event_ids = self.tokenizer.encode(event_target, allowed_special="all")
+            punc_itn_bottom_ids = self.tokenizer.encode(punc_itn_bottom, allowed_special="all")
+
+            ids = lid_ids + emo_ids + event_ids + punc_itn_bottom_ids + target_ids # [lid, emo, lid, itn, text]
+            ids_lengths = len(ids)
+
+            text = torch.tensor(ids, dtype=torch.int64)
+            text_lengths = torch.tensor([ids_lengths], dtype=torch.int32)
+
+            output = {
+                "speech": speech[0, :, :],
+                "speech_lengths": speech_lengths,
+                "text": text,
+                "text_lengths": text_lengths,
+            }
+            break
+
+        return output
+
+    def collator(self, samples: list = None):
+        outputs = {}
+        for sample in samples:
+            if sample is None:
+                continue
+            for key in sample.keys():
+                if key not in outputs:
+                    outputs[key] = []
+                outputs[key].append(sample[key])
+
+        if len(outputs) < 1:
+            logging.error(f"ERROR: data is empty!")
+            outputs = {
+                "speech": torch.rand((10, 128), dtype=torch.float32)[None, :, :],
+                "speech_lengths": torch.tensor(
+                    [
+                        10,
+                    ],
+                    dtype=torch.int32,
+                )[:, None],
+                "text": torch.tensor(
+                    [
+                        58836,
+                    ],
+                    dtype=torch.int32,
+                )[None, :],
+                "text_lengths": torch.tensor(
+                    [
+                        1,
+                    ],
+                    dtype=torch.int32,
+                )[:, None],
+            }
+            return outputs
+
+        for key, data_list in outputs.items():
+            if isinstance(data_list[0], torch.Tensor):
+                if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:
+
+                    pad_value = self.int_pad_value
+                else:
+                    pad_value = self.float_pad_value
+
+                outputs[key] = torch.nn.utils.rnn.pad_sequence(
+                    data_list, batch_first=True, padding_value=pad_value
+                )
+
+        if self.batch_type != "example":
+            for i in range(10):
+                outputs = self._filter_badcase(outputs, i=i)
+
+        return outputs
+
+    def _filter_badcase(self, outputs, i=0):
+        b, t, _ = outputs["speech"].shape
+
+        if b * t > self.batch_size * 1.25:
+            beg = torch.randint(0, 2, ()).item()
+            if b < 2:
+                beg = 0
+            logging.info(
+                f"Warning, b * t: {b * t} > {self.batch_size}, drop half data {i}th, beg:{beg}"
+            )
+            for key, data_list in outputs.items():
+                outputs[key] = outputs[key][beg : beg + b : 2]
+
+            speech_lengths_max = outputs["speech_lengths"].max().item()
+            outputs["speech"] = outputs["speech"][:, :speech_lengths_max, :]
+            text_lengths_max = outputs["text_lengths"].max().item()
+            outputs["text"] = outputs["text"][:, :text_lengths_max]
+
+        return outputs
diff --git a/funasr/models/sanm/encoder.py b/funasr/models/sanm/encoder.py
index b2a442b..dc30a94 100644
--- a/funasr/models/sanm/encoder.py
+++ b/funasr/models/sanm/encoder.py
@@ -484,226 +484,6 @@
         return xs_pad, ilens, None
 
 
-@tables.register("encoder_classes", "SANMTPEncoder")
-class SANMTPEncoder(nn.Module):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
-    https://arxiv.org/abs/2006.01713
-    """
-    def __init__(
-            self,
-            input_size: int,
-            output_size: int = 256,
-            attention_heads: int = 4,
-            linear_units: int = 2048,
-            num_blocks: int = 6,
-            tp_blocks: int = 0,
-            dropout_rate: float = 0.1,
-            positional_dropout_rate: float = 0.1,
-            attention_dropout_rate: float = 0.0,
-            stochastic_depth_rate: float = 0.0,
-            input_layer: Optional[str] = "conv2d",
-            pos_enc_class=SinusoidalPositionEncoder,
-            normalize_before: bool = True,
-            concat_after: bool = False,
-            positionwise_layer_type: str = "linear",
-            positionwise_conv_kernel_size: int = 1,
-            padding_idx: int = -1,
-            kernel_size: int = 11,
-            sanm_shfit: int = 0,
-            selfattention_layer_type: str = "sanm",
-    ):
-        super().__init__()
-        self._output_size = output_size
-        if input_layer == "linear":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Linear(input_size, output_size),
-                torch.nn.LayerNorm(output_size),
-                torch.nn.Dropout(dropout_rate),
-                torch.nn.ReLU(),
-                eval(pos_enc_class)(output_size, positional_dropout_rate),
-            )
-        elif input_layer == "linear_no_pos":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Linear(input_size, output_size),
-                torch.nn.LayerNorm(output_size),
-                torch.nn.Dropout(dropout_rate),
-                eval(pos_enc_class)(output_size, positional_dropout_rate, use_pos=False),
-            )
-        elif input_layer == "conv2d":
-            self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
-        elif input_layer == "conv2d2":
-            self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
-        elif input_layer == "conv2d6":
-            self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
-        elif input_layer == "conv2d8":
-            self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
-        elif input_layer == "embed":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
-                eval(pos_enc_class)(output_size, positional_dropout_rate),
-            )
-        elif input_layer is None:
-            if input_size == output_size:
-                self.embed = None
-            else:
-                self.embed = torch.nn.Linear(input_size, output_size)
-        elif input_layer == "pe":
-            self.embed = SinusoidalPositionEncoder()
-        elif input_layer == "pe_online":
-            self.embed = StreamSinusoidalPositionEncoder()
-        else:
-            raise ValueError("unknown input_layer: " + input_layer)
-        self.normalize_before = normalize_before
-        if positionwise_layer_type == "linear":
-            positionwise_layer = PositionwiseFeedForward
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                dropout_rate,
-            )
-        elif positionwise_layer_type == "conv1d":
-            positionwise_layer = MultiLayeredConv1d
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                positionwise_conv_kernel_size,
-                dropout_rate,
-            )
-        elif positionwise_layer_type == "conv1d-linear":
-            positionwise_layer = Conv1dLinear
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                positionwise_conv_kernel_size,
-                dropout_rate,
-            )
-        else:
-            raise NotImplementedError("Support only linear or conv1d.")
-        if selfattention_layer_type == "selfattn":
-            encoder_selfattn_layer = MultiHeadedAttention
-            encoder_selfattn_layer_args = (
-                attention_heads,
-                output_size,
-                attention_dropout_rate,
-            )
-        elif selfattention_layer_type == "sanm":
-            encoder_selfattn_layer = MultiHeadedAttentionSANM
-            encoder_selfattn_layer_args0 = (
-                attention_heads,
-                input_size,
-                output_size,
-                attention_dropout_rate,
-                kernel_size,
-                sanm_shfit,
-            )
-            encoder_selfattn_layer_args = (
-                attention_heads,
-                output_size,
-                output_size,
-                attention_dropout_rate,
-                kernel_size,
-                sanm_shfit,
-            )
-        self.encoders0 = repeat(
-            1,
-            lambda lnum: EncoderLayerSANM(
-                input_size,
-                output_size,
-                encoder_selfattn_layer(*encoder_selfattn_layer_args0),
-                positionwise_layer(*positionwise_layer_args),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-        self.encoders = repeat(
-            num_blocks - 1,
-            lambda lnum: EncoderLayerSANM(
-                output_size,
-                output_size,
-                encoder_selfattn_layer(*encoder_selfattn_layer_args),
-                positionwise_layer(*positionwise_layer_args),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-                stochastic_depth_rate,
-            ),
-        )
-        self.tp_encoders = repeat(
-            tp_blocks,
-            lambda lnum: EncoderLayerSANM(
-                output_size,
-                output_size,
-                encoder_selfattn_layer(*encoder_selfattn_layer_args),
-                positionwise_layer(*positionwise_layer_args),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-                stochastic_depth_rate,
-            ),
-        )
-        if self.normalize_before:
-            self.after_norm = LayerNorm(output_size)
-        self.tp_blocks = tp_blocks
-        if self.tp_blocks > 0:
-            self.tp_norm = LayerNorm(output_size)
-    def output_size(self) -> int:
-        return self._output_size
-    def forward(
-            self,
-            xs_pad: torch.Tensor,
-            ilens: torch.Tensor,
-    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
-        """Embed positions in tensor.
-        Args:
-            xs_pad: input tensor (B, L, D)
-            ilens: input length (B)
-            prev_states: Not to be used now.
-        Returns:
-            position embedded tensor and mask
-        """
-        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
-        xs_pad *= self.output_size() ** 0.5
-        if self.embed is None:
-            xs_pad = xs_pad
-        elif (
-                isinstance(self.embed, Conv2dSubsampling)
-                or isinstance(self.embed, Conv2dSubsampling2)
-                or isinstance(self.embed, Conv2dSubsampling6)
-                or isinstance(self.embed, Conv2dSubsampling8)
-        ):
-            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
-            if short_status:
-                raise TooShortUttError(
-                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
-                    + f"(it needs more than {limit_size} frames), return empty results",
-                    xs_pad.size(1),
-                    limit_size,
-                )
-            xs_pad, masks = self.embed(xs_pad, masks)
-        else:
-            xs_pad = self.embed(xs_pad)
-        # forward encoder1
-        mask_shfit_chunk, mask_att_chunk_encoder = None, None
-        encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
-        xs_pad, masks = encoder_outs[0], encoder_outs[1]
-        encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
-        xs_pad, masks = encoder_outs[0], encoder_outs[1]
-        if self.normalize_before:
-            xs_pad = self.after_norm(xs_pad)
-        # forward encoder2
-        olens = masks.squeeze(1).sum(1)
-        mask_shfit_chunk2, mask_att_chunk_encoder2 = None, None
-        for layer_idx, encoder_layer in enumerate(self.tp_encoders):
-            encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk2, mask_att_chunk_encoder2)
-            xs_pad, masks = encoder_outs[0], encoder_outs[1]
-        if self.tp_blocks > 0:
-            xs_pad = self.tp_norm(xs_pad)
-        return xs_pad, olens
-
-
 class EncoderLayerSANMExport(nn.Module):
     def __init__(
         self,
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index a9b2149..9db6539 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -10,7 +10,7 @@
 from torch import Tensor
 from torch import nn
 from torch.cuda.amp import autocast
-from funasr.metrics.compute_acc import compute_accuracy
+from funasr.metrics.compute_acc import compute_accuracy, th_accuracy
 from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
 from funasr.train_utils.device_funcs import force_gatherable
 from . import whisper_lib as whisper
@@ -662,9 +662,11 @@
         else:
             encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
 
-        loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
-            encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask
-        )
+        with autocast(False):
+            loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
+                encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask
+            )
+
         loss = loss_att
         stats = {}
         stats["acc"] = acc_att
@@ -1390,275 +1392,3 @@
 
 from funasr.models.paraformer.search import Hypothesis
 from funasr.utils import postprocess_utils
-
-
-@tables.register("model_classes", "SenseVoiceSANMCTC")
-class SenseVoiceSANMCTC(nn.Module):
-    """CTC-attention hybrid Encoder-Decoder model"""
-
-    def __init__(
-        self,
-        specaug: str = None,
-        specaug_conf: dict = None,
-        normalize: str = None,
-        normalize_conf: dict = None,
-        encoder: str = None,
-        encoder_conf: dict = None,
-        ctc_conf: dict = None,
-        input_size: int = 80,
-        vocab_size: int = -1,
-        ignore_id: int = -1,
-        blank_id: int = 0,
-        sos: int = 1,
-        eos: int = 2,
-        length_normalized_loss: bool = False,
-        **kwargs,
-    ):
-
-        super().__init__()
-
-        if specaug is not None:
-            specaug_class = tables.specaug_classes.get(specaug)
-            specaug = specaug_class(**specaug_conf)
-        if normalize is not None:
-            normalize_class = tables.normalize_classes.get(normalize)
-            normalize = normalize_class(**normalize_conf)
-        encoder_class = tables.encoder_classes.get(encoder)
-        encoder = encoder_class(input_size=input_size, **encoder_conf)
-        encoder_output_size = encoder.output_size()
-
-        if ctc_conf is None:
-            ctc_conf = {}
-        ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf)
-
-        self.blank_id = blank_id
-        self.sos = sos if sos is not None else vocab_size - 1
-        self.eos = eos if eos is not None else vocab_size - 1
-        self.vocab_size = vocab_size
-        self.ignore_id = ignore_id
-        self.specaug = specaug
-        self.normalize = normalize
-        self.encoder = encoder
-        self.error_calculator = None
-
-        self.ctc = ctc
-
-        self.length_normalized_loss = length_normalized_loss
-        self.encoder_output_size = encoder_output_size
-
-        self.lid_dict = {"zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13}
-        self.textnorm_dict = {"withtextnorm": 14, "wotextnorm": 15}
-        self.embed = torch.nn.Embedding(8 + len(self.lid_dict) + len(self.textnorm_dict), 560)
-
-    def forward(
-        self,
-        speech: torch.Tensor,
-        speech_lengths: torch.Tensor,
-        text: torch.Tensor,
-        text_lengths: torch.Tensor,
-        **kwargs,
-    ):
-        """Encoder + Decoder + Calc loss
-        Args:
-                speech: (Batch, Length, ...)
-                speech_lengths: (Batch, )
-                text: (Batch, Length)
-                text_lengths: (Batch,)
-        """
-        # import pdb;
-        # pdb.set_trace()
-        if len(text_lengths.size()) > 1:
-            text_lengths = text_lengths[:, 0]
-        if len(speech_lengths.size()) > 1:
-            speech_lengths = speech_lengths[:, 0]
-
-        batch_size = speech.shape[0]
-
-        # 1. Encoder
-        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-
-        loss_ctc, cer_ctc = None, None
-        stats = dict()
-
-        loss_ctc, cer_ctc = self._calc_ctc_loss(encoder_out, encoder_out_lens, text, text_lengths)
-
-        loss = loss_ctc
-
-        # Collect total loss stats
-        stats["loss"] = torch.clone(loss.detach())
-
-        # force_gatherable: to-device and to-tensor if scalar for DataParallel
-        if self.length_normalized_loss:
-            batch_size = int((text_lengths + 1).sum())
-        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
-        return loss, stats, weight
-
-    def encode(
-        self,
-        speech: torch.Tensor,
-        speech_lengths: torch.Tensor,
-        **kwargs,
-    ):
-        """Frontend + Encoder. Note that this method is used by asr_inference.py
-        Args:
-                speech: (Batch, Length, ...)
-                speech_lengths: (Batch, )
-                ind: int
-        """
-
-        # Data augmentation
-        if self.specaug is not None and self.training:
-            speech, speech_lengths = self.specaug(speech, speech_lengths)
-
-        # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
-        if self.normalize is not None:
-            speech, speech_lengths = self.normalize(speech, speech_lengths)
-
-        # Forward encoder
-        # feats: (Batch, Length, Dim)
-        # -> encoder_out: (Batch, Length2, Dim2)
-        encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
-
-        return encoder_out, encoder_out_lens
-
-    def _calc_ctc_loss(
-        self,
-        encoder_out: torch.Tensor,
-        encoder_out_lens: torch.Tensor,
-        ys_pad: torch.Tensor,
-        ys_pad_lens: torch.Tensor,
-    ):
-        # Calc CTC loss
-        loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
-
-        # Calc CER using CTC
-        cer_ctc = None
-        if not self.training and self.error_calculator is not None:
-            ys_hat = self.ctc.argmax(encoder_out).data
-            cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
-        return loss_ctc, cer_ctc
-
-    def inference(
-        self,
-        data_in,
-        data_lengths=None,
-        key: list = None,
-        tokenizer=None,
-        frontend=None,
-        **kwargs,
-    ):
-
-        if kwargs.get("batch_size", 1) > 1:
-            raise NotImplementedError("batch decoding is not implemented")
-
-        meta_data = {}
-        if (
-            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
-        ):  # fbank
-            speech, speech_lengths = data_in, data_lengths
-            if len(speech.shape) < 3:
-                speech = speech[None, :, :]
-            if speech_lengths is None:
-                speech_lengths = speech.shape[1]
-        else:
-            # extract fbank feats
-            time1 = time.perf_counter()
-            audio_sample_list = load_audio_text_image_video(
-                data_in,
-                fs=frontend.fs,
-                audio_fs=kwargs.get("fs", 16000),
-                data_type=kwargs.get("data_type", "sound"),
-                tokenizer=tokenizer,
-            )
-            time2 = time.perf_counter()
-            meta_data["load_data"] = f"{time2 - time1:0.3f}"
-            speech, speech_lengths = extract_fbank(
-                audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
-            )
-            time3 = time.perf_counter()
-            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
-            meta_data["batch_data_time"] = (
-                speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
-            )
-
-        speech = speech.to(device=kwargs["device"])
-        speech_lengths = speech_lengths.to(device=kwargs["device"])
-
-        language = kwargs.get("language", None)
-        if language is not None:
-            language_query = self.embed(
-                torch.LongTensor(
-                    [[self.lid_dict[language] if language in self.lid_dict else 0]]
-                ).to(speech.device)
-            ).repeat(speech.size(0), 1, 1)
-        else:
-            language_query = self.embed(torch.LongTensor([[0]]).to(speech.device)).repeat(
-                speech.size(0), 1, 1
-            )
-        textnorm = kwargs.get("text_norm", "wotextnorm")
-        textnorm_query = self.embed(
-            torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device)
-        ).repeat(speech.size(0), 1, 1)
-        speech = torch.cat((textnorm_query, speech), dim=1)
-        speech_lengths += 1
-
-        event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(
-            speech.size(0), 1, 1
-        )
-        input_query = torch.cat((language_query, event_emo_query), dim=1)
-        speech = torch.cat((input_query, speech), dim=1)
-        speech_lengths += 3
-
-        # Encoder
-        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-        if isinstance(encoder_out, tuple):
-            encoder_out = encoder_out[0]
-
-        # c. Passed the encoder result and the beam search
-        ctc_logits = self.ctc.log_softmax(encoder_out)
-
-        results = []
-        b, n, d = encoder_out.size()
-        if isinstance(key[0], (list, tuple)):
-            key = key[0]
-        if len(key) < b:
-            key = key * b
-        for i in range(b):
-            x = ctc_logits[i, : encoder_out_lens[i], :]
-            yseq = x.argmax(dim=-1)
-            yseq = torch.unique_consecutive(yseq, dim=-1)
-            yseq = torch.tensor([self.sos] + yseq.tolist() + [self.eos], device=yseq.device)
-            nbest_hyps = [Hypothesis(yseq=yseq)]
-
-            for nbest_idx, hyp in enumerate(nbest_hyps):
-                ibest_writer = None
-                if kwargs.get("output_dir") is not None:
-                    if not hasattr(self, "writer"):
-                        self.writer = DatadirWriter(kwargs.get("output_dir"))
-                    ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
-
-                # remove sos/eos and get results
-                last_pos = -1
-                if isinstance(hyp.yseq, list):
-                    token_int = hyp.yseq[1:last_pos]
-                else:
-                    token_int = hyp.yseq[1:last_pos].tolist()
-
-                # remove blank symbol id, which is assumed to be 0
-                token_int = list(
-                    filter(
-                        lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
-                    )
-                )
-
-                # Change integer-ids to tokens
-                text = tokenizer.decode(token_int)
-
-                result_i = {"key": key[i], "text": text}
-                results.append(result_i)
-
-                if ibest_writer is not None:
-                    ibest_writer["token"][key[i]] = " ".join(token)
-                    ibest_writer["text"][key[i]] = text_postprocessed
-
-        return results, meta_data
diff --git a/funasr/tokenizer/sentencepiece_tokenizer.py b/funasr/tokenizer/sentencepiece_tokenizer.py
index 1be1b81..0b47a9f 100644
--- a/funasr/tokenizer/sentencepiece_tokenizer.py
+++ b/funasr/tokenizer/sentencepiece_tokenizer.py
@@ -49,3 +49,9 @@
 
     def get_vocab_size(self):
         return self.sp.GetPieceSize()
+
+    def ids2tokens(self, *args, **kwargs):
+        return self.decode(*args, **kwargs)
+
+    def tokens2ids(self, *args, **kwargs):
+        return self.encode(*args, **kwargs)
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index afc632d..665a7af 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -362,10 +362,10 @@
         time_beg = time.perf_counter()
         time5 = time_beg
         for batch_idx, batch in enumerate(dataloader_train):
-            if self.use_ddp or self.use_fsdp:
-                dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
-                if iterator_stop > 0:
-                    break
+            # if self.use_ddp or self.use_fsdp:
+            #     dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
+            #     if iterator_stop > 0:
+            #         break
             self.batch_total += 1
             self.step_in_epoch += 1
             time1 = time.perf_counter()
@@ -381,11 +381,11 @@
                 with maybe_autocast(self.use_fp16):
                     retval = model(**batch)
 
-                    if (
-                        self.reset_gpu_cache
-                        and (torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024) > 70
-                    ):
-                        torch.cuda.empty_cache()
+                    # if (
+                    #     self.reset_gpu_cache
+                    #     and (torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024) > 70
+                    # ):
+                    #     torch.cuda.empty_cache()
 
                 loss, stats, weight = retval
                 stats = {k: v for k, v in stats.items() if v is not None}
@@ -516,14 +516,14 @@
                 )
 
             time_beg = time.perf_counter()
-        else:
-            if self.use_ddp or self.use_fsdp:
-                iterator_stop.fill_(1)
-                dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
+        # else:
+        #     if self.use_ddp or self.use_fsdp:
+        #         iterator_stop.fill_(1)
+        #         dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
 
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
-            iterator_stop = torch.tensor(0).to(self.device)
+            # iterator_stop = torch.tensor(0).to(self.device)
 
     def validate_epoch(
         self,

--
Gitblit v1.9.1