From c8bae0ec85eee25d66de6b1e4502eff74d750b24 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 21 十二月 2023 13:29:37 +0800
Subject: [PATCH] funasr2

---
 funasr/bin/inference.py                                |    5 
 funasr/models/fsmn_vad/encoder.py                      |    6 
 /dev/null                                              |  130 --------
 funasr/models/fsmn_vad/model.py                        |  517 ++++++++++++++++++-------------
 funasr/bin/train.py                                    |    4 
 funasr/tokenizer/abs_tokenizer.py                      |    5 
 funasr/models/ct_transformer/model.py                  |  212 +++++++++++++
 examples/industrial_data_pretraining/fsmn-vad/infer.sh |    8 
 setup.py                                               |    5 
 funasr/models/ct_transformer/encoder.py                |    0 
 funasr/datasets/audio_datasets/datasets.py             |   16 +
 funasr/train_utils/load_pretrained_model.py            |    3 
 12 files changed, 552 insertions(+), 359 deletions(-)

diff --git a/examples/industrial_data_pretraining/fsmn-vad/infer.sh b/examples/industrial_data_pretraining/fsmn-vad/infer.sh
new file mode 100644
index 0000000..9bfd8ba
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn-vad/infer.sh
@@ -0,0 +1,8 @@
+
+cmd="funasr/bin/inference.py"
+
+python $cmd \
++model="/Users/zhifu/Downloads/modelscope_models/speech_fsmn_vad_zh-cn-16k-common-pytorch" \
++input="/Users/zhifu/Downloads/asr_example.wav" \
++output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2_vad" \
++device="cpu" \
diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py
index fd884cd..50ea4d4 100644
--- a/funasr/bin/inference.py
+++ b/funasr/bin/inference.py
@@ -101,6 +101,7 @@
 			tokenizer_class = registry_tables.tokenizer_classes.get(tokenizer.lower())
 			tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
 			kwargs["tokenizer"] = tokenizer
+			kwargs["token_list"] = tokenizer.token_list
 		
 		# build frontend
 		frontend = kwargs.get("frontend", None)
@@ -112,11 +113,9 @@
 		
 		# build model
 		model_class = registry_tables.model_classes.get(kwargs["model"].lower())
-		model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
+		model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list) if tokenizer is not None else -1)
 		model.eval()
 		model.to(device)
-		
-		kwargs["token_list"] = tokenizer.token_list
 		
 		# init_param
 		init_param = kwargs.get("init_param", None)
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 8112002..1e06c50 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -145,7 +145,8 @@
 	# dataloader
 	batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
 	batch_sampler_class = registry_tables.batch_sampler_classes.get(batch_sampler.lower())
-	batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
+	if batch_sampler is not None:
+		batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
 	dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
 	                                            collate_fn=dataset_tr.collator,
 	                                            batch_sampler=batch_sampler,
@@ -153,7 +154,6 @@
 	                                            pin_memory=True)
 	
 
-	
 	trainer = Trainer(
 	    model=model,
 	    optim=optim,
diff --git a/funasr/datasets/audio_datasets/datasets.py b/funasr/datasets/audio_datasets/datasets.py
index 353a3a0..d69d0b5 100644
--- a/funasr/datasets/audio_datasets/datasets.py
+++ b/funasr/datasets/audio_datasets/datasets.py
@@ -24,6 +24,17 @@
 		super().__init__()
 		index_ds_class = registry_tables.index_ds_classes.get(index_ds.lower())
 		self.index_ds = index_ds_class(path)
+		preprocessor_speech = kwargs.get("preprocessor_speech", None)
+		if preprocessor_speech:
+			preprocessor_speech_class = registry_tables.preprocessor_speech_classes.get(preprocessor_speech.lower())
+			preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
+		self.preprocessor_speech = preprocessor_speech
+		preprocessor_text = kwargs.get("preprocessor_text", None)
+		if preprocessor_text:
+			preprocessor_text_class = registry_tables.preprocessor_text_classes.get(preprocessor_text.lower())
+			preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
+		self.preprocessor_text = preprocessor_text
+		
 		self.frontend = frontend
 		self.fs = 16000 if frontend is None else frontend.fs
 		self.data_type = "sound"
@@ -49,8 +60,13 @@
 		# pdb.set_trace()
 		source = item["source"]
 		data_src = load_audio(source, fs=self.fs)
+		if self.preprocessor_speech:
+			data_src = self.preprocessor_speech(data_src)
 		speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d]
+
 		target = item["target"]
+		if self.preprocessor_text:
+			target = self.preprocessor_text(target)
 		ids = self.tokenizer.encode(target)
 		ids_lengths = len(ids)
 		text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)
