游雁
2023-12-21 c8bae0ec85eee25d66de6b1e4502eff74d750b24
funasr2
7个文件已修改
2个文件已添加
2 文件已重命名
1个文件已删除
909 ■■■■■ 已修改文件
examples/industrial_data_pretraining/fsmn-vad/infer.sh 8 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/inference.py 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/datasets.py 16 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/ct_transformer/encoder.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/ct_transformer/model.py 212 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/ct_transformer/target_delay_transformer.py 130 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/fsmn_vad/encoder.py 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/fsmn_vad/model.py 517 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tokenizer/abs_tokenizer.py 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/load_pretrained_model.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
setup.py 5 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/fsmn-vad/infer.sh
New file
@@ -0,0 +1,8 @@
cmd="funasr/bin/inference.py"
python $cmd \
+model="/Users/zhifu/Downloads/modelscope_models/speech_fsmn_vad_zh-cn-16k-common-pytorch" \
+input="/Users/zhifu/Downloads/asr_example.wav" \
+output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2_vad" \
+device="cpu" \
funasr/bin/inference.py
@@ -101,6 +101,7 @@
            tokenizer_class = registry_tables.tokenizer_classes.get(tokenizer.lower())
            tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
            kwargs["tokenizer"] = tokenizer
            kwargs["token_list"] = tokenizer.token_list
        
        # build frontend
        frontend = kwargs.get("frontend", None)
@@ -112,11 +113,9 @@
        
        # build model
        model_class = registry_tables.model_classes.get(kwargs["model"].lower())
        model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
        model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list) if tokenizer is not None else -1)
        model.eval()
        model.to(device)
        kwargs["token_list"] = tokenizer.token_list
        
        # init_param
        init_param = kwargs.get("init_param", None)
funasr/bin/train.py
@@ -145,13 +145,13 @@
    # dataloader
    batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
    batch_sampler_class = registry_tables.batch_sampler_classes.get(batch_sampler.lower())
    if batch_sampler is not None:
    batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
    dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
                                                collate_fn=dataset_tr.collator,
                                                batch_sampler=batch_sampler,
                                                num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
                                                pin_memory=True)
    
    trainer = Trainer(
funasr/datasets/audio_datasets/datasets.py
@@ -24,6 +24,17 @@
        super().__init__()
        index_ds_class = registry_tables.index_ds_classes.get(index_ds.lower())
        self.index_ds = index_ds_class(path)
        preprocessor_speech = kwargs.get("preprocessor_speech", None)
        if preprocessor_speech:
            preprocessor_speech_class = registry_tables.preprocessor_speech_classes.get(preprocessor_speech.lower())
            preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
        self.preprocessor_speech = preprocessor_speech
        preprocessor_text = kwargs.get("preprocessor_text", None)
        if preprocessor_text:
            preprocessor_text_class = registry_tables.preprocessor_text_classes.get(preprocessor_text.lower())
            preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
        self.preprocessor_text = preprocessor_text
        self.frontend = frontend
        self.fs = 16000 if frontend is None else frontend.fs
        self.data_type = "sound"
@@ -49,8 +60,13 @@
        # pdb.set_trace()
        source = item["source"]
        data_src = load_audio(source, fs=self.fs)
        if self.preprocessor_speech:
            data_src = self.preprocessor_speech(data_src)
        speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d]
        target = item["target"]
        if self.preprocessor_text:
            target = self.preprocessor_text(target)
        ids = self.tokenizer.encode(target)
        ids_lengths = len(ids)
        text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)
