雾聪
2023-06-27 69dcdbcfc0c21627c40fb8f7c435136c11691574
Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR into main
16个文件已修改
3个文件已添加
534 ■■■■■ 已修改文件
.gitignore 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
README.md 11 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
docs/benchmark/benchmark_pipeline_cer.md 254 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
docs/installation/installation.md 14 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_launch.py 9 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/iterable_dataset.py 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/large_datasets/dataset.py 11 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/small_datasets/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/fileio/sound_scp.py 70 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/conformer_encoder.py 55 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/eend_ola/utils/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/subsampling.py 21 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/asr_utils.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/prepare_data.py 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/runtime_sdk_download_tool.py 38 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/wav_utils.py 15 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/version.txt 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
setup.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
.gitignore
@@ -18,4 +18,5 @@
build
funasr.egg-info
docs/_build
modelscope
modelscope
samples
README.md
@@ -34,9 +34,9 @@
Install from pip
```shell
pip install -U funasr
pip3 install -U funasr
# For the users in China, you could install with the command:
# pip install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
# pip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
Or install from source code
@@ -96,14 +96,17 @@
### runtime
An example with websocket:
For the server:
```shell
cd funasr/runtime/python/websocket
python wss_srv_asr.py --port 10095
```
For the client:
```shell
python wss_client_asr.py --host "0.0.0.0" --port 10095 --mode 2pass --chunk_size "5,10,5"
#python wss_client_asr.py --host "0.0.0.0" --port 10095 --mode 2pass --chunk_size "8,8,4" --audio_in "./data/wav.scp" --output_dir "./results"
python wss_client_asr.py --host "127.0.0.1" --port 10095 --mode 2pass --chunk_size "5,10,5"
#python wss_client_asr.py --host "127.0.0.1" --port 10095 --mode 2pass --chunk_size "8,8,4" --audio_in "./data/wav.scp" --output_dir "./results"
```
More examples could be found in [docs](https://alibaba-damo-academy.github.io/FunASR/en/runtime/websocket_python.html#id2)
## Contact
docs/benchmark/benchmark_pipeline_cer.md
@@ -45,156 +45,156 @@
### Chinese Dataset
<table>
<table border="1">
    <tr align="center">
        <td>Model</td>
        <td>Offline/Online</td>
        <td colspan="2">Aishell1</td>
        <td colspan="4">Aishell2</td>
        <td colspan="3">WenetSpeech</td>
        <td style="border: 1px solid">Model</td>
        <td style="border: 1px solid">Offline/Online</td>
        <td colspan="2" style="border: 1px solid">Aishell1</td>
        <td colspan="4" style="border: 1px solid">Aishell2</td>
        <td colspan="3" style="border: 1px solid">WenetSpeech</td>
    </tr>
    <tr align="center">
        <td></td>
        <td></td>
        <td>dev</td>
        <td>test</td>
        <td>dev_ios</td>
        <td>test_ios</td>
        <td>test_android</td>
        <td>test_mic</td>
        <td>dev</td>
        <td>test_meeting</td>
        <td>test_net</td>
        <td style="border: 1px solid"></td>
        <td style="border: 1px solid"></td>
        <td style="border: 1px solid">dev</td>
        <td style="border: 1px solid">test</td>
        <td style="border: 1px solid">dev_ios</td>
        <td style="border: 1px solid">test_ios</td>
        <td style="border: 1px solid">test_android</td>
        <td style="border: 1px solid">test_mic</td>
        <td style="border: 1px solid">dev</td>
        <td style="border: 1px solid">test_meeting</td>
        <td style="border: 1px solid">test_net</td>
    </tr>
    <tr align="center">
        <td> <a href="https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary">Paraformer-large</a> </td>
        <td>Offline</td>
        <td>1.76</td>
        <td>1.94</td>
        <td>2.79</td>
        <td>2.84</td>
        <td>3.08</td>
        <td>3.03</td>
        <td>3.43</td>
        <td>7.01</td>
        <td>6.66</td>
        <td style="border: 1px solid"> <a href="https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary">Paraformer-large</a> </td>
        <td style="border: 1px solid">Offline</td>
        <td style="border: 1px solid">1.76</td>
        <td style="border: 1px solid">1.94</td>
        <td style="border: 1px solid">2.79</td>
        <td style="border: 1px solid">2.84</td>
        <td style="border: 1px solid">3.08</td>
        <td style="border: 1px solid">3.03</td>
        <td style="border: 1px solid">3.43</td>
        <td style="border: 1px solid">7.01</td>
        <td style="border: 1px solid">6.66</td>
    </tr>
    <tr align="center">
        <td> <a href="https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary">Paraformer-large-long</a> </td>
        <td>Offline</td>
        <td>1.80</td>
        <td>2.10</td>
        <td>2.78</td>
        <td>2.87</td>
        <td>3.12</td>
        <td>3.11</td>
        <td>3.44</td>
        <td>13.28</td>
        <td>7.08</td>
        <td style="border: 1px solid"> <a href="https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary">Paraformer-large-long</a> </td>
        <td style="border: 1px solid">Offline</td>
        <td style="border: 1px solid">1.80</td>
        <td style="border: 1px solid">2.10</td>
        <td style="border: 1px solid">2.78</td>
        <td style="border: 1px solid">2.87</td>
        <td style="border: 1px solid">3.12</td>
        <td style="border: 1px solid">3.11</td>
        <td style="border: 1px solid">3.44</td>
        <td style="border: 1px solid">13.28</td>
        <td style="border: 1px solid">7.08</td>
    </tr>
    <tr align="center">
        <td> <a href="https://www.modelscope.cn/models/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/summary">Paraformer-large-contextual</a> </td>
        <td>Offline</td>
        <td>1.76</td>
        <td>2.02</td>
        <td>2.73</td>
        <td>2.85</td>
        <td>2.98</td>
        <td>2.95</td>
        <td>3.42</td>
        <td>7.16</td>
        <td>6.72</td>
        <td style="border: 1px solid"> <a href="https://www.modelscope.cn/models/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/summary">Paraformer-large-contextual</a> </td>
        <td style="border: 1px solid">Offline</td>
        <td style="border: 1px solid">1.76</td>
        <td style="border: 1px solid">2.02</td>
        <td style="border: 1px solid">2.73</td>
        <td style="border: 1px solid">2.85</td>
        <td style="border: 1px solid">2.98</td>
        <td style="border: 1px solid">2.95</td>
        <td style="border: 1px solid">3.42</td>
        <td style="border: 1px solid">7.16</td>
        <td style="border: 1px solid">6.72</td>
    </tr>
    <tr align="center">
        <td> <a href="https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary">Paraformer</a> </td>
        <td>Offline</td>
        <td>3.24</td>
        <td>3.69</td>
        <td>4.58</td>
        <td>4.63</td>
        <td>4.83</td>
        <td>4.71</td>
        <td>4.19</td>
        <td>8.32</td>
        <td>9.19</td>
        <td style="border: 1px solid"> <a href="https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary">Paraformer</a> </td>
        <td style="border: 1px solid">Offline</td>
        <td style="border: 1px solid">3.24</td>
        <td style="border: 1px solid">3.69</td>
        <td style="border: 1px solid">4.58</td>
        <td style="border: 1px solid">4.63</td>
        <td style="border: 1px solid">4.83</td>
        <td style="border: 1px solid">4.71</td>
        <td style="border: 1px solid">4.19</td>
        <td style="border: 1px solid">8.32</td>
        <td style="border: 1px solid">9.19</td>
    </tr>
   <tr align="center">
        <td> <a href="https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/summary">UniASR</a> </td>
        <td>Online</td>
        <td>3.34</td>
        <td>3.99</td>
        <td>4.62</td>
        <td>4.52</td>
        <td>4.77</td>
        <td>4.73</td>
        <td>4.51</td>
        <td>10.63</td>
        <td>9.70</td>
        <td style="border: 1px solid"> <a href="https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/summary">UniASR</a> </td>
        <td style="border: 1px solid">Online</td>
        <td style="border: 1px solid">3.34</td>
        <td style="border: 1px solid">3.99</td>
        <td style="border: 1px solid">4.62</td>
        <td style="border: 1px solid">4.52</td>
        <td style="border: 1px solid">4.77</td>
        <td style="border: 1px solid">4.73</td>
        <td style="border: 1px solid">4.51</td>
        <td style="border: 1px solid">10.63</td>
        <td style="border: 1px solid">9.70</td>
    </tr>
   <tr align="center">
        <td> <a href="https://modelscope.cn/models/damo/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/summary">UniASR-large</a> </td>
        <td>Offline</td>
        <td>2.93</td>
        <td>3.48</td>
        <td>3.95</td>
        <td>3.87</td>
        <td>4.11</td>
        <td>4.11</td>
        <td>4.16</td>
        <td>10.09</td>
        <td>8.69</td>
        <td style="border: 1px solid"> <a href="https://modelscope.cn/models/damo/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/summary">UniASR-large</a> </td>
        <td style="border: 1px solid">Offline</td>
        <td style="border: 1px solid">2.93</td>
        <td style="border: 1px solid">3.48</td>
        <td style="border: 1px solid">3.95</td>
        <td style="border: 1px solid">3.87</td>
        <td style="border: 1px solid">4.11</td>
        <td style="border: 1px solid">4.11</td>
        <td style="border: 1px solid">4.16</td>
        <td style="border: 1px solid">10.09</td>
        <td style="border: 1px solid">8.69</td>
    </tr>
    <tr align="center">
        <td> <a href="https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell1-pytorch/summary">Paraformer-aishell</a> </td>
        <td>Offline</td>
        <td>4.88</td>
        <td>5.43</td>
        <td>-</td>
        <td>-</td>
        <td>-</td>
        <td>-</td>
        <td>-</td>
        <td>-</td>
        <td>-</td>
        <td style="border: 1px solid"> <a href="https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell1-pytorch/summary">Paraformer-aishell</a> </td>
        <td style="border: 1px solid">Offline</td>
        <td style="border: 1px solid">4.88</td>
        <td style="border: 1px solid">5.43</td>
        <td style="border: 1px solid">-</td>
        <td style="border: 1px solid">-</td>
        <td style="border: 1px solid">-</td>
        <td style="border: 1px solid">-</td>
        <td style="border: 1px solid">-</td>
        <td style="border: 1px solid">-</td>
        <td style="border: 1px solid">-</td>
    </tr>
   <tr align="center">
        <td> <a href="https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary">ParaformerBert-aishell</a> </td>
        <td>Offline</td>
        <td>6.14</td>
        <td>7.01</td>
        <td>-</td>
        <td>-</td>
        <td>-</td>
        <td>-</td>
        <td>-</td>
        <td>-</td>
        <td>-</td>
        <td style="border: 1px solid"> <a href="https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary">ParaformerBert-aishell</a> </td>
        <td style="border: 1px solid">Offline</td>
        <td style="border: 1px solid">6.14</td>
        <td style="border: 1px solid">7.01</td>
        <td style="border: 1px solid">-</td>
        <td style="border: 1px solid">-</td>
        <td style="border: 1px solid">-</td>
        <td style="border: 1px solid">-</td>
        <td style="border: 1px solid">-</td>
        <td style="border: 1px solid">-</td>
        <td style="border: 1px solid">-</td>
    </tr>
   <tr align="center">
        <td> <a href="https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary">Paraformer-aishell2</a> </td>
        <td>Offline</td>
        <td>-</td>
        <td>-</td>
        <td>5.82</td>
        <td>6.30</td>
        <td>6.60</td>
        <td>5.83</td>
        <td>-</td>
        <td>-</td>
        <td>-</td>
        <td style="border: 1px solid"> <a href="https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary">Paraformer-aishell2</a> </td>
        <td style="border: 1px solid">Offline</td>
        <td style="border: 1px solid">-</td>
        <td style="border: 1px solid">-</td>
        <td style="border: 1px solid">5.82</td>
        <td style="border: 1px solid">6.30</td>
        <td style="border: 1px solid">6.60</td>
        <td style="border: 1px solid">5.83</td>
        <td style="border: 1px solid">-</td>
        <td style="border: 1px solid">-</td>
        <td style="border: 1px solid">-</td>
    </tr>
   <tr align="center">
        <td> <a href="https://www.modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary">ParaformerBert-aishell2</a> </td>
        <td>Offline</td>
        <td>-</td>
        <td>-</td>
        <td>4.95</td>
        <td>5.45</td>
        <td>5.59</td>
        <td>5.83</td>
        <td>-</td>
        <td>-</td>
        <td>-</td>
        <td style="border: 1px solid"> <a href="https://www.modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary">ParaformerBert-aishell2</a> </td>
        <td style="border: 1px solid">Offline</td>
        <td style="border: 1px solid">-</td>
        <td style="border: 1px solid">-</td>
        <td style="border: 1px solid">4.95</td>
        <td style="border: 1px solid">5.45</td>
        <td style="border: 1px solid">5.59</td>
        <td style="border: 1px solid">5.83</td>
        <td style="border: 1px solid">-</td>
        <td style="border: 1px solid">-</td>
        <td style="border: 1px solid">-</td>
    </tr>
</table>
docs/installation/installation.md
@@ -32,7 +32,7 @@
### Install Pytorch (version >= 1.11.0):
```sh
pip install torch torchaudio
pip3 install torch torchaudio
```
If there exists CUDAs in your environments, you should install the pytorch with the version matching the CUDA. The matching list could be found in [docs](https://pytorch.org/get-started/previous-versions/).
### Install funasr
@@ -40,27 +40,27 @@
#### Install from pip
```shell
pip install -U funasr
pip3 install -U funasr
# For the users in China, you could install with the command:
# pip install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
# pip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
#### Or install from source code
``` sh
git clone https://github.com/alibaba/FunASR.git && cd FunASR
pip install -e ./
pip3 install -e ./
# For the users in China, you could install with the command:
# pip install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple
# pip3 install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
### Install modelscope (Optional)
If you want to use the pretrained models in ModelScope, you should install the modelscope:
```shell
pip install -U modelscope
pip3 install -U modelscope
# For the users in China, you could install with the command:
# pip install -U modelscope -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -i https://mirror.sjtu.edu.cn/pypi/web/simple
# pip3 install -U modelscope -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
### FQA
egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
@@ -6,7 +6,7 @@
      unified_model_training: true
      default_chunk_size: 16
      jitter_range: 4
      left_chunk_size: 0
      left_chunk_size: 1
      embed_vgg_like: false
      subsampling_factor: 4
      linear_units: 2048
