From 69dcdbcfc0c21627c40fb8f7c435136c11691574 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期二, 27 六月 2023 17:19:49 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR into main
---
funasr/fileio/sound_scp.py | 70 +++++++
funasr/version.txt | 2
funasr/utils/wav_utils.py | 15 +
.gitignore | 3
funasr/models/encoder/conformer_encoder.py | 55 +++++
funasr/utils/runtime_sdk_download_tool.py | 38 ++++
setup.py | 2
funasr/datasets/iterable_dataset.py | 12 +
funasr/modules/subsampling.py | 21 +-
funasr/modules/eend_ola/utils/__init__.py | 0
egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml | 4
README.md | 11
funasr/datasets/large_datasets/dataset.py | 11 +
funasr/utils/prepare_data.py | 7
docs/installation/installation.md | 14
funasr/datasets/small_datasets/__init__.py | 0
docs/benchmark/benchmark_pipeline_cer.md | 254 ++++++++++++++--------------
funasr/bin/asr_inference_launch.py | 9
funasr/utils/asr_utils.py | 6
19 files changed, 368 insertions(+), 166 deletions(-)
diff --git a/.gitignore b/.gitignore
index 58bee36..d47674c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -18,4 +18,5 @@
build
funasr.egg-info
docs/_build
-modelscope
\ No newline at end of file
+modelscope
+samples
\ No newline at end of file
diff --git a/README.md b/README.md
index 8368b3b..26cf940 100644
--- a/README.md
+++ b/README.md
@@ -34,9 +34,9 @@
Install from pip
```shell
-pip install -U funasr
+pip3 install -U funasr
# For the users in China, you could install with the command:
-# pip install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
+# pip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
Or install from source code
@@ -96,14 +96,17 @@
### runtime
An example with websocket:
+
For the server:
```shell
+cd funasr/runtime/python/websocket
python wss_srv_asr.py --port 10095
```
+
For the client:
```shell
-python wss_client_asr.py --host "0.0.0.0" --port 10095 --mode 2pass --chunk_size "5,10,5"
-#python wss_client_asr.py --host "0.0.0.0" --port 10095 --mode 2pass --chunk_size "8,8,4" --audio_in "./data/wav.scp" --output_dir "./results"
+python wss_client_asr.py --host "127.0.0.1" --port 10095 --mode 2pass --chunk_size "5,10,5"
+#python wss_client_asr.py --host "127.0.0.1" --port 10095 --mode 2pass --chunk_size "8,8,4" --audio_in "./data/wav.scp" --output_dir "./results"
```
More examples could be found in [docs](https://alibaba-damo-academy.github.io/FunASR/en/runtime/websocket_python.html#id2)
## Contact
diff --git a/docs/benchmark/benchmark_pipeline_cer.md b/docs/benchmark/benchmark_pipeline_cer.md
index 97776a6..d978f3e 100644
--- a/docs/benchmark/benchmark_pipeline_cer.md
+++ b/docs/benchmark/benchmark_pipeline_cer.md
@@ -45,156 +45,156 @@
### Chinese Dataset
-<table>
+<table border="1">
<tr align="center">
- <td>Model</td>
- <td>Offline/Online</td>
- <td colspan="2">Aishell1</td>
- <td colspan="4">Aishell2</td>
- <td colspan="3">WenetSpeech</td>
+ <td style="border: 1px solid">Model</td>
+ <td style="border: 1px solid">Offline/Online</td>
+ <td colspan="2" style="border: 1px solid">Aishell1</td>
+ <td colspan="4" style="border: 1px solid">Aishell2</td>
+ <td colspan="3" style="border: 1px solid">WenetSpeech</td>
</tr>
<tr align="center">
- <td></td>
- <td></td>
- <td>dev</td>
- <td>test</td>
- <td>dev_ios</td>
- <td>test_ios</td>
- <td>test_android</td>
- <td>test_mic</td>
- <td>dev</td>
- <td>test_meeting</td>
- <td>test_net</td>
+ <td style="border: 1px solid"></td>
+ <td style="border: 1px solid"></td>
+ <td style="border: 1px solid">dev</td>
+ <td style="border: 1px solid">test</td>
+ <td style="border: 1px solid">dev_ios</td>
+ <td style="border: 1px solid">test_ios</td>
+ <td style="border: 1px solid">test_android</td>
+ <td style="border: 1px solid">test_mic</td>
+ <td style="border: 1px solid">dev</td>
+ <td style="border: 1px solid">test_meeting</td>
+ <td style="border: 1px solid">test_net</td>
</tr>
<tr align="center">
- <td> <a href="https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary">Paraformer-large</a> </td>
- <td>Offline</td>
- <td>1.76</td>
- <td>1.94</td>
- <td>2.79</td>
- <td>2.84</td>
- <td>3.08</td>
- <td>3.03</td>
- <td>3.43</td>
- <td>7.01</td>
- <td>6.66</td>
+ <td style="border: 1px solid"> <a href="https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary">Paraformer-large</a> </td>
+ <td style="border: 1px solid">Offline</td>
+ <td style="border: 1px solid">1.76</td>
+ <td style="border: 1px solid">1.94</td>
+ <td style="border: 1px solid">2.79</td>
+ <td style="border: 1px solid">2.84</td>
+ <td style="border: 1px solid">3.08</td>
+ <td style="border: 1px solid">3.03</td>
+ <td style="border: 1px solid">3.43</td>
+ <td style="border: 1px solid">7.01</td>
+ <td style="border: 1px solid">6.66</td>
</tr>
<tr align="center">
- <td> <a href="https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary">Paraformer-large-long</a> </td>
- <td>Offline</td>
- <td>1.80</td>
- <td>2.10</td>
- <td>2.78</td>
- <td>2.87</td>
- <td>3.12</td>
- <td>3.11</td>
- <td>3.44</td>
- <td>13.28</td>
- <td>7.08</td>
+ <td style="border: 1px solid"> <a href="https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary">Paraformer-large-long</a> </td>
+ <td style="border: 1px solid">Offline</td>
+ <td style="border: 1px solid">1.80</td>
+ <td style="border: 1px solid">2.10</td>
+ <td style="border: 1px solid">2.78</td>
+ <td style="border: 1px solid">2.87</td>
+ <td style="border: 1px solid">3.12</td>
+ <td style="border: 1px solid">3.11</td>
+ <td style="border: 1px solid">3.44</td>
+ <td style="border: 1px solid">13.28</td>
+ <td style="border: 1px solid">7.08</td>
</tr>
<tr align="center">
- <td> <a href="https://www.modelscope.cn/models/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/summary">Paraformer-large-contextual</a> </td>
- <td>Offline</td>
- <td>1.76</td>
- <td>2.02</td>
- <td>2.73</td>
- <td>2.85</td>
- <td>2.98</td>
- <td>2.95</td>
- <td>3.42</td>
- <td>7.16</td>
- <td>6.72</td>
+ <td style="border: 1px solid"> <a href="https://www.modelscope.cn/models/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/summary">Paraformer-large-contextual</a> </td>
+ <td style="border: 1px solid">Offline</td>
+ <td style="border: 1px solid">1.76</td>
+ <td style="border: 1px solid">2.02</td>
+ <td style="border: 1px solid">2.73</td>
+ <td style="border: 1px solid">2.85</td>
+ <td style="border: 1px solid">2.98</td>
+ <td style="border: 1px solid">2.95</td>
+ <td style="border: 1px solid">3.42</td>
+ <td style="border: 1px solid">7.16</td>
+ <td style="border: 1px solid">6.72</td>
</tr>
<tr align="center">
- <td> <a href="https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary">Paraformer</a> </td>
- <td>Offline</td>
- <td>3.24</td>
- <td>3.69</td>
- <td>4.58</td>
- <td>4.63</td>
- <td>4.83</td>
- <td>4.71</td>
- <td>4.19</td>
- <td>8.32</td>
- <td>9.19</td>
+ <td style="border: 1px solid"> <a href="https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary">Paraformer</a> </td>
+ <td style="border: 1px solid">Offline</td>
+ <td style="border: 1px solid">3.24</td>
+ <td style="border: 1px solid">3.69</td>
+ <td style="border: 1px solid">4.58</td>
+ <td style="border: 1px solid">4.63</td>
+ <td style="border: 1px solid">4.83</td>
+ <td style="border: 1px solid">4.71</td>
+ <td style="border: 1px solid">4.19</td>
+ <td style="border: 1px solid">8.32</td>
+ <td style="border: 1px solid">9.19</td>
</tr>
<tr align="center">
- <td> <a href="https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/summary">UniASR</a> </td>
- <td>Online</td>
- <td>3.34</td>
- <td>3.99</td>
- <td>4.62</td>
- <td>4.52</td>
- <td>4.77</td>
- <td>4.73</td>
- <td>4.51</td>
- <td>10.63</td>
- <td>9.70</td>
+ <td style="border: 1px solid"> <a href="https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/summary">UniASR</a> </td>
+ <td style="border: 1px solid">Online</td>
+ <td style="border: 1px solid">3.34</td>
+ <td style="border: 1px solid">3.99</td>
+ <td style="border: 1px solid">4.62</td>
+ <td style="border: 1px solid">4.52</td>
+ <td style="border: 1px solid">4.77</td>
+ <td style="border: 1px solid">4.73</td>
+ <td style="border: 1px solid">4.51</td>
+ <td style="border: 1px solid">10.63</td>
+ <td style="border: 1px solid">9.70</td>
</tr>
<tr align="center">
- <td> <a href="https://modelscope.cn/models/damo/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/summary">UniASR-large</a> </td>
- <td>Offline</td>
- <td>2.93</td>
- <td>3.48</td>
- <td>3.95</td>
- <td>3.87</td>
- <td>4.11</td>
- <td>4.11</td>
- <td>4.16</td>
- <td>10.09</td>
- <td>8.69</td>
+ <td style="border: 1px solid"> <a href="https://modelscope.cn/models/damo/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/summary">UniASR-large</a> </td>
+ <td style="border: 1px solid">Offline</td>
+ <td style="border: 1px solid">2.93</td>
+ <td style="border: 1px solid">3.48</td>
+ <td style="border: 1px solid">3.95</td>
+ <td style="border: 1px solid">3.87</td>
+ <td style="border: 1px solid">4.11</td>
+ <td style="border: 1px solid">4.11</td>
+ <td style="border: 1px solid">4.16</td>
+ <td style="border: 1px solid">10.09</td>
+ <td style="border: 1px solid">8.69</td>
</tr>
<tr align="center">
- <td> <a href="https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell1-pytorch/summary">Paraformer-aishell</a> </td>
- <td>Offline</td>
- <td>4.88</td>
- <td>5.43</td>
- <td>-</td>
- <td>-</td>
- <td>-</td>
- <td>-</td>
- <td>-</td>
- <td>-</td>
- <td>-</td>
+ <td style="border: 1px solid"> <a href="https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell1-pytorch/summary">Paraformer-aishell</a> </td>
+ <td style="border: 1px solid">Offline</td>
+ <td style="border: 1px solid">4.88</td>
+ <td style="border: 1px solid">5.43</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
</tr>
<tr align="center">
- <td> <a href="https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary">ParaformerBert-aishell</a> </td>
- <td>Offline</td>
- <td>6.14</td>
- <td>7.01</td>
- <td>-</td>
- <td>-</td>
- <td>-</td>
- <td>-</td>
- <td>-</td>
- <td>-</td>
- <td>-</td>
+ <td style="border: 1px solid"> <a href="https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary">ParaformerBert-aishell</a> </td>
+ <td style="border: 1px solid">Offline</td>
+ <td style="border: 1px solid">6.14</td>
+ <td style="border: 1px solid">7.01</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
</tr>
<tr align="center">
- <td> <a href="https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary">Paraformer-aishell2</a> </td>
- <td>Offline</td>
- <td>-</td>
- <td>-</td>
- <td>5.82</td>
- <td>6.30</td>
- <td>6.60</td>
- <td>5.83</td>
- <td>-</td>
- <td>-</td>
- <td>-</td>
+ <td style="border: 1px solid"> <a href="https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary">Paraformer-aishell2</a> </td>
+ <td style="border: 1px solid">Offline</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">5.82</td>
+ <td style="border: 1px solid">6.30</td>
+ <td style="border: 1px solid">6.60</td>
+ <td style="border: 1px solid">5.83</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
</tr>
<tr align="center">
- <td> <a href="https://www.modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary">ParaformerBert-aishell2</a> </td>
- <td>Offline</td>
- <td>-</td>
- <td>-</td>
- <td>4.95</td>
- <td>5.45</td>
- <td>5.59</td>
- <td>5.83</td>
- <td>-</td>
- <td>-</td>
- <td>-</td>
+ <td style="border: 1px solid"> <a href="https://www.modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary">ParaformerBert-aishell2</a> </td>
+ <td style="border: 1px solid">Offline</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">4.95</td>
+ <td style="border: 1px solid">5.45</td>
+ <td style="border: 1px solid">5.59</td>
+ <td style="border: 1px solid">5.83</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
</tr>
</table>
diff --git a/docs/installation/installation.md b/docs/installation/installation.md
index d020b51..f81ae83 100755
--- a/docs/installation/installation.md
+++ b/docs/installation/installation.md
@@ -32,7 +32,7 @@
### Install Pytorch (version >= 1.11.0):
```sh
-pip install torch torchaudio
+pip3 install torch torchaudio
```
If there exists CUDAs in your environments, you should install the pytorch with the version matching the CUDA. The matching list could be found in [docs](https://pytorch.org/get-started/previous-versions/).
### Install funasr
@@ -40,27 +40,27 @@
#### Install from pip
```shell
-pip install -U funasr
+pip3 install -U funasr
# For the users in China, you could install with the command:
-# pip install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
+# pip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
#### Or install from source code
``` sh
git clone https://github.com/alibaba/FunASR.git && cd FunASR
-pip install -e ./
+pip3 install -e ./
# For the users in China, you could install with the command:
-# pip install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple
+# pip3 install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
### Install modelscope (Optional)
If you want to use the pretrained models in ModelScope, you should install the modelscope:
```shell
-pip install -U modelscope
+pip3 install -U modelscope
# For the users in China, you could install with the command:
-# pip install -U modelscope -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -i https://mirror.sjtu.edu.cn/pypi/web/simple
+# pip3 install -U modelscope -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
### FQA
diff --git a/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
index 59f9936..a1f27a3 100644
--- a/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
+++ b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
@@ -6,7 +6,7 @@
unified_model_training: true
default_chunk_size: 16
jitter_range: 4
- left_chunk_size: 0
+ left_chunk_size: 1
embed_vgg_like: false
subsampling_factor: 4
linear_units: 2048
@@ -51,7 +51,7 @@
# optimization related
accum_grad: 1
grad_clip: 5
-max_epoch: 200
+max_epoch: 120
val_scheduler_criterion:
- valid
- loss
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index 656a965..5d1b804 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -19,6 +19,7 @@
import numpy as np
import torch
import torchaudio
+import soundfile
import yaml
from typeguard import check_argument_types
@@ -863,7 +864,13 @@
raw_inputs = _load_bytes(data_path_and_name_and_type[0])
raw_inputs = torch.tensor(raw_inputs)
if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
- raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
+ try:
+ raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
+ except:
+ raw_inputs = soundfile.read(data_path_and_name_and_type[0], dtype='float32')[0]
+ if raw_inputs.ndim == 2:
+ raw_inputs = raw_inputs[:, 0]
+ raw_inputs = torch.tensor(raw_inputs)
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, np.ndarray):
raw_inputs = torch.tensor(raw_inputs)
diff --git a/funasr/datasets/iterable_dataset.py b/funasr/datasets/iterable_dataset.py
index 4b2fb1a..d240d93 100644
--- a/funasr/datasets/iterable_dataset.py
+++ b/funasr/datasets/iterable_dataset.py
@@ -14,6 +14,7 @@
import numpy as np
import torch
import torchaudio
+import soundfile
from torch.utils.data.dataset import IterableDataset
from typeguard import check_argument_types
import os.path
@@ -66,8 +67,17 @@
bytes = f.read()
return load_bytes(bytes)
+def load_wav(input):
+ try:
+ return torchaudio.load(input)[0].numpy()
+ except:
+ waveform, _ = soundfile.read(input, dtype='float32')
+ if waveform.ndim == 2:
+ waveform = waveform[:, 0]
+ return np.expand_dims(waveform, axis=0)
+
DATA_TYPES = {
- "sound": lambda x: torchaudio.load(x)[0].numpy(),
+ "sound": load_wav,
"pcm": load_pcm,
"kaldi_ark": load_kaldi,
"bytes": load_bytes,
diff --git a/funasr/datasets/large_datasets/dataset.py b/funasr/datasets/large_datasets/dataset.py
index 68b63e1..5f2c2c6 100644
--- a/funasr/datasets/large_datasets/dataset.py
+++ b/funasr/datasets/large_datasets/dataset.py
@@ -6,6 +6,8 @@
import torch
import torch.distributed as dist
import torchaudio
+import numpy as np
+import soundfile
from kaldiio import ReadHelper
from torch.utils.data import IterableDataset
@@ -123,7 +125,14 @@
sample_dict["key"] = key
elif data_type == "sound":
key, path = item.strip().split()
- waveform, sampling_rate = torchaudio.load(path)
+ try:
+ waveform, sampling_rate = torchaudio.load(path)
+ except:
+ waveform, sampling_rate = soundfile.read(path, dtype='float32')
+ if waveform.ndim == 2:
+ waveform = waveform[:, 0]
+ waveform = np.expand_dims(waveform, axis=0)
+ waveform = torch.tensor(waveform)
if self.frontend_conf is not None:
if sampling_rate != self.frontend_conf["fs"]:
waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
diff --git a/funasr/datasets/small_datasets/__init__.py b/funasr/datasets/small_datasets/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/datasets/small_datasets/__init__.py
diff --git a/funasr/fileio/sound_scp.py b/funasr/fileio/sound_scp.py
index c752fe6..9b25fe5 100644
--- a/funasr/fileio/sound_scp.py
+++ b/funasr/fileio/sound_scp.py
@@ -1,6 +1,6 @@
import collections.abc
from pathlib import Path
-from typing import Union
+from typing import List, Tuple, Union
import random
import numpy as np
@@ -13,6 +13,74 @@
from funasr.fileio.read_text import read_2column_text
+def soundfile_read(
+ wavs: Union[str, List[str]],
+ dtype=None,
+ always_2d: bool = False,
+ concat_axis: int = 1,
+ start: int = 0,
+ end: int = None,
+ return_subtype: bool = False,
+) -> Tuple[np.array, int]:
+ if isinstance(wavs, str):
+ wavs = [wavs]
+
+ arrays = []
+ subtypes = []
+ prev_rate = None
+ prev_wav = None
+ for wav in wavs:
+ with soundfile.SoundFile(wav) as f:
+ f.seek(start)
+ if end is not None:
+ frames = end - start
+ else:
+ frames = -1
+ if dtype == "float16":
+ array = f.read(
+ frames,
+ dtype="float32",
+ always_2d=always_2d,
+ ).astype(dtype)
+ else:
+ array = f.read(frames, dtype=dtype, always_2d=always_2d)
+ rate = f.samplerate
+ subtype = f.subtype
+ subtypes.append(subtype)
+
+ if len(wavs) > 1 and array.ndim == 1 and concat_axis == 1:
+ # array: (Time, Channel)
+ array = array[:, None]
+
+ if prev_wav is not None:
+ if prev_rate != rate:
+ raise RuntimeError(
+ f"'{prev_wav}' and '{wav}' have mismatched sampling rate: "
+ f"{prev_rate} != {rate}"
+ )
+
+ dim1 = arrays[0].shape[1 - concat_axis]
+ dim2 = array.shape[1 - concat_axis]
+ if dim1 != dim2:
+ raise RuntimeError(
+ "Shapes must match with "
+ f"{1 - concat_axis} axis, but gut {dim1} and {dim2}"
+ )
+
+ prev_rate = rate
+ prev_wav = wav
+ arrays.append(array)
+
+ if len(arrays) == 1:
+ array = arrays[0]
+ else:
+ array = np.concatenate(arrays, axis=concat_axis)
+
+ if return_subtype:
+ return array, rate, subtypes
+ else:
+ return array, rate
+
class SoundScpReader(collections.abc.Mapping):
"""Reader class for 'wav.scp'.
diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py
index 5f20dee..994607f 100644
--- a/funasr/models/encoder/conformer_encoder.py
+++ b/funasr/models/encoder/conformer_encoder.py
@@ -1081,7 +1081,10 @@
mask = make_source_mask(x_len).to(x.device)
if self.unified_model_training:
- chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+ if self.training:
+ chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+ else:
+ chunk_size = self.default_chunk_size
x, mask = self.embed(x, mask, chunk_size)
pos_enc = self.pos_enc(x)
chunk_mask = make_chunk_mask(
@@ -1113,12 +1116,15 @@
elif self.dynamic_chunk_training:
max_len = x.size(1)
- chunk_size = torch.randint(1, max_len, (1,)).item()
+ if self.training:
+ chunk_size = torch.randint(1, max_len, (1,)).item()
- if chunk_size > (max_len * self.short_chunk_threshold):
- chunk_size = max_len
+ if chunk_size > (max_len * self.short_chunk_threshold):
+ chunk_size = max_len
+ else:
+ chunk_size = (chunk_size % self.short_chunk_size) + 1
else:
- chunk_size = (chunk_size % self.short_chunk_size) + 1
+ chunk_size = self.default_chunk_size
x, mask = self.embed(x, mask, chunk_size)
pos_enc = self.pos_enc(x)
@@ -1147,6 +1153,45 @@
return x, olens, None
+ def full_utt_forward(
+ self,
+ x: torch.Tensor,
+ x_len: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encode input sequences.
+ Args:
+ x: Encoder input features. (B, T_in, F)
+ x_len: Encoder input features lengths. (B,)
+ Returns:
+ x: Encoder outputs. (B, T_out, D_enc)
+ x_len: Encoder outputs lenghts. (B,)
+ """
+ short_status, limit_size = check_short_utt(
+ self.embed.subsampling_factor, x.size(1)
+ )
+
+ if short_status:
+ raise TooShortUttError(
+ f"has {x.size(1)} frames and is too short for subsampling "
+ + f"(it needs more than {limit_size} frames), return empty results",
+ x.size(1),
+ limit_size,
+ )
+
+ mask = make_source_mask(x_len).to(x.device)
+ x, mask = self.embed(x, mask, None)
+ pos_enc = self.pos_enc(x)
+ x_utt = self.encoders(
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=None,
+ )
+
+ if self.time_reduction_factor > 1:
+ x_utt = x_utt[:,::self.time_reduction_factor,:]
+ return x_utt
+
def simu_chunk_forward(
self,
x: torch.Tensor,
diff --git a/funasr/modules/eend_ola/utils/__init__.py b/funasr/modules/eend_ola/utils/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/modules/eend_ola/utils/__init__.py
diff --git a/funasr/modules/subsampling.py b/funasr/modules/subsampling.py
index a2b91a7..77aa422 100644
--- a/funasr/modules/subsampling.py
+++ b/funasr/modules/subsampling.py
@@ -427,6 +427,7 @@
conv_size: Union[int, Tuple],
subsampling_factor: int = 4,
vgg_like: bool = True,
+ conv_kernel_size: int = 3,
output_size: Optional[int] = None,
) -> None:
"""Construct a ConvInput object."""
@@ -436,14 +437,14 @@
conv_size1, conv_size2 = conv_size
self.conv = torch.nn.Sequential(
- torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
+ torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
- torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
+ torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
torch.nn.MaxPool2d((1, 2)),
- torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
+ torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
- torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
+ torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
torch.nn.MaxPool2d((1, 2)),
)
@@ -462,14 +463,14 @@
kernel_1 = int(subsampling_factor / 2)
self.conv = torch.nn.Sequential(
- torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
+ torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
- torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
+ torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
torch.nn.MaxPool2d((kernel_1, 2)),
- torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
+ torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
- torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
+ torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
torch.nn.MaxPool2d((2, 2)),
)
@@ -487,14 +488,14 @@
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]),
torch.nn.ReLU(),
- torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]),
+ torch.nn.Conv2d(conv_size, conv_size, conv_kernel_size, [1,2], [1,0]),
torch.nn.ReLU(),
)
output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2)
self.subsampling_factor = subsampling_factor
- self.kernel_2 = 3
+ self.kernel_2 = conv_kernel_size
self.stride_2 = 1
self.create_new_mask = self.create_new_conv2d_mask
diff --git a/funasr/utils/asr_utils.py b/funasr/utils/asr_utils.py
index 4067b04..5aa40ec 100644
--- a/funasr/utils/asr_utils.py
+++ b/funasr/utils/asr_utils.py
@@ -5,6 +5,7 @@
from typing import Any, Dict, List, Union
import torchaudio
+import soundfile
import numpy as np
import pkg_resources
from modelscope.utils.logger import get_logger
@@ -135,7 +136,10 @@
if support_audio_type == "pcm":
fs = None
else:
- audio, fs = torchaudio.load(fname)
+ try:
+ audio, fs = torchaudio.load(fname)
+ except:
+ audio, fs = soundfile.read(fname)
break
if audio_type.rfind(".scp") >= 0:
with open(fname, encoding="utf-8") as f:
diff --git a/funasr/utils/prepare_data.py b/funasr/utils/prepare_data.py
index 7602740..0e773bb 100644
--- a/funasr/utils/prepare_data.py
+++ b/funasr/utils/prepare_data.py
@@ -7,6 +7,7 @@
import numpy as np
import torch.distributed as dist
import torchaudio
+import soundfile
def filter_wav_text(data_dir, dataset):
@@ -42,7 +43,11 @@
def wav2num_frame(wav_path, frontend_conf):
- waveform, sampling_rate = torchaudio.load(wav_path)
+ try:
+ waveform, sampling_rate = torchaudio.load(wav_path)
+ except:
+ waveform, sampling_rate = soundfile.read(wav_path)
+ waveform = np.expand_dims(waveform, axis=0)
n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
return n_frames, feature_dim
diff --git a/funasr/utils/runtime_sdk_download_tool.py b/funasr/utils/runtime_sdk_download_tool.py
new file mode 100644
index 0000000..dbddd55
--- /dev/null
+++ b/funasr/utils/runtime_sdk_download_tool.py
@@ -0,0 +1,38 @@
+from pathlib import Path
+import os
+import argparse
+from funasr.utils.types import str2bool
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--model-name', type=str, required=True)
+parser.add_argument('--export-dir', type=str, required=True)
+parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
+parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
+parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
+parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
+parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
+parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
+args = parser.parse_args()
+
+model_dir = args.model_name
+if not Path(args.model_name).exists():
+ from modelscope.hub.snapshot_download import snapshot_download
+ try:
+ model_dir = snapshot_download(args.model_name, cache_dir=args.export_dir)
+ except:
+ raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format \
+ (model_dir)
+
+model_file = os.path.join(model_dir, 'model.onnx')
+if args.quantize:
+ model_file = os.path.join(model_dir, 'model_quant.onnx')
+if not os.path.exists(model_file):
+ print(".onnx is not exist, begin to export onnx")
+ from funasr.export.export_model import ModelExport
+ export_model = ModelExport(
+ cache_dir=args.export_dir,
+ onnx=True,
+ device="cpu",
+ quant=args.quantize,
+ )
+ export_model.export(model_dir)
\ No newline at end of file
diff --git a/funasr/utils/wav_utils.py b/funasr/utils/wav_utils.py
index ebb80d2..bd067c2 100644
--- a/funasr/utils/wav_utils.py
+++ b/funasr/utils/wav_utils.py
@@ -11,6 +11,7 @@
import numpy as np
import torch
import torchaudio
+import soundfile
import torchaudio.compliance.kaldi as kaldi
@@ -162,7 +163,13 @@
waveform = torch.from_numpy(waveform.reshape(1, -1))
else:
# load pcm from wav, and resample
- waveform, audio_sr = torchaudio.load(wav_file)
+ try:
+ waveform, audio_sr = torchaudio.load(wav_file)
+ except:
+ waveform, audio_sr = soundfile.read(wav_file, dtype='float32')
+ if waveform.ndim == 2:
+ waveform = waveform[:, 0]
+ waveform = torch.tensor(np.expand_dims(waveform, axis=0))
waveform = waveform * (1 << 15)
waveform = torch_resample(waveform, audio_sr, model_sr)
@@ -181,7 +188,11 @@
def wav2num_frame(wav_path, frontend_conf):
- waveform, sampling_rate = torchaudio.load(wav_path)
+ try:
+ waveform, sampling_rate = torchaudio.load(wav_path)
+ except:
+ waveform, sampling_rate = soundfile.read(wav_path)
+ waveform = torch.tensor(np.expand_dims(waveform, axis=0))
speech_length = (waveform.shape[1] / sampling_rate) * 1000.
n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
diff --git a/funasr/version.txt b/funasr/version.txt
index 844f6a9..ef5e445 100644
--- a/funasr/version.txt
+++ b/funasr/version.txt
@@ -1 +1 @@
-0.6.3
+0.6.5
diff --git a/setup.py b/setup.py
index 5b49d06..f13a2c2 100644
--- a/setup.py
+++ b/setup.py
@@ -20,7 +20,7 @@
"librosa",
"jamo==0.4.1", # For kss
"PyYAML>=5.1.2",
- "soundfile>=0.10.2",
+ "soundfile>=0.11.0",
"h5py>=2.10.0",
"kaldiio>=2.17.0",
"torch_complex",
--
Gitblit v1.9.1