funasr/models/ct_transformer/encoder.py
funasr/models/ct_transformer/model.py
New file
@@ -0,0 +1,212 @@
from typing import Any
from typing import List
from typing import Tuple
import torch
import torch.nn as nn
from funasr.utils.register import register_class, registry_tables
@register_class("model_classes", "CTTransformer")
class CTTransformer(nn.Module):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
    https://arxiv.org/pdf/2003.01309.pdf
    """
    def __init__(
        self,
        encoder: str = None,
        encoder_conf: str = None,
        vocab_size: int = -1,
        punc_list: list = None,
        punc_weight: list = None,
        embed_unit: int = 128,
        att_unit: int = 256,
        dropout_rate: float = 0.5,
        ignore_id: int = -1,
        sos: int = 1,
        eos: int = 2,
        **kwargs,
    ):
        super().__init__()
        punc_size = len(punc_list)
        if punc_weight is None:
            punc_weight = [1] * punc_size
        self.embed = nn.Embedding(vocab_size, embed_unit)
        encoder_class = registry_tables.encoder_classes.get(encoder.lower())
        encoder = encoder_class(**encoder_conf)
        self.decoder = nn.Linear(att_unit, punc_size)
        self.encoder = encoder
        self.punc_list = punc_list
        self.punc_weight = punc_weight
        self.ignore_id = ignore_id
        self.sos = sos
        self.eos = eos
    def punc_forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
        """Compute loss value from buffer sequences.
        Args:
            input (torch.Tensor): Input ids. (batch, len)
            hidden (torch.Tensor): Target ids. (batch, len)
        """
        x = self.embed(input)
        # mask = self._target_mask(input)
        h, _, _ = self.encoder(x, text_lengths)
        y = self.decoder(h)
        return y, None
    def with_vad(self):
        return False
    def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
        """Score new token.
        Args:
            y (torch.Tensor): 1D torch.int64 prefix tokens.
            state: Scorer state for prefix tokens
            x (torch.Tensor): encoder feature that generates ys.
        Returns:
            tuple[torch.Tensor, Any]: Tuple of
                torch.float32 scores for next token (vocab_size)
                and next state for ys
        """
        y = y.unsqueeze(0)
        h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
        h = self.decoder(h[:, -1])
        logp = h.log_softmax(dim=-1).squeeze(0)
        return logp, cache
    def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
        """Score new token batch.
        Args:
            ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
            states (List[Any]): Scorer states for prefix tokens.
            xs (torch.Tensor):
                The encoder feature that generates ys (n_batch, xlen, n_feat).
        Returns:
            tuple[torch.Tensor, List[Any]]: Tuple of
                batchfied scores for next token with shape of `(n_batch, vocab_size)`
                and next state list for ys.
        """
        # merge states
        n_batch = len(ys)
        n_layers = len(self.encoder.encoders)
        if states[0] is None:
            batch_state = None
        else:
            # transpose state of [batch, layer] into [layer, batch]
            batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
        # batch decoding
        h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
        h = self.decoder(h[:, -1])
        logp = h.log_softmax(dim=-1)
        # transpose state of [layer, batch] into [batch, layer]
        state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
        return logp, state_list
    def nll(
        self,
        text: torch.Tensor,
        punc: torch.Tensor,
        text_lengths: torch.Tensor,
        punc_lengths: torch.Tensor,
        max_length: Optional[int] = None,
        vad_indexes: Optional[torch.Tensor] = None,
        vad_indexes_lengths: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute negative log likelihood(nll)
        Normally, this function is called in batchify_nll.
        Args:
            text: (Batch, Length)
            punc: (Batch, Length)
            text_lengths: (Batch,)
            max_lengths: int
        """
        batch_size = text.size(0)
        # For data parallel
        if max_length is None:
            text = text[:, :text_lengths.max()]
            punc = punc[:, :text_lengths.max()]
        else:
            text = text[:, :max_length]
            punc = punc[:, :max_length]
        if self.with_vad():
            # Should be VadRealtimeTransformer
            assert vad_indexes is not None
            y, _ = self.punc_forward(text, text_lengths, vad_indexes)
        else:
            # Should be TargetDelayTransformer,
            y, _ = self.punc_forward(text, text_lengths)
        # Calc negative log likelihood
        # nll: (BxL,)
        if self.training == False:
            _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
            from sklearn.metrics import f1_score
            f1_score = f1_score(punc.view(-1).detach().cpu().numpy(),
                                indices.squeeze(-1).detach().cpu().numpy(),
                                average='micro')
            nll = torch.Tensor([f1_score]).repeat(text_lengths.sum())
            return nll, text_lengths
        else:
            self.punc_weight = self.punc_weight.to(punc.device)
            nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
                                  ignore_index=self.ignore_id)
        # nll: (BxL,) -> (BxL,)
        if max_length is None:
            nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
        else:
            nll.masked_fill_(
                make_pad_mask(text_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
                0.0,
            )
        # nll: (BxL,) -> (B, L)
        nll = nll.view(batch_size, -1)
        return nll, text_lengths
    def forward(
        self,
        text: torch.Tensor,
        punc: torch.Tensor,
        text_lengths: torch.Tensor,
        punc_lengths: torch.Tensor,
        vad_indexes: Optional[torch.Tensor] = None,
        vad_indexes_lengths: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
        ntokens = y_lengths.sum()
        loss = nll.sum() / ntokens
        stats = dict(loss=loss.detach())
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
        return loss, stats, weight
    def generate(self,
                  text: torch.Tensor,
                  text_lengths: torch.Tensor,
                  vad_indexes: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, None]:
        if self.with_vad():
            assert vad_indexes is not None
            return self.punc_forward(text, text_lengths, vad_indexes)
        else:
            return self.punc_forward(text, text_lengths)
funasr/models/ct_transformer/target_delay_transformer.py
File was deleted
funasr/models/fsmn_vad/encoder.py
File was renamed from funasr/models/fsmn_vad/fsmn_encoder.py
@@ -6,6 +6,8 @@
import torch.nn as nn
import torch.nn.functional as F
from funasr.utils.register import register_class, registry_tables
class LinearTransform(nn.Module):
    def __init__(self, input_dim, output_dim):
@@ -156,7 +158,7 @@
fsmn_layers:            no. of sequential fsmn layers
'''
@register_class("encoder_classes", "FSMN")
class FSMN(nn.Module):
    def __init__(
            self,
@@ -227,7 +229,7 @@
rstride:                right stride
'''
@register_class("encoder_classes", "DFSMN")
class DFSMN(nn.Module):
    def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
funasr/models/fsmn_vad/model.py
@@ -1,33 +1,244 @@
from enum import Enum
from typing import List, Tuple, Dict, Any
import logging
import os
import json
import torch
from torch import nn
import math
from typing import Optional
from funasr.models.encoder.fsmn_encoder import FSMN
from funasr.models.base_model import FunASRModel
from funasr.models.model_class_factory import *
import time
from funasr.utils.register import register_class, registry_tables
from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio,extract_fbank
from funasr.utils.datadir_writer import DatadirWriter
from torch.nn.utils.rnn import pad_sequence
class VadStateMachine(Enum):
    kVadInStateStartPointNotDetected = 1
    kVadInStateInSpeechSegment = 2
    kVadInStateEndPointDetected = 3
class FrameState(Enum):
    kFrameStateInvalid = -1
    kFrameStateSpeech = 1
    kFrameStateSil = 0
# final voice/unvoice state per frame
class AudioChangeState(Enum):
    kChangeStateSpeech2Speech = 0
    kChangeStateSpeech2Sil = 1
    kChangeStateSil2Sil = 2
    kChangeStateSil2Speech = 3
    kChangeStateNoBegin = 4
    kChangeStateInvalid = 5
class VadDetectMode(Enum):
    kVadSingleUtteranceDetectMode = 0
    kVadMutipleUtteranceDetectMode = 1
class VADXOptions:
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(
            self,
            sample_rate: int = 16000,
            detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
            snr_mode: int = 0,
            max_end_silence_time: int = 800,
            max_start_silence_time: int = 3000,
            do_start_point_detection: bool = True,
            do_end_point_detection: bool = True,
            window_size_ms: int = 200,
            sil_to_speech_time_thres: int = 150,
            speech_to_sil_time_thres: int = 150,
            speech_2_noise_ratio: float = 1.0,
            do_extend: int = 1,
            lookback_time_start_point: int = 200,
            lookahead_time_end_point: int = 100,
            max_single_segment_time: int = 60000,
            nn_eval_block_size: int = 8,
            dcd_block_size: int = 4,
            snr_thres: int = -100.0,
            noise_frame_num_used_for_snr: int = 100,
            decibel_thres: int = -100.0,
            speech_noise_thres: float = 0.6,
            fe_prior_thres: float = 1e-4,
            silence_pdf_num: int = 1,
            sil_pdf_ids: List[int] = [0],
            speech_noise_thresh_low: float = -0.1,
            speech_noise_thresh_high: float = 0.3,
            output_frame_probs: bool = False,
            frame_in_ms: int = 10,
            frame_length_ms: int = 25,
            **kwargs,
    ):
        self.sample_rate = sample_rate
        self.detect_mode = detect_mode
        self.snr_mode = snr_mode
        self.max_end_silence_time = max_end_silence_time
        self.max_start_silence_time = max_start_silence_time
        self.do_start_point_detection = do_start_point_detection
        self.do_end_point_detection = do_end_point_detection
        self.window_size_ms = window_size_ms
        self.sil_to_speech_time_thres = sil_to_speech_time_thres
        self.speech_to_sil_time_thres = speech_to_sil_time_thres
        self.speech_2_noise_ratio = speech_2_noise_ratio
        self.do_extend = do_extend
        self.lookback_time_start_point = lookback_time_start_point
        self.lookahead_time_end_point = lookahead_time_end_point
        self.max_single_segment_time = max_single_segment_time
        self.nn_eval_block_size = nn_eval_block_size
        self.dcd_block_size = dcd_block_size
        self.snr_thres = snr_thres
        self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
        self.decibel_thres = decibel_thres
        self.speech_noise_thres = speech_noise_thres
        self.fe_prior_thres = fe_prior_thres
        self.silence_pdf_num = silence_pdf_num
        self.sil_pdf_ids = sil_pdf_ids
        self.speech_noise_thresh_low = speech_noise_thresh_low
        self.speech_noise_thresh_high = speech_noise_thresh_high
        self.output_frame_probs = output_frame_probs
        self.frame_in_ms = frame_in_ms
        self.frame_length_ms = frame_length_ms
class E2EVadSpeechBufWithDoa(object):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(self):
        self.start_ms = 0
        self.end_ms = 0
        self.buffer = []
        self.contain_seg_start_point = False
        self.contain_seg_end_point = False
        self.doa = 0
    def Reset(self):
        self.start_ms = 0
        self.end_ms = 0
        self.buffer = []
        self.contain_seg_start_point = False
        self.contain_seg_end_point = False
        self.doa = 0
class E2EVadFrameProb(object):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(self):
        self.noise_prob = 0.0
        self.speech_prob = 0.0
        self.score = 0.0
        self.frame_id = 0
        self.frm_state = 0
class WindowDetector(object):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(self, window_size_ms: int, sil_to_speech_time: int,
                 speech_to_sil_time: int, frame_size_ms: int):
        self.window_size_ms = window_size_ms
        self.sil_to_speech_time = sil_to_speech_time
        self.speech_to_sil_time = speech_to_sil_time
        self.frame_size_ms = frame_size_ms
        self.win_size_frame = int(window_size_ms / frame_size_ms)
        self.win_sum = 0
        self.win_state = [0] * self.win_size_frame  # 初始化窗
        self.cur_win_pos = 0
        self.pre_frame_state = FrameState.kFrameStateSil
        self.cur_frame_state = FrameState.kFrameStateSil
        self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
        self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
        self.voice_last_frame_count = 0
        self.noise_last_frame_count = 0
        self.hydre_frame_count = 0
    def Reset(self) -> None:
        self.cur_win_pos = 0
        self.win_sum = 0
        self.win_state = [0] * self.win_size_frame
        self.pre_frame_state = FrameState.kFrameStateSil
        self.cur_frame_state = FrameState.kFrameStateSil
        self.voice_last_frame_count = 0
        self.noise_last_frame_count = 0
        self.hydre_frame_count = 0
    def GetWinSize(self) -> int:
        return int(self.win_size_frame)
    def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState:
        cur_frame_state = FrameState.kFrameStateSil
        if frameState == FrameState.kFrameStateSpeech:
            cur_frame_state = 1
        elif frameState == FrameState.kFrameStateSil:
            cur_frame_state = 0
        else:
            return AudioChangeState.kChangeStateInvalid
        self.win_sum -= self.win_state[self.cur_win_pos]
        self.win_sum += cur_frame_state
        self.win_state[self.cur_win_pos] = cur_frame_state
        self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
        if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
            self.pre_frame_state = FrameState.kFrameStateSpeech
            return AudioChangeState.kChangeStateSil2Speech
        if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
            self.pre_frame_state = FrameState.kFrameStateSil
            return AudioChangeState.kChangeStateSpeech2Sil
        if self.pre_frame_state == FrameState.kFrameStateSil:
            return AudioChangeState.kChangeStateSil2Sil
        if self.pre_frame_state == FrameState.kFrameStateSpeech:
            return AudioChangeState.kChangeStateSpeech2Speech
        return AudioChangeState.kChangeStateInvalid
    def FrameSizeMs(self) -> int:
        return int(self.frame_size_ms)
@register_class("model_classes", "FsmnVAD")
class FsmnVAD(nn.Module):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(self, encoder: str = None,
    def __init__(self,
                 encoder: str = None,
                 encoder_conf: Optional[Dict] = None,
                 vad_post_args: Dict[str, Any] = None,
                 frontend=None):
                 **kwargs,
                 ):
        super().__init__()
        self.vad_opts = VADXOptions(**vad_post_args)
        self.vad_opts = VADXOptions(**kwargs)
        self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
                                               self.vad_opts.sil_to_speech_time_thres,
                                               self.vad_opts.speech_to_sil_time_thres,
                                               self.vad_opts.frame_in_ms)
        
        encoder_class = encoder_classes.get_class(encoder)
        encoder_class = registry_tables.encoder_classes.get(encoder.lower())
        encoder = encoder_class(**encoder_conf)
        self.encoder = encoder
        # init variables
@@ -57,7 +268,6 @@
        self.data_buf = None
        self.data_buf_all = None
        self.waveform = None
        self.frontend = frontend
        self.last_drop_frames = 0
    def AllResetDetection(self):
@@ -239,7 +449,7 @@
            vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
        return vad_latency
    def GetFrameState(self, t: int) -> FrameState:
    def GetFrameState(self, t: int):
        frame_state = FrameState.kFrameStateInvalid
        cur_decibel = self.decibel[t]
        cur_snr = cur_decibel - self.noise_average_decibel
@@ -285,7 +495,7 @@
    def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
                is_final: bool = False
                ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
                ):
        if not in_cache:
            self.AllResetDetection()
        self.waveform = waveform  # compute decibel for each frame