@@ -51,7 +51,7 @@
# optimization related
accum_grad: 1
grad_clip: 5
max_epoch: 200
max_epoch: 120
val_scheduler_criterion:
    - valid
    - loss
funasr/bin/asr_inference_launch.py
@@ -19,6 +19,7 @@
import numpy as np
import torch
import torchaudio
import soundfile
import yaml
from typeguard import check_argument_types
@@ -863,7 +864,13 @@
            raw_inputs = _load_bytes(data_path_and_name_and_type[0])
            raw_inputs = torch.tensor(raw_inputs)
        if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
            raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
            try:
                raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
            except:
                raw_inputs = soundfile.read(data_path_and_name_and_type[0], dtype='float32')[0]
                if raw_inputs.ndim == 2:
                    raw_inputs = raw_inputs[:, 0]
                raw_inputs = torch.tensor(raw_inputs)
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, np.ndarray):
                raw_inputs = torch.tensor(raw_inputs)
funasr/datasets/iterable_dataset.py
@@ -14,6 +14,7 @@
import numpy as np
import torch
import torchaudio
import soundfile
from torch.utils.data.dataset import IterableDataset
from typeguard import check_argument_types
import os.path
@@ -66,8 +67,17 @@
        bytes = f.read()
    return load_bytes(bytes)