diff --git a/funasr/models/ct_transformer/sanm_encoder.py b/funasr/models/ct_transformer/encoder.py
similarity index 100%
rename from funasr/models/ct_transformer/sanm_encoder.py
rename to funasr/models/ct_transformer/encoder.py
diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py
new file mode 100644
index 0000000..31b2af2
--- /dev/null
+++ b/funasr/models/ct_transformer/model.py
@@ -0,0 +1,212 @@
+from typing import Any
+from typing import List
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+
+from funasr.utils.register import register_class, registry_tables
+
+@register_class("model_classes", "CTTransformer")
+class CTTransformer(nn.Module):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
+    https://arxiv.org/pdf/2003.01309.pdf
+    """
+    def __init__(
+        self,
+        encoder: str = None,
+        encoder_conf: str = None,
+        vocab_size: int = -1,
+        punc_list: list = None,
+        punc_weight: list = None,
+        embed_unit: int = 128,
+        att_unit: int = 256,
+        dropout_rate: float = 0.5,
+        ignore_id: int = -1,
+        sos: int = 1,
+        eos: int = 2,
+        **kwargs,
+    ):
+        super().__init__()
+
+        punc_size = len(punc_list)
+        if punc_weight is None:
+            punc_weight = [1] * punc_size
+        
+        
+        self.embed = nn.Embedding(vocab_size, embed_unit)
+        encoder_class = registry_tables.encoder_classes.get(encoder.lower())
+        encoder = encoder_class(**encoder_conf)
+
+        self.decoder = nn.Linear(att_unit, punc_size)
+        self.encoder = encoder
+        self.punc_list = punc_list
+        self.punc_weight = punc_weight
+        self.ignore_id = ignore_id
+        self.sos = sos
+        self.eos = eos
+        
+        
+
+    def punc_forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
+        """Compute loss value from buffer sequences.
+
+        Args:
+            input (torch.Tensor): Input ids. (batch, len)
+            hidden (torch.Tensor): Target ids. (batch, len)
+
+        """
+        x = self.embed(input)
+        # mask = self._target_mask(input)
+        h, _, _ = self.encoder(x, text_lengths)
+        y = self.decoder(h)
+        return y, None
+
+    def with_vad(self):
+        return False
+
+    def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
+        """Score new token.
+
+        Args:
+            y (torch.Tensor): 1D torch.int64 prefix tokens.
+            state: Scorer state for prefix tokens
+            x (torch.Tensor): encoder feature that generates ys.
+
+        Returns:
+            tuple[torch.Tensor, Any]: Tuple of
+                torch.float32 scores for next token (vocab_size)
+                and next state for ys
+
+        """
+        y = y.unsqueeze(0)
+        h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
+        h = self.decoder(h[:, -1])
+        logp = h.log_softmax(dim=-1).squeeze(0)
+        return logp, cache
+
+    def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
+        """Score new token batch.
+
+        Args:
+            ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
+            states (List[Any]): Scorer states for prefix tokens.
+            xs (torch.Tensor):
+                The encoder feature that generates ys (n_batch, xlen, n_feat).
+
+        Returns:
+            tuple[torch.Tensor, List[Any]]: Tuple of
+                batchfied scores for next token with shape of `(n_batch, vocab_size)`
+                and next state list for ys.
+
+        """
+        # merge states
+        n_batch = len(ys)
+        n_layers = len(self.encoder.encoders)
+        if states[0] is None:
+            batch_state = None
+        else:
+            # transpose state of [batch, layer] into [layer, batch]
+            batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
+
+        # batch decoding
+        h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
+        h = self.decoder(h[:, -1])
+        logp = h.log_softmax(dim=-1)
+
+        # transpose state of [layer, batch] into [batch, layer]
+        state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
+        return logp, state_list
+
+    def nll(
+        self,
+        text: torch.Tensor,
+        punc: torch.Tensor,
+        text_lengths: torch.Tensor,
+        punc_lengths: torch.Tensor,
+        max_length: Optional[int] = None,
+        vad_indexes: Optional[torch.Tensor] = None,
+        vad_indexes_lengths: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Compute negative log likelihood(nll)
+
+        Normally, this function is called in batchify_nll.
+        Args:
+            text: (Batch, Length)
+            punc: (Batch, Length)
+            text_lengths: (Batch,)
+            max_lengths: int
+        """
+        batch_size = text.size(0)
+        # For data parallel
+        if max_length is None:
+            text = text[:, :text_lengths.max()]
+            punc = punc[:, :text_lengths.max()]
+        else:
+            text = text[:, :max_length]
+            punc = punc[:, :max_length]
+    
+        if self.with_vad():
+            # Should be VadRealtimeTransformer
+            assert vad_indexes is not None
+            y, _ = self.punc_forward(text, text_lengths, vad_indexes)
+        else:
+            # Should be TargetDelayTransformer,
+            y, _ = self.punc_forward(text, text_lengths)
+    
+        # Calc negative log likelihood
+        # nll: (BxL,)
+        if self.training == False:
+            _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
+            from sklearn.metrics import f1_score
+            f1_score = f1_score(punc.view(-1).detach().cpu().numpy(),
+                                indices.squeeze(-1).detach().cpu().numpy(),
+                                average='micro')
+            nll = torch.Tensor([f1_score]).repeat(text_lengths.sum())
+            return nll, text_lengths
+        else:
+            self.punc_weight = self.punc_weight.to(punc.device)
+            nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
+                                  ignore_index=self.ignore_id)
+        # nll: (BxL,) -> (BxL,)
+        if max_length is None:
+            nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
+        else:
+            nll.masked_fill_(
+                make_pad_mask(text_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
+                0.0,
+            )
+        # nll: (BxL,) -> (B, L)
+        nll = nll.view(batch_size, -1)
+        return nll, text_lengths
+
+
+    def forward(
+        self,
+        text: torch.Tensor,
+        punc: torch.Tensor,
+        text_lengths: torch.Tensor,
+        punc_lengths: torch.Tensor,
+        vad_indexes: Optional[torch.Tensor] = None,
+        vad_indexes_lengths: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
+        ntokens = y_lengths.sum()
+        loss = nll.sum() / ntokens
+        stats = dict(loss=loss.detach())
+    
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
+        return loss, stats, weight
+    
+    def generate(self,
+                  text: torch.Tensor,
+                  text_lengths: torch.Tensor,
+                  vad_indexes: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, None]:
+        if self.with_vad():
+            assert vad_indexes is not None
+            return self.punc_forward(text, text_lengths, vad_indexes)
+        else:
+            return self.punc_forward(text, text_lengths)
\ No newline at end of file
diff --git a/funasr/models/ct_transformer/target_delay_transformer.py b/funasr/models/ct_transformer/target_delay_transformer.py
deleted file mode 100644
index 59884a3..0000000
--- a/funasr/models/ct_transformer/target_delay_transformer.py
+++ /dev/null
@@ -1,130 +0,0 @@
-from typing import Any
-from typing import List
-from typing import Tuple
-
-import torch
-import torch.nn as nn
-
-from funasr.models.transformer.embedding import SinusoidalPositionEncoder
-from funasr.models.sanm.encoder import SANMEncoder as Encoder
-
-
-class TargetDelayTransformer(torch.nn.Module):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
-    https://arxiv.org/pdf/2003.01309.pdf
-    """
-    def __init__(
-        self,
-        vocab_size: int,
-        punc_size: int,
-        pos_enc: str = None,
-        embed_unit: int = 128,
-        att_unit: int = 256,
-        head: int = 2,
-        unit: int = 1024,
-        layer: int = 4,
-        dropout_rate: float = 0.5,
-    ):
-        super().__init__()
-        if pos_enc == "sinusoidal":
-            #            pos_enc_class = PositionalEncoding
-            pos_enc_class = SinusoidalPositionEncoder
-        elif pos_enc is None:
-
-            def pos_enc_class(*args, **kwargs):
-                return nn.Sequential()  # indentity
-
-        else:
-            raise ValueError(f"unknown pos-enc option: {pos_enc}")
-
-        self.embed = nn.Embedding(vocab_size, embed_unit)
-        self.encoder = Encoder(
-            input_size=embed_unit,
-            output_size=att_unit,
-            attention_heads=head,
-            linear_units=unit,
-            num_blocks=layer,
-            dropout_rate=dropout_rate,
-            input_layer="pe",
-            # pos_enc_class=pos_enc_class,
-            padding_idx=0,
-        )
-        self.decoder = nn.Linear(att_unit, punc_size)
-
-
-#    def _target_mask(self, ys_in_pad):
-#        ys_mask = ys_in_pad != 0
-#        m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
-#        return ys_mask.unsqueeze(-2) & m
-
-    def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
-        """Compute loss value from buffer sequences.
-
-        Args:
-            input (torch.Tensor): Input ids. (batch, len)
-            hidden (torch.Tensor): Target ids. (batch, len)
-
-        """
-        x = self.embed(input)
-        # mask = self._target_mask(input)
-        h, _, _ = self.encoder(x, text_lengths)
-        y = self.decoder(h)
-        return y, None
-
-    def with_vad(self):
-        return False
-
-    def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
-        """Score new token.
-
-        Args:
-            y (torch.Tensor): 1D torch.int64 prefix tokens.
-            state: Scorer state for prefix tokens
-            x (torch.Tensor): encoder feature that generates ys.
-
-        Returns:
-            tuple[torch.Tensor, Any]: Tuple of
-                torch.float32 scores for next token (vocab_size)
-                and next state for ys
-
-        """
-        y = y.unsqueeze(0)
-        h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
-        h = self.decoder(h[:, -1])
-        logp = h.log_softmax(dim=-1).squeeze(0)
-        return logp, cache
-
-    def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
-        """Score new token batch.
-
-        Args:
-            ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
-            states (List[Any]): Scorer states for prefix tokens.
-            xs (torch.Tensor):
-                The encoder feature that generates ys (n_batch, xlen, n_feat).
-
-        Returns:
-            tuple[torch.Tensor, List[Any]]: Tuple of
-                batchfied scores for next token with shape of `(n_batch, vocab_size)`
-                and next state list for ys.
-
-        """
-        # merge states
-        n_batch = len(ys)
-        n_layers = len(self.encoder.encoders)
-        if states[0] is None:
-            batch_state = None
-        else:
-            # transpose state of [batch, layer] into [layer, batch]
-            batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
-
-        # batch decoding
-        h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
-        h = self.decoder(h[:, -1])
-        logp = h.log_softmax(dim=-1)
-
-        # transpose state of [layer, batch] into [batch, layer]
-        state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
-        return logp, state_list
diff --git a/funasr/models/fsmn_vad/fsmn_encoder.py b/funasr/models/fsmn_vad/encoder.py
similarity index 98%
rename from funasr/models/fsmn_vad/fsmn_encoder.py
rename to funasr/models/fsmn_vad/encoder.py
index 38d164d..50e31fc 100755
--- a/funasr/models/fsmn_vad/fsmn_encoder.py
+++ b/funasr/models/fsmn_vad/encoder.py
@@ -6,6 +6,8 @@
 import torch.nn as nn
 import torch.nn.functional as F
 
+from funasr.utils.register import register_class, registry_tables
+
 class LinearTransform(nn.Module):
 
     def __init__(self, input_dim, output_dim):
@@ -156,7 +158,7 @@
 fsmn_layers:            no. of sequential fsmn layers
 '''
 
-
+@register_class("encoder_classes", "FSMN")
 class FSMN(nn.Module):
     def __init__(
             self,
@@ -227,7 +229,7 @@
 rstride:                right stride
 '''
 
-
+@register_class("encoder_classes", "DFSMN")
 class DFSMN(nn.Module):
 
     def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
diff --git a/funasr/models/fsmn_vad/model.py b/funasr/models/fsmn_vad/model.py
index cc3c87e..16f21dc 100644
--- a/funasr/models/fsmn_vad/model.py
+++ b/funasr/models/fsmn_vad/model.py
@@ -1,33 +1,244 @@
 from enum import Enum
 from typing import List, Tuple, Dict, Any
-
+import logging
+import os
+import json
 import torch
 from torch import nn
 import math
 from typing import Optional
-from funasr.models.encoder.fsmn_encoder import FSMN
-from funasr.models.base_model import FunASRModel
-from funasr.models.model_class_factory import *
+import time
+from funasr.utils.register import register_class, registry_tables
+from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio,extract_fbank
+from funasr.utils.datadir_writer import DatadirWriter
+from torch.nn.utils.rnn import pad_sequence
+
+class VadStateMachine(Enum):
+    kVadInStateStartPointNotDetected = 1
+    kVadInStateInSpeechSegment = 2
+    kVadInStateEndPointDetected = 3
 
 
+class FrameState(Enum):
+    kFrameStateInvalid = -1
+    kFrameStateSpeech = 1
+    kFrameStateSil = 0
+
+
+# final voice/unvoice state per frame
+class AudioChangeState(Enum):
+    kChangeStateSpeech2Speech = 0
+    kChangeStateSpeech2Sil = 1
+    kChangeStateSil2Sil = 2
+    kChangeStateSil2Speech = 3
+    kChangeStateNoBegin = 4
+    kChangeStateInvalid = 5
+
+
+class VadDetectMode(Enum):
+    kVadSingleUtteranceDetectMode = 0
+    kVadMutipleUtteranceDetectMode = 1
+
+
+class VADXOptions:
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
+    def __init__(
+            self,
+            sample_rate: int = 16000,
+            detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
+            snr_mode: int = 0,
+            max_end_silence_time: int = 800,
+            max_start_silence_time: int = 3000,
+            do_start_point_detection: bool = True,
+            do_end_point_detection: bool = True,
+            window_size_ms: int = 200,
+            sil_to_speech_time_thres: int = 150,
+            speech_to_sil_time_thres: int = 150,
+            speech_2_noise_ratio: float = 1.0,
+            do_extend: int = 1,
+            lookback_time_start_point: int = 200,
+            lookahead_time_end_point: int = 100,
+            max_single_segment_time: int = 60000,
+            nn_eval_block_size: int = 8,
+            dcd_block_size: int = 4,
+            snr_thres: int = -100.0,
+            noise_frame_num_used_for_snr: int = 100,
+            decibel_thres: int = -100.0,
+            speech_noise_thres: float = 0.6,
+            fe_prior_thres: float = 1e-4,
+            silence_pdf_num: int = 1,
+            sil_pdf_ids: List[int] = [0],
+            speech_noise_thresh_low: float = -0.1,
+            speech_noise_thresh_high: float = 0.3,
+            output_frame_probs: bool = False,
+            frame_in_ms: int = 10,
+            frame_length_ms: int = 25,
+            **kwargs,
+    ):
+        self.sample_rate = sample_rate
+        self.detect_mode = detect_mode
+        self.snr_mode = snr_mode
+        self.max_end_silence_time = max_end_silence_time
+        self.max_start_silence_time = max_start_silence_time
+        self.do_start_point_detection = do_start_point_detection
+        self.do_end_point_detection = do_end_point_detection
+        self.window_size_ms = window_size_ms
+        self.sil_to_speech_time_thres = sil_to_speech_time_thres
+        self.speech_to_sil_time_thres = speech_to_sil_time_thres
+        self.speech_2_noise_ratio = speech_2_noise_ratio
+        self.do_extend = do_extend
+        self.lookback_time_start_point = lookback_time_start_point
+        self.lookahead_time_end_point = lookahead_time_end_point
+        self.max_single_segment_time = max_single_segment_time
+        self.nn_eval_block_size = nn_eval_block_size
+        self.dcd_block_size = dcd_block_size
+        self.snr_thres = snr_thres
+        self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
+        self.decibel_thres = decibel_thres
+        self.speech_noise_thres = speech_noise_thres
+        self.fe_prior_thres = fe_prior_thres
+        self.silence_pdf_num = silence_pdf_num
+        self.sil_pdf_ids = sil_pdf_ids
+        self.speech_noise_thresh_low = speech_noise_thresh_low
+        self.speech_noise_thresh_high = speech_noise_thresh_high
+        self.output_frame_probs = output_frame_probs
+        self.frame_in_ms = frame_in_ms
+        self.frame_length_ms = frame_length_ms
+
+
+class E2EVadSpeechBufWithDoa(object):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
+    def __init__(self):
+        self.start_ms = 0
+        self.end_ms = 0
+        self.buffer = []
+        self.contain_seg_start_point = False
+        self.contain_seg_end_point = False
+        self.doa = 0
+
+    def Reset(self):
+        self.start_ms = 0
+        self.end_ms = 0
+        self.buffer = []
+        self.contain_seg_start_point = False
+        self.contain_seg_end_point = False
+        self.doa = 0
+
+
+class E2EVadFrameProb(object):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
+    def __init__(self):
+        self.noise_prob = 0.0
+        self.speech_prob = 0.0
+        self.score = 0.0
+        self.frame_id = 0
+        self.frm_state = 0
+
+
+class WindowDetector(object):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
+    def __init__(self, window_size_ms: int, sil_to_speech_time: int,
+                 speech_to_sil_time: int, frame_size_ms: int):
+        self.window_size_ms = window_size_ms
+        self.sil_to_speech_time = sil_to_speech_time
+        self.speech_to_sil_time = speech_to_sil_time
+        self.frame_size_ms = frame_size_ms
+
+        self.win_size_frame = int(window_size_ms / frame_size_ms)
+        self.win_sum = 0
+        self.win_state = [0] * self.win_size_frame  # 鍒濆鍖栫獥
+
+        self.cur_win_pos = 0
+        self.pre_frame_state = FrameState.kFrameStateSil
+        self.cur_frame_state = FrameState.kFrameStateSil
+        self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
+        self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
+
+        self.voice_last_frame_count = 0
+        self.noise_last_frame_count = 0
+        self.hydre_frame_count = 0
+
+    def Reset(self) -> None:
+        self.cur_win_pos = 0
+        self.win_sum = 0
+        self.win_state = [0] * self.win_size_frame
+        self.pre_frame_state = FrameState.kFrameStateSil
+        self.cur_frame_state = FrameState.kFrameStateSil
+        self.voice_last_frame_count = 0
+        self.noise_last_frame_count = 0
+        self.hydre_frame_count = 0
+
+    def GetWinSize(self) -> int:
+        return int(self.win_size_frame)
+
+    def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState:
+        cur_frame_state = FrameState.kFrameStateSil
+        if frameState == FrameState.kFrameStateSpeech:
+            cur_frame_state = 1
+        elif frameState == FrameState.kFrameStateSil:
+            cur_frame_state = 0
+        else:
+            return AudioChangeState.kChangeStateInvalid
+        self.win_sum -= self.win_state[self.cur_win_pos]
+        self.win_sum += cur_frame_state
+        self.win_state[self.cur_win_pos] = cur_frame_state
+        self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
+
+        if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
+            self.pre_frame_state = FrameState.kFrameStateSpeech
+            return AudioChangeState.kChangeStateSil2Speech
+
+        if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
+            self.pre_frame_state = FrameState.kFrameStateSil
+            return AudioChangeState.kChangeStateSpeech2Sil
+
+        if self.pre_frame_state == FrameState.kFrameStateSil:
+            return AudioChangeState.kChangeStateSil2Sil
+        if self.pre_frame_state == FrameState.kFrameStateSpeech:
+            return AudioChangeState.kChangeStateSpeech2Speech
+        return AudioChangeState.kChangeStateInvalid
+
+    def FrameSizeMs(self) -> int:
+        return int(self.frame_size_ms)
+
+
+@register_class("model_classes", "FsmnVAD")
 class FsmnVAD(nn.Module):
     """
     Author: Speech Lab of DAMO Academy, Alibaba Group
     Deep-FSMN for Large Vocabulary Continuous Speech Recognition
     https://arxiv.org/abs/1803.05030
     """
-    def __init__(self, encoder: str = None,
+    def __init__(self,
+                 encoder: str = None,
                  encoder_conf: Optional[Dict] = None,
                  vad_post_args: Dict[str, Any] = None,
-                 frontend=None):
+                 **kwargs,
+                 ):
         super().__init__()
-        self.vad_opts = VADXOptions(**vad_post_args)
+        self.vad_opts = VADXOptions(**kwargs)
         self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
                                                self.vad_opts.sil_to_speech_time_thres,
                                                self.vad_opts.speech_to_sil_time_thres,
                                                self.vad_opts.frame_in_ms)
         
-        encoder_class = encoder_classes.get_class(encoder)
+        encoder_class = registry_tables.encoder_classes.get(encoder.lower())
         encoder = encoder_class(**encoder_conf)
         self.encoder = encoder
         # init variables
@@ -57,7 +268,6 @@
         self.data_buf = None
         self.data_buf_all = None
         self.waveform = None
-        self.frontend = frontend
         self.last_drop_frames = 0
 
     def AllResetDetection(self):
@@ -239,7 +449,7 @@
             vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
         return vad_latency
 
-    def GetFrameState(self, t: int) -> FrameState:
+    def GetFrameState(self, t: int):
         frame_state = FrameState.kFrameStateInvalid
         cur_decibel = self.decibel[t]
         cur_snr = cur_decibel - self.noise_average_decibel
@@ -285,7 +495,7 @@
 
     def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
                 is_final: bool = False
-                ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
+                ):
         if not in_cache:
             self.AllResetDetection()
         self.waveform = waveform  # compute decibel for each frame
@@ -312,6 +522,87 @@
             # reset class variables and clear the dict for the next query
             self.AllResetDetection()
         return segments, in_cache
+
+    def generate(self,
+                 data_in,
+                 data_lengths=None,
+                 key: list = None,
+                 tokenizer=None,
+                 frontend=None,
+                 **kwargs,
+                 ):
+
+
+        meta_data = {}
+        audio_sample_list = [data_in]
+        if isinstance(data_in, torch.Tensor):  # fbank
+            speech, speech_lengths = data_in, data_lengths
+            if len(speech.shape) < 3:
+                speech = speech[None, :, :]
+            if speech_lengths is None:
+                speech_lengths = speech.shape[1]
+        else:
+            # extract fbank feats
+            time1 = time.perf_counter()
+            audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
+            time2 = time.perf_counter()
+            meta_data["load_data"] = f"{time2 - time1:0.3f}"
+            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+                                                   frontend=frontend)
+            time3 = time.perf_counter()
+            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+            meta_data[
+                "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+
+        speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+
+        # b. Forward Encoder streaming
+        t_offset = 0
+        feats = speech
+        feats_len = speech_lengths.max().item()
+        waveform = pad_sequence(audio_sample_list, batch_first=True).to(device=kwargs["device"]) # data: [batch, N]
+        in_cache = kwargs.get("in_cache", {})
+        batch_size = kwargs.get("batch_size", 1)
+        step = min(feats_len, 6000)
+        segments = [[]] * batch_size
+
+        for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
+            if t_offset + step >= feats_len - 1:
+                step = feats_len - t_offset
+                is_final = True
+            else:
+                is_final = False
+            batch = {
+                "feats": feats[:, t_offset:t_offset + step, :],
+                "waveform": waveform[:, t_offset * 160:min(waveform.shape[-1], (t_offset + step - 1) * 160 + 400)],
+                "is_final": is_final,
+                "in_cache": in_cache
+            }
+
+
+            segments_part, in_cache = self.forward(**batch)
+            if segments_part:
+                for batch_num in range(0, batch_size):
+                    segments[batch_num] += segments_part[batch_num]
+
+        ibest_writer = None
+        if ibest_writer is None and kwargs.get("output_dir") is not None:
+            writer = DatadirWriter(kwargs.get("output_dir"))
+            ibest_writer = writer[f"{1}best_recog"]
+
+        results = []
+        for i in range(batch_size):
+            
+            if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
+                results[i] = json.dumps(results[i])
+                
+            if ibest_writer is not None:
+                ibest_writer["text"][key[i]] = segments[i]
+
+            result_i = {"key": key[i], "value": segments[i]}
+            results.append(result_i)
+ 
+        return results, meta_data
 
     def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
                        is_final: bool = False, max_end_sil: int = 800
@@ -481,209 +772,5 @@
                 self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
             self.ResetDetection()
 
-
-
-class VadStateMachine(Enum):
-    kVadInStateStartPointNotDetected = 1
-    kVadInStateInSpeechSegment = 2
-    kVadInStateEndPointDetected = 3
-
-
-class FrameState(Enum):
-    kFrameStateInvalid = -1
-    kFrameStateSpeech = 1
-    kFrameStateSil = 0
-
-
-# final voice/unvoice state per frame
-class AudioChangeState(Enum):
-    kChangeStateSpeech2Speech = 0
-    kChangeStateSpeech2Sil = 1
-    kChangeStateSil2Sil = 2
-    kChangeStateSil2Speech = 3
-    kChangeStateNoBegin = 4
-    kChangeStateInvalid = 5
-
-
-class VadDetectMode(Enum):
-    kVadSingleUtteranceDetectMode = 0
-    kVadMutipleUtteranceDetectMode = 1
-
-
-class VADXOptions:
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
-    https://arxiv.org/abs/1803.05030
-    """
-    def __init__(
-            self,
-            sample_rate: int = 16000,
-            detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
-            snr_mode: int = 0,
-            max_end_silence_time: int = 800,
-            max_start_silence_time: int = 3000,
-            do_start_point_detection: bool = True,
-            do_end_point_detection: bool = True,
-            window_size_ms: int = 200,
-            sil_to_speech_time_thres: int = 150,
-            speech_to_sil_time_thres: int = 150,
-            speech_2_noise_ratio: float = 1.0,
-            do_extend: int = 1,
-            lookback_time_start_point: int = 200,
-            lookahead_time_end_point: int = 100,
-            max_single_segment_time: int = 60000,
-            nn_eval_block_size: int = 8,
-            dcd_block_size: int = 4,
-            snr_thres: int = -100.0,
-            noise_frame_num_used_for_snr: int = 100,
-            decibel_thres: int = -100.0,
-            speech_noise_thres: float = 0.6,
-            fe_prior_thres: float = 1e-4,
-            silence_pdf_num: int = 1,
-            sil_pdf_ids: List[int] = [0],
-            speech_noise_thresh_low: float = -0.1,
-            speech_noise_thresh_high: float = 0.3,
-            output_frame_probs: bool = False,
-            frame_in_ms: int = 10,
-            frame_length_ms: int = 25,
-    ):
-        self.sample_rate = sample_rate
-        self.detect_mode = detect_mode
-        self.snr_mode = snr_mode
-        self.max_end_silence_time = max_end_silence_time
-        self.max_start_silence_time = max_start_silence_time
-        self.do_start_point_detection = do_start_point_detection
-        self.do_end_point_detection = do_end_point_detection
-        self.window_size_ms = window_size_ms
-        self.sil_to_speech_time_thres = sil_to_speech_time_thres
-        self.speech_to_sil_time_thres = speech_to_sil_time_thres
-        self.speech_2_noise_ratio = speech_2_noise_ratio
-        self.do_extend = do_extend
-        self.lookback_time_start_point = lookback_time_start_point
-        self.lookahead_time_end_point = lookahead_time_end_point
-        self.max_single_segment_time = max_single_segment_time
-        self.nn_eval_block_size = nn_eval_block_size
-        self.dcd_block_size = dcd_block_size
-        self.snr_thres = snr_thres
-        self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
-        self.decibel_thres = decibel_thres
-        self.speech_noise_thres = speech_noise_thres
-        self.fe_prior_thres = fe_prior_thres
-        self.silence_pdf_num = silence_pdf_num
-        self.sil_pdf_ids = sil_pdf_ids
-        self.speech_noise_thresh_low = speech_noise_thresh_low
-        self.speech_noise_thresh_high = speech_noise_thresh_high
-        self.output_frame_probs = output_frame_probs
-        self.frame_in_ms = frame_in_ms
-        self.frame_length_ms = frame_length_ms
-
-
-class E2EVadSpeechBufWithDoa(object):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
-    https://arxiv.org/abs/1803.05030
-    """
-    def __init__(self):
-        self.start_ms = 0
-        self.end_ms = 0
-        self.buffer = []
-        self.contain_seg_start_point = False
-        self.contain_seg_end_point = False
-        self.doa = 0
-
-    def Reset(self):
-        self.start_ms = 0
-        self.end_ms = 0
-        self.buffer = []
-        self.contain_seg_start_point = False
-        self.contain_seg_end_point = False
-        self.doa = 0
-
-
-class E2EVadFrameProb(object):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
-    https://arxiv.org/abs/1803.05030
-    """
-    def __init__(self):
-        self.noise_prob = 0.0
-        self.speech_prob = 0.0
-        self.score = 0.0
-        self.frame_id = 0
-        self.frm_state = 0
-
-
-class WindowDetector(object):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
-    https://arxiv.org/abs/1803.05030
-    """
-    def __init__(self, window_size_ms: int, sil_to_speech_time: int,
-                 speech_to_sil_time: int, frame_size_ms: int):
-        self.window_size_ms = window_size_ms
-        self.sil_to_speech_time = sil_to_speech_time
-        self.speech_to_sil_time = speech_to_sil_time
-        self.frame_size_ms = frame_size_ms
-
-        self.win_size_frame = int(window_size_ms / frame_size_ms)
-        self.win_sum = 0
-        self.win_state = [0] * self.win_size_frame  # 鍒濆鍖栫獥
-
-        self.cur_win_pos = 0
-        self.pre_frame_state = FrameState.kFrameStateSil
-        self.cur_frame_state = FrameState.kFrameStateSil
-        self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
-        self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
-
-        self.voice_last_frame_count = 0
-        self.noise_last_frame_count = 0
-        self.hydre_frame_count = 0
-
-    def Reset(self) -> None:
-        self.cur_win_pos = 0
-        self.win_sum = 0
-        self.win_state = [0] * self.win_size_frame
-        self.pre_frame_state = FrameState.kFrameStateSil
-        self.cur_frame_state = FrameState.kFrameStateSil
-        self.voice_last_frame_count = 0
-        self.noise_last_frame_count = 0
-        self.hydre_frame_count = 0
-
-    def GetWinSize(self) -> int:
-        return int(self.win_size_frame)
-
-    def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState:
-        cur_frame_state = FrameState.kFrameStateSil
-        if frameState == FrameState.kFrameStateSpeech:
-            cur_frame_state = 1
-        elif frameState == FrameState.kFrameStateSil:
-            cur_frame_state = 0
-        else:
-            return AudioChangeState.kChangeStateInvalid
-        self.win_sum -= self.win_state[self.cur_win_pos]
-        self.win_sum += cur_frame_state
-        self.win_state[self.cur_win_pos] = cur_frame_state
-        self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
-
-        if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
-            self.pre_frame_state = FrameState.kFrameStateSpeech
-            return AudioChangeState.kChangeStateSil2Speech
-
-        if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
-            self.pre_frame_state = FrameState.kFrameStateSil
-            return AudioChangeState.kChangeStateSpeech2Sil
-
-        if self.pre_frame_state == FrameState.kFrameStateSil:
-            return AudioChangeState.kChangeStateSil2Sil
-        if self.pre_frame_state == FrameState.kFrameStateSpeech:
-            return AudioChangeState.kChangeStateSpeech2Speech
-        return AudioChangeState.kChangeStateInvalid
-
-    def FrameSizeMs(self) -> int:
-        return int(self.frame_size_ms)
 
 
diff --git a/funasr/tokenizer/abs_tokenizer.py b/funasr/tokenizer/abs_tokenizer.py
index 349ebc0..d43d7b2 100644
--- a/funasr/tokenizer/abs_tokenizer.py
+++ b/funasr/tokenizer/abs_tokenizer.py
@@ -42,8 +42,9 @@
                 self.token_list_repr = str(token_list)
                 self.token_list: List[str] = []
 
-                with open('data.json', 'r', encoding='utf-8') as f:
-                    self.token_list = json.loads(f.read())
+                with open(token_list, 'r', encoding='utf-8') as f:
+                    self.token_list = json.load(f)
+                    
 
             else:
                 self.token_list: List[str] = list(token_list)
diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py
index b54f777..963d734 100644
--- a/funasr/train_utils/load_pretrained_model.py
+++ b/funasr/train_utils/load_pretrained_model.py
@@ -120,6 +120,7 @@
     if ignore_init_mismatch:
         src_state = filter_state_dict(dst_state, src_state)
 
-    # logging.info("Loaded src_state keys: {}".format(src_state.keys()))
+    logging.debug("Loaded src_state keys: {}".format(src_state.keys()))
+    logging.debug("Loaded dst_state keys: {}".format(dst_state.keys()))
     dst_state.update(src_state)
     obj.load_state_dict(dst_state)
diff --git a/setup.py b/setup.py
index a1e47af..ecd3d3d 100644
--- a/setup.py
+++ b/setup.py
@@ -10,14 +10,11 @@
 
 requirements = {
     "install": [
-        # "setuptools>=38.5.1",
-        "humanfriendly",
         "scipy>=1.4.1",
         "librosa",
         "jamo",  # For kss
         "PyYAML>=5.1.2",
         # "soundfile>=0.12.1",
-        # "h5py>=3.1.0",
         "kaldiio>=2.17.0",
         "torch_complex",
         # "nltk>=3.4.5",
@@ -32,7 +29,6 @@
         # ENH
         "pytorch_wpe",
         "editdistance>=0.5.2",
-        "tensorboard",
         # "g2p",
         # "nara_wpe",
         # PAI
@@ -44,6 +40,7 @@
         "hdbscan",
         "umap",
         "jaconv",
+        "hydra-core",
     ],
     # train: The modules invoked when training only.
     "train": [

--
Gitblit v1.9.1