From 96bae0153cb04c82d6e7ca7cb9654d55eb987567 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 15 三月 2023 17:34:34 +0800
Subject: [PATCH] rnnt bug fix
---
funasr/bin/asr_inference_rnnt.py | 145 +++++------------------------------------------
funasr/tasks/abs_task.py | 2
funasr/models_transducer/encoder/blocks/conv_input.py | 9 --
3 files changed, 20 insertions(+), 136 deletions(-)
diff --git a/funasr/bin/asr_inference_rnnt.py b/funasr/bin/asr_inference_rnnt.py
index f651f11..c8a2916 100644
--- a/funasr/bin/asr_inference_rnnt.py
+++ b/funasr/bin/asr_inference_rnnt.py
@@ -31,7 +31,7 @@
from funasr.utils import config_argparse
from funasr.utils.types import str2bool, str2triple_str, str_or_none
from funasr.utils.cli_utils import get_commandline_args
-
+from funasr.models.frontend.wav_frontend import WavFrontend
class Speech2Text:
"""Speech2Text class for Transducer models.
@@ -62,6 +62,7 @@
self,
asr_train_config: Union[Path, str] = None,
asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
beam_search_config: Dict[str, Any] = None,
lm_train_config: Union[Path, str] = None,
lm_file: Union[Path, str] = None,
@@ -86,10 +87,13 @@
super().__init__()
assert check_argument_types()
-
asr_model, asr_train_args = ASRTransducerTask.build_model_from_file(
- asr_train_config, asr_model_file, device
+ asr_train_config, asr_model_file, cmvn_file, device
)
+
+ frontend = None
+ if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
+ frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
if quantize_asr_model:
if quantize_modules is not None:
@@ -156,7 +160,7 @@
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
-
+
self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.device = device
@@ -181,23 +185,13 @@
self.simu_streaming = False
self.asr_model.encoder.dynamic_chunk_training = False
- self.n_fft = asr_train_args.frontend_conf.get("n_fft", 512)
- self.hop_length = asr_train_args.frontend_conf.get("hop_length", 128)
-
- if asr_train_args.frontend_conf.get("win_length", None) is not None:
- self.frontend_window_size = asr_train_args.frontend_conf["win_length"]
- else:
- self.frontend_window_size = self.n_fft
-
+ self.frontend = frontend
self.window_size = self.chunk_size + self.right_context
- self._raw_ctx = self.asr_model.encoder.get_encoder_input_raw_size(
- self.window_size, self.hop_length
- )
+
self._ctx = self.asr_model.encoder.get_encoder_input_size(
self.window_size
)
-
#self.last_chunk_length = (
# self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
#) * self.hop_length
@@ -217,112 +211,6 @@
self.beam_search.reset_inference_cache()
self.num_processed_frames = torch.tensor([[0]], device=self.device)
-
- def apply_frontend(
- self, speech: torch.Tensor, is_final: bool = False
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Forward frontend.
- Args:
- speech: Speech data. (S)
- is_final: Whether speech corresponds to the final (or only) chunk of data.
- Returns:
- feats: Features sequence. (1, T_in, F)
- feats_lengths: Features sequence length. (1, T_in, F)
- """
- if self.frontend_cache is not None:
- speech = torch.cat([self.frontend_cache["waveform_buffer"], speech], dim=0)
-
- if is_final:
- if self.streaming and speech.size(0) < self.last_chunk_length:
- pad = torch.zeros(
- self.last_chunk_length - speech.size(0), dtype=speech.dtype
- )
- speech = torch.cat([speech, pad], dim=0)
-
- speech_to_process = speech
- waveform_buffer = None
- else:
- n_frames = (
- speech.size(0) - (self.frontend_window_size - self.hop_length)
- ) // self.hop_length
-
- n_residual = (
- speech.size(0) - (self.frontend_window_size - self.hop_length)
- ) % self.hop_length
-
- speech_to_process = speech.narrow(
- 0,
- 0,
- (self.frontend_window_size - self.hop_length)
- + n_frames * self.hop_length,
- )
-
- waveform_buffer = speech.narrow(
- 0,
- speech.size(0)
- - (self.frontend_window_size - self.hop_length)
- - n_residual,
- (self.frontend_window_size - self.hop_length) + n_residual,
- ).clone()
-
- speech_to_process = speech_to_process.unsqueeze(0).to(
- getattr(torch, self.dtype)
- )
- lengths = speech_to_process.new_full(
- [1], dtype=torch.long, fill_value=speech_to_process.size(1)
- )
- batch = {"speech": speech_to_process, "speech_lengths": lengths}
- batch = to_device(batch, device=self.device)
-
- feats, feats_lengths = self.asr_model._extract_feats(**batch)
- if self.asr_model.normalize is not None:
- feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
-
- if is_final:
- if self.frontend_cache is None:
- pass
- else:
- feats = feats.narrow(
- 1,
- math.ceil(
- math.ceil(self.frontend_window_size / self.hop_length) / 2
- ),
- feats.size(1)
- - math.ceil(
- math.ceil(self.frontend_window_size / self.hop_length) / 2
- ),
- )
- else:
- if self.frontend_cache is None:
- feats = feats.narrow(
- 1,
- 0,
- feats.size(1)
- - math.ceil(
- math.ceil(self.frontend_window_size / self.hop_length) / 2
- ),
- )
- else:
- feats = feats.narrow(
- 1,
- math.ceil(
- math.ceil(self.frontend_window_size / self.hop_length) / 2
- ),
- feats.size(1)
- - 2
- * math.ceil(
- math.ceil(self.frontend_window_size / self.hop_length) / 2
- ),
- )
-
- feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
-
- if is_final:
- self.frontend_cache = None
- else:
- self.frontend_cache = {"waveform_buffer": waveform_buffer}
-
- return feats, feats_lengths
@torch.no_grad()
def streaming_decode(
@@ -410,14 +298,9 @@
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
- # lengths: (1,)
- # feats, feats_length = self.apply_frontend(speech)
feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
- # lengths: (1,)
feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
- # print(feats.shape)
- # print(feats_lengths)
if self.asr_model.normalize is not None:
feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
@@ -495,6 +378,7 @@
data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
asr_train_config: Optional[str],
asr_model_file: Optional[str],
+ cmvn_file: Optional[str],
beam_search_config: Optional[dict],
lm_train_config: Optional[str],
lm_file: Optional[str],
@@ -562,7 +446,6 @@
device = "cuda"
else:
device = "cpu"
-
# 1. Set random-seed
set_all_random_seed(seed)
@@ -570,6 +453,7 @@
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
beam_search_config=beam_search_config,
lm_train_config=lm_train_config,
lm_file=lm_file,
@@ -720,6 +604,11 @@
help="ASR model parameter file",
)
group.add_argument(
+ "--cmvn_file",
+ type=str,
+ help="Global cmvn file",
+ )
+ group.add_argument(
"--lm_train_config",
type=str,
help="LM training configuration",
diff --git a/funasr/models_transducer/encoder/blocks/conv_input.py b/funasr/models_transducer/encoder/blocks/conv_input.py
index 931d0f0..c68c73b 100644
--- a/funasr/models_transducer/encoder/blocks/conv_input.py
+++ b/funasr/models_transducer/encoder/blocks/conv_input.py
@@ -120,7 +120,7 @@
self.create_new_mask = self.create_new_conv2d_mask
self.vgg_like = vgg_like
- self.min_frame_length = 2
+ self.min_frame_length = 7
if output_size is not None:
self.output = torch.nn.Linear(output_proj, output_size)
@@ -218,9 +218,4 @@
: Number of frames before subsampling.
"""
- if self.subsampling_factor > 1:
- if self.vgg_like:
- return ((size * 2) * self.stride_1) + 1
-
- return ((size + 2) * 2) + (self.kernel_2 - 1) * self.stride_2
- return size
+ return size * self.subsampling_factor
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index e0884ce..cc5b708 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -1576,7 +1576,7 @@
preprocess=iter_options.preprocess_fn,
max_cache_size=iter_options.max_cache_size,
max_cache_fd=iter_options.max_cache_fd,
- dest_sample_rate=args.frontend_conf["fs"],
+ dest_sample_rate=args.frontend_conf["fs"] if args.frontend_conf else 16000,
)
cls.check_task_requirements(
dataset, args.allow_variable_data_keys, train=iter_options.train
--
Gitblit v1.9.1