def load_wav(input):
    try:
        return torchaudio.load(input)[0].numpy()
    except:
        waveform, _ = soundfile.read(input, dtype='float32')
        if waveform.ndim == 2:
            waveform = waveform[:, 0]
        return np.expand_dims(waveform, axis=0)
DATA_TYPES = {
    "sound": lambda x: torchaudio.load(x)[0].numpy(),
    "sound": load_wav,
    "pcm": load_pcm,
    "kaldi_ark": load_kaldi,
    "bytes": load_bytes,
funasr/datasets/large_datasets/dataset.py
@@ -6,6 +6,8 @@
import torch
import torch.distributed as dist
import torchaudio
import numpy as np
import soundfile
from kaldiio import ReadHelper
from torch.utils.data import IterableDataset
@@ -123,7 +125,14 @@
                            sample_dict["key"] = key
                    elif data_type == "sound":
                        key, path = item.strip().split()
                        waveform, sampling_rate = torchaudio.load(path)
                        try:
                            waveform, sampling_rate = torchaudio.load(path)
                        except:
                            waveform, sampling_rate = soundfile.read(path, dtype='float32')
                            if waveform.ndim == 2:
                                waveform = waveform[:, 0]
                            waveform = np.expand_dims(waveform, axis=0)
                            waveform = torch.tensor(waveform)
                        if self.frontend_conf is not None:
                            if sampling_rate != self.frontend_conf["fs"]:
                                waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
funasr/datasets/small_datasets/__init__.py
funasr/fileio/sound_scp.py
@@ -1,6 +1,6 @@
import collections.abc
from pathlib import Path
from typing import Union
from typing import List, Tuple, Union
import random
import numpy as np
@@ -13,6 +13,74 @@
from funasr.fileio.read_text import read_2column_text
def soundfile_read(
    wavs: Union[str, List[str]],
    dtype=None,
    always_2d: bool = False,
    concat_axis: int = 1,
    start: int = 0,
    end: int = None,
    return_subtype: bool = False,
) -> Tuple[np.array, int]:
    if isinstance(wavs, str):
        wavs = [wavs]
    arrays = []
    subtypes = []
    prev_rate = None
    prev_wav = None
    for wav in wavs:
        with soundfile.SoundFile(wav) as f:
            f.seek(start)
            if end is not None:
                frames = end - start
            else:
                frames = -1
            if dtype == "float16":
                array = f.read(
                    frames,
                    dtype="float32",
                    always_2d=always_2d,
                ).astype(dtype)
            else:
                array = f.read(frames, dtype=dtype, always_2d=always_2d)
            rate = f.samplerate
            subtype = f.subtype
            subtypes.append(subtype)
        if len(wavs) > 1 and array.ndim == 1 and concat_axis == 1:
            # array: (Time, Channel)
            array = array[:, None]
        if prev_wav is not None:
            if prev_rate != rate:
                raise RuntimeError(
                    f"'{prev_wav}' and '{wav}' have mismatched sampling rate: "
                    f"{prev_rate} != {rate}"
                )
            dim1 = arrays[0].shape[1 - concat_axis]
            dim2 = array.shape[1 - concat_axis]
            if dim1 != dim2:
                raise RuntimeError(
                    "Shapes must match with "
                    f"{1 - concat_axis} axis, but gut {dim1} and {dim2}"
                )
        prev_rate = rate
        prev_wav = wav
        arrays.append(array)
    if len(arrays) == 1:
        array = arrays[0]
    else:
        array = np.concatenate(arrays, axis=concat_axis)
    if return_subtype:
        return array, rate, subtypes
    else:
        return array, rate
class SoundScpReader(collections.abc.Mapping):
    """Reader class for 'wav.scp'.
funasr/models/encoder/conformer_encoder.py
@@ -1081,7 +1081,10 @@
        mask = make_source_mask(x_len).to(x.device)
        if self.unified_model_training:
            chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
            if self.training:
                chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
            else:
                chunk_size = self.default_chunk_size
            x, mask = self.embed(x, mask, chunk_size)
            pos_enc = self.pos_enc(x)
            chunk_mask = make_chunk_mask(
@@ -1113,12 +1116,15 @@
        elif self.dynamic_chunk_training:
            max_len = x.size(1)
            chunk_size = torch.randint(1, max_len, (1,)).item()
            if self.training:
                chunk_size = torch.randint(1, max_len, (1,)).item()
            if chunk_size > (max_len * self.short_chunk_threshold):
                chunk_size = max_len
                if chunk_size > (max_len * self.short_chunk_threshold):
                    chunk_size = max_len
                else:
                    chunk_size = (chunk_size % self.short_chunk_size) + 1
            else:
                chunk_size = (chunk_size % self.short_chunk_size) + 1
                chunk_size = self.default_chunk_size
            x, mask = self.embed(x, mask, chunk_size)
            pos_enc = self.pos_enc(x)
@@ -1147,6 +1153,45 @@
        return x, olens, None
    def full_utt_forward(
        self,
        x: torch.Tensor,
        x_len: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encode input sequences.
        Args:
            x: Encoder input features. (B, T_in, F)
            x_len: Encoder input features lengths. (B,)
        Returns:
           x: Encoder outputs. (B, T_out, D_enc)
           x_len: Encoder outputs lenghts. (B,)
        """
        short_status, limit_size = check_short_utt(
            self.embed.subsampling_factor, x.size(1)
        )
        if short_status:
            raise TooShortUttError(
                f"has {x.size(1)} frames and is too short for subsampling "
                + f"(it needs more than {limit_size} frames), return empty results",
                x.size(1),
                limit_size,
            )
        mask = make_source_mask(x_len).to(x.device)
        x, mask = self.embed(x, mask, None)
        pos_enc = self.pos_enc(x)
        x_utt = self.encoders(
            x,
            pos_enc,
            mask,
            chunk_mask=None,
        )
        if self.time_reduction_factor > 1:
            x_utt = x_utt[:,::self.time_reduction_factor,:]
        return x_utt
    def simu_chunk_forward(
        self,
        x: torch.Tensor,
funasr/modules/eend_ola/utils/__init__.py
funasr/modules/subsampling.py
@@ -427,6 +427,7 @@
        conv_size: Union[int, Tuple],
        subsampling_factor: int = 4,
        vgg_like: bool = True,
        conv_kernel_size: int = 3,
        output_size: Optional[int] = None,
    ) -> None:
        """Construct a ConvInput object."""
@@ -436,14 +437,14 @@
                conv_size1, conv_size2 = conv_size
                self.conv = torch.nn.Sequential(
                    torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
                    torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
                    torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                    torch.nn.ReLU(),
                    torch.nn.MaxPool2d((1, 2)),
                    torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
                    torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
                    torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                    torch.nn.ReLU(),
                    torch.nn.MaxPool2d((1, 2)),
                )
@@ -462,14 +463,14 @@
                kernel_1 = int(subsampling_factor / 2)
                self.conv = torch.nn.Sequential(
                    torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
                    torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
                    torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                    torch.nn.ReLU(),
                    torch.nn.MaxPool2d((kernel_1, 2)),
                    torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
                    torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
                    torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                    torch.nn.ReLU(),
                    torch.nn.MaxPool2d((2, 2)),
                )
@@ -487,14 +488,14 @@
                self.conv = torch.nn.Sequential(
                    torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]),
                    torch.nn.Conv2d(conv_size, conv_size, conv_kernel_size, [1,2], [1,0]),
                    torch.nn.ReLU(),
                )
                output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2)
                self.subsampling_factor = subsampling_factor
                self.kernel_2 = 3
                self.kernel_2 = conv_kernel_size
                self.stride_2 = 1
                self.create_new_mask = self.create_new_conv2d_mask
