From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/eend/utils/feature.py | 120 ++++++++++++++++++++++++------------------------------------
1 files changed, 48 insertions(+), 72 deletions(-)
diff --git a/funasr/models/eend/utils/feature.py b/funasr/models/eend/utils/feature.py
index 544a352..1dd55dc 100644
--- a/funasr/models/eend/utils/feature.py
+++ b/funasr/models/eend/utils/feature.py
@@ -8,13 +8,13 @@
def get_input_dim(
- frame_size,
- context_size,
- transform_type,
+ frame_size,
+ context_size,
+ transform_type,
):
- if transform_type.startswith('logmel23'):
+ if transform_type.startswith("logmel23"):
frame_size = 23
- elif transform_type.startswith('logmel'):
+ elif transform_type.startswith("logmel"):
frame_size = 40
else:
fft_size = 1 << (frame_size - 1).bit_length()
@@ -23,11 +23,8 @@
return input_dim
-def transform(
- Y,
- transform_type=None,
- dtype=np.float32):
- """ Transform STFT feature
+def transform(Y, transform_type=None, dtype=np.float32):
+ """Transform STFT feature
Args:
Y: STFT
@@ -42,37 +39,37 @@
Y = np.abs(Y)
if not transform_type:
pass
- elif transform_type == 'log':
+ elif transform_type == "log":
Y = np.log(np.maximum(Y, 1e-10))
- elif transform_type == 'logmel':
+ elif transform_type == "logmel":
n_fft = 2 * (Y.shape[1] - 1)
sr = 16000
n_mels = 40
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
- Y = np.dot(Y ** 2, mel_basis.T)
+ Y = np.dot(Y**2, mel_basis.T)
Y = np.log10(np.maximum(Y, 1e-10))
- elif transform_type == 'logmel23':
+ elif transform_type == "logmel23":
n_fft = 2 * (Y.shape[1] - 1)
sr = 8000
n_mels = 23
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
- Y = np.dot(Y ** 2, mel_basis.T)
+ Y = np.dot(Y**2, mel_basis.T)
Y = np.log10(np.maximum(Y, 1e-10))
- elif transform_type == 'logmel23_mn':
+ elif transform_type == "logmel23_mn":
n_fft = 2 * (Y.shape[1] - 1)
sr = 8000
n_mels = 23
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
- Y = np.dot(Y ** 2, mel_basis.T)
+ Y = np.dot(Y**2, mel_basis.T)
Y = np.log10(np.maximum(Y, 1e-10))
mean = np.mean(Y, axis=0)
Y = Y - mean
- elif transform_type == 'logmel23_swn':
+ elif transform_type == "logmel23_swn":
n_fft = 2 * (Y.shape[1] - 1)
sr = 8000
n_mels = 23
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
- Y = np.dot(Y ** 2, mel_basis.T)
+ Y = np.dot(Y**2, mel_basis.T)
Y = np.log10(np.maximum(Y, 1e-10))
# b = np.ones(300)/300
# mean = scipy.signal.convolve2d(Y, b[:, None], mode='same')
@@ -84,32 +81,31 @@
th = (np.mean(powers[powers >= th]) + np.mean(powers[powers < th])) / 2
mean = np.mean(Y[powers > th, :], axis=0)
Y = Y - mean
- elif transform_type == 'logmel23_mvn':
+ elif transform_type == "logmel23_mvn":
n_fft = 2 * (Y.shape[1] - 1)
sr = 8000
n_mels = 23
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
- Y = np.dot(Y ** 2, mel_basis.T)
+ Y = np.dot(Y**2, mel_basis.T)
Y = np.log10(np.maximum(Y, 1e-10))
mean = np.mean(Y, axis=0)
Y = Y - mean
std = np.maximum(np.std(Y, axis=0), 1e-10)
Y = Y / std
else:
- raise ValueError('Unknown transform_type: %s' % transform_type)
+ raise ValueError("Unknown transform_type: %s" % transform_type)
return Y.astype(dtype)
def subsample(Y, T, subsampling=1):
- """ Frame subsampling
- """
+ """Frame subsampling"""
Y_ss = Y[::subsampling]
T_ss = T[::subsampling]
return Y_ss, T_ss
def splice(Y, context_size=0):
- """ Frame splicing
+ """Frame splicing
Args:
Y: feature
@@ -122,22 +118,18 @@
Y_spliced: spliced feature
(n_frames, n_featdim * (2 * context_size + 1))-shaped
"""
- Y_pad = np.pad(
- Y,
- [(context_size, context_size), (0, 0)],
- 'constant')
+ Y_pad = np.pad(Y, [(context_size, context_size), (0, 0)], "constant")
Y_spliced = np.lib.stride_tricks.as_strided(
np.ascontiguousarray(Y_pad),
(Y.shape[0], Y.shape[1] * (2 * context_size + 1)),
- (Y.itemsize * Y.shape[1], Y.itemsize), writeable=False)
+ (Y.itemsize * Y.shape[1], Y.itemsize),
+ writeable=False,
+ )
return Y_spliced
-def stft(
- data,
- frame_size=1024,
- frame_shift=256):
- """ Compute STFT features
+def stft(data, frame_size=1024, frame_shift=256):
+ """Compute STFT features
Args:
data: audio signal
@@ -154,11 +146,11 @@
# HACK: The last frame is ommited
# as librosa.stft produces such an excessive frame
if len(data) % frame_shift == 0:
- return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
- hop_length=frame_shift).T[:-1]
+ return librosa.stft(data, n_fft=fft_size, win_length=frame_size, hop_length=frame_shift).T[
+ :-1
+ ]
else:
- return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
- hop_length=frame_shift).T
+ return librosa.stft(data, n_fft=fft_size, win_length=frame_size, hop_length=frame_shift).T
def _count_frames(data_len, size, shift):
@@ -170,14 +162,9 @@
def get_frame_labels(
- kaldi_obj,
- rec,
- start=0,
- end=None,
- frame_size=1024,
- frame_shift=256,
- n_speakers=None):
- """ Get frame-aligned labels of given recording
+ kaldi_obj, rec, start=0, end=None, frame_size=1024, frame_shift=256, n_speakers=None
+):
+ """Get frame-aligned labels of given recording
Args:
kaldi_obj (KaldiData)
rec (str): recording id
@@ -192,10 +179,8 @@
T: label
(n_frames, n_speakers)-shaped np.int32 array
"""
- filtered_segments = kaldi_obj.segments[kaldi_obj.segments['rec'] == rec]
- speakers = np.unique(
- [kaldi_obj.utt2spk[seg['utt']] for seg
- in filtered_segments]).tolist()
+ filtered_segments = kaldi_obj.segments[kaldi_obj.segments["rec"] == rec]
+ speakers = np.unique([kaldi_obj.utt2spk[seg["utt"]] for seg in filtered_segments]).tolist()
if n_speakers is None:
n_speakers = len(speakers)
es = end * frame_shift if end is not None else None
@@ -206,11 +191,9 @@
end = n_frames
for seg in filtered_segments:
- speaker_index = speakers.index(kaldi_obj.utt2spk[seg['utt']])
- start_frame = np.rint(
- seg['st'] * rate / frame_shift).astype(int)
- end_frame = np.rint(
- seg['et'] * rate / frame_shift).astype(int)
+ speaker_index = speakers.index(kaldi_obj.utt2spk[seg["utt"]])
+ start_frame = np.rint(seg["st"] * rate / frame_shift).astype(int)
+ end_frame = np.rint(seg["et"] * rate / frame_shift).astype(int)
rel_start = rel_end = None
if start <= start_frame and start_frame < end:
rel_start = start_frame - start
@@ -222,11 +205,9 @@
def get_labeledSTFT(
- kaldi_obj,
- rec, start, end, frame_size, frame_shift,
- n_speakers=None,
- use_speaker_id=False):
- """ Extracts STFT and corresponding labels
+ kaldi_obj, rec, start, end, frame_size, frame_shift, n_speakers=None, use_speaker_id=False
+):
+ """Extracts STFT and corresponding labels
Extracts STFT and corresponding diarization labels for
given recording id and start/end times
@@ -246,14 +227,11 @@
T: label
(n_frmaes, n_speakers)-shaped np.int32 array.
"""
- data, rate = kaldi_obj.load_wav(
- rec, start * frame_shift, end * frame_shift)
+ data, rate = kaldi_obj.load_wav(rec, start * frame_shift, end * frame_shift)
Y = stft(data, frame_size, frame_shift)
filtered_segments = kaldi_obj.segments[rec]
# filtered_segments = kaldi_obj.segments[kaldi_obj.segments['rec'] == rec]
- speakers = np.unique(
- [kaldi_obj.utt2spk[seg['utt']] for seg
- in filtered_segments]).tolist()
+ speakers = np.unique([kaldi_obj.utt2spk[seg["utt"]] for seg in filtered_segments]).tolist()
if n_speakers is None:
n_speakers = len(speakers)
T = np.zeros((Y.shape[0], n_speakers), dtype=np.int32)
@@ -263,13 +241,11 @@
S = np.zeros((Y.shape[0], len(all_speakers)), dtype=np.int32)
for seg in filtered_segments:
- speaker_index = speakers.index(kaldi_obj.utt2spk[seg['utt']])
+ speaker_index = speakers.index(kaldi_obj.utt2spk[seg["utt"]])
if use_speaker_id:
- all_speaker_index = all_speakers.index(kaldi_obj.utt2spk[seg['utt']])
- start_frame = np.rint(
- seg['st'] * rate / frame_shift).astype(int)
- end_frame = np.rint(
- seg['et'] * rate / frame_shift).astype(int)
+ all_speaker_index = all_speakers.index(kaldi_obj.utt2spk[seg["utt"]])
+ start_frame = np.rint(seg["st"] * rate / frame_shift).astype(int)
+ end_frame = np.rint(seg["et"] * rate / frame_shift).astype(int)
rel_start = rel_end = None
if start <= start_frame and start_frame < end:
rel_start = start_frame - start
--
Gitblit v1.9.1