From cdf117b9746fdb72c6d0a2aa1ada4e1a131895ec Mon Sep 17 00:00:00 2001
From: aky15 <ankeyuthu@gmail.com>
Date: 星期二, 27 六月 2023 09:59:50 +0800
Subject: [PATCH] bug fix (#667)
---
funasr/fileio/sound_scp.py | 70 +++++++++++++++++++++++
funasr/models/encoder/conformer_encoder.py | 55 ++++++++++++++++-
funasr/modules/subsampling.py | 21 +++---
egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml | 4
4 files changed, 132 insertions(+), 18 deletions(-)
diff --git a/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
index 59f9936..a1f27a3 100644
--- a/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
+++ b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
@@ -6,7 +6,7 @@
unified_model_training: true
default_chunk_size: 16
jitter_range: 4
- left_chunk_size: 0
+ left_chunk_size: 1
embed_vgg_like: false
subsampling_factor: 4
linear_units: 2048
@@ -51,7 +51,7 @@
# optimization related
accum_grad: 1
grad_clip: 5
-max_epoch: 200
+max_epoch: 120
val_scheduler_criterion:
- valid
- loss
diff --git a/funasr/fileio/sound_scp.py b/funasr/fileio/sound_scp.py
index c752fe6..9b25fe5 100644
--- a/funasr/fileio/sound_scp.py
+++ b/funasr/fileio/sound_scp.py
@@ -1,6 +1,6 @@
import collections.abc
from pathlib import Path
-from typing import Union
+from typing import List, Tuple, Union
import random
import numpy as np
@@ -13,6 +13,74 @@
from funasr.fileio.read_text import read_2column_text
+def soundfile_read(
+ wavs: Union[str, List[str]],
+ dtype=None,
+ always_2d: bool = False,
+ concat_axis: int = 1,
+ start: int = 0,
+ end: int = None,
+ return_subtype: bool = False,
+) -> Tuple[np.array, int]:
+ if isinstance(wavs, str):
+ wavs = [wavs]
+
+ arrays = []
+ subtypes = []
+ prev_rate = None
+ prev_wav = None
+ for wav in wavs:
+ with soundfile.SoundFile(wav) as f:
+ f.seek(start)
+ if end is not None:
+ frames = end - start
+ else:
+ frames = -1
+ if dtype == "float16":
+ array = f.read(
+ frames,
+ dtype="float32",
+ always_2d=always_2d,
+ ).astype(dtype)
+ else:
+ array = f.read(frames, dtype=dtype, always_2d=always_2d)
+ rate = f.samplerate
+ subtype = f.subtype
+ subtypes.append(subtype)
+
+ if len(wavs) > 1 and array.ndim == 1 and concat_axis == 1:
+ # array: (Time, Channel)
+ array = array[:, None]
+
+ if prev_wav is not None:
+ if prev_rate != rate:
+ raise RuntimeError(
+ f"'{prev_wav}' and '{wav}' have mismatched sampling rate: "
+ f"{prev_rate} != {rate}"
+ )
+
+ dim1 = arrays[0].shape[1 - concat_axis]
+ dim2 = array.shape[1 - concat_axis]
+ if dim1 != dim2:
+ raise RuntimeError(
+ "Shapes must match with "
+ f"{1 - concat_axis} axis, but gut {dim1} and {dim2}"
+ )
+
+ prev_rate = rate
+ prev_wav = wav
+ arrays.append(array)
+
+ if len(arrays) == 1:
+ array = arrays[0]
+ else:
+ array = np.concatenate(arrays, axis=concat_axis)
+
+ if return_subtype:
+ return array, rate, subtypes
+ else:
+ return array, rate
+
class SoundScpReader(collections.abc.Mapping):
"""Reader class for 'wav.scp'.
diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py
index 5f20dee..994607f 100644
--- a/funasr/models/encoder/conformer_encoder.py
+++ b/funasr/models/encoder/conformer_encoder.py
@@ -1081,7 +1081,10 @@
mask = make_source_mask(x_len).to(x.device)
if self.unified_model_training:
- chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+ if self.training:
+ chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+ else:
+ chunk_size = self.default_chunk_size
x, mask = self.embed(x, mask, chunk_size)
pos_enc = self.pos_enc(x)
chunk_mask = make_chunk_mask(
@@ -1113,12 +1116,15 @@
elif self.dynamic_chunk_training:
max_len = x.size(1)
- chunk_size = torch.randint(1, max_len, (1,)).item()
+ if self.training:
+ chunk_size = torch.randint(1, max_len, (1,)).item()
- if chunk_size > (max_len * self.short_chunk_threshold):
- chunk_size = max_len
+ if chunk_size > (max_len * self.short_chunk_threshold):
+ chunk_size = max_len
+ else:
+ chunk_size = (chunk_size % self.short_chunk_size) + 1
else:
- chunk_size = (chunk_size % self.short_chunk_size) + 1
+ chunk_size = self.default_chunk_size
x, mask = self.embed(x, mask, chunk_size)
pos_enc = self.pos_enc(x)
@@ -1147,6 +1153,45 @@
return x, olens, None
+ def full_utt_forward(
+ self,
+ x: torch.Tensor,
+ x_len: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encode input sequences.
+ Args:
+ x: Encoder input features. (B, T_in, F)
+ x_len: Encoder input features lengths. (B,)
+ Returns:
+ x: Encoder outputs. (B, T_out, D_enc)
+ x_len: Encoder outputs lenghts. (B,)
+ """
+ short_status, limit_size = check_short_utt(
+ self.embed.subsampling_factor, x.size(1)
+ )
+
+ if short_status:
+ raise TooShortUttError(
+ f"has {x.size(1)} frames and is too short for subsampling "
+ + f"(it needs more than {limit_size} frames), return empty results",
+ x.size(1),
+ limit_size,
+ )
+
+ mask = make_source_mask(x_len).to(x.device)
+ x, mask = self.embed(x, mask, None)
+ pos_enc = self.pos_enc(x)
+ x_utt = self.encoders(
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=None,
+ )
+
+ if self.time_reduction_factor > 1:
+ x_utt = x_utt[:,::self.time_reduction_factor,:]
+ return x_utt
+
def simu_chunk_forward(
self,
x: torch.Tensor,
diff --git a/funasr/modules/subsampling.py b/funasr/modules/subsampling.py
index a2b91a7..77aa422 100644
--- a/funasr/modules/subsampling.py
+++ b/funasr/modules/subsampling.py
@@ -427,6 +427,7 @@
conv_size: Union[int, Tuple],
subsampling_factor: int = 4,
vgg_like: bool = True,
+ conv_kernel_size: int = 3,
output_size: Optional[int] = None,
) -> None:
"""Construct a ConvInput object."""
@@ -436,14 +437,14 @@
conv_size1, conv_size2 = conv_size
self.conv = torch.nn.Sequential(
- torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
+ torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
- torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
+ torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
torch.nn.MaxPool2d((1, 2)),
- torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
+ torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
- torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
+ torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
torch.nn.MaxPool2d((1, 2)),
)
@@ -462,14 +463,14 @@
kernel_1 = int(subsampling_factor / 2)
self.conv = torch.nn.Sequential(
- torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
+ torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
- torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
+ torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
torch.nn.MaxPool2d((kernel_1, 2)),
- torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
+ torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
- torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
+ torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
torch.nn.MaxPool2d((2, 2)),
)
@@ -487,14 +488,14 @@
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]),
torch.nn.ReLU(),
- torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]),
+ torch.nn.Conv2d(conv_size, conv_size, conv_kernel_size, [1,2], [1,0]),
torch.nn.ReLU(),
)
output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2)
self.subsampling_factor = subsampling_factor
- self.kernel_2 = 3
+ self.kernel_2 = conv_kernel_size
self.stride_2 = 1
self.create_new_mask = self.create_new_conv2d_mask
--
Gitblit v1.9.1