@@ -312,6 +522,87 @@
            # reset class variables and clear the dict for the next query
            self.AllResetDetection()
        return segments, in_cache
    def generate(self,
                 data_in,
                 data_lengths=None,
                 key: list = None,
                 tokenizer=None,
                 frontend=None,
                 **kwargs,
                 ):
        meta_data = {}
        audio_sample_list = [data_in]
        if isinstance(data_in, torch.Tensor):  # fbank
            speech, speech_lengths = data_in, data_lengths
            if len(speech.shape) < 3:
                speech = speech[None, :, :]
            if speech_lengths is None:
                speech_lengths = speech.shape[1]
        else:
            # extract fbank feats
            time1 = time.perf_counter()
            audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
            time2 = time.perf_counter()
            meta_data["load_data"] = f"{time2 - time1:0.3f}"
            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
                                                   frontend=frontend)
            time3 = time.perf_counter()
            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
            meta_data[
                "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
        speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
        # b. Forward Encoder streaming
        t_offset = 0
        feats = speech
        feats_len = speech_lengths.max().item()
        waveform = pad_sequence(audio_sample_list, batch_first=True).to(device=kwargs["device"]) # data: [batch, N]
        in_cache = kwargs.get("in_cache", {})
        batch_size = kwargs.get("batch_size", 1)
        step = min(feats_len, 6000)
        segments = [[]] * batch_size
        for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
            if t_offset + step >= feats_len - 1:
                step = feats_len - t_offset
                is_final = True
            else:
                is_final = False
            batch = {
                "feats": feats[:, t_offset:t_offset + step, :],
                "waveform": waveform[:, t_offset * 160:min(waveform.shape[-1], (t_offset + step - 1) * 160 + 400)],
                "is_final": is_final,
                "in_cache": in_cache
            }
            segments_part, in_cache = self.forward(**batch)
            if segments_part:
                for batch_num in range(0, batch_size):
                    segments[batch_num] += segments_part[batch_num]
        ibest_writer = None
        if ibest_writer is None and kwargs.get("output_dir") is not None:
            writer = DatadirWriter(kwargs.get("output_dir"))
            ibest_writer = writer[f"{1}best_recog"]
        results = []
        for i in range(batch_size):
            if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
                results[i] = json.dumps(results[i])
            if ibest_writer is not None:
                ibest_writer["text"][key[i]] = segments[i]
            result_i = {"key": key[i], "value": segments[i]}
            results.append(result_i)
        return results, meta_data
    def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
                       is_final: bool = False, max_end_sil: int = 800
@@ -481,209 +772,5 @@
                self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
            self.ResetDetection()
class VadStateMachine(Enum):
    kVadInStateStartPointNotDetected = 1
    kVadInStateInSpeechSegment = 2
    kVadInStateEndPointDetected = 3
class FrameState(Enum):
    kFrameStateInvalid = -1
    kFrameStateSpeech = 1
    kFrameStateSil = 0
# final voice/unvoice state per frame
class AudioChangeState(Enum):
    kChangeStateSpeech2Speech = 0
    kChangeStateSpeech2Sil = 1
    kChangeStateSil2Sil = 2
    kChangeStateSil2Speech = 3
    kChangeStateNoBegin = 4
    kChangeStateInvalid = 5
class VadDetectMode(Enum):
    kVadSingleUtteranceDetectMode = 0
    kVadMutipleUtteranceDetectMode = 1
class VADXOptions:
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(
            self,
            sample_rate: int = 16000,
            detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
            snr_mode: int = 0,
            max_end_silence_time: int = 800,
            max_start_silence_time: int = 3000,
            do_start_point_detection: bool = True,
            do_end_point_detection: bool = True,
            window_size_ms: int = 200,
            sil_to_speech_time_thres: int = 150,
            speech_to_sil_time_thres: int = 150,
            speech_2_noise_ratio: float = 1.0,
            do_extend: int = 1,
            lookback_time_start_point: int = 200,
            lookahead_time_end_point: int = 100,
            max_single_segment_time: int = 60000,
            nn_eval_block_size: int = 8,
            dcd_block_size: int = 4,
            snr_thres: int = -100.0,
            noise_frame_num_used_for_snr: int = 100,
            decibel_thres: int = -100.0,
            speech_noise_thres: float = 0.6,
            fe_prior_thres: float = 1e-4,
            silence_pdf_num: int = 1,
            sil_pdf_ids: List[int] = [0],
            speech_noise_thresh_low: float = -0.1,
            speech_noise_thresh_high: float = 0.3,
            output_frame_probs: bool = False,
            frame_in_ms: int = 10,
            frame_length_ms: int = 25,
    ):
        self.sample_rate = sample_rate
        self.detect_mode = detect_mode
        self.snr_mode = snr_mode
        self.max_end_silence_time = max_end_silence_time
        self.max_start_silence_time = max_start_silence_time
        self.do_start_point_detection = do_start_point_detection
        self.do_end_point_detection = do_end_point_detection
        self.window_size_ms = window_size_ms
        self.sil_to_speech_time_thres = sil_to_speech_time_thres
        self.speech_to_sil_time_thres = speech_to_sil_time_thres
        self.speech_2_noise_ratio = speech_2_noise_ratio
        self.do_extend = do_extend
        self.lookback_time_start_point = lookback_time_start_point
        self.lookahead_time_end_point = lookahead_time_end_point
        self.max_single_segment_time = max_single_segment_time
        self.nn_eval_block_size = nn_eval_block_size
        self.dcd_block_size = dcd_block_size
        self.snr_thres = snr_thres
        self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
        self.decibel_thres = decibel_thres
        self.speech_noise_thres = speech_noise_thres
        self.fe_prior_thres = fe_prior_thres
        self.silence_pdf_num = silence_pdf_num
        self.sil_pdf_ids = sil_pdf_ids
        self.speech_noise_thresh_low = speech_noise_thresh_low
        self.speech_noise_thresh_high = speech_noise_thresh_high
        self.output_frame_probs = output_frame_probs
        self.frame_in_ms = frame_in_ms
        self.frame_length_ms = frame_length_ms
class E2EVadSpeechBufWithDoa(object):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(self):
        self.start_ms = 0
        self.end_ms = 0
        self.buffer = []
        self.contain_seg_start_point = False
        self.contain_seg_end_point = False
        self.doa = 0
    def Reset(self):
        self.start_ms = 0
        self.end_ms = 0
        self.buffer = []
        self.contain_seg_start_point = False
        self.contain_seg_end_point = False
        self.doa = 0
class E2EVadFrameProb(object):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(self):
        self.noise_prob = 0.0
        self.speech_prob = 0.0
        self.score = 0.0
        self.frame_id = 0
        self.frm_state = 0
class WindowDetector(object):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(self, window_size_ms: int, sil_to_speech_time: int,
                 speech_to_sil_time: int, frame_size_ms: int):
        self.window_size_ms = window_size_ms
        self.sil_to_speech_time = sil_to_speech_time
        self.speech_to_sil_time = speech_to_sil_time
        self.frame_size_ms = frame_size_ms
        self.win_size_frame = int(window_size_ms / frame_size_ms)
        self.win_sum = 0
        self.win_state = [0] * self.win_size_frame  # 初始化窗
        self.cur_win_pos = 0
        self.pre_frame_state = FrameState.kFrameStateSil
        self.cur_frame_state = FrameState.kFrameStateSil
        self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
        self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
        self.voice_last_frame_count = 0
        self.noise_last_frame_count = 0
        self.hydre_frame_count = 0
    def Reset(self) -> None:
        self.cur_win_pos = 0
        self.win_sum = 0
        self.win_state = [0] * self.win_size_frame
        self.pre_frame_state = FrameState.kFrameStateSil
        self.cur_frame_state = FrameState.kFrameStateSil
        self.voice_last_frame_count = 0
        self.noise_last_frame_count = 0
        self.hydre_frame_count = 0
    def GetWinSize(self) -> int:
        return int(self.win_size_frame)
    def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState:
        cur_frame_state = FrameState.kFrameStateSil
        if frameState == FrameState.kFrameStateSpeech:
            cur_frame_state = 1
        elif frameState == FrameState.kFrameStateSil:
            cur_frame_state = 0
        else:
            return AudioChangeState.kChangeStateInvalid
        self.win_sum -= self.win_state[self.cur_win_pos]
        self.win_sum += cur_frame_state
        self.win_state[self.cur_win_pos] = cur_frame_state
        self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
        if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
            self.pre_frame_state = FrameState.kFrameStateSpeech
            return AudioChangeState.kChangeStateSil2Speech
        if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
            self.pre_frame_state = FrameState.kFrameStateSil
            return AudioChangeState.kChangeStateSpeech2Sil
        if self.pre_frame_state == FrameState.kFrameStateSil:
            return AudioChangeState.kChangeStateSil2Sil
        if self.pre_frame_state == FrameState.kFrameStateSpeech:
            return AudioChangeState.kChangeStateSpeech2Speech
        return AudioChangeState.kChangeStateInvalid
    def FrameSizeMs(self) -> int:
        return int(self.frame_size_ms)
funasr/tokenizer/abs_tokenizer.py
@@ -42,8 +42,9 @@
                self.token_list_repr = str(token_list)
                self.token_list: List[str] = []
                with open('data.json', 'r', encoding='utf-8') as f:
                    self.token_list = json.loads(f.read())
                with open(token_list, 'r', encoding='utf-8') as f:
                    self.token_list = json.load(f)
            else:
                self.token_list: List[str] = list(token_list)
funasr/train_utils/load_pretrained_model.py
@@ -120,6 +120,7 @@
    if ignore_init_mismatch:
        src_state = filter_state_dict(dst_state, src_state)
    # logging.info("Loaded src_state keys: {}".format(src_state.keys()))
    logging.debug("Loaded src_state keys: {}".format(src_state.keys()))
    logging.debug("Loaded dst_state keys: {}".format(dst_state.keys()))
    dst_state.update(src_state)
    obj.load_state_dict(dst_state)
setup.py
@@ -10,14 +10,11 @@
requirements = {
    "install": [
        # "setuptools>=38.5.1",
        "humanfriendly",
        "scipy>=1.4.1",
        "librosa",
        "jamo",  # For kss
        "PyYAML>=5.1.2",
        # "soundfile>=0.12.1",
        # "h5py>=3.1.0",
        "kaldiio>=2.17.0",
        "torch_complex",
        # "nltk>=3.4.5",
@@ -32,7 +29,6 @@
        # ENH
        "pytorch_wpe",
        "editdistance>=0.5.2",
        "tensorboard",
        # "g2p",
        # "nara_wpe",
        # PAI
@@ -44,6 +40,7 @@
        "hdbscan",
        "umap",
        "jaconv",
        "hydra-core",
    ],
    # train: The modules invoked when training only.
    "train": [