From c644ac8f58895b9e29e9cfca79465fd2c0efaa5a Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 21 十一月 2023 14:09:01 +0800
Subject: [PATCH] funasr v2 setup
---
funasr/fileio/sound_scp.py | 4
funasr/datasets/small_datasets/preprocessor.py | 50 ++--
funasr/build_utils/build_trainer.py | 9
funasr/train/trainer.py | 9
setup.py | 36 +-
funasr/datasets/iterable_dataset.py | 6
funasr/bin/vad_infer.py | 8
funasr/bin/diar_infer.py | 8
egs/alimeeting/modular_sa_asr/local/meeting_speaker_number_process.py | 5
funasr/datasets/preprocessor.py | 51 ++--
funasr/models/encoder/mossformer_encoder.py | 6
egs/alimeeting/modular_sa_asr/local/make_textgrid_rttm.py | 5
funasr/datasets/dataset.py | 6
funasr/layers/stft.py | 5
funasr/utils/wav_utils.py | 6
funasr/layers/complex_utils.py | 8
funasr/bin/diar_inference_launch.py | 7
funasr/bin/ss_infer.py | 4
funasr/bin/asr_infer.py | 28 +-
funasr/datasets/large_datasets/dataset.py | 6
/dev/null | 77 -------
funasr/utils/prepare_data.py | 4
funasr/utils/speaker_utils.py | 4
egs/alimeeting/sa_asr/local/alimeeting_process_textgrid.py | 5
funasr/utils/timestamp_tools.py | 222 ++++++++++----------
egs/alimeeting/sa_asr/local/process_sot_fifo_textchar2spk.py | 5
funasr/models/frontend/default.py | 5
funasr/bin/sv_infer.py | 4
funasr/utils/whisper_utils/audio.py | 5
funasr/bin/ss_inference_launch.py | 9
funasr/modules/eend_ola/utils/kaldi_data.py | 8
egs/alimeeting/sa_asr/local/alimeeting_process_overlap_force.py | 5
egs/alimeeting/sa_asr/local/process_textgrid_to_single_speaker_wav.py | 7
funasr/bin/asr_inference_launch.py | 6
funasr/utils/asr_utils.py | 4
35 files changed, 305 insertions(+), 332 deletions(-)
diff --git a/egs/alimeeting/modular_sa_asr/local/make_textgrid_rttm.py b/egs/alimeeting/modular_sa_asr/local/make_textgrid_rttm.py
index f83c572..3b6373c 100755
--- a/egs/alimeeting/modular_sa_asr/local/make_textgrid_rttm.py
+++ b/egs/alimeeting/modular_sa_asr/local/make_textgrid_rttm.py
@@ -1,7 +1,10 @@
import argparse
import tqdm
import codecs
-import textgrid
+try:
+ import textgrid
+except:
+ raise "Please install textgrid firstly: pip install textgrid"
import pdb
class Segment(object):
diff --git a/egs/alimeeting/modular_sa_asr/local/meeting_speaker_number_process.py b/egs/alimeeting/modular_sa_asr/local/meeting_speaker_number_process.py
index 1b09d0a..8dc9890 100755
--- a/egs/alimeeting/modular_sa_asr/local/meeting_speaker_number_process.py
+++ b/egs/alimeeting/modular_sa_asr/local/meeting_speaker_number_process.py
@@ -6,7 +6,10 @@
import codecs
from distutils.util import strtobool
from pathlib import Path
-import textgrid
+try:
+ import textgrid
+except:
+ raise "Please install textgrid firstly: pip install textgrid"
import pdb
class Segment(object):
diff --git a/egs/alimeeting/sa_asr/local/alimeeting_process_overlap_force.py b/egs/alimeeting/sa_asr/local/alimeeting_process_overlap_force.py
index 8ece757..769003d 100755
--- a/egs/alimeeting/sa_asr/local/alimeeting_process_overlap_force.py
+++ b/egs/alimeeting/sa_asr/local/alimeeting_process_overlap_force.py
@@ -6,7 +6,10 @@
import codecs
from distutils.util import strtobool
from pathlib import Path
-import textgrid
+try:
+ import textgrid
+except:
+ raise "Please install textgrid firstly: pip install textgrid"
import pdb
class Segment(object):
diff --git a/egs/alimeeting/sa_asr/local/alimeeting_process_textgrid.py b/egs/alimeeting/sa_asr/local/alimeeting_process_textgrid.py
index 81c1965..b6d0157 100755
--- a/egs/alimeeting/sa_asr/local/alimeeting_process_textgrid.py
+++ b/egs/alimeeting/sa_asr/local/alimeeting_process_textgrid.py
@@ -6,7 +6,10 @@
import codecs
from distutils.util import strtobool
from pathlib import Path
-import textgrid
+try:
+ import textgrid
+except:
+ raise "Please install textgrid firstly: pip install textgrid"
import pdb
class Segment(object):
diff --git a/egs/alimeeting/sa_asr/local/process_sot_fifo_textchar2spk.py b/egs/alimeeting/sa_asr/local/process_sot_fifo_textchar2spk.py
index 488344f..c26ba32 100755
--- a/egs/alimeeting/sa_asr/local/process_sot_fifo_textchar2spk.py
+++ b/egs/alimeeting/sa_asr/local/process_sot_fifo_textchar2spk.py
@@ -6,7 +6,10 @@
import codecs
from distutils.util import strtobool
from pathlib import Path
-import textgrid
+try:
+ import textgrid
+except:
+ raise "Please install textgrid firstly: pip install textgrid"
import pdb
def get_args():
diff --git a/egs/alimeeting/sa_asr/local/process_textgrid_to_single_speaker_wav.py b/egs/alimeeting/sa_asr/local/process_textgrid_to_single_speaker_wav.py
index fdf2460..b72ddc9 100755
--- a/egs/alimeeting/sa_asr/local/process_textgrid_to_single_speaker_wav.py
+++ b/egs/alimeeting/sa_asr/local/process_textgrid_to_single_speaker_wav.py
@@ -6,7 +6,12 @@
import codecs
from distutils.util import strtobool
from pathlib import Path
-import textgrid
+
+try:
+ import textgrid
+except:
+ raise "Please install textgrid firstly: pip install textgrid"
+
import pdb
import numpy as np
import sys
diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index 7015eb8..c1d08df 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -44,9 +44,9 @@
"""Speech2Text class
Examples:
- >>> import soundfile
+ >>> import librosa
>>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
- >>> audio, rate = soundfile.read("speech.wav")
+ >>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
@@ -251,9 +251,9 @@
"""Speech2Text class
Examples:
- >>> import soundfile
+ >>> import librosa
>>> speech2text = Speech2TextParaformer("asr_config.yml", "asr.pb")
- >>> audio, rate = soundfile.read("speech.wav")
+ >>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
@@ -625,9 +625,9 @@
"""Speech2Text class
Examples:
- >>> import soundfile
+ >>> import librosa
>>> speech2text = Speech2TextParaformerOnline("asr_config.yml", "asr.pth")
- >>> audio, rate = soundfile.read("speech.wav")
+ >>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
@@ -876,9 +876,9 @@
"""Speech2Text class
Examples:
- >>> import soundfile
+ >>> import librosa
>>> speech2text = Speech2TextUniASR("asr_config.yml", "asr.pb")
- >>> audio, rate = soundfile.read("speech.wav")
+ >>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
@@ -1106,9 +1106,9 @@
"""Speech2Text class
Examples:
- >>> import soundfile
+ >>> import librosa
>>> speech2text = Speech2TextMFCCA("asr_config.yml", "asr.pb")
- >>> audio, rate = soundfile.read("speech.wav")
+ >>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
@@ -1637,9 +1637,9 @@
"""Speech2Text class
Examples:
- >>> import soundfile
+ >>> import librosa
>>> speech2text = Speech2TextSAASR("asr_config.yml", "asr.pb")
- >>> audio, rate = soundfile.read("speech.wav")
+ >>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
@@ -1885,9 +1885,9 @@
"""Speech2Text class
Examples:
- >>> import soundfile
+ >>> import librosa
>>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
- >>> audio, rate = soundfile.read("speech.wav")
+ >>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index e1a32c5..7dd27fc 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -20,7 +20,8 @@
import numpy as np
import torch
import torchaudio
-import soundfile
+# import librosa
+import librosa
import yaml
from funasr.bin.asr_infer import Speech2Text
@@ -1281,7 +1282,8 @@
try:
raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
except:
- raw_inputs = soundfile.read(data_path_and_name_and_type[0], dtype='float32')[0]
+ # raw_inputs = librosa.load(data_path_and_name_and_type[0], dtype='float32')[0]
+ raw_inputs, sr = librosa.load(data_path_and_name_and_type[0], dtype='float32')
if raw_inputs.ndim == 2:
raw_inputs = raw_inputs[:, 0]
raw_inputs = torch.tensor(raw_inputs)
diff --git a/funasr/bin/diar_infer.py b/funasr/bin/diar_infer.py
index 6fc1da1..bb40f5e 100755
--- a/funasr/bin/diar_infer.py
+++ b/funasr/bin/diar_infer.py
@@ -27,11 +27,11 @@
"""Speech2Diarlization class
Examples:
- >>> import soundfile
+ >>> import librosa
>>> import numpy as np
>>> speech2diar = Speech2DiarizationEEND("diar_sond_config.yml", "diar_sond.pb")
>>> profile = np.load("profiles.npy")
- >>> audio, rate = soundfile.read("speech.wav")
+ >>> audio, rate = librosa.load("speech.wav")
>>> speech2diar(audio, profile)
{"spk1": [(int, int), ...], ...}
@@ -109,11 +109,11 @@
"""Speech2Xvector class
Examples:
- >>> import soundfile
+ >>> import librosa
>>> import numpy as np
>>> speech2diar = Speech2DiarizationSOND("diar_sond_config.yml", "diar_sond.pb")
>>> profile = np.load("profiles.npy")
- >>> audio, rate = soundfile.read("speech.wav")
+ >>> audio, rate = librosa.load("speech.wav")
>>> speech2diar(audio, profile)
{"spk1": [(int, int), ...], ...}
diff --git a/funasr/bin/diar_inference_launch.py b/funasr/bin/diar_inference_launch.py
index b655df5..f5a11b1 100755
--- a/funasr/bin/diar_inference_launch.py
+++ b/funasr/bin/diar_inference_launch.py
@@ -15,7 +15,8 @@
from typing import Union
import numpy as np
-import soundfile
+# import librosa
+import librosa
import torch
from scipy.signal import medfilt
@@ -144,7 +145,9 @@
# read waveform file
example = [load_bytes(x) if isinstance(x, bytes) else x
for x in example]
- example = [soundfile.read(x)[0] if isinstance(x, str) else x
+ # example = [librosa.load(x)[0] if isinstance(x, str) else x
+ # for x in example]
+ example = [librosa.load(x, dtype='float32')[0] if isinstance(x, str) else x
for x in example]
# convert torch tensor to numpy array
example = [x.numpy() if isinstance(example[0], torch.Tensor) else x
diff --git a/funasr/bin/ss_infer.py b/funasr/bin/ss_infer.py
index 483967b..a3eca11 100644
--- a/funasr/bin/ss_infer.py
+++ b/funasr/bin/ss_infer.py
@@ -20,9 +20,9 @@
"""SpeechSeparator class
Examples:
- >>> import soundfile
+ >>> import librosa
>>> speech_separator = MossFormer("ss_config.yml", "ss.pt")
- >>> audio, rate = soundfile.read("speech.wav")
+ >>> audio, rate = librosa.load("speech.wav")
>>> separated_wavs = speech_separator(audio)
"""
diff --git a/funasr/bin/ss_inference_launch.py b/funasr/bin/ss_inference_launch.py
index 64503a0..0c02419 100644
--- a/funasr/bin/ss_inference_launch.py
+++ b/funasr/bin/ss_inference_launch.py
@@ -13,7 +13,7 @@
import numpy as np
import torch
-import soundfile as sf
+import librosa
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
@@ -104,7 +104,12 @@
ss_results = speech_separator(**batch)
for spk in range(num_spks):
- sf.write(os.path.join(output_path, keys[0] + '_s' + str(spk+1)+'.wav'), ss_results[spk], sample_rate)
+ # sf.write(os.path.join(output_path, keys[0] + '_s' + str(spk+1)+'.wav'), ss_results[spk], sample_rate)
+ try:
+ librosa.output.write_wav(os.path.join(output_path, keys[0] + '_s' + str(spk+1)+'.wav'), ss_results[spk], sample_rate)
+ except:
+ print("To write wav by librosa, you should install librosa<=0.8.0")
+ raise
torch.cuda.empty_cache()
return ss_results
diff --git a/funasr/bin/sv_infer.py b/funasr/bin/sv_infer.py
index 346440a..19cfc2e 100755
--- a/funasr/bin/sv_infer.py
+++ b/funasr/bin/sv_infer.py
@@ -22,9 +22,9 @@
"""Speech2Xvector class
Examples:
- >>> import soundfile
+ >>> import librosa
>>> speech2xvector = Speech2Xvector("sv_config.yml", "sv.pb")
- >>> audio, rate = soundfile.read("speech.wav")
+ >>> audio, rate = librosa.load("speech.wav")
>>> speech2xvector(audio)
[(text, token, token_int, hypothesis object), ...]
diff --git a/funasr/bin/vad_infer.py b/funasr/bin/vad_infer.py
index 73e1f3f..5763873 100644
--- a/funasr/bin/vad_infer.py
+++ b/funasr/bin/vad_infer.py
@@ -23,9 +23,9 @@
"""Speech2VadSegment class
Examples:
- >>> import soundfile
+ >>> import librosa
>>> speech2segment = Speech2VadSegment("vad_config.yml", "vad.pt")
- >>> audio, rate = soundfile.read("speech.wav")
+ >>> audio, rate = librosa.load("speech.wav")
>>> speech2segment(audio)
[[10, 230], [245, 450], ...]
@@ -118,9 +118,9 @@
"""Speech2VadSegmentOnline class
Examples:
- >>> import soundfile
+ >>> import librosa
>>> speech2segment = Speech2VadSegmentOnline("vad_config.yml", "vad.pt")
- >>> audio, rate = soundfile.read("speech.wav")
+ >>> audio, rate = librosa.load("speech.wav")
>>> speech2segment(audio)
[[10, 230], [245, 450], ...]
diff --git a/funasr/build_utils/build_trainer.py b/funasr/build_utils/build_trainer.py
index 03aa780..498d05d 100644
--- a/funasr/build_utils/build_trainer.py
+++ b/funasr/build_utils/build_trainer.py
@@ -246,14 +246,11 @@
for iepoch in range(start_epoch, trainer_options.max_epoch + 1):
if iepoch != start_epoch:
logging.info(
- "{}/{}epoch started. Estimated time to finish: {}".format(
+ "{}/{}epoch started. Estimated time to finish: {} hours".format(
iepoch,
trainer_options.max_epoch,
- humanfriendly.format_timespan(
- (time.perf_counter() - start_time)
- / (iepoch - start_epoch)
- * (trainer_options.max_epoch - iepoch + 1)
- ),
+ (time.perf_counter() - start_time) / 3600.0 / (iepoch - start_epoch) * (
+ trainer_options.max_epoch - iepoch + 1),
)
)
else:
diff --git a/funasr/datasets/dataset.py b/funasr/datasets/dataset.py
index 407f6aa..673a9b2 100644
--- a/funasr/datasets/dataset.py
+++ b/funasr/datasets/dataset.py
@@ -16,8 +16,10 @@
from typing import Mapping
from typing import Tuple
from typing import Union
-
-import h5py
+try:
+ import h5py
+except:
+ print("If you want use h5py dataset, please pip install h5py, and try it again")
import humanfriendly
import kaldiio
import numpy as np
diff --git a/funasr/datasets/iterable_dataset.py b/funasr/datasets/iterable_dataset.py
index 6398e0c..b2cc283 100644
--- a/funasr/datasets/iterable_dataset.py
+++ b/funasr/datasets/iterable_dataset.py
@@ -14,7 +14,8 @@
import numpy as np
import torch
import torchaudio
-import soundfile
+# import librosa
+import librosa
from torch.utils.data.dataset import IterableDataset
import os.path
@@ -70,7 +71,8 @@
try:
return torchaudio.load(input)[0].numpy()
except:
- waveform, _ = soundfile.read(input, dtype='float32')
+ # waveform, _ = librosa.load(input, dtype='float32')
+ waveform, _ = librosa.load(input, dtype='float32')
if waveform.ndim == 2:
waveform = waveform[:, 0]
return np.expand_dims(waveform, axis=0)
diff --git a/funasr/datasets/large_datasets/dataset.py b/funasr/datasets/large_datasets/dataset.py
index adfe4f6..d3489c1 100644
--- a/funasr/datasets/large_datasets/dataset.py
+++ b/funasr/datasets/large_datasets/dataset.py
@@ -7,7 +7,8 @@
import torch.distributed as dist
import torchaudio
import numpy as np
-import soundfile
+# import librosa
+import librosa
from kaldiio import ReadHelper
from torch.utils.data import IterableDataset
@@ -128,7 +129,8 @@
try:
waveform, sampling_rate = torchaudio.load(path)
except:
- waveform, sampling_rate = soundfile.read(path, dtype='float32')
+ # waveform, sampling_rate = librosa.load(path, dtype='float32')
+ waveform, sampling_rate = librosa.load(path, dtype='float32')
if waveform.ndim == 2:
waveform = waveform[:, 0]
waveform = np.expand_dims(waveform, axis=0)
diff --git a/funasr/datasets/preprocessor.py b/funasr/datasets/preprocessor.py
index 9b5c4e7..26e062c 100644
--- a/funasr/datasets/preprocessor.py
+++ b/funasr/datasets/preprocessor.py
@@ -10,7 +10,7 @@
import numpy as np
import scipy.signal
-import soundfile
+import librosa
import jieba
from funasr.text.build_tokenizer import build_tokenizer
@@ -284,7 +284,7 @@
if self.rirs is not None and self.rir_apply_prob >= np.random.random():
rir_path = np.random.choice(self.rirs)
if rir_path is not None:
- rir, _ = soundfile.read(
+ rir, _ = librosa.load(
rir_path, dtype=np.float64, always_2d=True
)
@@ -310,28 +310,31 @@
noise_db = np.random.uniform(
self.noise_db_low, self.noise_db_high
)
- with soundfile.SoundFile(noise_path) as f:
- if f.frames == nsamples:
- noise = f.read(dtype=np.float64, always_2d=True)
- elif f.frames < nsamples:
- offset = np.random.randint(0, nsamples - f.frames)
- # noise: (Time, Nmic)
- noise = f.read(dtype=np.float64, always_2d=True)
- # Repeat noise
- noise = np.pad(
- noise,
- [(offset, nsamples - f.frames - offset), (0, 0)],
- mode="wrap",
- )
- else:
- offset = np.random.randint(0, f.frames - nsamples)
- f.seek(offset)
- # noise: (Time, Nmic)
- noise = f.read(
- nsamples, dtype=np.float64, always_2d=True
- )
- if len(noise) != nsamples:
- raise RuntimeError(f"Something wrong: {noise_path}")
+
+ audio_data = librosa.load(noise_path, dtype='float32')[0][None, :]
+ frames = len(audio_data[0])
+ if frames == nsamples:
+ noise = audio_data
+ elif frames < nsamples:
+ offset = np.random.randint(0, nsamples - frames)
+ # noise: (Time, Nmic)
+ noise = audio_data
+ # Repeat noise
+ noise = np.pad(
+ noise,
+ [(offset, nsamples - frames - offset), (0, 0)],
+ mode="wrap",
+ )
+ else:
+ noise = audio_data[:, nsamples]
+ # offset = np.random.randint(0, frames - nsamples)
+ # f.seek(offset)
+ # noise: (Time, Nmic)
+ # noise = f.read(
+ # nsamples, dtype=np.float64, always_2d=True
+ # )
+ # if len(noise) != nsamples:
+ # raise RuntimeError(f"Something wrong: {noise_path}")
# noise: (Nmic, Time)
noise = noise.T
diff --git a/funasr/datasets/small_datasets/preprocessor.py b/funasr/datasets/small_datasets/preprocessor.py
index 0ebf325..f0d3c9a 100644
--- a/funasr/datasets/small_datasets/preprocessor.py
+++ b/funasr/datasets/small_datasets/preprocessor.py
@@ -9,7 +9,7 @@
import numpy as np
import scipy.signal
-import soundfile
+import librosa
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.cleaner import TextCleaner
@@ -275,7 +275,7 @@
if self.rirs is not None and self.rir_apply_prob >= np.random.random():
rir_path = np.random.choice(self.rirs)
if rir_path is not None:
- rir, _ = soundfile.read(
+ rir, _ = librosa.load(
rir_path, dtype=np.float64, always_2d=True
)
@@ -301,28 +301,30 @@
noise_db = np.random.uniform(
self.noise_db_low, self.noise_db_high
)
- with soundfile.SoundFile(noise_path) as f:
- if f.frames == nsamples:
- noise = f.read(dtype=np.float64, always_2d=True)
- elif f.frames < nsamples:
- offset = np.random.randint(0, nsamples - f.frames)
- # noise: (Time, Nmic)
- noise = f.read(dtype=np.float64, always_2d=True)
- # Repeat noise
- noise = np.pad(
- noise,
- [(offset, nsamples - f.frames - offset), (0, 0)],
- mode="wrap",
- )
- else:
- offset = np.random.randint(0, f.frames - nsamples)
- f.seek(offset)
- # noise: (Time, Nmic)
- noise = f.read(
- nsamples, dtype=np.float64, always_2d=True
- )
- if len(noise) != nsamples:
- raise RuntimeError(f"Something wrong: {noise_path}")
+ audio_data = librosa.load(noise_path, dtype='float32')[0][None, :]
+ frames = len(audio_data[0])
+ if frames == nsamples:
+ noise = audio_data
+ elif frames < nsamples:
+ offset = np.random.randint(0, nsamples - frames)
+ # noise: (Time, Nmic)
+ noise = audio_data
+ # Repeat noise
+ noise = np.pad(
+ noise,
+ [(offset, nsamples - frames - offset), (0, 0)],
+ mode="wrap",
+ )
+ else:
+ noise = audio_data[:, nsamples]
+ # offset = np.random.randint(0, frames - nsamples)
+ # f.seek(offset)
+ # noise: (Time, Nmic)
+ # noise = f.read(
+ # nsamples, dtype=np.float64, always_2d=True
+ # )
+ # if len(noise) != nsamples:
+ # raise RuntimeError(f"Something wrong: {noise_path}")
# noise: (Nmic, Time)
noise = noise.T
diff --git a/funasr/export/export_conformer.py b/funasr/export/export_conformer.py
deleted file mode 100644
index 4980775..0000000
--- a/funasr/export/export_conformer.py
+++ /dev/null
@@ -1,151 +0,0 @@
-import json
-from typing import Union, Dict
-from pathlib import Path
-
-import os
-import logging
-import torch
-
-from funasr.export.models import get_model
-import numpy as np
-import random
-from funasr.utils.types import str2bool, str2triple_str
-# torch_version = float(".".join(torch.__version__.split(".")[:2]))
-# assert torch_version > 1.9
-
-class ModelExport:
- def __init__(
- self,
- cache_dir: Union[Path, str] = None,
- onnx: bool = True,
- device: str = "cpu",
- quant: bool = True,
- fallback_num: int = 0,
- audio_in: str = None,
- calib_num: int = 200,
- model_revision: str = None,
- ):
- self.set_all_random_seed(0)
-
- self.cache_dir = cache_dir
- self.export_config = dict(
- feats_dim=560,
- onnx=False,
- )
-
- self.onnx = onnx
- self.device = device
- self.quant = quant
- self.fallback_num = fallback_num
- self.frontend = None
- self.audio_in = audio_in
- self.calib_num = calib_num
- self.model_revision = model_revision
-
- def _export(
- self,
- model,
- model_dir: str = None,
- verbose: bool = False,
- ):
-
- export_dir = model_dir
- os.makedirs(export_dir, exist_ok=True)
-
- self.export_config["model_name"] = "model"
- model = get_model(
- model,
- self.export_config,
- )
- model.eval()
-
- if self.onnx:
- self._export_onnx(model, verbose, export_dir)
-
- print("output dir: {}".format(export_dir))
-
- def _export_onnx(self, model, verbose, path):
- model._export_onnx(verbose, path)
-
- def set_all_random_seed(self, seed: int):
- random.seed(seed)
- np.random.seed(seed)
- torch.random.manual_seed(seed)
-
- def parse_audio_in(self, audio_in):
-
- wav_list, name_list = [], []
- if audio_in.endswith(".scp"):
- f = open(audio_in, 'r')
- lines = f.readlines()[:self.calib_num]
- for line in lines:
- name, path = line.strip().split()
- name_list.append(name)
- wav_list.append(path)
- else:
- wav_list = [audio_in,]
- name_list = ["test",]
- return wav_list, name_list
-
- def load_feats(self, audio_in: str = None):
- import torchaudio
-
- wav_list, name_list = self.parse_audio_in(audio_in)
- feats = []
- feats_len = []
- for line in wav_list:
- path = line.strip()
- waveform, sampling_rate = torchaudio.load(path)
- if sampling_rate != self.frontend.fs:
- waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
- new_freq=self.frontend.fs)(waveform)
- fbank, fbank_len = self.frontend(waveform, [waveform.size(1)])
- feats.append(fbank)
- feats_len.append(fbank_len)
- return feats, feats_len
-
- def export(self,
- mode: str = None,
- ):
-
- if mode.startswith('conformer'):
- from funasr.tasks.asr import ASRTask
- config = os.path.join(model_dir, 'config.yaml')
- model_file = os.path.join(model_dir, 'model.pb')
- cmvn_file = os.path.join(model_dir, 'am.mvn')
- model, asr_train_args = ASRTask.build_model_from_file(
- config, model_file, cmvn_file, 'cpu'
- )
- self.frontend = model.frontend
- self.export_config["feats_dim"] = 560
-
- self._export(model, self.cache_dir)
-
-if __name__ == '__main__':
- import argparse
- parser = argparse.ArgumentParser()
- # parser.add_argument('--model-name', type=str, required=True)
- parser.add_argument('--model-name', type=str, action="append", required=True, default=[])
- parser.add_argument('--export-dir', type=str, required=True)
- parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
- parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
- parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
- parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
- parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
- parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
- parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
- args = parser.parse_args()
-
- export_model = ModelExport(
- cache_dir=args.export_dir,
- onnx=args.type == 'onnx',
- device=args.device,
- quant=args.quantize,
- fallback_num=args.fallback_num,
- audio_in=args.audio_in,
- calib_num=args.calib_num,
- model_revision=args.model_revision,
- )
- for model_name in args.model_name:
- print("export model: {}".format(model_name))
- export_model.export(model_name)
diff --git a/funasr/export/models/language_models/__init__.py b/funasr/export/models/language_models/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/export/models/language_models/__init__.py
+++ /dev/null
diff --git a/funasr/export/models/language_models/embed.py b/funasr/export/models/language_models/embed.py
deleted file mode 100644
index 57748f2..0000000
--- a/funasr/export/models/language_models/embed.py
+++ /dev/null
@@ -1,403 +0,0 @@
-"""Positional Encoding Module."""
-
-import math
-
-import torch
-import torch.nn as nn
-from funasr.modules.embedding import (
- LegacyRelPositionalEncoding, PositionalEncoding, RelPositionalEncoding,
- ScaledPositionalEncoding, StreamPositionalEncoding)
-from funasr.modules.subsampling import (
- Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6,
- Conv2dSubsampling8)
-from funasr.modules.subsampling_without_posenc import \
- Conv2dSubsamplingWOPosEnc
-
-from funasr.export.models.language_models.subsampling import (
- OnnxConv2dSubsampling, OnnxConv2dSubsampling2, OnnxConv2dSubsampling6,
- OnnxConv2dSubsampling8)
-
-
-def get_pos_emb(pos_emb, max_seq_len=512, use_cache=True):
- if isinstance(pos_emb, LegacyRelPositionalEncoding):
- return OnnxLegacyRelPositionalEncoding(pos_emb, max_seq_len, use_cache)
- elif isinstance(pos_emb, ScaledPositionalEncoding):
- return OnnxScaledPositionalEncoding(pos_emb, max_seq_len, use_cache)
- elif isinstance(pos_emb, RelPositionalEncoding):
- return OnnxRelPositionalEncoding(pos_emb, max_seq_len, use_cache)
- elif isinstance(pos_emb, PositionalEncoding):
- return OnnxPositionalEncoding(pos_emb, max_seq_len, use_cache)
- elif isinstance(pos_emb, StreamPositionalEncoding):
- return OnnxStreamPositionalEncoding(pos_emb, max_seq_len, use_cache)
- elif (isinstance(pos_emb, nn.Sequential) and len(pos_emb) == 0) or (
- isinstance(pos_emb, Conv2dSubsamplingWOPosEnc)
- ):
- return pos_emb
- else:
- raise ValueError("Embedding model is not supported.")
-
-
-class Embedding(nn.Module):
- def __init__(self, model, max_seq_len=512, use_cache=True):
- super().__init__()
- self.model = model
- if not isinstance(model, nn.Embedding):
- if isinstance(model, Conv2dSubsampling):
- self.model = OnnxConv2dSubsampling(model)
- self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
- elif isinstance(model, Conv2dSubsampling2):
- self.model = OnnxConv2dSubsampling2(model)
- self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
- elif isinstance(model, Conv2dSubsampling6):
- self.model = OnnxConv2dSubsampling6(model)
- self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
- elif isinstance(model, Conv2dSubsampling8):
- self.model = OnnxConv2dSubsampling8(model)
- self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
- else:
- self.model[-1] = get_pos_emb(model[-1], max_seq_len)
-
- def forward(self, x, mask=None):
- if mask is None:
- return self.model(x)
- else:
- return self.model(x, mask)
-
-
-def _pre_hook(
- state_dict,
- prefix,
- local_metadata,
- strict,
- missing_keys,
- unexpected_keys,
- error_msgs,
-):
- """Perform pre-hook in load_state_dict for backward compatibility.
-
- Note:
- We saved self.pe until v.0.5.2 but we have omitted it later.
- Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
-
- """
- k = prefix + "pe"
- if k in state_dict:
- state_dict.pop(k)
-
-
-class OnnxPositionalEncoding(torch.nn.Module):
- """Positional encoding.
-
- Args:
- d_model (int): Embedding dimension.
- dropout_rate (float): Dropout rate.
- max_seq_len (int): Maximum input length.
- reverse (bool): Whether to reverse the input position. Only for
- the class LegacyRelPositionalEncoding. We remove it in the current
- class RelPositionalEncoding.
- """
-
- def __init__(self, model, max_seq_len=512, reverse=False, use_cache=True):
- """Construct an PositionalEncoding object."""
- super(OnnxPositionalEncoding, self).__init__()
- self.d_model = model.d_model
- self.reverse = reverse
- self.max_seq_len = max_seq_len
- self.xscale = math.sqrt(self.d_model)
- self._register_load_state_dict_pre_hook(_pre_hook)
- self.pe = model.pe
- self.use_cache = use_cache
- self.model = model
- if self.use_cache:
- self.extend_pe()
- else:
- self.div_term = torch.exp(
- torch.arange(0, self.d_model, 2, dtype=torch.float32)
- * -(math.log(10000.0) / self.d_model)
- )
-
- def extend_pe(self):
- """Reset the positional encodings."""
- pe_length = len(self.pe[0])
- if self.max_seq_len < pe_length:
- self.pe = self.pe[:, : self.max_seq_len]
- else:
- self.model.extend_pe(torch.tensor(0.0).expand(1, self.max_seq_len))
- self.pe = self.model.pe
-
- def _add_pe(self, x):
- """Computes positional encoding"""
- if self.reverse:
- position = torch.arange(
- x.size(1) - 1, -1, -1.0, dtype=torch.float32
- ).unsqueeze(1)
- else:
- position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
-
- x = x * self.xscale
- x[:, :, 0::2] += torch.sin(position * self.div_term)
- x[:, :, 1::2] += torch.cos(position * self.div_term)
- return x
-
- def forward(self, x: torch.Tensor):
- """Add positional encoding.
-
- Args:
- x (torch.Tensor): Input tensor (batch, time, `*`).
-
- Returns:
- torch.Tensor: Encoded tensor (batch, time, `*`).
- """
- if self.use_cache:
- x = x * self.xscale + self.pe[:, : x.size(1)]
- else:
- x = self._add_pe(x)
- return x
-
-
-class OnnxScaledPositionalEncoding(OnnxPositionalEncoding):
- """Scaled positional encoding module.
-
- See Sec. 3.2 https://arxiv.org/abs/1809.08895
-
- Args:
- d_model (int): Embedding dimension.
- dropout_rate (float): Dropout rate.
- max_seq_len (int): Maximum input length.
-
- """
-
- def __init__(self, model, max_seq_len=512, use_cache=True):
- """Initialize class."""
- super().__init__(model, max_seq_len, use_cache=use_cache)
- self.alpha = torch.nn.Parameter(torch.tensor(1.0))
-
- def reset_parameters(self):
- """Reset parameters."""
- self.alpha.data = torch.tensor(1.0)
-
- def _add_pe(self, x):
- """Computes positional encoding"""
- if self.reverse:
- position = torch.arange(
- x.size(1) - 1, -1, -1.0, dtype=torch.float32
- ).unsqueeze(1)
- else:
- position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
-
- x = x * self.alpha
- x[:, :, 0::2] += torch.sin(position * self.div_term)
- x[:, :, 1::2] += torch.cos(position * self.div_term)
- return x
-
- def forward(self, x):
- """Add positional encoding.
-
- Args:
- x (torch.Tensor): Input tensor (batch, time, `*`).
-
- Returns:
- torch.Tensor: Encoded tensor (batch, time, `*`).
-
- """
- if self.use_cache:
- x = x + self.alpha * self.pe[:, : x.size(1)]
- else:
- x = self._add_pe(x)
- return x
-
-
-class OnnxLegacyRelPositionalEncoding(OnnxPositionalEncoding):
- """Relative positional encoding module (old version).
-
- Details can be found in https://github.com/espnet/espnet/pull/2816.
-
- See : Appendix B in https://arxiv.org/abs/1901.02860
-
- Args:
- d_model (int): Embedding dimension.
- dropout_rate (float): Dropout rate.
- max_seq_len (int): Maximum input length.
-
- """
-
- def __init__(self, model, max_seq_len=512, use_cache=True):
- """Initialize class."""
- super().__init__(model, max_seq_len, reverse=True, use_cache=use_cache)
-
- def _get_pe(self, x):
- """Computes positional encoding"""
- if self.reverse:
- position = torch.arange(
- x.size(1) - 1, -1, -1.0, dtype=torch.float32
- ).unsqueeze(1)
- else:
- position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
-
- pe = torch.zeros(x.shape)
- pe[:, :, 0::2] += torch.sin(position * self.div_term)
- pe[:, :, 1::2] += torch.cos(position * self.div_term)
- return pe
-
- def forward(self, x):
- """Compute positional encoding.
-
- Args:
- x (torch.Tensor): Input tensor (batch, time, `*`).
-
- Returns:
- torch.Tensor: Encoded tensor (batch, time, `*`).
- torch.Tensor: Positional embedding tensor (1, time, `*`).
-
- """
- x = x * self.xscale
- if self.use_cache:
- pos_emb = self.pe[:, : x.size(1)]
- else:
- pos_emb = self._get_pe(x)
- return x, pos_emb
-
-
-class OnnxRelPositionalEncoding(torch.nn.Module):
- """Relative positional encoding module (new implementation).
- Details can be found in https://github.com/espnet/espnet/pull/2816.
- See : Appendix B in https://arxiv.org/abs/1901.02860
- Args:
- d_model (int): Embedding dimension.
- dropout_rate (float): Dropout rate.
- max_seq_len (int): Maximum input length.
- """
-
- def __init__(self, model, max_seq_len=512, use_cache=True):
- """Construct an PositionalEncoding object."""
- super(OnnxRelPositionalEncoding, self).__init__()
- self.d_model = model.d_model
- self.xscale = math.sqrt(self.d_model)
- self.pe = None
- self.use_cache = use_cache
- if self.use_cache:
- self.extend_pe(torch.tensor(0.0).expand(1, max_seq_len))
- else:
- self.div_term = torch.exp(
- torch.arange(0, self.d_model, 2, dtype=torch.float32)
- * -(math.log(10000.0) / self.d_model)
- )
-
- def extend_pe(self, x):
- """Reset the positional encodings."""
- if self.pe is not None and self.pe.size(1) >= x.size(1) * 2 - 1:
- # self.pe contains both positive and negative parts
- # the length of self.pe is 2 * input_len - 1
- if self.pe.dtype != x.dtype or self.pe.device != x.device:
- self.pe = self.pe.to(dtype=x.dtype, device=x.device)
- return
- # Suppose `i` means to the position of query vecotr and `j` means the
- # position of key vector. We use position relative positions when keys
- # are to the left (i>j) and negative relative positions otherwise (i<j).
- pe_positive = torch.zeros(x.size(1), self.d_model)
- pe_negative = torch.zeros(x.size(1), self.d_model)
- position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
- div_term = torch.exp(
- torch.arange(0, self.d_model, 2, dtype=torch.float32)
- * -(math.log(10000.0) / self.d_model)
- )
- pe_positive[:, 0::2] = torch.sin(position * div_term)
- pe_positive[:, 1::2] = torch.cos(position * div_term)
- pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
- pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
-
- # Reserve the order of positive indices and concat both positive and
- # negative indices. This is used to support the shifting trick
- # as in https://arxiv.org/abs/1901.02860
- pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
- pe_negative = pe_negative[1:].unsqueeze(0)
- pe = torch.cat([pe_positive, pe_negative], dim=1)
- self.pe = pe.to(device=x.device, dtype=x.dtype)
-
- def _get_pe(self, x):
- pe_positive = torch.zeros(x.size(1), self.d_model)
- pe_negative = torch.zeros(x.size(1), self.d_model)
- theta = (
- torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) * self.div_term
- )
- pe_positive[:, 0::2] = torch.sin(theta)
- pe_positive[:, 1::2] = torch.cos(theta)
- pe_negative[:, 0::2] = -1 * torch.sin(theta)
- pe_negative[:, 1::2] = torch.cos(theta)
-
- # Reserve the order of positive indices and concat both positive and
- # negative indices. This is used to support the shifting trick
- # as in https://arxiv.org/abs/1901.02860
- pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
- pe_negative = pe_negative[1:].unsqueeze(0)
- return torch.cat([pe_positive, pe_negative], dim=1)
-
- def forward(self, x: torch.Tensor, use_cache=True):
- """Add positional encoding.
- Args:
- x (torch.Tensor): Input tensor (batch, time, `*`).
- Returns:
- torch.Tensor: Encoded tensor (batch, time, `*`).
- """
- x = x * self.xscale
- if self.use_cache:
- pos_emb = self.pe[
- :,
- self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
- ]
- else:
- pos_emb = self._get_pe(x)
- return x, pos_emb
-
-
-class OnnxStreamPositionalEncoding(torch.nn.Module):
- """Streaming Positional encoding."""
-
- def __init__(self, model, max_seq_len=5000, use_cache=True):
- """Construct an PositionalEncoding object."""
- super(StreamPositionalEncoding, self).__init__()
- self.use_cache = use_cache
- self.d_model = model.d_model
- self.xscale = model.xscale
- self.pe = model.pe
- self.use_cache = use_cache
- self.max_seq_len = max_seq_len
- if self.use_cache:
- self.extend_pe()
- else:
- self.div_term = torch.exp(
- torch.arange(0, self.d_model, 2, dtype=torch.float32)
- * -(math.log(10000.0) / self.d_model)
- )
- self._register_load_state_dict_pre_hook(_pre_hook)
-
- def extend_pe(self):
- """Reset the positional encodings."""
- pe_length = len(self.pe[0])
- if self.max_seq_len < pe_length:
- self.pe = self.pe[:, : self.max_seq_len]
- else:
- self.model.extend_pe(self.max_seq_len)
- self.pe = self.model.pe
-
- def _add_pe(self, x, start_idx):
- position = torch.arange(start_idx, x.size(1), dtype=torch.float32).unsqueeze(1)
- x = x * self.xscale
- x[:, :, 0::2] += torch.sin(position * self.div_term)
- x[:, :, 1::2] += torch.cos(position * self.div_term)
- return x
-
- def forward(self, x: torch.Tensor, start_idx: int = 0):
- """Add positional encoding.
-
- Args:
- x (torch.Tensor): Input tensor (batch, time, `*`).
-
- Returns:
- torch.Tensor: Encoded tensor (batch, time, `*`).
-
- """
- if self.use_cache:
- return x * self.xscale + self.pe[:, start_idx : start_idx + x.size(1)]
- else:
- return self._add_pe(x, start_idx)
diff --git a/funasr/export/models/language_models/seq_rnn.py b/funasr/export/models/language_models/seq_rnn.py
deleted file mode 100644
index ecff4b8..0000000
--- a/funasr/export/models/language_models/seq_rnn.py
+++ /dev/null
@@ -1,84 +0,0 @@
-import os
-
-import torch
-import torch.nn as nn
-
-class SequentialRNNLM(nn.Module):
- def __init__(self, model, **kwargs):
- super().__init__()
- self.encoder = model.encoder
- self.rnn = model.rnn
- self.rnn_type = model.rnn_type
- self.decoder = model.decoder
- self.nlayers = model.nlayers
- self.nhid = model.nhid
- self.model_name = "seq_rnnlm"
-
- def forward(self, y, hidden1, hidden2=None):
- # batch_score function.
- emb = self.encoder(y)
- if self.rnn_type == "LSTM":
- output, (hidden1, hidden2) = self.rnn(emb, (hidden1, hidden2))
- else:
- output, hidden1 = self.rnn(emb, hidden1)
-
- decoded = self.decoder(
- output.contiguous().view(output.size(0) * output.size(1), output.size(2))
- )
- if self.rnn_type == "LSTM":
- return (
- decoded.view(output.size(0), output.size(1), decoded.size(1)),
- hidden1,
- hidden2,
- )
- else:
- return (
- decoded.view(output.size(0), output.size(1), decoded.size(1)),
- hidden1,
- )
-
- def get_dummy_inputs(self):
- tgt = torch.LongTensor([0, 1]).unsqueeze(0)
- hidden = torch.randn(self.nlayers, 1, self.nhid)
- if self.rnn_type == "LSTM":
- return (tgt, hidden, hidden)
- else:
- return (tgt, hidden)
-
- def get_input_names(self):
- if self.rnn_type == "LSTM":
- return ["x", "in_hidden1", "in_hidden2"]
- else:
- return ["x", "in_hidden1"]
-
- def get_output_names(self):
- if self.rnn_type == "LSTM":
- return ["y", "out_hidden1", "out_hidden2"]
- else:
- return ["y", "out_hidden1"]
-
- def get_dynamic_axes(self):
- ret = {
- "x": {0: "x_batch", 1: "x_length"},
- "y": {0: "y_batch"},
- "in_hidden1": {1: "hidden1_batch"},
- "out_hidden1": {1: "out_hidden1_batch"},
- }
- if self.rnn_type == "LSTM":
- ret.update(
- {
- "in_hidden2": {1: "hidden2_batch"},
- "out_hidden2": {1: "out_hidden2_batch"},
- }
- )
- return ret
-
- def get_model_config(self, path):
- return {
- "use_lm": True,
- "model_path": os.path.join(path, f"{self.model_name}.onnx"),
- "lm_type": "SequentialRNNLM",
- "rnn_type": self.rnn_type,
- "nhid": self.nhid,
- "nlayers": self.nlayers,
- }
diff --git a/funasr/export/models/language_models/subsampling.py b/funasr/export/models/language_models/subsampling.py
deleted file mode 100644
index e71e127..0000000
--- a/funasr/export/models/language_models/subsampling.py
+++ /dev/null
@@ -1,185 +0,0 @@
-"""Subsampling layer definition."""
-
-import torch
-
-
-class OnnxConv2dSubsampling(torch.nn.Module):
- """Convolutional 2D subsampling (to 1/4 length).
-
- Args:
- idim (int): Input dimension.
- odim (int): Output dimension.
- dropout_rate (float): Dropout rate.
- pos_enc (torch.nn.Module): Custom position encoding layer.
-
- """
-
- def __init__(self, model):
- """Construct an Conv2dSubsampling object."""
- super().__init__()
- self.conv = model.conv
- self.out = model.out
-
- def forward(self, x, x_mask):
- """Subsample x.
-
- Args:
- x (torch.Tensor): Input tensor (#batch, time, idim).
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
-
- Returns:
- torch.Tensor: Subsampled tensor (#batch, time', odim),
- where time' = time // 4.
- torch.Tensor: Subsampled mask (#batch, 1, time'),
- where time' = time // 4.
-
- """
- x = x.unsqueeze(1) # (b, c, t, f)
- x = self.conv(x)
- b, c, t, f = x.size()
- x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
- if x_mask is None:
- return x, None
- return x, x_mask[:, :-2:2][:, :-2:2]
-
- def __getitem__(self, key):
- """Get item.
-
- When reset_parameters() is called, if use_scaled_pos_enc is used,
- return the positioning encoding.
-
- """
- if key != -1:
- raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
- return self.out[key]
-
-
-class OnnxConv2dSubsampling2(torch.nn.Module):
- """Convolutional 2D subsampling (to 1/2 length).
-
- Args:
- idim (int): Input dimension.
- odim (int): Output dimension.
- dropout_rate (float): Dropout rate.
- pos_enc (torch.nn.Module): Custom position encoding layer.
-
- """
-
- def __init__(self, model):
- """Construct an Conv2dSubsampling object."""
- super().__init__()
- self.conv = model.conv
- self.out = model.out
-
- def forward(self, x, x_mask):
- """Subsample x.
-
- Args:
- x (torch.Tensor): Input tensor (#batch, time, idim).
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
-
- Returns:
- torch.Tensor: Subsampled tensor (#batch, time', odim),
- where time' = time // 2.
- torch.Tensor: Subsampled mask (#batch, 1, time'),
- where time' = time // 2.
-
- """
- x = x.unsqueeze(1) # (b, c, t, f)
- x = self.conv(x)
- b, c, t, f = x.size()
- x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
- if x_mask is None:
- return x, None
- return x, x_mask[:, :-2:2][:, :-2:1]
-
- def __getitem__(self, key):
- """Get item.
-
- When reset_parameters() is called, if use_scaled_pos_enc is used,
- return the positioning encoding.
-
- """
- if key != -1:
- raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
- return self.out[key]
-
-
-class OnnxConv2dSubsampling6(torch.nn.Module):
- """Convolutional 2D subsampling (to 1/6 length).
-
- Args:
- idim (int): Input dimension.
- odim (int): Output dimension.
- dropout_rate (float): Dropout rate.
- pos_enc (torch.nn.Module): Custom position encoding layer.
-
- """
-
- def __init__(self, model):
- """Construct an Conv2dSubsampling object."""
- super().__init__()
- self.conv = model.conv
- self.out = model.out
-
- def forward(self, x, x_mask):
- """Subsample x.
-
- Args:
- x (torch.Tensor): Input tensor (#batch, time, idim).
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
-
- Returns:
- torch.Tensor: Subsampled tensor (#batch, time', odim),
- where time' = time // 6.
- torch.Tensor: Subsampled mask (#batch, 1, time'),
- where time' = time // 6.
-
- """
- x = x.unsqueeze(1) # (b, c, t, f)
- x = self.conv(x)
- b, c, t, f = x.size()
- x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
- if x_mask is None:
- return x, None
- return x, x_mask[:, :-2:2][:, :-4:3]
-
-
-class OnnxConv2dSubsampling8(torch.nn.Module):
- """Convolutional 2D subsampling (to 1/8 length).
-
- Args:
- idim (int): Input dimension.
- odim (int): Output dimension.
- dropout_rate (float): Dropout rate.
- pos_enc (torch.nn.Module): Custom position encoding layer.
-
- """
-
- def __init__(self, model):
- """Construct an Conv2dSubsampling object."""
- super().__init__()
- self.conv = model.conv
- self.out = model.out
-
- def forward(self, x, x_mask):
- """Subsample x.
-
- Args:
- x (torch.Tensor): Input tensor (#batch, time, idim).
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
-
- Returns:
- torch.Tensor: Subsampled tensor (#batch, time', odim),
- where time' = time // 8.
- torch.Tensor: Subsampled mask (#batch, 1, time'),
- where time' = time // 8.
-
- """
- x = x.unsqueeze(1) # (b, c, t, f)
- x = self.conv(x)
- b, c, t, f = x.size()
- x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
- if x_mask is None:
- return x, None
- return x, x_mask[:, :-2:2][:, :-2:2][:, :-2:2]
diff --git a/funasr/export/models/language_models/transformer.py b/funasr/export/models/language_models/transformer.py
deleted file mode 100644
index ebf0574..0000000
--- a/funasr/export/models/language_models/transformer.py
+++ /dev/null
@@ -1,110 +0,0 @@
-import os
-
-import torch
-import torch.nn as nn
-from funasr.modules.vgg2l import import VGG2L
-from funasr.modules.attention import MultiHeadedAttention
-from funasr.modules.subsampling import (
- Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8)
-
-from funasr.export.models.modules.encoder_layer import EncoderLayerConformer as OnnxEncoderLayer
-from funasr.export.models.language_models.embed import Embedding
-from funasr.export.models.modules.multihead_att import OnnxMultiHeadedAttention
-
-from funasr.export.utils.torch_function import MakePadMask
-
-class TransformerLM(nn.Module, AbsExportModel):
- def __init__(self, model, max_seq_len=512, **kwargs):
- super().__init__()
- self.embed = Embedding(model.embed, max_seq_len)
- self.encoder = model.encoder
- self.decoder = model.decoder
- self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
- # replace multihead attention module into customized module.
- for i, d in enumerate(self.encoder.encoders):
- # d is EncoderLayer
- if isinstance(d.self_attn, MultiHeadedAttention):
- d.self_attn = OnnxMultiHeadedAttention(d.self_attn)
- self.encoder.encoders[i] = OnnxEncoderLayer(d)
-
- self.model_name = "transformer_lm"
- self.num_heads = self.encoder.encoders[0].self_attn.h
- self.hidden_size = self.encoder.encoders[0].self_attn.linear_out.out_features
-
- def prepare_mask(self, mask):
- if len(mask.shape) == 2:
- mask = mask[:, None, None, :]
- elif len(mask.shape) == 3:
- mask = mask[:, None, :]
- mask = 1 - mask
- return mask * -10000.0
-
- def forward(self, y, cache):
- feats_length = torch.ones(y.shape).sum(dim=-1).type(torch.long)
- mask = self.make_pad_mask(feats_length) # (B, T)
- mask = (y != 0) * mask
-
- xs = self.embed(y)
- # forward_one_step of Encoder
- if isinstance(
- self.encoder.embed,
- (Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8, VGG2L),
- ):
- xs, mask = self.encoder.embed(xs, mask)
- else:
- xs = self.encoder.embed(xs)
-
- new_cache = []
- mask = self.prepare_mask(mask)
- for c, e in zip(cache, self.encoder.encoders):
- xs, mask = e(xs, mask, c)
- new_cache.append(xs)
-
- if self.encoder.normalize_before:
- xs = self.encoder.after_norm(xs)
-
- h = self.decoder(xs[:, -1])
- return h, new_cache
-
- def get_dummy_inputs(self):
- tgt = torch.LongTensor([1]).unsqueeze(0)
- cache = [
- torch.zeros((1, 1, self.encoder.encoders[0].size))
- for _ in range(len(self.encoder.encoders))
- ]
- return (tgt, cache)
-
- def is_optimizable(self):
- return True
-
- def get_input_names(self):
- return ["tgt"] + ["cache_%d" % i for i in range(len(self.encoder.encoders))]
-
- def get_output_names(self):
- return ["y"] + ["out_cache_%d" % i for i in range(len(self.encoder.encoders))]
-
- def get_dynamic_axes(self):
- ret = {"tgt": {0: "tgt_batch", 1: "tgt_length"}}
- ret.update(
- {
- "cache_%d" % d: {0: "cache_%d_batch" % d, 1: "cache_%d_length" % d}
- for d in range(len(self.encoder.encoders))
- }
- )
- ret.update(
- {
- "out_cache_%d"
- % d: {0: "out_cache_%d_batch" % d, 1: "out_cache_%d_length" % d}
- for d in range(len(self.encoder.encoders))
- }
- )
- return ret
-
- def get_model_config(self, path):
- return {
- "use_lm": True,
- "model_path": os.path.join(path, f"{self.model_name}.onnx"),
- "lm_type": "TransformerLM",
- "odim": self.encoder.encoders[0].size,
- "nlayers": len(self.encoder.encoders),
- }
diff --git a/funasr/fileio/sound_scp.py b/funasr/fileio/sound_scp.py
index b912f1e..b9364c6 100644
--- a/funasr/fileio/sound_scp.py
+++ b/funasr/fileio/sound_scp.py
@@ -4,7 +4,7 @@
import random
import numpy as np
-import soundfile
+import librosa
import librosa
import torch
@@ -116,7 +116,7 @@
def __getitem__(self, key):
wav = self.data[key]
if self.normalize:
- # soundfile.read normalizes data to [-1,1] if dtype is not given
+ # librosa.load normalizes data to [-1,1] if dtype is not given
array, rate = librosa.load(
wav, sr=self.dest_sample_rate, mono=self.always_2d
)
diff --git a/funasr/layers/complex_utils.py b/funasr/layers/complex_utils.py
index bf4799f..d6f7c6d 100644
--- a/funasr/layers/complex_utils.py
+++ b/funasr/layers/complex_utils.py
@@ -5,8 +5,12 @@
from typing import Union
import torch
-from torch_complex import functional as FC
-from torch_complex.tensor import ComplexTensor
+try:
+ from torch_complex import functional as FC
+ from torch_complex.tensor import ComplexTensor
+except:
+ raise "Please install torch_complex firstly"
+
EPS = torch.finfo(torch.double).eps
diff --git a/funasr/layers/stft.py b/funasr/layers/stft.py
index dfb6919..67ebf7a 100644
--- a/funasr/layers/stft.py
+++ b/funasr/layers/stft.py
@@ -4,8 +4,11 @@
from typing import Union
import torch
-from torch_complex.tensor import ComplexTensor
+try:
+ from torch_complex.tensor import ComplexTensor
+except:
+ raise "Please install torch_complex firstly"
from funasr.modules.nets_utils import make_pad_mask
from funasr.layers.complex_utils import is_complex
from funasr.layers.inversible_interface import InversibleInterface
diff --git a/funasr/models/encoder/mossformer_encoder.py b/funasr/models/encoder/mossformer_encoder.py
index 54d80ca..f7d9c47 100644
--- a/funasr/models/encoder/mossformer_encoder.py
+++ b/funasr/models/encoder/mossformer_encoder.py
@@ -1,8 +1,10 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-
-from rotary_embedding_torch import RotaryEmbedding
+try:
+ from rotary_embedding_torch import RotaryEmbedding
+except:
+ raise "Please install rotary_embedding_torch by: \n pip install -U funasr[all]"
from funasr.modules.layer_norm import GlobalLayerNorm, CumulativeLayerNorm, ScaleNorm
from funasr.modules.embedding import ScaledSinuEmbedding
from funasr.modules.mossformer import FLASH_ShareA_FFConvM
diff --git a/funasr/models/frontend/default.py b/funasr/models/frontend/default.py
index b41af80..8d60e20 100644
--- a/funasr/models/frontend/default.py
+++ b/funasr/models/frontend/default.py
@@ -6,7 +6,10 @@
import humanfriendly
import numpy as np
import torch
-from torch_complex.tensor import ComplexTensor
+try:
+ from torch_complex.tensor import ComplexTensor
+except:
+ raise "Please install torch_complex firstly"
from funasr.layers.log_mel import LogMel
from funasr.layers.stft import Stft
diff --git a/funasr/modules/eend_ola/utils/kaldi_data.py b/funasr/modules/eend_ola/utils/kaldi_data.py
index 42f6d5e..53f6230 100644
--- a/funasr/modules/eend_ola/utils/kaldi_data.py
+++ b/funasr/modules/eend_ola/utils/kaldi_data.py
@@ -9,7 +9,7 @@
import sys
import numpy as np
import subprocess
-import soundfile as sf
+import librosa as sf
import io
from functools import lru_cache
@@ -67,18 +67,18 @@
# input piped command
p = subprocess.Popen(wav_rxfilename[:-1], shell=True,
stdout=subprocess.PIPE)
- data, samplerate = sf.read(io.BytesIO(p.stdout.read()),
+ data, samplerate = sf.load(io.BytesIO(p.stdout.read()),
dtype='float32')
# cannot seek
data = data[start:end]
elif wav_rxfilename == '-':
# stdin
- data, samplerate = sf.read(sys.stdin, dtype='float32')
+ data, samplerate = sf.load(sys.stdin, dtype='float32')
# cannot seek
data = data[start:end]
else:
# normal wav file
- data, samplerate = sf.read(wav_rxfilename, start=start, stop=end)
+ data, samplerate = sf.load(wav_rxfilename, start=start, stop=end)
return data, samplerate
diff --git a/funasr/modules/frontends/__init__.py b/funasr/modules/frontends/__init__.py
deleted file mode 100644
index b7f1773..0000000
--- a/funasr/modules/frontends/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""Initialize sub package."""
diff --git a/funasr/modules/frontends/beamformer.py b/funasr/modules/frontends/beamformer.py
deleted file mode 100644
index f3eccee..0000000
--- a/funasr/modules/frontends/beamformer.py
+++ /dev/null
@@ -1,84 +0,0 @@
-import torch
-from torch_complex import functional as FC
-from torch_complex.tensor import ComplexTensor
-
-
-def get_power_spectral_density_matrix(
- xs: ComplexTensor, mask: torch.Tensor, normalization=True, eps: float = 1e-15
-) -> ComplexTensor:
- """Return cross-channel power spectral density (PSD) matrix
-
- Args:
- xs (ComplexTensor): (..., F, C, T)
- mask (torch.Tensor): (..., F, C, T)
- normalization (bool):
- eps (float):
- Returns
- psd (ComplexTensor): (..., F, C, C)
-
- """
- # outer product: (..., C_1, T) x (..., C_2, T) -> (..., T, C, C_2)
- psd_Y = FC.einsum("...ct,...et->...tce", [xs, xs.conj()])
-
- # Averaging mask along C: (..., C, T) -> (..., T)
- mask = mask.mean(dim=-2)
-
- # Normalized mask along T: (..., T)
- if normalization:
- # If assuming the tensor is padded with zero, the summation along
- # the time axis is same regardless of the padding length.
- mask = mask / (mask.sum(dim=-1, keepdim=True) + eps)
-
- # psd: (..., T, C, C)
- psd = psd_Y * mask[..., None, None]
- # (..., T, C, C) -> (..., C, C)
- psd = psd.sum(dim=-3)
-
- return psd
-
-
-def get_mvdr_vector(
- psd_s: ComplexTensor,
- psd_n: ComplexTensor,
- reference_vector: torch.Tensor,
- eps: float = 1e-15,
-) -> ComplexTensor:
- """Return the MVDR(Minimum Variance Distortionless Response) vector:
-
- h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u
-
- Reference:
- On optimal frequency-domain multichannel linear filtering
- for noise reduction; M. Souden et al., 2010;
- https://ieeexplore.ieee.org/document/5089420
-
- Args:
- psd_s (ComplexTensor): (..., F, C, C)
- psd_n (ComplexTensor): (..., F, C, C)
- reference_vector (torch.Tensor): (..., C)
- eps (float):
- Returns:
- beamform_vector (ComplexTensor)r: (..., F, C)
- """
- # Add eps
- C = psd_n.size(-1)
- eye = torch.eye(C, dtype=psd_n.dtype, device=psd_n.device)
- shape = [1 for _ in range(psd_n.dim() - 2)] + [C, C]
- eye = eye.view(*shape)
- psd_n += eps * eye
-
- # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3)
- numerator = FC.einsum("...ec,...cd->...ed", [psd_n.inverse(), psd_s])
- # ws: (..., C, C) / (...,) -> (..., C, C)
- ws = numerator / (FC.trace(numerator)[..., None, None] + eps)
- # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
- beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector])
- return beamform_vector
-
-
-def apply_beamforming_vector(
- beamform_vector: ComplexTensor, mix: ComplexTensor
-) -> ComplexTensor:
- # (..., C) x (..., C, T) -> (..., T)
- es = FC.einsum("...c,...ct->...t", [beamform_vector.conj(), mix])
- return es
diff --git a/funasr/modules/frontends/dnn_beamformer.py b/funasr/modules/frontends/dnn_beamformer.py
deleted file mode 100644
index e75d771..0000000
--- a/funasr/modules/frontends/dnn_beamformer.py
+++ /dev/null
@@ -1,172 +0,0 @@
-"""DNN beamformer module."""
-from typing import Tuple
-
-import torch
-from torch.nn import functional as F
-
-from funasr.modules.frontends.beamformer import apply_beamforming_vector
-from funasr.modules.frontends.beamformer import get_mvdr_vector
-from funasr.modules.frontends.beamformer import (
- get_power_spectral_density_matrix, # noqa: H301
-)
-from funasr.modules.frontends.mask_estimator import MaskEstimator
-from torch_complex.tensor import ComplexTensor
-
-
-class DNN_Beamformer(torch.nn.Module):
- """DNN mask based Beamformer
-
- Citation:
- Multichannel End-to-end Speech Recognition; T. Ochiai et al., 2017;
- https://arxiv.org/abs/1703.04783
-
- """
-
- def __init__(
- self,
- bidim,
- btype="blstmp",
- blayers=3,
- bunits=300,
- bprojs=320,
- bnmask=2,
- dropout_rate=0.0,
- badim=320,
- ref_channel: int = -1,
- beamformer_type="mvdr",
- ):
- super().__init__()
- self.mask = MaskEstimator(
- btype, bidim, blayers, bunits, bprojs, dropout_rate, nmask=bnmask
- )
- self.ref = AttentionReference(bidim, badim)
- self.ref_channel = ref_channel
-
- self.nmask = bnmask
-
- if beamformer_type != "mvdr":
- raise ValueError(
- "Not supporting beamformer_type={}".format(beamformer_type)
- )
- self.beamformer_type = beamformer_type
-
- def forward(
- self, data: ComplexTensor, ilens: torch.LongTensor
- ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
- """The forward function
-
- Notation:
- B: Batch
- C: Channel
- T: Time or Sequence length
- F: Freq
-
- Args:
- data (ComplexTensor): (B, T, C, F)
- ilens (torch.Tensor): (B,)
- Returns:
- enhanced (ComplexTensor): (B, T, F)
- ilens (torch.Tensor): (B,)
-
- """
-
- def apply_beamforming(data, ilens, psd_speech, psd_noise):
- # u: (B, C)
- if self.ref_channel < 0:
- u, _ = self.ref(psd_speech, ilens)
- else:
- # (optional) Create onehot vector for fixed reference microphone
- u = torch.zeros(
- *(data.size()[:-3] + (data.size(-2),)), device=data.device
- )
- u[..., self.ref_channel].fill_(1)
-
- ws = get_mvdr_vector(psd_speech, psd_noise, u)
- enhanced = apply_beamforming_vector(ws, data)
-
- return enhanced, ws
-
- # data (B, T, C, F) -> (B, F, C, T)
- data = data.permute(0, 3, 2, 1)
-
- # mask: (B, F, C, T)
- masks, _ = self.mask(data, ilens)
- assert self.nmask == len(masks)
-
- if self.nmask == 2: # (mask_speech, mask_noise)
- mask_speech, mask_noise = masks
-
- psd_speech = get_power_spectral_density_matrix(data, mask_speech)
- psd_noise = get_power_spectral_density_matrix(data, mask_noise)
-
- enhanced, ws = apply_beamforming(data, ilens, psd_speech, psd_noise)
-
- # (..., F, T) -> (..., T, F)
- enhanced = enhanced.transpose(-1, -2)
- mask_speech = mask_speech.transpose(-1, -3)
- else: # multi-speaker case: (mask_speech1, ..., mask_noise)
- mask_speech = list(masks[:-1])
- mask_noise = masks[-1]
-
- psd_speeches = [
- get_power_spectral_density_matrix(data, mask) for mask in mask_speech
- ]
- psd_noise = get_power_spectral_density_matrix(data, mask_noise)
-
- enhanced = []
- ws = []
- for i in range(self.nmask - 1):
- psd_speech = psd_speeches.pop(i)
- # treat all other speakers' psd_speech as noises
- enh, w = apply_beamforming(
- data, ilens, psd_speech, sum(psd_speeches) + psd_noise
- )
- psd_speeches.insert(i, psd_speech)
-
- # (..., F, T) -> (..., T, F)
- enh = enh.transpose(-1, -2)
- mask_speech[i] = mask_speech[i].transpose(-1, -3)
-
- enhanced.append(enh)
- ws.append(w)
-
- return enhanced, ilens, mask_speech
-
-
-class AttentionReference(torch.nn.Module):
- def __init__(self, bidim, att_dim):
- super().__init__()
- self.mlp_psd = torch.nn.Linear(bidim, att_dim)
- self.gvec = torch.nn.Linear(att_dim, 1)
-
- def forward(
- self, psd_in: ComplexTensor, ilens: torch.LongTensor, scaling: float = 2.0
- ) -> Tuple[torch.Tensor, torch.LongTensor]:
- """The forward function
-
- Args:
- psd_in (ComplexTensor): (B, F, C, C)
- ilens (torch.Tensor): (B,)
- scaling (float):
- Returns:
- u (torch.Tensor): (B, C)
- ilens (torch.Tensor): (B,)
- """
- B, _, C = psd_in.size()[:3]
- assert psd_in.size(2) == psd_in.size(3), psd_in.size()
- # psd_in: (B, F, C, C)
- psd = psd_in.masked_fill(
- torch.eye(C, dtype=torch.bool, device=psd_in.device), 0
- )
- # psd: (B, F, C, C) -> (B, C, F)
- psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2)
-
- # Calculate amplitude
- psd_feat = (psd.real**2 + psd.imag**2) ** 0.5
-
- # (B, C, F) -> (B, C, F2)
- mlp_psd = self.mlp_psd(psd_feat)
- # (B, C, F2) -> (B, C, 1) -> (B, C)
- e = self.gvec(torch.tanh(mlp_psd)).squeeze(-1)
- u = F.softmax(scaling * e, dim=-1)
- return u, ilens
diff --git a/funasr/modules/frontends/dnn_wpe.py b/funasr/modules/frontends/dnn_wpe.py
deleted file mode 100644
index 9596765..0000000
--- a/funasr/modules/frontends/dnn_wpe.py
+++ /dev/null
@@ -1,93 +0,0 @@
-from typing import Tuple
-
-from pytorch_wpe import wpe_one_iteration
-import torch
-from torch_complex.tensor import ComplexTensor
-
-from funasr.modules.frontends.mask_estimator import MaskEstimator
-from funasr.modules.nets_utils import make_pad_mask
-
-
-class DNN_WPE(torch.nn.Module):
- def __init__(
- self,
- wtype: str = "blstmp",
- widim: int = 257,
- wlayers: int = 3,
- wunits: int = 300,
- wprojs: int = 320,
- dropout_rate: float = 0.0,
- taps: int = 5,
- delay: int = 3,
- use_dnn_mask: bool = True,
- iterations: int = 1,
- normalization: bool = False,
- ):
- super().__init__()
- self.iterations = iterations
- self.taps = taps
- self.delay = delay
-
- self.normalization = normalization
- self.use_dnn_mask = use_dnn_mask
-
- self.inverse_power = True
-
- if self.use_dnn_mask:
- self.mask_est = MaskEstimator(
- wtype, widim, wlayers, wunits, wprojs, dropout_rate, nmask=1
- )
-
- def forward(
- self, data: ComplexTensor, ilens: torch.LongTensor
- ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
- """The forward function
-
- Notation:
- B: Batch
- C: Channel
- T: Time or Sequence length
- F: Freq or Some dimension of the feature vector
-
- Args:
- data: (B, C, T, F)
- ilens: (B,)
- Returns:
- data: (B, C, T, F)
- ilens: (B,)
- """
- # (B, T, C, F) -> (B, F, C, T)
- enhanced = data = data.permute(0, 3, 2, 1)
- mask = None
-
- for i in range(self.iterations):
- # Calculate power: (..., C, T)
- power = enhanced.real**2 + enhanced.imag**2
- if i == 0 and self.use_dnn_mask:
- # mask: (B, F, C, T)
- (mask,), _ = self.mask_est(enhanced, ilens)
- if self.normalization:
- # Normalize along T
- mask = mask / mask.sum(dim=-1)[..., None]
- # (..., C, T) * (..., C, T) -> (..., C, T)
- power = power * mask
-
- # Averaging along the channel axis: (..., C, T) -> (..., T)
- power = power.mean(dim=-2)
-
- # enhanced: (..., C, T) -> (..., C, T)
- enhanced = wpe_one_iteration(
- data.contiguous(),
- power,
- taps=self.taps,
- delay=self.delay,
- inverse_power=self.inverse_power,
- )
-
- enhanced.masked_fill_(make_pad_mask(ilens, enhanced.real), 0)
-
- # (B, F, C, T) -> (B, T, C, F)
- enhanced = enhanced.permute(0, 3, 2, 1)
- if mask is not None:
- mask = mask.transpose(-1, -3)
- return enhanced, ilens, mask
diff --git a/funasr/modules/frontends/feature_transform.py b/funasr/modules/frontends/feature_transform.py
deleted file mode 100644
index 353dca1..0000000
--- a/funasr/modules/frontends/feature_transform.py
+++ /dev/null
@@ -1,263 +0,0 @@
-from typing import List
-from typing import Tuple
-from typing import Union
-
-import librosa
-import numpy as np
-import torch
-from torch_complex.tensor import ComplexTensor
-
-from funasr.modules.nets_utils import make_pad_mask
-
-
-class FeatureTransform(torch.nn.Module):
- def __init__(
- self,
- # Mel options,
- fs: int = 16000,
- n_fft: int = 512,
- n_mels: int = 80,
- fmin: float = 0.0,
- fmax: float = None,
- # Normalization
- stats_file: str = None,
- apply_uttmvn: bool = True,
- uttmvn_norm_means: bool = True,
- uttmvn_norm_vars: bool = False,
- ):
- super().__init__()
- self.apply_uttmvn = apply_uttmvn
-
- self.logmel = LogMel(fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
- self.stats_file = stats_file
- if stats_file is not None:
- self.global_mvn = GlobalMVN(stats_file)
- else:
- self.global_mvn = None
-
- if self.apply_uttmvn is not None:
- self.uttmvn = UtteranceMVN(
- norm_means=uttmvn_norm_means, norm_vars=uttmvn_norm_vars
- )
- else:
- self.uttmvn = None
-
- def forward(
- self, x: ComplexTensor, ilens: Union[torch.LongTensor, np.ndarray, List[int]]
- ) -> Tuple[torch.Tensor, torch.LongTensor]:
- # (B, T, F) or (B, T, C, F)
- if x.dim() not in (3, 4):
- raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
- if not torch.is_tensor(ilens):
- ilens = torch.from_numpy(np.asarray(ilens)).to(x.device)
-
- if x.dim() == 4:
- # h: (B, T, C, F) -> h: (B, T, F)
- if self.training:
- # Select 1ch randomly
- ch = np.random.randint(x.size(2))
- h = x[:, :, ch, :]
- else:
- # Use the first channel
- h = x[:, :, 0, :]
- else:
- h = x
-
- # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
- h = h.real**2 + h.imag**2
-
- h, _ = self.logmel(h, ilens)
- if self.stats_file is not None:
- h, _ = self.global_mvn(h, ilens)
- if self.apply_uttmvn:
- h, _ = self.uttmvn(h, ilens)
-
- return h, ilens
-
-
-class LogMel(torch.nn.Module):
- """Convert STFT to fbank feats
-
- The arguments is same as librosa.filters.mel
-
- Args:
- fs: number > 0 [scalar] sampling rate of the incoming signal
- n_fft: int > 0 [scalar] number of FFT components
- n_mels: int > 0 [scalar] number of Mel bands to generate
- fmin: float >= 0 [scalar] lowest frequency (in Hz)
- fmax: float >= 0 [scalar] highest frequency (in Hz).
- If `None`, use `fmax = fs / 2.0`
- htk: use HTK formula instead of Slaney
- norm: {None, 1, np.inf} [scalar]
- if 1, divide the triangular mel weights by the width of the mel band
- (area normalization). Otherwise, leave all the triangles aiming for
- a peak value of 1.0
-
- """
-
- def __init__(
- self,
- fs: int = 16000,
- n_fft: int = 512,
- n_mels: int = 80,
- fmin: float = 0.0,
- fmax: float = None,
- htk: bool = False,
- norm=1,
- ):
- super().__init__()
-
- _mel_options = dict(
- sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm
- )
- self.mel_options = _mel_options
-
- # Note(kamo): The mel matrix of librosa is different from kaldi.
- melmat = librosa.filters.mel(**_mel_options)
- # melmat: (D2, D1) -> (D1, D2)
- self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
-
- def extra_repr(self):
- return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
-
- def forward(
- self, feat: torch.Tensor, ilens: torch.LongTensor
- ) -> Tuple[torch.Tensor, torch.LongTensor]:
- # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
- mel_feat = torch.matmul(feat, self.melmat)
-
- logmel_feat = (mel_feat + 1e-20).log()
- # Zero padding
- logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens, logmel_feat, 1), 0.0)
- return logmel_feat, ilens
-
-
-class GlobalMVN(torch.nn.Module):
- """Apply global mean and variance normalization
-
- Args:
- stats_file(str): npy file of 1-dim array or text file.
- From the _first element to
- the {(len(array) - 1) / 2}th element are treated as
- the sum of features,
- and the rest excluding the last elements are
- treated as the sum of the square value of features,
- and the last elements eqauls to the number of samples.
- std_floor(float):
- """
-
- def __init__(
- self,
- stats_file: str,
- norm_means: bool = True,
- norm_vars: bool = True,
- eps: float = 1.0e-20,
- ):
- super().__init__()
- self.norm_means = norm_means
- self.norm_vars = norm_vars
-
- self.stats_file = stats_file
- stats = np.load(stats_file)
-
- stats = stats.astype(float)
- assert (len(stats) - 1) % 2 == 0, stats.shape
-
- count = stats.flatten()[-1]
- mean = stats[: (len(stats) - 1) // 2] / count
- var = stats[(len(stats) - 1) // 2 : -1] / count - mean * mean
- std = np.maximum(np.sqrt(var), eps)
-
- self.register_buffer("bias", torch.from_numpy(-mean.astype(np.float32)))
- self.register_buffer("scale", torch.from_numpy(1 / std.astype(np.float32)))
-
- def extra_repr(self):
- return (
- f"stats_file={self.stats_file}, "
- f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
- )
-
- def forward(
- self, x: torch.Tensor, ilens: torch.LongTensor
- ) -> Tuple[torch.Tensor, torch.LongTensor]:
- # feat: (B, T, D)
- if self.norm_means:
- x += self.bias.type_as(x)
- x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
-
- if self.norm_vars:
- x *= self.scale.type_as(x)
- return x, ilens
-
-
-class UtteranceMVN(torch.nn.Module):
- def __init__(
- self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20
- ):
- super().__init__()
- self.norm_means = norm_means
- self.norm_vars = norm_vars
- self.eps = eps
-
- def extra_repr(self):
- return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
-
- def forward(
- self, x: torch.Tensor, ilens: torch.LongTensor
- ) -> Tuple[torch.Tensor, torch.LongTensor]:
- return utterance_mvn(
- x, ilens, norm_means=self.norm_means, norm_vars=self.norm_vars, eps=self.eps
- )
-
-
-def utterance_mvn(
- x: torch.Tensor,
- ilens: torch.LongTensor,
- norm_means: bool = True,
- norm_vars: bool = False,
- eps: float = 1.0e-20,
-) -> Tuple[torch.Tensor, torch.LongTensor]:
- """Apply utterance mean and variance normalization
-
- Args:
- x: (B, T, D), assumed zero padded
- ilens: (B, T, D)
- norm_means:
- norm_vars:
- eps:
-
- """
- ilens_ = ilens.type_as(x)
- # mean: (B, D)
- mean = x.sum(dim=1) / ilens_[:, None]
-
- if norm_means:
- x -= mean[:, None, :]
- x_ = x
- else:
- x_ = x - mean[:, None, :]
-
- # Zero padding
- x_.masked_fill(make_pad_mask(ilens, x_, 1), 0.0)
- if norm_vars:
- var = x_.pow(2).sum(dim=1) / ilens_[:, None]
- var = torch.clamp(var, min=eps)
- x /= var.sqrt()[:, None, :]
- x_ = x
- return x_, ilens
-
-
-def feature_transform_for(args, n_fft):
- return FeatureTransform(
- # Mel options,
- fs=args.fbank_fs,
- n_fft=n_fft,
- n_mels=args.n_mels,
- fmin=args.fbank_fmin,
- fmax=args.fbank_fmax,
- # Normalization
- stats_file=args.stats_file,
- apply_uttmvn=args.apply_uttmvn,
- uttmvn_norm_means=args.uttmvn_norm_means,
- uttmvn_norm_vars=args.uttmvn_norm_vars,
- )
diff --git a/funasr/modules/frontends/frontend.py b/funasr/modules/frontends/frontend.py
deleted file mode 100644
index ab5ea3b..0000000
--- a/funasr/modules/frontends/frontend.py
+++ /dev/null
@@ -1,151 +0,0 @@
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-
-import numpy
-import torch
-import torch.nn as nn
-from torch_complex.tensor import ComplexTensor
-
-from funasr.modules.frontends.dnn_beamformer import DNN_Beamformer
-from funasr.modules.frontends.dnn_wpe import DNN_WPE
-
-
-class Frontend(nn.Module):
- def __init__(
- self,
- idim: int,
- # WPE options
- use_wpe: bool = False,
- wtype: str = "blstmp",
- wlayers: int = 3,
- wunits: int = 300,
- wprojs: int = 320,
- wdropout_rate: float = 0.0,
- taps: int = 5,
- delay: int = 3,
- use_dnn_mask_for_wpe: bool = True,
- # Beamformer options
- use_beamformer: bool = False,
- btype: str = "blstmp",
- blayers: int = 3,
- bunits: int = 300,
- bprojs: int = 320,
- bnmask: int = 2,
- badim: int = 320,
- ref_channel: int = -1,
- bdropout_rate=0.0,
- ):
- super().__init__()
-
- self.use_beamformer = use_beamformer
- self.use_wpe = use_wpe
- self.use_dnn_mask_for_wpe = use_dnn_mask_for_wpe
- # use frontend for all the data,
- # e.g. in the case of multi-speaker speech separation
- self.use_frontend_for_all = bnmask > 2
-
- if self.use_wpe:
- if self.use_dnn_mask_for_wpe:
- # Use DNN for power estimation
- # (Not observed significant gains)
- iterations = 1
- else:
- # Performing as conventional WPE, without DNN Estimator
- iterations = 2
-
- self.wpe = DNN_WPE(
- wtype=wtype,
- widim=idim,
- wunits=wunits,
- wprojs=wprojs,
- wlayers=wlayers,
- taps=taps,
- delay=delay,
- dropout_rate=wdropout_rate,
- iterations=iterations,
- use_dnn_mask=use_dnn_mask_for_wpe,
- )
- else:
- self.wpe = None
-
- if self.use_beamformer:
- self.beamformer = DNN_Beamformer(
- btype=btype,
- bidim=idim,
- bunits=bunits,
- bprojs=bprojs,
- blayers=blayers,
- bnmask=bnmask,
- dropout_rate=bdropout_rate,
- badim=badim,
- ref_channel=ref_channel,
- )
- else:
- self.beamformer = None
-
- def forward(
- self, x: ComplexTensor, ilens: Union[torch.LongTensor, numpy.ndarray, List[int]]
- ) -> Tuple[ComplexTensor, torch.LongTensor, Optional[ComplexTensor]]:
- assert len(x) == len(ilens), (len(x), len(ilens))
- # (B, T, F) or (B, T, C, F)
- if x.dim() not in (3, 4):
- raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
- if not torch.is_tensor(ilens):
- ilens = torch.from_numpy(numpy.asarray(ilens)).to(x.device)
-
- mask = None
- h = x
- if h.dim() == 4:
- if self.training:
- choices = [(False, False)] if not self.use_frontend_for_all else []
- if self.use_wpe:
- choices.append((True, False))
-
- if self.use_beamformer:
- choices.append((False, True))
-
- use_wpe, use_beamformer = choices[numpy.random.randint(len(choices))]
-
- else:
- use_wpe = self.use_wpe
- use_beamformer = self.use_beamformer
-
- # 1. WPE
- if use_wpe:
- # h: (B, T, C, F) -> h: (B, T, C, F)
- h, ilens, mask = self.wpe(h, ilens)
-
- # 2. Beamformer
- if use_beamformer:
- # h: (B, T, C, F) -> h: (B, T, F)
- h, ilens, mask = self.beamformer(h, ilens)
-
- return h, ilens, mask
-
-
-def frontend_for(args, idim):
- return Frontend(
- idim=idim,
- # WPE options
- use_wpe=args.use_wpe,
- wtype=args.wtype,
- wlayers=args.wlayers,
- wunits=args.wunits,
- wprojs=args.wprojs,
- wdropout_rate=args.wdropout_rate,
- taps=args.wpe_taps,
- delay=args.wpe_delay,
- use_dnn_mask_for_wpe=args.use_dnn_mask_for_wpe,
- # Beamformer options
- use_beamformer=args.use_beamformer,
- btype=args.btype,
- blayers=args.blayers,
- bunits=args.bunits,
- bprojs=args.bprojs,
- bnmask=args.bnmask,
- badim=args.badim,
- ref_channel=args.ref_channel,
- bdropout_rate=args.bdropout_rate,
- )
diff --git a/funasr/modules/frontends/mask_estimator.py b/funasr/modules/frontends/mask_estimator.py
deleted file mode 100644
index 53072bf..0000000
--- a/funasr/modules/frontends/mask_estimator.py
+++ /dev/null
@@ -1,77 +0,0 @@
-from typing import Tuple
-
-import numpy as np
-import torch
-from torch.nn import functional as F
-from torch_complex.tensor import ComplexTensor
-
-from funasr.modules.nets_utils import make_pad_mask
-from funasr.modules.rnn.encoders import RNN
-from funasr.modules.rnn.encoders import RNNP
-
-
-class MaskEstimator(torch.nn.Module):
- def __init__(self, type, idim, layers, units, projs, dropout, nmask=1):
- super().__init__()
- subsample = np.ones(layers + 1, dtype=np.int32)
-
- typ = type.lstrip("vgg").rstrip("p")
- if type[-1] == "p":
- self.brnn = RNNP(idim, layers, units, projs, subsample, dropout, typ=typ)
- else:
- self.brnn = RNN(idim, layers, units, projs, dropout, typ=typ)
-
- self.type = type
- self.nmask = nmask
- self.linears = torch.nn.ModuleList(
- [torch.nn.Linear(projs, idim) for _ in range(nmask)]
- )
-
- def forward(
- self, xs: ComplexTensor, ilens: torch.LongTensor
- ) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]:
- """The forward function
-
- Args:
- xs: (B, F, C, T)
- ilens: (B,)
- Returns:
- hs (torch.Tensor): The hidden vector (B, F, C, T)
- masks: A tuple of the masks. (B, F, C, T)
- ilens: (B,)
- """
- assert xs.size(0) == ilens.size(0), (xs.size(0), ilens.size(0))
- _, _, C, input_length = xs.size()
- # (B, F, C, T) -> (B, C, T, F)
- xs = xs.permute(0, 2, 3, 1)
-
- # Calculate amplitude: (B, C, T, F) -> (B, C, T, F)
- xs = (xs.real**2 + xs.imag**2) ** 0.5
- # xs: (B, C, T, F) -> xs: (B * C, T, F)
- xs = xs.contiguous().view(-1, xs.size(-2), xs.size(-1))
- # ilens: (B,) -> ilens_: (B * C)
- ilens_ = ilens[:, None].expand(-1, C).contiguous().view(-1)
-
- # xs: (B * C, T, F) -> xs: (B * C, T, D)
- xs, _, _ = self.brnn(xs, ilens_)
- # xs: (B * C, T, D) -> xs: (B, C, T, D)
- xs = xs.view(-1, C, xs.size(-2), xs.size(-1))
-
- masks = []
- for linear in self.linears:
- # xs: (B, C, T, D) -> mask:(B, C, T, F)
- mask = linear(xs)
-
- mask = torch.sigmoid(mask)
- # Zero padding
- mask.masked_fill(make_pad_mask(ilens, mask, length_dim=2), 0)
-
- # (B, C, T, F) -> (B, F, C, T)
- mask = mask.permute(0, 3, 1, 2)
-
- # Take cares of multi gpu cases: If input_length > max(ilens)
- if mask.size(-1) < input_length:
- mask = F.pad(mask, [0, input_length - mask.size(-1)], value=0)
- masks.append(mask)
-
- return tuple(masks), ilens
diff --git a/funasr/train/trainer.py b/funasr/train/trainer.py
index 27d6f9c..a5069d0 100644
--- a/funasr/train/trainer.py
+++ b/funasr/train/trainer.py
@@ -278,14 +278,11 @@
for iepoch in range(start_epoch, trainer_options.max_epoch + 1):
if iepoch != start_epoch:
logging.info(
- "{}/{}epoch started. Estimated time to finish: {}".format(
+ "{}/{}epoch started. Estimated time to finish: {} hours".format(
iepoch,
trainer_options.max_epoch,
- humanfriendly.format_timespan(
- (time.perf_counter() - start_time)
- / (iepoch - start_epoch)
- * (trainer_options.max_epoch - iepoch + 1)
- ),
+ (time.perf_counter() - start_time) / 3600.0 / (iepoch - start_epoch) * (
+ trainer_options.max_epoch - iepoch + 1),
)
)
else:
diff --git a/funasr/utils/asr_utils.py b/funasr/utils/asr_utils.py
index 5aa40ec..364746a 100644
--- a/funasr/utils/asr_utils.py
+++ b/funasr/utils/asr_utils.py
@@ -5,7 +5,7 @@
from typing import Any, Dict, List, Union
import torchaudio
-import soundfile
+import librosa
import numpy as np
import pkg_resources
from modelscope.utils.logger import get_logger
@@ -139,7 +139,7 @@
try:
audio, fs = torchaudio.load(fname)
except:
- audio, fs = soundfile.read(fname)
+ audio, fs = librosa.load(fname)
break
if audio_type.rfind(".scp") >= 0:
with open(fname, encoding="utf-8") as f:
diff --git a/funasr/utils/prepare_data.py b/funasr/utils/prepare_data.py
index 702c7f3..36eebdc 100644
--- a/funasr/utils/prepare_data.py
+++ b/funasr/utils/prepare_data.py
@@ -5,7 +5,7 @@
import kaldiio
import numpy as np
-import soundfile
+import librosa
import torch.distributed as dist
import torchaudio
@@ -46,7 +46,7 @@
try:
waveform, sampling_rate = torchaudio.load(wav_path)
except:
- waveform, sampling_rate = soundfile.read(wav_path)
+ waveform, sampling_rate = librosa.load(wav_path)
waveform = np.expand_dims(waveform, axis=0)
n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
diff --git a/funasr/utils/speaker_utils.py b/funasr/utils/speaker_utils.py
index a1c610f..38ef11c 100644
--- a/funasr/utils/speaker_utils.py
+++ b/funasr/utils/speaker_utils.py
@@ -12,7 +12,7 @@
from typing import Any, Dict, List, Union
import numpy as np
-import soundfile as sf
+import librosa as sf
import torch
import torchaudio
import logging
@@ -43,7 +43,7 @@
for i in range(len(inputs)):
if isinstance(inputs[i], str):
file_bytes = File.read(inputs[i])
- data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32')
+ data, fs = sf.load(io.BytesIO(file_bytes), dtype='float32')
if len(data.shape) == 2:
data = data[:, 0]
data = torch.from_numpy(data).unsqueeze(0)
diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py
index 6594273..c463f0c 100644
--- a/funasr/utils/timestamp_tools.py
+++ b/funasr/utils/timestamp_tools.py
@@ -3,7 +3,7 @@
import logging
import argparse
import numpy as np
-import edit_distance
+# import edit_distance
from itertools import zip_longest
@@ -160,112 +160,112 @@
return res
-class AverageShiftCalculator():
- def __init__(self):
- logging.warning("Calculating average shift.")
- def __call__(self, file1, file2):
- uttid_list1, ts_dict1 = self.read_timestamps(file1)
- uttid_list2, ts_dict2 = self.read_timestamps(file2)
- uttid_intersection = self._intersection(uttid_list1, uttid_list2)
- res = self.as_cal(uttid_intersection, ts_dict1, ts_dict2)
- logging.warning("Average shift of {} and {}: {}.".format(file1, file2, str(res)[:8]))
- logging.warning("Following timestamp pair differs most: {}, detail:{}".format(self.max_shift, self.max_shift_uttid))
-
- def _intersection(self, list1, list2):
- set1 = set(list1)
- set2 = set(list2)
- if set1 == set2:
- logging.warning("Uttid same checked.")
- return set1
- itsc = list(set1 & set2)
- logging.warning("Uttid differs: file1 {}, file2 {}, lines same {}.".format(len(list1), len(list2), len(itsc)))
- return itsc
-
- def read_timestamps(self, file):
- # read timestamps file in standard format
- uttid_list = []
- ts_dict = {}
- with codecs.open(file, 'r') as fin:
- for line in fin.readlines():
- text = ''
- ts_list = []
- line = line.rstrip()
- uttid = line.split()[0]
- uttid_list.append(uttid)
- body = " ".join(line.split()[1:])
- for pd in body.split(';'):
- if not len(pd): continue
- # pdb.set_trace()
- char, start, end = pd.lstrip(" ").split(' ')
- text += char + ','
- ts_list.append((float(start), float(end)))
- # ts_lists.append(ts_list)
- ts_dict[uttid] = (text[:-1], ts_list)
- logging.warning("File {} read done.".format(file))
- return uttid_list, ts_dict
-
- def _shift(self, filtered_timestamp_list1, filtered_timestamp_list2):
- shift_time = 0
- for fts1, fts2 in zip(filtered_timestamp_list1, filtered_timestamp_list2):
- shift_time += abs(fts1[0] - fts2[0]) + abs(fts1[1] - fts2[1])
- num_tokens = len(filtered_timestamp_list1)
- return shift_time, num_tokens
-
- def as_cal(self, uttid_list, ts_dict1, ts_dict2):
- # calculate average shift between timestamp1 and timestamp2
- # when characters differ, use edit distance alignment
- # and calculate the error between the same characters
- self._accumlated_shift = 0
- self._accumlated_tokens = 0
- self.max_shift = 0
- self.max_shift_uttid = None
- for uttid in uttid_list:
- (t1, ts1) = ts_dict1[uttid]
- (t2, ts2) = ts_dict2[uttid]
- _align, _align2, _align3 = [], [], []
- fts1, fts2 = [], []
- _t1, _t2 = [], []
- sm = edit_distance.SequenceMatcher(t1.split(','), t2.split(','))
- s = sm.get_opcodes()
- for j in range(len(s)):
- if s[j][0] == "replace" or s[j][0] == "insert":
- _align.append(0)
- if s[j][0] == "replace" or s[j][0] == "delete":
- _align3.append(0)
- elif s[j][0] == "equal":
- _align.append(1)
- _align3.append(1)
- else:
- continue
- # use s to index t2
- for a, ts , t in zip(_align, ts2, t2.split(',')):
- if a:
- fts2.append(ts)
- _t2.append(t)
- sm2 = edit_distance.SequenceMatcher(t2.split(','), t1.split(','))
- s = sm2.get_opcodes()
- for j in range(len(s)):
- if s[j][0] == "replace" or s[j][0] == "insert":
- _align2.append(0)
- elif s[j][0] == "equal":
- _align2.append(1)
- else:
- continue
- # use s2 tp index t1
- for a, ts, t in zip(_align3, ts1, t1.split(',')):
- if a:
- fts1.append(ts)
- _t1.append(t)
- if len(fts1) == len(fts2):
- shift_time, num_tokens = self._shift(fts1, fts2)
- self._accumlated_shift += shift_time
- self._accumlated_tokens += num_tokens
- if shift_time/num_tokens > self.max_shift:
- self.max_shift = shift_time/num_tokens
- self.max_shift_uttid = uttid
- else:
- logging.warning("length mismatch")
- return self._accumlated_shift / self._accumlated_tokens
+# class AverageShiftCalculator():
+# def __init__(self):
+# logging.warning("Calculating average shift.")
+# def __call__(self, file1, file2):
+# uttid_list1, ts_dict1 = self.read_timestamps(file1)
+# uttid_list2, ts_dict2 = self.read_timestamps(file2)
+# uttid_intersection = self._intersection(uttid_list1, uttid_list2)
+# res = self.as_cal(uttid_intersection, ts_dict1, ts_dict2)
+# logging.warning("Average shift of {} and {}: {}.".format(file1, file2, str(res)[:8]))
+# logging.warning("Following timestamp pair differs most: {}, detail:{}".format(self.max_shift, self.max_shift_uttid))
+#
+# def _intersection(self, list1, list2):
+# set1 = set(list1)
+# set2 = set(list2)
+# if set1 == set2:
+# logging.warning("Uttid same checked.")
+# return set1
+# itsc = list(set1 & set2)
+# logging.warning("Uttid differs: file1 {}, file2 {}, lines same {}.".format(len(list1), len(list2), len(itsc)))
+# return itsc
+#
+# def read_timestamps(self, file):
+# # read timestamps file in standard format
+# uttid_list = []
+# ts_dict = {}
+# with codecs.open(file, 'r') as fin:
+# for line in fin.readlines():
+# text = ''
+# ts_list = []
+# line = line.rstrip()
+# uttid = line.split()[0]
+# uttid_list.append(uttid)
+# body = " ".join(line.split()[1:])
+# for pd in body.split(';'):
+# if not len(pd): continue
+# # pdb.set_trace()
+# char, start, end = pd.lstrip(" ").split(' ')
+# text += char + ','
+# ts_list.append((float(start), float(end)))
+# # ts_lists.append(ts_list)
+# ts_dict[uttid] = (text[:-1], ts_list)
+# logging.warning("File {} read done.".format(file))
+# return uttid_list, ts_dict
+#
+# def _shift(self, filtered_timestamp_list1, filtered_timestamp_list2):
+# shift_time = 0
+# for fts1, fts2 in zip(filtered_timestamp_list1, filtered_timestamp_list2):
+# shift_time += abs(fts1[0] - fts2[0]) + abs(fts1[1] - fts2[1])
+# num_tokens = len(filtered_timestamp_list1)
+# return shift_time, num_tokens
+#
+# # def as_cal(self, uttid_list, ts_dict1, ts_dict2):
+# # # calculate average shift between timestamp1 and timestamp2
+# # # when characters differ, use edit distance alignment
+# # # and calculate the error between the same characters
+# # self._accumlated_shift = 0
+# # self._accumlated_tokens = 0
+# # self.max_shift = 0
+# # self.max_shift_uttid = None
+# # for uttid in uttid_list:
+# # (t1, ts1) = ts_dict1[uttid]
+# # (t2, ts2) = ts_dict2[uttid]
+# # _align, _align2, _align3 = [], [], []
+# # fts1, fts2 = [], []
+# # _t1, _t2 = [], []
+# # sm = edit_distance.SequenceMatcher(t1.split(','), t2.split(','))
+# # s = sm.get_opcodes()
+# # for j in range(len(s)):
+# # if s[j][0] == "replace" or s[j][0] == "insert":
+# # _align.append(0)
+# # if s[j][0] == "replace" or s[j][0] == "delete":
+# # _align3.append(0)
+# # elif s[j][0] == "equal":
+# # _align.append(1)
+# # _align3.append(1)
+# # else:
+# # continue
+# # # use s to index t2
+# # for a, ts , t in zip(_align, ts2, t2.split(',')):
+# # if a:
+# # fts2.append(ts)
+# # _t2.append(t)
+# # sm2 = edit_distance.SequenceMatcher(t2.split(','), t1.split(','))
+# # s = sm2.get_opcodes()
+# # for j in range(len(s)):
+# # if s[j][0] == "replace" or s[j][0] == "insert":
+# # _align2.append(0)
+# # elif s[j][0] == "equal":
+# # _align2.append(1)
+# # else:
+# # continue
+# # # use s2 tp index t1
+# # for a, ts, t in zip(_align3, ts1, t1.split(',')):
+# # if a:
+# # fts1.append(ts)
+# # _t1.append(t)
+# # if len(fts1) == len(fts2):
+# # shift_time, num_tokens = self._shift(fts1, fts2)
+# # self._accumlated_shift += shift_time
+# # self._accumlated_tokens += num_tokens
+# # if shift_time/num_tokens > self.max_shift:
+# # self.max_shift = shift_time/num_tokens
+# # self.max_shift_uttid = uttid
+# # else:
+# # logging.warning("length mismatch")
+# # return self._accumlated_shift / self._accumlated_tokens
def convert_external_alphas(alphas_file, text_file, output_file):
@@ -311,10 +311,10 @@
def main(args):
- if args.mode == 'cal_aas':
- asc = AverageShiftCalculator()
- asc(args.input, args.input2)
- elif args.mode == 'read_ext_alphas':
+ # if args.mode == 'cal_aas':
+ # asc = AverageShiftCalculator()
+ # asc(args.input, args.input2)
+ if args.mode == 'read_ext_alphas':
convert_external_alphas(args.input, args.input2, args.output)
else:
logging.error("Mode {} not in SUPPORTED_MODES: {}.".format(args.mode, SUPPORTED_MODES))
diff --git a/funasr/utils/wav_utils.py b/funasr/utils/wav_utils.py
index bd067c2..8c2dc68 100644
--- a/funasr/utils/wav_utils.py
+++ b/funasr/utils/wav_utils.py
@@ -11,7 +11,7 @@
import numpy as np
import torch
import torchaudio
-import soundfile
+import librosa
import torchaudio.compliance.kaldi as kaldi
@@ -166,7 +166,7 @@
try:
waveform, audio_sr = torchaudio.load(wav_file)
except:
- waveform, audio_sr = soundfile.read(wav_file, dtype='float32')
+ waveform, audio_sr = librosa.load(wav_file, dtype='float32')
if waveform.ndim == 2:
waveform = waveform[:, 0]
waveform = torch.tensor(np.expand_dims(waveform, axis=0))
@@ -191,7 +191,7 @@
try:
waveform, sampling_rate = torchaudio.load(wav_path)
except:
- waveform, sampling_rate = soundfile.read(wav_path)
+ waveform, sampling_rate = librosa.load(wav_path)
waveform = torch.tensor(np.expand_dims(waveform, axis=0))
speech_length = (waveform.shape[1] / sampling_rate) * 1000.
n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
diff --git a/funasr/utils/whisper_utils/audio.py b/funasr/utils/whisper_utils/audio.py
index 004bd0d..6dd4cb1 100644
--- a/funasr/utils/whisper_utils/audio.py
+++ b/funasr/utils/whisper_utils/audio.py
@@ -1,8 +1,11 @@
import os
from functools import lru_cache
from typing import Union
+try:
+ import ffmpeg
+except:
+ print("Please Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.")
-import ffmpeg
import numpy as np
import torch
import torch.nn.functional as F
diff --git a/setup.py b/setup.py
index dd485d3..5b7b83c 100644
--- a/setup.py
+++ b/setup.py
@@ -10,36 +10,36 @@
requirements = {
"install": [
- "setuptools>=38.5.1",
+ # "setuptools>=38.5.1",
"humanfriendly",
"scipy>=1.4.1",
"librosa",
- "jamo", # For kss
+ # "jamo", # For kss
"PyYAML>=5.1.2",
- "soundfile>=0.12.1",
- "h5py>=3.1.0",
+ # "soundfile>=0.12.1",
+ # "h5py>=3.1.0",
"kaldiio>=2.17.0",
- "torch_complex",
- "nltk>=3.4.5",
+ # "torch_complex",
+ # "nltk>=3.4.5",
# ASR
- "sentencepiece",
+ "sentencepiece", # train
"jieba",
- "rotary_embedding_torch",
- "ffmpeg",
+ # "rotary_embedding_torch",
+ # "ffmpeg-python",
# TTS
- "pypinyin>=0.44.0",
- "espnet_tts_frontend",
+ # "pypinyin>=0.44.0",
+ # "espnet_tts_frontend",
# ENH
- "pytorch_wpe",
+ # "pytorch_wpe",
"editdistance>=0.5.2",
"tensorboard",
- "g2p",
- "nara_wpe",
+ # "g2p",
+ # "nara_wpe",
# PAI
"oss2",
- "edit-distance",
- "textgrid",
- "protobuf",
+ # "edit-distance",
+ # "textgrid",
+ # "protobuf",
"tqdm",
"hdbscan",
"umap",
@@ -104,7 +104,7 @@
name="funasr",
version=version,
url="https://github.com/alibaba-damo-academy/FunASR.git",
- author="Speech Lab of DAMO Academy, Alibaba Group",
+ author="Speech Lab of Alibaba Group",
author_email="funasr@list.alibaba-inc.com",
description="FunASR: A Fundamental End-to-End Speech Recognition Toolkit",
long_description=open(os.path.join(dirname, "README.md"), encoding="utf-8").read(),
--
Gitblit v1.9.1