funasr/utils/asr_utils.py
@@ -5,6 +5,7 @@
from typing import Any, Dict, List, Union
import torchaudio
import soundfile
import numpy as np
import pkg_resources
from modelscope.utils.logger import get_logger
@@ -135,7 +136,10 @@
                if support_audio_type == "pcm":
                    fs = None
                else:
                    audio, fs = torchaudio.load(fname)
                    try:
                        audio, fs = torchaudio.load(fname)
                    except:
                        audio, fs = soundfile.read(fname)
                break
        if audio_type.rfind(".scp") >= 0:
            with open(fname, encoding="utf-8") as f:
funasr/utils/prepare_data.py
@@ -7,6 +7,7 @@
import numpy as np
import torch.distributed as dist
import torchaudio
import soundfile
def filter_wav_text(data_dir, dataset):
@@ -42,7 +43,11 @@
def wav2num_frame(wav_path, frontend_conf):
    waveform, sampling_rate = torchaudio.load(wav_path)
    try:
        waveform, sampling_rate = torchaudio.load(wav_path)
    except:
        waveform, sampling_rate = soundfile.read(wav_path)
        waveform = np.expand_dims(waveform, axis=0)
    n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
    feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
    return n_frames, feature_dim
funasr/utils/runtime_sdk_download_tool.py
New file
@@ -0,0 +1,38 @@
from pathlib import Path
import os
import argparse
from funasr.utils.types import str2bool
parser = argparse.ArgumentParser()
parser.add_argument('--model-name', type=str, required=True)
parser.add_argument('--export-dir', type=str, required=True)
parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
args = parser.parse_args()
model_dir = args.model_name
if not Path(args.model_name).exists():
    from modelscope.hub.snapshot_download import snapshot_download
    try:
        model_dir = snapshot_download(args.model_name, cache_dir=args.export_dir)
    except:
        raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format \
            (model_dir)
