From c8bae0ec85eee25d66de6b1e4502eff74d750b24 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 21 十二月 2023 13:29:37 +0800
Subject: [PATCH] funasr2
---
funasr/bin/inference.py | 5
funasr/models/fsmn_vad/encoder.py | 6
/dev/null | 130 --------
funasr/models/fsmn_vad/model.py | 517 ++++++++++++++++++-------------
funasr/bin/train.py | 4
funasr/tokenizer/abs_tokenizer.py | 5
funasr/models/ct_transformer/model.py | 212 +++++++++++++
examples/industrial_data_pretraining/fsmn-vad/infer.sh | 8
setup.py | 5
funasr/models/ct_transformer/encoder.py | 0
funasr/datasets/audio_datasets/datasets.py | 16 +
funasr/train_utils/load_pretrained_model.py | 3
12 files changed, 552 insertions(+), 359 deletions(-)
diff --git a/examples/industrial_data_pretraining/fsmn-vad/infer.sh b/examples/industrial_data_pretraining/fsmn-vad/infer.sh
new file mode 100644
index 0000000..9bfd8ba
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn-vad/infer.sh
@@ -0,0 +1,8 @@
+
+cmd="funasr/bin/inference.py"
+
+python $cmd \
++model="/Users/zhifu/Downloads/modelscope_models/speech_fsmn_vad_zh-cn-16k-common-pytorch" \
++input="/Users/zhifu/Downloads/asr_example.wav" \
++output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2_vad" \
++device="cpu" \
diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py
index fd884cd..50ea4d4 100644
--- a/funasr/bin/inference.py
+++ b/funasr/bin/inference.py
@@ -101,6 +101,7 @@
tokenizer_class = registry_tables.tokenizer_classes.get(tokenizer.lower())
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
kwargs["tokenizer"] = tokenizer
+ kwargs["token_list"] = tokenizer.token_list
# build frontend
frontend = kwargs.get("frontend", None)
@@ -112,11 +113,9 @@
# build model
model_class = registry_tables.model_classes.get(kwargs["model"].lower())
- model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
+ model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list) if tokenizer is not None else -1)
model.eval()
model.to(device)
-
- kwargs["token_list"] = tokenizer.token_list
# init_param
init_param = kwargs.get("init_param", None)
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 8112002..1e06c50 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -145,7 +145,8 @@
# dataloader
batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
batch_sampler_class = registry_tables.batch_sampler_classes.get(batch_sampler.lower())
- batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
+ if batch_sampler is not None:
+ batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
collate_fn=dataset_tr.collator,
batch_sampler=batch_sampler,
@@ -153,7 +154,6 @@
pin_memory=True)
-
trainer = Trainer(
model=model,
optim=optim,
diff --git a/funasr/datasets/audio_datasets/datasets.py b/funasr/datasets/audio_datasets/datasets.py
index 353a3a0..d69d0b5 100644
--- a/funasr/datasets/audio_datasets/datasets.py
+++ b/funasr/datasets/audio_datasets/datasets.py
@@ -24,6 +24,17 @@
super().__init__()
index_ds_class = registry_tables.index_ds_classes.get(index_ds.lower())
self.index_ds = index_ds_class(path)
+ preprocessor_speech = kwargs.get("preprocessor_speech", None)
+ if preprocessor_speech:
+ preprocessor_speech_class = registry_tables.preprocessor_speech_classes.get(preprocessor_speech.lower())
+ preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
+ self.preprocessor_speech = preprocessor_speech
+ preprocessor_text = kwargs.get("preprocessor_text", None)
+ if preprocessor_text:
+ preprocessor_text_class = registry_tables.preprocessor_text_classes.get(preprocessor_text.lower())
+ preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
+ self.preprocessor_text = preprocessor_text
+
self.frontend = frontend
self.fs = 16000 if frontend is None else frontend.fs
self.data_type = "sound"
@@ -49,8 +60,13 @@
# pdb.set_trace()
source = item["source"]
data_src = load_audio(source, fs=self.fs)
+ if self.preprocessor_speech:
+ data_src = self.preprocessor_speech(data_src)
speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d]
+
target = item["target"]
+ if self.preprocessor_text:
+ target = self.preprocessor_text(target)
ids = self.tokenizer.encode(target)
ids_lengths = len(ids)
text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)
diff --git a/funasr/models/ct_transformer/sanm_encoder.py b/funasr/models/ct_transformer/encoder.py
similarity index 100%
rename from funasr/models/ct_transformer/sanm_encoder.py
rename to funasr/models/ct_transformer/encoder.py
diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py
new file mode 100644
index 0000000..31b2af2
--- /dev/null
+++ b/funasr/models/ct_transformer/model.py
@@ -0,0 +1,212 @@
+from typing import Any
+from typing import List
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+
+from funasr.utils.register import register_class, registry_tables
+
+@register_class("model_classes", "CTTransformer")
+class CTTransformer(nn.Module):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
+ https://arxiv.org/pdf/2003.01309.pdf
+ """
+ def __init__(
+ self,
+ encoder: str = None,
+ encoder_conf: str = None,
+ vocab_size: int = -1,
+ punc_list: list = None,
+ punc_weight: list = None,
+ embed_unit: int = 128,
+ att_unit: int = 256,
+ dropout_rate: float = 0.5,
+ ignore_id: int = -1,
+ sos: int = 1,
+ eos: int = 2,
+ **kwargs,
+ ):
+ super().__init__()
+
+ punc_size = len(punc_list)
+ if punc_weight is None:
+ punc_weight = [1] * punc_size
+
+
+ self.embed = nn.Embedding(vocab_size, embed_unit)
+ encoder_class = registry_tables.encoder_classes.get(encoder.lower())
+ encoder = encoder_class(**encoder_conf)
+
+ self.decoder = nn.Linear(att_unit, punc_size)
+ self.encoder = encoder
+ self.punc_list = punc_list
+ self.punc_weight = punc_weight
+ self.ignore_id = ignore_id
+ self.sos = sos
+ self.eos = eos
+
+
+
+ def punc_forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
+ """Compute loss value from buffer sequences.
+
+ Args:
+ input (torch.Tensor): Input ids. (batch, len)
+ hidden (torch.Tensor): Target ids. (batch, len)
+
+ """
+ x = self.embed(input)
+ # mask = self._target_mask(input)
+ h, _, _ = self.encoder(x, text_lengths)
+ y = self.decoder(h)
+ return y, None
+
+ def with_vad(self):
+ return False
+
+ def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
+ """Score new token.
+
+ Args:
+ y (torch.Tensor): 1D torch.int64 prefix tokens.
+ state: Scorer state for prefix tokens
+ x (torch.Tensor): encoder feature that generates ys.
+
+ Returns:
+ tuple[torch.Tensor, Any]: Tuple of
+ torch.float32 scores for next token (vocab_size)
+ and next state for ys
+
+ """
+ y = y.unsqueeze(0)
+ h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
+ h = self.decoder(h[:, -1])
+ logp = h.log_softmax(dim=-1).squeeze(0)
+ return logp, cache
+
+ def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
+ """Score new token batch.
+
+ Args:
+ ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
+ states (List[Any]): Scorer states for prefix tokens.
+ xs (torch.Tensor):
+ The encoder feature that generates ys (n_batch, xlen, n_feat).
+
+ Returns:
+ tuple[torch.Tensor, List[Any]]: Tuple of
+ batchfied scores for next token with shape of `(n_batch, vocab_size)`
+ and next state list for ys.
+
+ """
+ # merge states
+ n_batch = len(ys)
+ n_layers = len(self.encoder.encoders)
+ if states[0] is None:
+ batch_state = None
+ else:
+ # transpose state of [batch, layer] into [layer, batch]
+ batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
+
+ # batch decoding
+ h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
+ h = self.decoder(h[:, -1])
+ logp = h.log_softmax(dim=-1)
+
+ # transpose state of [layer, batch] into [batch, layer]
+ state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
+ return logp, state_list
+
+ def nll(
+ self,
+ text: torch.Tensor,
+ punc: torch.Tensor,
+ text_lengths: torch.Tensor,
+ punc_lengths: torch.Tensor,
+ max_length: Optional[int] = None,
+ vad_indexes: Optional[torch.Tensor] = None,
+ vad_indexes_lengths: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute negative log likelihood(nll)
+
+ Normally, this function is called in batchify_nll.
+ Args:
+ text: (Batch, Length)
+ punc: (Batch, Length)
+ text_lengths: (Batch,)
+ max_lengths: int
+ """
+ batch_size = text.size(0)
+ # For data parallel
+ if max_length is None:
+ text = text[:, :text_lengths.max()]
+ punc = punc[:, :text_lengths.max()]
+ else:
+ text = text[:, :max_length]
+ punc = punc[:, :max_length]
+
+ if self.with_vad():
+ # Should be VadRealtimeTransformer
+ assert vad_indexes is not None
+ y, _ = self.punc_forward(text, text_lengths, vad_indexes)
+ else:
+ # Should be TargetDelayTransformer,
+ y, _ = self.punc_forward(text, text_lengths)
+
+ # Calc negative log likelihood
+ # nll: (BxL,)
+ if self.training == False:
+ _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
+ from sklearn.metrics import f1_score
+ f1_score = f1_score(punc.view(-1).detach().cpu().numpy(),
+ indices.squeeze(-1).detach().cpu().numpy(),
+ average='micro')
+ nll = torch.Tensor([f1_score]).repeat(text_lengths.sum())
+ return nll, text_lengths
+ else:
+ self.punc_weight = self.punc_weight.to(punc.device)
+ nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
+ ignore_index=self.ignore_id)
+ # nll: (BxL,) -> (BxL,)
+ if max_length is None:
+ nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
+ else:
+ nll.masked_fill_(
+ make_pad_mask(text_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
+ 0.0,
+ )
+ # nll: (BxL,) -> (B, L)
+ nll = nll.view(batch_size, -1)
+ return nll, text_lengths
+
+
+ def forward(
+ self,
+ text: torch.Tensor,
+ punc: torch.Tensor,
+ text_lengths: torch.Tensor,
+ punc_lengths: torch.Tensor,
+ vad_indexes: Optional[torch.Tensor] = None,
+ vad_indexes_lengths: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
+ ntokens = y_lengths.sum()
+ loss = nll.sum() / ntokens
+ stats = dict(loss=loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
+ return loss, stats, weight
+
+ def generate(self,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ vad_indexes: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, None]:
+ if self.with_vad():
+ assert vad_indexes is not None
+ return self.punc_forward(text, text_lengths, vad_indexes)
+ else:
+ return self.punc_forward(text, text_lengths)
\ No newline at end of file
diff --git a/funasr/models/ct_transformer/target_delay_transformer.py b/funasr/models/ct_transformer/target_delay_transformer.py
deleted file mode 100644
index 59884a3..0000000
--- a/funasr/models/ct_transformer/target_delay_transformer.py
+++ /dev/null
@@ -1,130 +0,0 @@
-from typing import Any
-from typing import List
-from typing import Tuple
-
-import torch
-import torch.nn as nn
-
-from funasr.models.transformer.embedding import SinusoidalPositionEncoder
-from funasr.models.sanm.encoder import SANMEncoder as Encoder
-
-
-class TargetDelayTransformer(torch.nn.Module):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
- https://arxiv.org/pdf/2003.01309.pdf
- """
- def __init__(
- self,
- vocab_size: int,
- punc_size: int,
- pos_enc: str = None,
- embed_unit: int = 128,
- att_unit: int = 256,
- head: int = 2,
- unit: int = 1024,
- layer: int = 4,
- dropout_rate: float = 0.5,
- ):
- super().__init__()
- if pos_enc == "sinusoidal":
- # pos_enc_class = PositionalEncoding
- pos_enc_class = SinusoidalPositionEncoder
- elif pos_enc is None:
-
- def pos_enc_class(*args, **kwargs):
- return nn.Sequential() # indentity
-
- else:
- raise ValueError(f"unknown pos-enc option: {pos_enc}")
-
- self.embed = nn.Embedding(vocab_size, embed_unit)
- self.encoder = Encoder(
- input_size=embed_unit,
- output_size=att_unit,
- attention_heads=head,
- linear_units=unit,
- num_blocks=layer,
- dropout_rate=dropout_rate,
- input_layer="pe",
- # pos_enc_class=pos_enc_class,
- padding_idx=0,
- )
- self.decoder = nn.Linear(att_unit, punc_size)
-
-
-# def _target_mask(self, ys_in_pad):
-# ys_mask = ys_in_pad != 0
-# m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
-# return ys_mask.unsqueeze(-2) & m
-
- def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
- """Compute loss value from buffer sequences.
-
- Args:
- input (torch.Tensor): Input ids. (batch, len)
- hidden (torch.Tensor): Target ids. (batch, len)
-
- """
- x = self.embed(input)
- # mask = self._target_mask(input)
- h, _, _ = self.encoder(x, text_lengths)
- y = self.decoder(h)
- return y, None
-
- def with_vad(self):
- return False
-
- def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
- """Score new token.
-
- Args:
- y (torch.Tensor): 1D torch.int64 prefix tokens.
- state: Scorer state for prefix tokens
- x (torch.Tensor): encoder feature that generates ys.
-
- Returns:
- tuple[torch.Tensor, Any]: Tuple of
- torch.float32 scores for next token (vocab_size)
- and next state for ys
-
- """
- y = y.unsqueeze(0)
- h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
- h = self.decoder(h[:, -1])
- logp = h.log_softmax(dim=-1).squeeze(0)
- return logp, cache
-
- def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
- """Score new token batch.
-
- Args:
- ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
- states (List[Any]): Scorer states for prefix tokens.
- xs (torch.Tensor):
- The encoder feature that generates ys (n_batch, xlen, n_feat).
-
- Returns:
- tuple[torch.Tensor, List[Any]]: Tuple of
- batchfied scores for next token with shape of `(n_batch, vocab_size)`
- and next state list for ys.
-
- """
- # merge states
- n_batch = len(ys)
- n_layers = len(self.encoder.encoders)
- if states[0] is None:
- batch_state = None
- else:
- # transpose state of [batch, layer] into [layer, batch]
- batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
-
- # batch decoding
- h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
- h = self.decoder(h[:, -1])
- logp = h.log_softmax(dim=-1)
-
- # transpose state of [layer, batch] into [batch, layer]
- state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
- return logp, state_list
diff --git a/funasr/models/fsmn_vad/fsmn_encoder.py b/funasr/models/fsmn_vad/encoder.py
similarity index 98%
rename from funasr/models/fsmn_vad/fsmn_encoder.py
rename to funasr/models/fsmn_vad/encoder.py
index 38d164d..50e31fc 100755
--- a/funasr/models/fsmn_vad/fsmn_encoder.py
+++ b/funasr/models/fsmn_vad/encoder.py
@@ -6,6 +6,8 @@
import torch.nn as nn
import torch.nn.functional as F
+from funasr.utils.register import register_class, registry_tables
+
class LinearTransform(nn.Module):
def __init__(self, input_dim, output_dim):
@@ -156,7 +158,7 @@
fsmn_layers: no. of sequential fsmn layers
'''
-
+@register_class("encoder_classes", "FSMN")
class FSMN(nn.Module):
def __init__(
self,
@@ -227,7 +229,7 @@
rstride: right stride
'''
-
+@register_class("encoder_classes", "DFSMN")
class DFSMN(nn.Module):
def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
diff --git a/funasr/models/fsmn_vad/model.py b/funasr/models/fsmn_vad/model.py
index cc3c87e..16f21dc 100644
--- a/funasr/models/fsmn_vad/model.py
+++ b/funasr/models/fsmn_vad/model.py
@@ -1,33 +1,244 @@
from enum import Enum
from typing import List, Tuple, Dict, Any
-
+import logging
+import os
+import json
import torch
from torch import nn
import math
from typing import Optional
-from funasr.models.encoder.fsmn_encoder import FSMN
-from funasr.models.base_model import FunASRModel
-from funasr.models.model_class_factory import *
+import time
+from funasr.utils.register import register_class, registry_tables
+from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio,extract_fbank
+from funasr.utils.datadir_writer import DatadirWriter
+from torch.nn.utils.rnn import pad_sequence
+
+class VadStateMachine(Enum):
+ kVadInStateStartPointNotDetected = 1
+ kVadInStateInSpeechSegment = 2
+ kVadInStateEndPointDetected = 3
+class FrameState(Enum):
+ kFrameStateInvalid = -1
+ kFrameStateSpeech = 1
+ kFrameStateSil = 0
+
+
+# final voice/unvoice state per frame
+class AudioChangeState(Enum):
+ kChangeStateSpeech2Speech = 0
+ kChangeStateSpeech2Sil = 1
+ kChangeStateSil2Sil = 2
+ kChangeStateSil2Speech = 3
+ kChangeStateNoBegin = 4
+ kChangeStateInvalid = 5
+
+
+class VadDetectMode(Enum):
+ kVadSingleUtteranceDetectMode = 0
+ kVadMutipleUtteranceDetectMode = 1
+
+
+class VADXOptions:
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
+ def __init__(
+ self,
+ sample_rate: int = 16000,
+ detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
+ snr_mode: int = 0,
+ max_end_silence_time: int = 800,
+ max_start_silence_time: int = 3000,
+ do_start_point_detection: bool = True,
+ do_end_point_detection: bool = True,
+ window_size_ms: int = 200,
+ sil_to_speech_time_thres: int = 150,
+ speech_to_sil_time_thres: int = 150,
+ speech_2_noise_ratio: float = 1.0,
+ do_extend: int = 1,
+ lookback_time_start_point: int = 200,
+ lookahead_time_end_point: int = 100,
+ max_single_segment_time: int = 60000,
+ nn_eval_block_size: int = 8,
+ dcd_block_size: int = 4,
+ snr_thres: int = -100.0,
+ noise_frame_num_used_for_snr: int = 100,
+ decibel_thres: int = -100.0,
+ speech_noise_thres: float = 0.6,
+ fe_prior_thres: float = 1e-4,
+ silence_pdf_num: int = 1,
+ sil_pdf_ids: List[int] = [0],
+ speech_noise_thresh_low: float = -0.1,
+ speech_noise_thresh_high: float = 0.3,
+ output_frame_probs: bool = False,
+ frame_in_ms: int = 10,
+ frame_length_ms: int = 25,
+ **kwargs,
+ ):
+ self.sample_rate = sample_rate
+ self.detect_mode = detect_mode
+ self.snr_mode = snr_mode
+ self.max_end_silence_time = max_end_silence_time
+ self.max_start_silence_time = max_start_silence_time
+ self.do_start_point_detection = do_start_point_detection
+ self.do_end_point_detection = do_end_point_detection
+ self.window_size_ms = window_size_ms
+ self.sil_to_speech_time_thres = sil_to_speech_time_thres
+ self.speech_to_sil_time_thres = speech_to_sil_time_thres
+ self.speech_2_noise_ratio = speech_2_noise_ratio
+ self.do_extend = do_extend
+ self.lookback_time_start_point = lookback_time_start_point
+ self.lookahead_time_end_point = lookahead_time_end_point
+ self.max_single_segment_time = max_single_segment_time
+ self.nn_eval_block_size = nn_eval_block_size
+ self.dcd_block_size = dcd_block_size
+ self.snr_thres = snr_thres
+ self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
+ self.decibel_thres = decibel_thres
+ self.speech_noise_thres = speech_noise_thres
+ self.fe_prior_thres = fe_prior_thres
+ self.silence_pdf_num = silence_pdf_num
+ self.sil_pdf_ids = sil_pdf_ids
+ self.speech_noise_thresh_low = speech_noise_thresh_low
+ self.speech_noise_thresh_high = speech_noise_thresh_high
+ self.output_frame_probs = output_frame_probs
+ self.frame_in_ms = frame_in_ms
+ self.frame_length_ms = frame_length_ms
+
+
+class E2EVadSpeechBufWithDoa(object):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
+ def __init__(self):
+ self.start_ms = 0
+ self.end_ms = 0
+ self.buffer = []
+ self.contain_seg_start_point = False
+ self.contain_seg_end_point = False
+ self.doa = 0
+
+ def Reset(self):
+ self.start_ms = 0
+ self.end_ms = 0
+ self.buffer = []
+ self.contain_seg_start_point = False
+ self.contain_seg_end_point = False
+ self.doa = 0
+
+
+class E2EVadFrameProb(object):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
+ def __init__(self):
+ self.noise_prob = 0.0
+ self.speech_prob = 0.0
+ self.score = 0.0
+ self.frame_id = 0
+ self.frm_state = 0
+
+
+class WindowDetector(object):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
+ def __init__(self, window_size_ms: int, sil_to_speech_time: int,
+ speech_to_sil_time: int, frame_size_ms: int):
+ self.window_size_ms = window_size_ms
+ self.sil_to_speech_time = sil_to_speech_time
+ self.speech_to_sil_time = speech_to_sil_time
+ self.frame_size_ms = frame_size_ms
+
+ self.win_size_frame = int(window_size_ms / frame_size_ms)
+ self.win_sum = 0
+ self.win_state = [0] * self.win_size_frame # 鍒濆鍖栫獥
+
+ self.cur_win_pos = 0
+ self.pre_frame_state = FrameState.kFrameStateSil
+ self.cur_frame_state = FrameState.kFrameStateSil
+ self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
+ self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
+
+ self.voice_last_frame_count = 0
+ self.noise_last_frame_count = 0
+ self.hydre_frame_count = 0
+
+ def Reset(self) -> None:
+ self.cur_win_pos = 0
+ self.win_sum = 0
+ self.win_state = [0] * self.win_size_frame
+ self.pre_frame_state = FrameState.kFrameStateSil
+ self.cur_frame_state = FrameState.kFrameStateSil
+ self.voice_last_frame_count = 0
+ self.noise_last_frame_count = 0
+ self.hydre_frame_count = 0
+
+ def GetWinSize(self) -> int:
+ return int(self.win_size_frame)
+
+ def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState:
+ cur_frame_state = FrameState.kFrameStateSil
+ if frameState == FrameState.kFrameStateSpeech:
+ cur_frame_state = 1
+ elif frameState == FrameState.kFrameStateSil:
+ cur_frame_state = 0
+ else:
+ return AudioChangeState.kChangeStateInvalid
+ self.win_sum -= self.win_state[self.cur_win_pos]
+ self.win_sum += cur_frame_state
+ self.win_state[self.cur_win_pos] = cur_frame_state
+ self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
+
+ if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
+ self.pre_frame_state = FrameState.kFrameStateSpeech
+ return AudioChangeState.kChangeStateSil2Speech
+
+ if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
+ self.pre_frame_state = FrameState.kFrameStateSil
+ return AudioChangeState.kChangeStateSpeech2Sil
+
+ if self.pre_frame_state == FrameState.kFrameStateSil:
+ return AudioChangeState.kChangeStateSil2Sil
+ if self.pre_frame_state == FrameState.kFrameStateSpeech:
+ return AudioChangeState.kChangeStateSpeech2Speech
+ return AudioChangeState.kChangeStateInvalid
+
+ def FrameSizeMs(self) -> int:
+ return int(self.frame_size_ms)
+
+
+@register_class("model_classes", "FsmnVAD")
class FsmnVAD(nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
- def __init__(self, encoder: str = None,
+ def __init__(self,
+ encoder: str = None,
encoder_conf: Optional[Dict] = None,
vad_post_args: Dict[str, Any] = None,
- frontend=None):
+ **kwargs,
+ ):
super().__init__()
- self.vad_opts = VADXOptions(**vad_post_args)
+ self.vad_opts = VADXOptions(**kwargs)
self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
self.vad_opts.sil_to_speech_time_thres,
self.vad_opts.speech_to_sil_time_thres,
self.vad_opts.frame_in_ms)
- encoder_class = encoder_classes.get_class(encoder)
+ encoder_class = registry_tables.encoder_classes.get(encoder.lower())
encoder = encoder_class(**encoder_conf)
self.encoder = encoder
# init variables
@@ -57,7 +268,6 @@
self.data_buf = None
self.data_buf_all = None
self.waveform = None
- self.frontend = frontend
self.last_drop_frames = 0
def AllResetDetection(self):
@@ -239,7 +449,7 @@
vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
return vad_latency
- def GetFrameState(self, t: int) -> FrameState:
+ def GetFrameState(self, t: int):
frame_state = FrameState.kFrameStateInvalid
cur_decibel = self.decibel[t]
cur_snr = cur_decibel - self.noise_average_decibel
@@ -285,7 +495,7 @@
def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
is_final: bool = False
- ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
+ ):
if not in_cache:
self.AllResetDetection()
self.waveform = waveform # compute decibel for each frame
@@ -312,6 +522,87 @@
# reset class variables and clear the dict for the next query
self.AllResetDetection()
return segments, in_cache
+
+ def generate(self,
+ data_in,
+ data_lengths=None,
+ key: list = None,
+ tokenizer=None,
+ frontend=None,
+ **kwargs,
+ ):
+
+
+ meta_data = {}
+ audio_sample_list = [data_in]
+ if isinstance(data_in, torch.Tensor): # fbank
+ speech, speech_lengths = data_in, data_lengths
+ if len(speech.shape) < 3:
+ speech = speech[None, :, :]
+ if speech_lengths is None:
+ speech_lengths = speech.shape[1]
+ else:
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+ frontend=frontend)
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data[
+ "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+
+ speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+
+ # b. Forward Encoder streaming
+ t_offset = 0
+ feats = speech
+ feats_len = speech_lengths.max().item()
+ waveform = pad_sequence(audio_sample_list, batch_first=True).to(device=kwargs["device"]) # data: [batch, N]
+ in_cache = kwargs.get("in_cache", {})
+ batch_size = kwargs.get("batch_size", 1)
+ step = min(feats_len, 6000)
+ segments = [[]] * batch_size
+
+ for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
+ if t_offset + step >= feats_len - 1:
+ step = feats_len - t_offset
+ is_final = True
+ else:
+ is_final = False
+ batch = {
+ "feats": feats[:, t_offset:t_offset + step, :],
+ "waveform": waveform[:, t_offset * 160:min(waveform.shape[-1], (t_offset + step - 1) * 160 + 400)],
+ "is_final": is_final,
+ "in_cache": in_cache
+ }
+
+
+ segments_part, in_cache = self.forward(**batch)
+ if segments_part:
+ for batch_num in range(0, batch_size):
+ segments[batch_num] += segments_part[batch_num]
+
+ ibest_writer = None
+ if ibest_writer is None and kwargs.get("output_dir") is not None:
+ writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = writer[f"{1}best_recog"]
+
+ results = []
+ for i in range(batch_size):
+
+ if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
+ results[i] = json.dumps(results[i])
+
+ if ibest_writer is not None:
+ ibest_writer["text"][key[i]] = segments[i]
+
+ result_i = {"key": key[i], "value": segments[i]}
+ results.append(result_i)
+
+ return results, meta_data
def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
is_final: bool = False, max_end_sil: int = 800
@@ -481,209 +772,5 @@
self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
self.ResetDetection()
-
-
-class VadStateMachine(Enum):
- kVadInStateStartPointNotDetected = 1
- kVadInStateInSpeechSegment = 2
- kVadInStateEndPointDetected = 3
-
-
-class FrameState(Enum):
- kFrameStateInvalid = -1
- kFrameStateSpeech = 1
- kFrameStateSil = 0
-
-
-# final voice/unvoice state per frame
-class AudioChangeState(Enum):
- kChangeStateSpeech2Speech = 0
- kChangeStateSpeech2Sil = 1
- kChangeStateSil2Sil = 2
- kChangeStateSil2Speech = 3
- kChangeStateNoBegin = 4
- kChangeStateInvalid = 5
-
-
-class VadDetectMode(Enum):
- kVadSingleUtteranceDetectMode = 0
- kVadMutipleUtteranceDetectMode = 1
-
-
-class VADXOptions:
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- Deep-FSMN for Large Vocabulary Continuous Speech Recognition
- https://arxiv.org/abs/1803.05030
- """
- def __init__(
- self,
- sample_rate: int = 16000,
- detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
- snr_mode: int = 0,
- max_end_silence_time: int = 800,
- max_start_silence_time: int = 3000,
- do_start_point_detection: bool = True,
- do_end_point_detection: bool = True,
- window_size_ms: int = 200,
- sil_to_speech_time_thres: int = 150,
- speech_to_sil_time_thres: int = 150,
- speech_2_noise_ratio: float = 1.0,
- do_extend: int = 1,
- lookback_time_start_point: int = 200,
- lookahead_time_end_point: int = 100,
- max_single_segment_time: int = 60000,
- nn_eval_block_size: int = 8,
- dcd_block_size: int = 4,
- snr_thres: int = -100.0,
- noise_frame_num_used_for_snr: int = 100,
- decibel_thres: int = -100.0,
- speech_noise_thres: float = 0.6,
- fe_prior_thres: float = 1e-4,
- silence_pdf_num: int = 1,
- sil_pdf_ids: List[int] = [0],
- speech_noise_thresh_low: float = -0.1,
- speech_noise_thresh_high: float = 0.3,
- output_frame_probs: bool = False,
- frame_in_ms: int = 10,
- frame_length_ms: int = 25,
- ):
- self.sample_rate = sample_rate
- self.detect_mode = detect_mode
- self.snr_mode = snr_mode
- self.max_end_silence_time = max_end_silence_time
- self.max_start_silence_time = max_start_silence_time
- self.do_start_point_detection = do_start_point_detection
- self.do_end_point_detection = do_end_point_detection
- self.window_size_ms = window_size_ms
- self.sil_to_speech_time_thres = sil_to_speech_time_thres
- self.speech_to_sil_time_thres = speech_to_sil_time_thres
- self.speech_2_noise_ratio = speech_2_noise_ratio
- self.do_extend = do_extend
- self.lookback_time_start_point = lookback_time_start_point
- self.lookahead_time_end_point = lookahead_time_end_point
- self.max_single_segment_time = max_single_segment_time
- self.nn_eval_block_size = nn_eval_block_size
- self.dcd_block_size = dcd_block_size
- self.snr_thres = snr_thres
- self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
- self.decibel_thres = decibel_thres
- self.speech_noise_thres = speech_noise_thres
- self.fe_prior_thres = fe_prior_thres
- self.silence_pdf_num = silence_pdf_num
- self.sil_pdf_ids = sil_pdf_ids
- self.speech_noise_thresh_low = speech_noise_thresh_low
- self.speech_noise_thresh_high = speech_noise_thresh_high
- self.output_frame_probs = output_frame_probs
- self.frame_in_ms = frame_in_ms
- self.frame_length_ms = frame_length_ms
-
-
-class E2EVadSpeechBufWithDoa(object):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- Deep-FSMN for Large Vocabulary Continuous Speech Recognition
- https://arxiv.org/abs/1803.05030
- """
- def __init__(self):
- self.start_ms = 0
- self.end_ms = 0
- self.buffer = []
- self.contain_seg_start_point = False
- self.contain_seg_end_point = False
- self.doa = 0
-
- def Reset(self):
- self.start_ms = 0
- self.end_ms = 0
- self.buffer = []
- self.contain_seg_start_point = False
- self.contain_seg_end_point = False
- self.doa = 0
-
-
-class E2EVadFrameProb(object):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- Deep-FSMN for Large Vocabulary Continuous Speech Recognition
- https://arxiv.org/abs/1803.05030
- """
- def __init__(self):
- self.noise_prob = 0.0
- self.speech_prob = 0.0
- self.score = 0.0
- self.frame_id = 0
- self.frm_state = 0
-
-
-class WindowDetector(object):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- Deep-FSMN for Large Vocabulary Continuous Speech Recognition
- https://arxiv.org/abs/1803.05030
- """
- def __init__(self, window_size_ms: int, sil_to_speech_time: int,
- speech_to_sil_time: int, frame_size_ms: int):
- self.window_size_ms = window_size_ms
- self.sil_to_speech_time = sil_to_speech_time
- self.speech_to_sil_time = speech_to_sil_time
- self.frame_size_ms = frame_size_ms
-
- self.win_size_frame = int(window_size_ms / frame_size_ms)
- self.win_sum = 0
- self.win_state = [0] * self.win_size_frame # 鍒濆鍖栫獥
-
- self.cur_win_pos = 0
- self.pre_frame_state = FrameState.kFrameStateSil
- self.cur_frame_state = FrameState.kFrameStateSil
- self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
- self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
-
- self.voice_last_frame_count = 0
- self.noise_last_frame_count = 0
- self.hydre_frame_count = 0
-
- def Reset(self) -> None:
- self.cur_win_pos = 0
- self.win_sum = 0
- self.win_state = [0] * self.win_size_frame
- self.pre_frame_state = FrameState.kFrameStateSil
- self.cur_frame_state = FrameState.kFrameStateSil
- self.voice_last_frame_count = 0
- self.noise_last_frame_count = 0
- self.hydre_frame_count = 0
-
- def GetWinSize(self) -> int:
- return int(self.win_size_frame)
-
- def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState:
- cur_frame_state = FrameState.kFrameStateSil
- if frameState == FrameState.kFrameStateSpeech:
- cur_frame_state = 1
- elif frameState == FrameState.kFrameStateSil:
- cur_frame_state = 0
- else:
- return AudioChangeState.kChangeStateInvalid
- self.win_sum -= self.win_state[self.cur_win_pos]
- self.win_sum += cur_frame_state
- self.win_state[self.cur_win_pos] = cur_frame_state
- self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
-
- if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
- self.pre_frame_state = FrameState.kFrameStateSpeech
- return AudioChangeState.kChangeStateSil2Speech
-
- if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
- self.pre_frame_state = FrameState.kFrameStateSil
- return AudioChangeState.kChangeStateSpeech2Sil
-
- if self.pre_frame_state == FrameState.kFrameStateSil:
- return AudioChangeState.kChangeStateSil2Sil
- if self.pre_frame_state == FrameState.kFrameStateSpeech:
- return AudioChangeState.kChangeStateSpeech2Speech
- return AudioChangeState.kChangeStateInvalid
-
- def FrameSizeMs(self) -> int:
- return int(self.frame_size_ms)
diff --git a/funasr/tokenizer/abs_tokenizer.py b/funasr/tokenizer/abs_tokenizer.py
index 349ebc0..d43d7b2 100644
--- a/funasr/tokenizer/abs_tokenizer.py
+++ b/funasr/tokenizer/abs_tokenizer.py
@@ -42,8 +42,9 @@
self.token_list_repr = str(token_list)
self.token_list: List[str] = []
- with open('data.json', 'r', encoding='utf-8') as f:
- self.token_list = json.loads(f.read())
+ with open(token_list, 'r', encoding='utf-8') as f:
+ self.token_list = json.load(f)
+
else:
self.token_list: List[str] = list(token_list)
diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py
index b54f777..963d734 100644
--- a/funasr/train_utils/load_pretrained_model.py
+++ b/funasr/train_utils/load_pretrained_model.py
@@ -120,6 +120,7 @@
if ignore_init_mismatch:
src_state = filter_state_dict(dst_state, src_state)
- # logging.info("Loaded src_state keys: {}".format(src_state.keys()))
+ logging.debug("Loaded src_state keys: {}".format(src_state.keys()))
+ logging.debug("Loaded dst_state keys: {}".format(dst_state.keys()))
dst_state.update(src_state)
obj.load_state_dict(dst_state)
diff --git a/setup.py b/setup.py
index a1e47af..ecd3d3d 100644
--- a/setup.py
+++ b/setup.py
@@ -10,14 +10,11 @@
requirements = {
"install": [
- # "setuptools>=38.5.1",
- "humanfriendly",
"scipy>=1.4.1",
"librosa",
"jamo", # For kss
"PyYAML>=5.1.2",
# "soundfile>=0.12.1",
- # "h5py>=3.1.0",
"kaldiio>=2.17.0",
"torch_complex",
# "nltk>=3.4.5",
@@ -32,7 +29,6 @@
# ENH
"pytorch_wpe",
"editdistance>=0.5.2",
- "tensorboard",
# "g2p",
# "nara_wpe",
# PAI
@@ -44,6 +40,7 @@
"hdbscan",
"umap",
"jaconv",
+ "hydra-core",
],
# train: The modules invoked when training only.
"train": [
--
Gitblit v1.9.1