model_file = os.path.join(model_dir, 'model.onnx')
if args.quantize:
    model_file = os.path.join(model_dir, 'model_quant.onnx')
if not os.path.exists(model_file):
    print(".onnx is not exist, begin to export onnx")
    from funasr.export.export_model import ModelExport
    export_model = ModelExport(
        cache_dir=args.export_dir,
        onnx=True,
        device="cpu",
        quant=args.quantize,
    )
    export_model.export(model_dir)
funasr/utils/wav_utils.py
@@ -11,6 +11,7 @@
import numpy as np
import torch
import torchaudio
import soundfile
import torchaudio.compliance.kaldi as kaldi
@@ -162,7 +163,13 @@
        waveform = torch.from_numpy(waveform.reshape(1, -1))
    else:
        # load pcm from wav, and resample
        waveform, audio_sr = torchaudio.load(wav_file)
        try:
            waveform, audio_sr = torchaudio.load(wav_file)
        except:
            waveform, audio_sr = soundfile.read(wav_file, dtype='float32')
            if waveform.ndim == 2:
                waveform = waveform[:, 0]
            waveform = torch.tensor(np.expand_dims(waveform, axis=0))
        waveform = waveform * (1 << 15)
        waveform = torch_resample(waveform, audio_sr, model_sr)
@@ -181,7 +188,11 @@
def wav2num_frame(wav_path, frontend_conf):
    waveform, sampling_rate = torchaudio.load(wav_path)
    try:
        waveform, sampling_rate = torchaudio.load(wav_path)
    except:
        waveform, sampling_rate = soundfile.read(wav_path)
        waveform = torch.tensor(np.expand_dims(waveform, axis=0))
    speech_length = (waveform.shape[1] / sampling_rate) * 1000.
    n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
    feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
funasr/version.txt
@@ -1 +1 @@
0.6.3
0.6.5
setup.py
@@ -20,7 +20,7 @@
        "librosa",
        "jamo==0.4.1",  # For kss
        "PyYAML>=5.1.2",
        "soundfile>=0.10.2",
        "soundfile>=0.11.0",
        "h5py>=2.10.0",
        "kaldiio>=2.17.0",
        "torch_complex",