From 2ff405b2f4ab899eff9bece232969fbb0c8f0555 Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 20 六月 2023 00:26:37 +0800
Subject: [PATCH] Merge pull request #653 from alibaba-damo-academy/dev_wjm_infer
---
funasr/bin/punc_infer.py | 60
funasr/build_utils/build_args.py | 19
funasr/models/e2e_asr_mfcca.py | 2
funasr/bin/lm_inference_launch.py | 127 +-
funasr/build_utils/build_diar_model.py | 22
funasr/build_utils/build_streaming_iterator.py | 67 +
funasr/build_utils/build_model_from_file.py | 193 +++
funasr/bin/vad_infer.py | 40
funasr/bin/diar_infer.py | 49
funasr/bin/train.py | 2
funasr/build_utils/build_vad_model.py | 4
.github/workflows/UnitTest.yml | 7
funasr/models/e2e_uni_asr.py | 4
funasr/bin/tp_inference_launch.py | 116 -
tests/test_vad_inference_pipeline.py | 2
funasr/build_utils/build_lm_model.py | 9
funasr/build_utils/build_asr_model.py | 30
funasr/bin/sv_inference_launch.py | 106 -
funasr/bin/vad_inference_launch.py | 59
funasr/bin/diar_inference_launch.py | 67
funasr/bin/punc_inference_launch.py | 105 -
funasr/build_utils/build_sv_model.py | 258 +++++
funasr/bin/asr_infer.py | 598 +++++------
funasr/build_utils/build_model.py | 5
funasr/bin/tp_infer.py | 65
funasr/models/e2e_asr_contextual_paraformer.py | 4
funasr/models/e2e_vad.py | 3
tests/test_sv_inference_pipeline.py | 2
funasr/bin/sv_infer.py | 28
funasr/bin/asr_inference_launch.py | 863 ++++++++---------
30 files changed, 1,602 insertions(+), 1,314 deletions(-)
diff --git a/.github/workflows/UnitTest.yml b/.github/workflows/UnitTest.yml
index 3b0a1ee..8ced9e4 100644
--- a/.github/workflows/UnitTest.yml
+++ b/.github/workflows/UnitTest.yml
@@ -8,6 +8,7 @@
branches:
- dev_wjm
- dev_jy
+ - dev_wjm_infer
jobs:
build:
@@ -18,6 +19,12 @@
python-version: ["3.7"]
steps:
+ - name: Remove unnecessary files
+ run:
+ sudo rm -rf /usr/share/dotnet
+ sudo rm -rf /opt/ghc
+ sudo rm -rf "/usr/local/share/boost"
+ sudo rm -rf "$AGENT_TOOLSDIRECTORY"
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index bed50b4..c722ebc 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -1,66 +1,46 @@
-# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
-import argparse
-import logging
-import sys
-import time
+
+import codecs
import copy
+import logging
import os
import re
-import codecs
import tempfile
-import requests
from pathlib import Path
+from typing import Any
+from typing import Dict
+from typing import List
from typing import Optional
-from typing import Sequence
from typing import Tuple
from typing import Union
-from typing import Dict
-from typing import Any
-from typing import List
import numpy as np
+import requests
import torch
from packaging.version import parse as V
from typeguard import check_argument_types
from typeguard import check_return_type
-from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.build_utils.build_model_from_file import build_model_from_file
+from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
+from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
+from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
from funasr.modules.beam_search.beam_search import BeamSearch
-# from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
from funasr.modules.beam_search.beam_search import Hypothesis
+from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis as HypothesisSAASR
from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
from funasr.modules.beam_search.beam_search_transducer import Hypothesis as HypothesisTransducer
-from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis as HypothesisSAASR
from funasr.modules.scorers.ctc import CTCPrefixScorer
from funasr.modules.scorers.length_bonus import LengthBonus
-from funasr.modules.subsampling import TooShortUttError
-from funasr.tasks.asr import ASRTask
-from funasr.tasks.lm import LMTask
+from funasr.build_utils.build_asr_model import frontend_choices
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.utils import asr_utils, wav_utils, postprocess_utils
-from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
-from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
-from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
-from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
-from funasr.bin.tp_infer import Speech2Timestamp
-from funasr.bin.vad_infer import Speech2VadSegment
-from funasr.bin.punc_infer import Text2Punc
-from funasr.utils.vad_utils import slice_padding_fbank
-from funasr.tasks.vad import VADTask
-from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
-from funasr.tasks.asr import frontend_choices
+
class Speech2Text:
"""Speech2Text class
@@ -73,36 +53,36 @@
[(text, token, token_int, hypothesis object), ...]
"""
-
+
def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- maxlenratio: float = 0.0,
- minlenratio: float = 0.0,
- batch_size: int = 1,
- dtype: str = "float32",
- beam_size: int = 20,
- ctc_weight: float = 0.5,
- lm_weight: float = 1.0,
- ngram_weight: float = 0.9,
- penalty: float = 0.0,
- nbest: int = 1,
- streaming: bool = False,
- frontend_conf: dict = None,
- **kwargs,
+ self,
+ asr_train_config: Union[Path, str] = None,
+ asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ lm_train_config: Union[Path, str] = None,
+ lm_file: Union[Path, str] = None,
+ token_type: str = None,
+ bpemodel: str = None,
+ device: str = "cpu",
+ maxlenratio: float = 0.0,
+ minlenratio: float = 0.0,
+ batch_size: int = 1,
+ dtype: str = "float32",
+ beam_size: int = 20,
+ ctc_weight: float = 0.5,
+ lm_weight: float = 1.0,
+ ngram_weight: float = 0.9,
+ penalty: float = 0.0,
+ nbest: int = 1,
+ streaming: bool = False,
+ frontend_conf: dict = None,
+ **kwargs,
):
assert check_argument_types()
-
+
# 1. Build ASR model
scorers = {}
- asr_model, asr_train_args = ASRTask.build_model_from_file(
+ asr_model, asr_train_args = build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
frontend = None
@@ -110,16 +90,15 @@
if asr_train_args.frontend == 'wav_frontend':
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
else:
- from funasr.tasks.asr import frontend_choices
frontend_class = frontend_choices.get_class(asr_train_args.frontend)
frontend = frontend_class(**asr_train_args.frontend_conf).eval()
-
+
logging.info("asr_model: {}".format(asr_model))
logging.info("asr_train_args: {}".format(asr_train_args))
asr_model.to(dtype=getattr(torch, dtype)).eval()
-
+
decoder = asr_model.decoder
-
+
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
token_list = asr_model.token_list
scorers.update(
@@ -127,24 +106,24 @@
ctc=ctc,
length_bonus=LengthBonus(len(token_list)),
)
-
+
# 2. Build Language model
if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
+ lm, lm_train_args = build_model_from_file(
lm_train_config, lm_file, None, device
)
scorers["lm"] = lm.lm
-
+
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
-
+
# 4. Build BeamSearch object
# transducer is not supported now
beam_search_transducer = None
from funasr.modules.beam_search.beam_search import BeamSearch
-
+
weights = dict(
decoder=1.0 - ctc_weight,
ctc=ctc_weight,
@@ -162,13 +141,13 @@
token_list=token_list,
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
)
-
+
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
token_type = asr_train_args.token_type
if bpemodel is None:
bpemodel = asr_train_args.bpemodel
-
+
if token_type is None:
tokenizer = None
elif token_type == "bpe":
@@ -180,7 +159,7 @@
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
-
+
self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.converter = converter
@@ -193,10 +172,10 @@
self.dtype = dtype
self.nbest = nbest
self.frontend = frontend
-
+
@torch.no_grad()
def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
+ self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
) -> List[
Tuple[
Optional[str],
@@ -214,11 +193,11 @@
"""
assert check_argument_types()
-
+
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
-
+
if self.frontend is not None:
feats, feats_len = self.frontend.forward(speech, speech_lengths)
feats = to_device(feats, device=self.device)
@@ -229,48 +208,49 @@
feats_len = speech_lengths
lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
batch = {"speech": feats, "speech_lengths": feats_len}
-
+
# a. To device
batch = to_device(batch, device=self.device)
-
+
# b. Forward Encoder
enc, _ = self.asr_model.encode(**batch)
if isinstance(enc, tuple):
enc = enc[0]
assert len(enc) == 1, len(enc)
-
+
# c. Passed the encoder result and the beam search
nbest_hyps = self.beam_search(
x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
)
-
+
nbest_hyps = nbest_hyps[: self.nbest]
-
+
results = []
for hyp in nbest_hyps:
assert isinstance(hyp, (Hypothesis)), type(hyp)
-
+
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
-
+
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != 0, token_int))
-
+
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
-
+
if self.tokenizer is not None:
text = self.tokenizer.tokens2text(token)
else:
text = None
results.append((text, token, token_int, hyp))
-
+
assert check_return_type(results)
return results
+
class Speech2TextParaformer:
"""Speech2Text class
@@ -312,9 +292,8 @@
# 1. Build ASR model
scorers = {}
- from funasr.tasks.asr import ASRTaskParaformer as ASRTask
- asr_model, asr_train_args = ASRTask.build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, device
+ asr_model, asr_train_args = build_model_from_file(
+ asr_train_config, asr_model_file, cmvn_file, device, mode="paraformer"
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
@@ -336,8 +315,8 @@
# 2. Build Language model
if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, device
+ lm, lm_train_args = build_model_from_file(
+ lm_train_config, lm_file, None, device, task_name="lm"
)
scorers["lm"] = lm.lm
@@ -466,18 +445,21 @@
pre_token_length = pre_token_length.round().long()
if torch.max(pre_token_length) < 1:
return []
- if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model, NeatContextualParaformer):
+ if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model,
+ NeatContextualParaformer):
if self.hotword_list:
logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
- decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
+ decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
+ pre_token_length)
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
else:
- decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list)
+ decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
+ pre_token_length, hw_list=self.hotword_list)
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
if isinstance(self.asr_model, BiCifParaformer):
_, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
- pre_token_length) # test no bias cif2
+ pre_token_length) # test no bias cif2
results = []
b, n, d = decoder_out.size()
@@ -527,12 +509,11 @@
text = None
timestamp = []
if isinstance(self.asr_model, BiCifParaformer):
- _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:enc_len[i]*3],
- us_peaks[i][:enc_len[i]*3],
- copy.copy(token),
- vad_offset=begin_time)
+ _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:enc_len[i] * 3],
+ us_peaks[i][:enc_len[i] * 3],
+ copy.copy(token),
+ vad_offset=begin_time)
results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor))
-
# assert check_return_type(results)
return results
@@ -591,6 +572,7 @@
hotword_list = None
return hotword_list
+
class Speech2TextParaformerOnline:
"""Speech2Text class
@@ -630,9 +612,8 @@
# 1. Build ASR model
scorers = {}
- from funasr.tasks.asr import ASRTaskParaformer as ASRTask
- asr_model, asr_train_args = ASRTask.build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, device
+ asr_model, asr_train_args = build_model_from_file(
+ asr_train_config, asr_model_file, cmvn_file, device, mode="paraformer"
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
@@ -654,8 +635,8 @@
# 2. Build Language model
if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, device
+ lm, lm_train_args = build_model_from_file(
+ lm_train_config, lm_file, None, device, task_name="lm"
)
scorers["lm"] = lm.lm
@@ -789,7 +770,7 @@
enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
predictor_outs = self.asr_model.calc_predictor_chunk(enc, cache)
- pre_acoustic_embeds, pre_token_length= predictor_outs[0], predictor_outs[1]
+ pre_acoustic_embeds, pre_token_length = predictor_outs[0], predictor_outs[1]
if torch.max(pre_token_length) < 1:
return []
decoder_outs = self.asr_model.cal_decoder_with_predictor_chunk(enc, pre_acoustic_embeds, cache)
@@ -839,11 +820,12 @@
postprocessed_result += item + " "
else:
postprocessed_result += item
-
+
results.append(postprocessed_result)
# assert check_return_type(results)
return results
+
class Speech2TextUniASR:
"""Speech2Text class
@@ -886,9 +868,8 @@
# 1. Build ASR model
scorers = {}
- from funasr.tasks.asr import ASRTaskUniASR as ASRTask
- asr_model, asr_train_args = ASRTask.build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, device
+ asr_model, asr_train_args = build_model_from_file(
+ asr_train_config, asr_model_file, cmvn_file, device, mode="uniasr"
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
@@ -914,8 +895,8 @@
# 2. Build Language model
if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, device
+ lm, lm_train_args = build_model_from_file(
+ lm_train_config, lm_file, device, "lm"
)
scorers["lm"] = lm.lm
@@ -1077,7 +1058,7 @@
assert check_return_type(results)
return results
-
+
class Speech2TextMFCCA:
"""Speech2Text class
@@ -1090,45 +1071,44 @@
[(text, token, token_int, hypothesis object), ...]
"""
-
+
def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- maxlenratio: float = 0.0,
- minlenratio: float = 0.0,
- batch_size: int = 1,
- dtype: str = "float32",
- beam_size: int = 20,
- ctc_weight: float = 0.5,
- lm_weight: float = 1.0,
- ngram_weight: float = 0.9,
- penalty: float = 0.0,
- nbest: int = 1,
- streaming: bool = False,
- **kwargs,
+ self,
+ asr_train_config: Union[Path, str] = None,
+ asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ lm_train_config: Union[Path, str] = None,
+ lm_file: Union[Path, str] = None,
+ token_type: str = None,
+ bpemodel: str = None,
+ device: str = "cpu",
+ maxlenratio: float = 0.0,
+ minlenratio: float = 0.0,
+ batch_size: int = 1,
+ dtype: str = "float32",
+ beam_size: int = 20,
+ ctc_weight: float = 0.5,
+ lm_weight: float = 1.0,
+ ngram_weight: float = 0.9,
+ penalty: float = 0.0,
+ nbest: int = 1,
+ streaming: bool = False,
+ **kwargs,
):
assert check_argument_types()
-
+
# 1. Build ASR model
- from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
scorers = {}
- asr_model, asr_train_args = ASRTask.build_model_from_file(
+ asr_model, asr_train_args = build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
-
+
logging.info("asr_model: {}".format(asr_model))
logging.info("asr_train_args: {}".format(asr_train_args))
asr_model.to(dtype=getattr(torch, dtype)).eval()
-
+
decoder = asr_model.decoder
-
+
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
token_list = asr_model.token_list
scorers.update(
@@ -1136,11 +1116,11 @@
ctc=ctc,
length_bonus=LengthBonus(len(token_list)),
)
-
+
# 2. Build Language model
if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, device
+ lm, lm_train_args = build_model_from_file(
+ lm_train_config, lm_file, None, device, task_name="lm"
)
lm.to(device)
scorers["lm"] = lm.lm
@@ -1148,11 +1128,11 @@
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
-
+
# 4. Build BeamSearch object
# transducer is not supported now
beam_search_transducer = None
-
+
weights = dict(
decoder=1.0 - ctc_weight,
ctc=ctc_weight,
@@ -1176,7 +1156,7 @@
token_type = asr_train_args.token_type
if bpemodel is None:
bpemodel = asr_train_args.bpemodel
-
+
if token_type is None:
tokenizer = None
elif token_type == "bpe":
@@ -1188,7 +1168,7 @@
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
-
+
self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.converter = converter
@@ -1200,10 +1180,10 @@
self.device = device
self.dtype = dtype
self.nbest = nbest
-
+
@torch.no_grad()
def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
+ self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
) -> List[
Tuple[
Optional[str],
@@ -1231,45 +1211,45 @@
# lenghts: (1,)
lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
batch = {"speech": speech, "speech_lengths": lengths}
-
+
# a. To device
batch = to_device(batch, device=self.device)
-
+
# b. Forward Encoder
enc, _ = self.asr_model.encode(**batch)
-
+
assert len(enc) == 1, len(enc)
-
+
# c. Passed the encoder result and the beam search
nbest_hyps = self.beam_search(
x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
)
-
+
nbest_hyps = nbest_hyps[: self.nbest]
-
+
results = []
for hyp in nbest_hyps:
assert isinstance(hyp, (Hypothesis)), type(hyp)
-
+
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
-
+
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != 0, token_int))
-
+
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
-
+
if self.tokenizer is not None:
text = self.tokenizer.tokens2text(token)
else:
text = None
results.append((text, token, token_int, hyp))
-
+
assert check_return_type(results)
return results
@@ -1298,45 +1278,44 @@
right_context: Number of frames in right context AFTER subsampling.
display_partial_hypotheses: Whether to display partial hypotheses.
"""
-
+
def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- beam_search_config: Dict[str, Any] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- beam_size: int = 5,
- dtype: str = "float32",
- lm_weight: float = 1.0,
- quantize_asr_model: bool = False,
- quantize_modules: List[str] = None,
- quantize_dtype: str = "qint8",
- nbest: int = 1,
- streaming: bool = False,
- simu_streaming: bool = False,
- chunk_size: int = 16,
- left_context: int = 32,
- right_context: int = 0,
- display_partial_hypotheses: bool = False,
+ self,
+ asr_train_config: Union[Path, str] = None,
+ asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ beam_search_config: Dict[str, Any] = None,
+ lm_train_config: Union[Path, str] = None,
+ lm_file: Union[Path, str] = None,
+ token_type: str = None,
+ bpemodel: str = None,
+ device: str = "cpu",
+ beam_size: int = 5,
+ dtype: str = "float32",
+ lm_weight: float = 1.0,
+ quantize_asr_model: bool = False,
+ quantize_modules: List[str] = None,
+ quantize_dtype: str = "qint8",
+ nbest: int = 1,
+ streaming: bool = False,
+ simu_streaming: bool = False,
+ chunk_size: int = 16,
+ left_context: int = 32,
+ right_context: int = 0,
+ display_partial_hypotheses: bool = False,
) -> None:
"""Construct a Speech2Text object."""
super().__init__()
-
+
assert check_argument_types()
- from funasr.tasks.asr import ASRTransducerTask
- asr_model, asr_train_args = ASRTransducerTask.build_model_from_file(
+ asr_model, asr_train_args = build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
-
+
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
-
+
if quantize_asr_model:
if quantize_modules is not None:
if not all([q in ["LSTM", "Linear"] for q in quantize_modules]):
@@ -1344,36 +1323,36 @@
"Only 'Linear' and 'LSTM' modules are currently supported"
" by PyTorch and in --quantize_modules"
)
-
+
q_config = set([getattr(torch.nn, q) for q in quantize_modules])
else:
q_config = {torch.nn.Linear}
-
+
if quantize_dtype == "float16" and (V(torch.__version__) < V("1.5.0")):
raise ValueError(
"float16 dtype for dynamic quantization is not supported with torch"
" version < 1.5.0. Switching to qint8 dtype instead."
)
q_dtype = getattr(torch, quantize_dtype)
-
+
asr_model = torch.quantization.quantize_dynamic(
asr_model, q_config, dtype=q_dtype
).eval()
else:
asr_model.to(dtype=getattr(torch, dtype)).eval()
-
+
if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, device
+ lm, lm_train_args = build_model_from_file(
+ lm_train_config, lm_file, None, device, task_name="lm"
)
lm_scorer = lm.lm
else:
lm_scorer = None
-
+
# 4. Build BeamSearch object
if beam_search_config is None:
beam_search_config = {}
-
+
beam_search = BeamSearchTransducer(
asr_model.decoder,
asr_model.joint_network,
@@ -1383,14 +1362,14 @@
nbest=nbest,
**beam_search_config,
)
-
+
token_list = asr_model.token_list
-
+
if token_type is None:
token_type = asr_train_args.token_type
if bpemodel is None:
bpemodel = asr_train_args.bpemodel
-
+
if token_type is None:
tokenizer = None
elif token_type == "bpe":
@@ -1402,60 +1381,60 @@
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
-
+
self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.device = device
self.dtype = dtype
self.nbest = nbest
-
+
self.converter = converter
self.tokenizer = tokenizer
-
+
self.beam_search = beam_search
self.streaming = streaming
self.simu_streaming = simu_streaming
self.chunk_size = max(chunk_size, 0)
self.left_context = left_context
self.right_context = max(right_context, 0)
-
+
if not streaming or chunk_size == 0:
self.streaming = False
self.asr_model.encoder.dynamic_chunk_training = False
-
+
if not simu_streaming or chunk_size == 0:
self.simu_streaming = False
self.asr_model.encoder.dynamic_chunk_training = False
-
+
self.frontend = frontend
self.window_size = self.chunk_size + self.right_context
-
+
if self.streaming:
self._ctx = self.asr_model.encoder.get_encoder_input_size(
self.window_size
)
-
+
self.last_chunk_length = (
- self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
+ self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
)
self.reset_inference_cache()
-
+
def reset_inference_cache(self) -> None:
"""Reset Speech2Text parameters."""
self.frontend_cache = None
-
+
self.asr_model.encoder.reset_streaming_cache(
self.left_context, device=self.device
)
self.beam_search.reset_inference_cache()
-
+
self.num_processed_frames = torch.tensor([[0]], device=self.device)
-
+
@torch.no_grad()
def streaming_decode(
- self,
- speech: Union[torch.Tensor, np.ndarray],
- is_final: bool = True,
+ self,
+ speech: Union[torch.Tensor, np.ndarray],
+ is_final: bool = True,
) -> List[HypothesisTransducer]:
"""Speech2Text streaming call.
Args:
@@ -1473,13 +1452,13 @@
)
speech = torch.cat([speech, pad],
dim=0) # feats, feats_length = self.apply_frontend(speech, is_final=is_final)
-
+
feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
-
+
if self.asr_model.normalize is not None:
feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
-
+
feats = to_device(feats, device=self.device)
feats_lengths = to_device(feats_lengths, device=self.device)
enc_out = self.asr_model.encoder.chunk_forward(
@@ -1491,14 +1470,14 @@
right_context=self.right_context,
)
nbest_hyps = self.beam_search(enc_out[0], is_final=is_final)
-
+
self.num_processed_frames += self.chunk_size
-
+
if is_final:
self.reset_inference_cache()
-
+
return nbest_hyps
-
+
@torch.no_grad()
def simu_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
"""Speech2Text call.
@@ -1508,29 +1487,29 @@
nbest_hypothesis: N-best hypothesis.
"""
assert check_argument_types()
-
+
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
-
+
if self.frontend is not None:
speech = torch.unsqueeze(speech, axis=0)
speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
feats, feats_lengths = self.frontend(speech, speech_lengths)
- else:
+ else:
feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
-
+
if self.asr_model.normalize is not None:
feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
-
+
feats = to_device(feats, device=self.device)
feats_lengths = to_device(feats_lengths, device=self.device)
enc_out = self.asr_model.encoder.simu_chunk_forward(feats, feats_lengths, self.chunk_size, self.left_context,
self.right_context)
nbest_hyps = self.beam_search(enc_out[0])
-
+
return nbest_hyps
-
+
@torch.no_grad()
def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
"""Speech2Text call.
@@ -1540,7 +1519,7 @@
nbest_hypothesis: N-best hypothesis.
"""
assert check_argument_types()
-
+
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
@@ -1548,19 +1527,19 @@
speech = torch.unsqueeze(speech, axis=0)
speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
feats, feats_lengths = self.frontend(speech, speech_lengths)
- else:
+ else:
feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
-
+
feats = to_device(feats, device=self.device)
feats_lengths = to_device(feats_lengths, device=self.device)
-
+
enc_out, _, _ = self.asr_model.encoder(feats, feats_lengths)
-
+
nbest_hyps = self.beam_search(enc_out[0])
-
+
return nbest_hyps
-
+
def hypotheses_to_results(self, nbest_hyps: List[HypothesisTransducer]) -> List[Any]:
"""Build partial or final results from the hypotheses.
Args:
@@ -1569,26 +1548,26 @@
results: Results containing different representation for the hypothesis.
"""
results = []
-
+
for hyp in nbest_hyps:
token_int = list(filter(lambda x: x != 0, hyp.yseq))
-
+
token = self.converter.ids2tokens(token_int)
-
+
if self.tokenizer is not None:
text = self.tokenizer.tokens2text(token)
else:
text = None
results.append((text, token, token_int, hyp))
-
+
assert check_return_type(results)
-
+
return results
-
+
@staticmethod
def from_pretrained(
- model_tag: Optional[str] = None,
- **kwargs: Optional[Any],
+ model_tag: Optional[str] = None,
+ **kwargs: Optional[Any],
) -> Speech2Text:
"""Build Speech2Text instance from the pretrained model.
Args:
@@ -1599,7 +1578,7 @@
if model_tag is not None:
try:
from espnet_model_zoo.downloader import ModelDownloader
-
+
except ImportError:
logging.error(
"`espnet_model_zoo` is not installed. "
@@ -1608,7 +1587,7 @@
raise
d = ModelDownloader()
kwargs.update(**d.download_and_unpack(model_tag))
-
+
return Speech2TextTransducer(**kwargs)
@@ -1623,37 +1602,36 @@
[(text, token, token_int, hypothesis object), ...]
"""
-
+
def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- maxlenratio: float = 0.0,
- minlenratio: float = 0.0,
- batch_size: int = 1,
- dtype: str = "float32",
- beam_size: int = 20,
- ctc_weight: float = 0.5,
- lm_weight: float = 1.0,
- ngram_weight: float = 0.9,
- penalty: float = 0.0,
- nbest: int = 1,
- streaming: bool = False,
- frontend_conf: dict = None,
- **kwargs,
+ self,
+ asr_train_config: Union[Path, str] = None,
+ asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ lm_train_config: Union[Path, str] = None,
+ lm_file: Union[Path, str] = None,
+ token_type: str = None,
+ bpemodel: str = None,
+ device: str = "cpu",
+ maxlenratio: float = 0.0,
+ minlenratio: float = 0.0,
+ batch_size: int = 1,
+ dtype: str = "float32",
+ beam_size: int = 20,
+ ctc_weight: float = 0.5,
+ lm_weight: float = 1.0,
+ ngram_weight: float = 0.9,
+ penalty: float = 0.0,
+ nbest: int = 1,
+ streaming: bool = False,
+ frontend_conf: dict = None,
+ **kwargs,
):
assert check_argument_types()
-
+
# 1. Build ASR model
- from funasr.tasks.asr import ASRTaskSAASR
scorers = {}
- asr_model, asr_train_args = ASRTaskSAASR.build_model_from_file(
+ asr_model, asr_train_args = build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
frontend = None
@@ -1665,13 +1643,13 @@
else:
frontend_class = frontend_choices.get_class(asr_train_args.frontend)
frontend = frontend_class(**asr_train_args.frontend_conf).eval()
-
+
logging.info("asr_model: {}".format(asr_model))
logging.info("asr_train_args: {}".format(asr_train_args))
asr_model.to(dtype=getattr(torch, dtype)).eval()
-
+
decoder = asr_model.decoder
-
+
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
token_list = asr_model.token_list
scorers.update(
@@ -1679,24 +1657,24 @@
ctc=ctc,
length_bonus=LengthBonus(len(token_list)),
)
-
+
# 2. Build Language model
if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, None, device
+ lm, lm_train_args = build_model_from_file(
+ lm_train_config, lm_file, None, device, task_name="lm"
)
scorers["lm"] = lm.lm
-
+
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
-
+
# 4. Build BeamSearch object
# transducer is not supported now
beam_search_transducer = None
from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch
-
+
weights = dict(
decoder=1.0 - ctc_weight,
ctc=ctc_weight,
@@ -1714,13 +1692,13 @@
token_list=token_list,
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
)
-
+
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
token_type = asr_train_args.token_type
if bpemodel is None:
bpemodel = asr_train_args.bpemodel
-
+
if token_type is None:
tokenizer = None
elif token_type == "bpe":
@@ -1732,7 +1710,7 @@
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
-
+
self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.converter = converter
@@ -1745,11 +1723,11 @@
self.dtype = dtype
self.nbest = nbest
self.frontend = frontend
-
+
@torch.no_grad()
def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray],
- profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray]
+ self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray],
+ profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray]
) -> List[
Tuple[
Optional[str],
@@ -1768,14 +1746,14 @@
"""
assert check_argument_types()
-
+
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
-
+
if isinstance(profile, np.ndarray):
profile = torch.tensor(profile)
-
+
if self.frontend is not None:
feats, feats_len = self.frontend.forward(speech, speech_lengths)
feats = to_device(feats, device=self.device)
@@ -1786,10 +1764,10 @@
feats_len = speech_lengths
lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
batch = {"speech": feats, "speech_lengths": feats_len}
-
+
# a. To device
batch = to_device(batch, device=self.device)
-
+
# b. Forward Encoder
asr_enc, _, spk_enc = self.asr_model.encode(**batch)
if isinstance(asr_enc, tuple):
@@ -1798,30 +1776,30 @@
spk_enc = spk_enc[0]
assert len(asr_enc) == 1, len(asr_enc)
assert len(spk_enc) == 1, len(spk_enc)
-
+
# c. Passed the encoder result and the beam search
nbest_hyps = self.beam_search(
asr_enc[0], spk_enc[0], profile[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
)
-
+
nbest_hyps = nbest_hyps[: self.nbest]
-
+
results = []
for hyp in nbest_hyps:
assert isinstance(hyp, (HypothesisSAASR)), type(hyp)
-
+
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1: last_pos]
else:
token_int = hyp.yseq[1: last_pos].tolist()
-
+
spk_weigths = torch.stack(hyp.spk_weigths, dim=0)
-
+
token_ori = self.converter.ids2tokens(token_int)
text_ori = self.tokenizer.tokens2text(token_ori)
-
+
text_ori_spklist = text_ori.split('$')
cur_index = 0
spk_choose = []
@@ -1833,32 +1811,32 @@
spk_weights_local = spk_weights_local.mean(dim=0)
spk_choose_local = spk_weights_local.argmax(-1)
spk_choose.append(spk_choose_local.item() + 1)
-
+
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != 0, token_int))
-
+
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
-
+
if self.tokenizer is not None:
text = self.tokenizer.tokens2text(token)
else:
text = None
-
+
text_spklist = text.split('$')
assert len(spk_choose) == len(text_spklist)
-
+
spk_list = []
for i in range(len(text_spklist)):
text_split = text_spklist[i]
n = len(text_split)
spk_list.append(str(spk_choose[i]) * n)
-
+
text_id = '$'.join(spk_list)
-
+
assert len(text) == len(text_id)
-
+
results.append((text, text_id, token, token_int, hyp))
-
+
assert check_return_type(results)
return results
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index 19042d2..656a965 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -1,5 +1,5 @@
-# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
@@ -7,109 +7,77 @@
import logging
import os
import sys
-from typing import Union, Dict, Any
-
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-
-#!/usr/bin/env python3
-import argparse
-import logging
-import sys
import time
-import copy
-import os
-import codecs
-import tempfile
-import requests
from pathlib import Path
+from typing import Dict
+from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
-from typing import Dict
-from typing import Any
-from typing import List
-import yaml
+
import numpy as np
import torch
import torchaudio
+import yaml
from typeguard import check_argument_types
-from typeguard import check_return_type
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.modules.beam_search.beam_search import BeamSearch
-# from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
+from funasr.bin.asr_infer import Speech2Text
+from funasr.bin.asr_infer import Speech2TextMFCCA
+from funasr.bin.asr_infer import Speech2TextParaformer, Speech2TextParaformerOnline
+from funasr.bin.asr_infer import Speech2TextSAASR
+from funasr.bin.asr_infer import Speech2TextTransducer
+from funasr.bin.asr_infer import Speech2TextUniASR
+from funasr.bin.punc_infer import Text2Punc
+from funasr.bin.tp_infer import Speech2Timestamp
+from funasr.bin.vad_infer import Speech2VadSegment
+from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
+from funasr.fileio.datadir_writer import DatadirWriter
from funasr.modules.beam_search.beam_search import Hypothesis
-from funasr.modules.scorers.ctc import CTCPrefixScorer
-from funasr.modules.scorers.length_bonus import LengthBonus
from funasr.modules.subsampling import TooShortUttError
-from funasr.tasks.asr import ASRTask
-from funasr.tasks.lm import LMTask
-from funasr.text.build_tokenizer import build_tokenizer
-from funasr.text.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import asr_utils, postprocess_utils
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
-from funasr.utils import asr_utils, wav_utils, postprocess_utils
-from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
-from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
-from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
-from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
-from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
-
-
from funasr.utils.vad_utils import slice_padding_fbank
-from funasr.tasks.vad import VADTask
-from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
-from funasr.bin.asr_infer import Speech2Text
-from funasr.bin.asr_infer import Speech2TextParaformer, Speech2TextParaformerOnline
-from funasr.bin.asr_infer import Speech2TextUniASR
-from funasr.bin.asr_infer import Speech2TextMFCCA
-from funasr.bin.vad_infer import Speech2VadSegment
-from funasr.bin.punc_infer import Text2Punc
-from funasr.bin.tp_infer import Speech2Timestamp
-from funasr.bin.asr_infer import Speech2TextTransducer
-from funasr.bin.asr_infer import Speech2TextSAASR
+
def inference_asr(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- streaming: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- mc: bool = False,
- param_dict: dict = None,
- **kwargs,
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ mc: bool = False,
+ param_dict: dict = None,
+ **kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
@@ -120,23 +88,23 @@
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
-
+
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
-
+
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
-
+
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
-
+
# 1. Set random-seed
set_all_random_seed(seed)
-
+
# 2. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
@@ -160,7 +128,7 @@
)
logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
speech2text = Speech2Text(**speech2text_kwargs)
-
+
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
@@ -173,20 +141,18 @@
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="asr",
+ preprocess_args=speech2text.asr_train_args,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
fs=fs,
mc=mc,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
- collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
)
-
+
finish_count = 0
file_count = 1
# 7 .Start for-loop
@@ -197,14 +163,14 @@
writer = DatadirWriter(output_path)
else:
writer = None
-
+
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
# batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
+
# N-best list of (text, token, token_int, hyp_object)
try:
results = speech2text(**batch)
@@ -212,19 +178,19 @@
logging.warning(f"Utterance {keys} {e}")
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
results = [[" ", ["sil"], [2], hyp]] * nbest
-
+
# Only supporting batch_size==1
key = keys[0]
for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
# Create a directory: outdir/{n}best_recog
if writer is not None:
ibest_writer = writer[f"{n}best_recog"]
-
+
# Write the result to each file
ibest_writer["token"][key] = " ".join(token)
ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["score"][key] = str(hyp.score)
-
+
if text is not None:
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
@@ -233,67 +199,67 @@
asr_utils.print_progress(finish_count / file_count)
if writer is not None:
ibest_writer["text"][key] = text
-
+
logging.info("uttid: {}".format(key))
logging.info("text predictions: {}\n".format(text))
return asr_result_list
-
+
return _forward
def inference_paraformer(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- output_dir: Optional[str] = None,
- timestamp_infer_config: Union[Path, str] = None,
- timestamp_model_file: Union[Path, str] = None,
- param_dict: dict = None,
- **kwargs,
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ output_dir: Optional[str] = None,
+ timestamp_infer_config: Union[Path, str] = None,
+ timestamp_model_file: Union[Path, str] = None,
+ param_dict: dict = None,
+ **kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
-
+
if word_lm_train_config is not None:
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
-
+
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
-
+
export_mode = False
if param_dict is not None:
hotword_list_or_file = param_dict.get('hotword')
export_mode = param_dict.get("export_mode", False)
else:
hotword_list_or_file = None
-
+
if kwargs.get("device", None) == "cpu":
ngpu = 0
if ngpu >= 1 and torch.cuda.is_available():
@@ -301,10 +267,10 @@
else:
device = "cpu"
batch_size = 1
-
+
# 1. Set random-seed
set_all_random_seed(seed)
-
+
# 2. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
@@ -326,9 +292,9 @@
nbest=nbest,
hotword_list_or_file=hotword_list_or_file,
)
-
+
speech2text = Speech2TextParaformer(**speech2text_kwargs)
-
+
if timestamp_model_file is not None:
speechtext2timestamp = Speech2Timestamp(
timestamp_cmvn_file=cmvn_file,
@@ -337,16 +303,16 @@
)
else:
speechtext2timestamp = None
-
+
def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- **kwargs,
+ data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ **kwargs,
):
-
+
hotword_list_or_file = None
if param_dict is not None:
hotword_list_or_file = param_dict.get('hotword')
@@ -354,30 +320,28 @@
hotword_list_or_file = kwargs['hotword']
if hotword_list_or_file is not None or 'hotword' in kwargs:
speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
-
+
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="asr",
+ preprocess_args=speech2text.asr_train_args,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
fs=fs,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
- collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
)
-
+
if param_dict is not None:
use_timestamp = param_dict.get('use_timestamp', True)
else:
use_timestamp = True
-
+
forward_time_total = 0.0
length_total = 0.0
finish_count = 0
@@ -390,17 +354,17 @@
writer = DatadirWriter(output_path)
else:
writer = None
-
+
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
# batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
-
+
logging.info("decoding, utt_id: {}".format(keys))
# N-best list of (text, token, token_int, hyp_object)
-
+
time_beg = time.time()
results = speech2text(**batch)
if len(results) < 1:
@@ -416,10 +380,10 @@
100 * forward_time / (
length * lfr_factor))
logging.info(rtf_cur)
-
+
for batch_id in range(_bs):
result = [results[batch_id][:-2]]
-
+
key = keys[batch_id]
for n, result in zip(range(1, nbest + 1), result):
text, token, token_int, hyp = result[0], result[1], result[2], result[3]
@@ -438,13 +402,13 @@
# Create a directory: outdir/{n}best_recog
if writer is not None:
ibest_writer = writer[f"{n}best_recog"]
-
+
# Write the result to each file
ibest_writer["token"][key] = " ".join(token)
# ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["score"][key] = str(hyp.score)
ibest_writer["rtf"][key] = rtf_cur
-
+
if text is not None:
if use_timestamp and timestamp is not None:
postprocessed_result = postprocess_utils.sentence_postprocess(token, timestamp)
@@ -465,7 +429,7 @@
# asr_utils.print_progress(finish_count / file_count)
if writer is not None:
ibest_writer["text"][key] = " ".join(word_lists)
-
+
logging.info("decoding, utt: {}, predictions: {}".format(key, text))
rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total,
forward_time_total,
@@ -475,74 +439,74 @@
if writer is not None:
ibest_writer["rtf"]["rtf_avf"] = rtf_avg
return asr_result_list
-
+
return _forward
def inference_paraformer_vad_punc(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- vad_infer_config: Optional[str] = None,
- vad_model_file: Optional[str] = None,
- vad_cmvn_file: Optional[str] = None,
- time_stamp_writer: bool = True,
- punc_infer_config: Optional[str] = None,
- punc_model_file: Optional[str] = None,
- outputs_dict: Optional[bool] = True,
- param_dict: dict = None,
- **kwargs,
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ vad_infer_config: Optional[str] = None,
+ vad_model_file: Optional[str] = None,
+ vad_cmvn_file: Optional[str] = None,
+ time_stamp_writer: bool = True,
+ punc_infer_config: Optional[str] = None,
+ punc_model_file: Optional[str] = None,
+ outputs_dict: Optional[bool] = True,
+ param_dict: dict = None,
+ **kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
-
+
if word_lm_train_config is not None:
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
-
+
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
-
+
if param_dict is not None:
hotword_list_or_file = param_dict.get('hotword')
else:
hotword_list_or_file = None
-
+
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
-
+
# 1. Set random-seed
set_all_random_seed(seed)
-
+
# 2. Build speech2vadsegment
speech2vadsegment_kwargs = dict(
vad_infer_config=vad_infer_config,
@@ -553,7 +517,7 @@
)
# logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
-
+
# 3. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
@@ -579,12 +543,12 @@
text2punc = None
if punc_model_file is not None:
text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype)
-
+
if output_dir is not None:
writer = DatadirWriter(output_dir)
ibest_writer = writer[f"1best_recog"]
ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
-
+
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
@@ -592,43 +556,41 @@
param_dict: dict = None,
**kwargs,
):
-
+
hotword_list_or_file = None
if param_dict is not None:
hotword_list_or_file = param_dict.get('hotword')
-
+
if 'hotword' in kwargs:
hotword_list_or_file = kwargs['hotword']
-
+
batch_size_token = kwargs.get("batch_size_token", 6000)
print("batch_size_token: ", batch_size_token)
-
+
if speech2text.hotword_list is None:
speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
-
+
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="asr",
+ preprocess_args=None,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
fs=fs,
batch_size=1,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
- collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
)
-
+
if param_dict is not None:
use_timestamp = param_dict.get('use_timestamp', True)
else:
use_timestamp = True
-
+
finish_count = 0
file_count = 1
lfr_factor = 6
@@ -639,7 +601,7 @@
if output_path is not None:
writer = DatadirWriter(output_path)
ibest_writer = writer[f"1best_recog"]
-
+
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
@@ -648,15 +610,16 @@
beg_vad = time.time()
vad_results = speech2vadsegment(**batch)
end_vad = time.time()
- print("time cost vad: ", end_vad-beg_vad)
+ print("time cost vad: ", end_vad - beg_vad)
_, vadsegments = vad_results[0], vad_results[1][0]
-
+
speech, speech_lengths = batch["speech"], batch["speech_lengths"]
-
+
n = len(vadsegments)
data_with_index = [(vadsegments[i], i) for i in range(n)]
sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
results_sorted = []
+
batch_size_token_ms = batch_size_token*60
if speech2text.device == "cpu":
batch_size_token_ms = 0
@@ -666,7 +629,8 @@
beg_idx = 0
for j, _ in enumerate(range(0, n)):
batch_size_token_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
- if j < n-1 and (batch_size_token_ms_cum + sorted_data[j+1][0][1] - sorted_data[j+1][0][0])<batch_size_token_ms:
+ if j < n - 1 and (batch_size_token_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][
+ 0]) < batch_size_token_ms:
continue
batch_size_token_ms_cum = 0
end_idx = j + 1
@@ -679,11 +643,11 @@
results = speech2text(**batch)
end_asr = time.time()
print("time cost asr: ", end_asr - beg_asr)
-
+
if len(results) < 1:
results = [["", [], [], [], [], [], []]]
results_sorted.extend(results)
-
+
restored_data = [0] * n
for j in range(n):
index = sorted_data[j][1]
@@ -699,12 +663,12 @@
t[1] += vadsegments[j][0]
result[4] += restored_data[j][4]
# result = [result[k]+restored_data[j][k] for k in range(len(result[:-2]))]
-
+
key = keys[0]
# result = result_segments[0]
text, token, token_int = result[0], result[1], result[2]
time_stamp = result[4] if len(result[4]) > 0 else None
-
+
if use_timestamp and time_stamp is not None:
postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
else:
@@ -718,23 +682,23 @@
postprocessed_result[2]
else:
text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
-
+
text_postprocessed_punc = text_postprocessed
punc_id_list = []
if len(word_lists) > 0 and text2punc is not None:
beg_punc = time.time()
text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
end_punc = time.time()
- print("time cost punc: ", end_punc-beg_punc)
-
+ print("time cost punc: ", end_punc - beg_punc)
+
item = {'key': key, 'value': text_postprocessed_punc}
if text_postprocessed != "":
item['text_postprocessed'] = text_postprocessed
if time_stamp_postprocessed != "":
item['time_stamp'] = time_stamp_postprocessed
-
+
item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed)
-
+
asr_result_list.append(item)
finish_count += 1
# asr_utils.print_progress(finish_count / file_count)
@@ -747,11 +711,12 @@
ibest_writer["text_with_punc"][key] = text_postprocessed_punc
if time_stamp_postprocessed is not None:
ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
-
+
logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc))
return asr_result_list
-
+
return _forward
+
def inference_paraformer_online(
maxlenratio: float,
@@ -852,7 +817,7 @@
data = yaml.load(f, Loader=yaml.Loader)
return data
- def _prepare_cache(cache: dict = {}, chunk_size=[5,10,5], batch_size=1):
+ def _prepare_cache(cache: dict = {}, chunk_size=[5, 10, 5], batch_size=1):
if len(cache) > 0:
return cache
config = _read_yaml(asr_train_config)
@@ -868,14 +833,15 @@
return cache
- def _cache_reset(cache: dict = {}, chunk_size=[5,10,5], batch_size=1):
+ def _cache_reset(cache: dict = {}, chunk_size=[5, 10, 5], batch_size=1):
if len(cache) > 0:
config = _read_yaml(asr_train_config)
enc_output_size = config["encoder_conf"]["output_size"]
feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
"cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
- "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
+ "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
+ "tail_chunk": False}
cache["encoder"] = cache_en
cache_de = {"decode_fsmn": None}
@@ -920,7 +886,7 @@
if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
sample_offset = 0
speech_length = raw_inputs.shape[1]
- stride_size = chunk_size[1] * 960
+ stride_size = chunk_size[1] * 960
cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
final_result = ""
for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)):
@@ -949,40 +915,40 @@
def inference_uniasr(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- ngram_file: Optional[str] = None,
- cmvn_file: Optional[str] = None,
- # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- streaming: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- token_num_relax: int = 1,
- decoding_ind: int = 0,
- decoding_mode: str = "model1",
- param_dict: dict = None,
- **kwargs,
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ ngram_file: Optional[str] = None,
+ cmvn_file: Optional[str] = None,
+ # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ token_num_relax: int = 1,
+ decoding_ind: int = 0,
+ decoding_mode: str = "model1",
+ param_dict: dict = None,
+ **kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
@@ -993,17 +959,17 @@
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
-
+
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
-
+
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
-
+
if param_dict is not None and "decoding_model" in param_dict:
if param_dict["decoding_model"] == "fast":
decoding_ind = 0
@@ -1016,10 +982,10 @@
decoding_mode = "model2"
else:
raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
-
+
# 1. Set random-seed
set_all_random_seed(seed)
-
+
# 2. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
@@ -1046,7 +1012,7 @@
decoding_mode=decoding_mode,
)
speech2text = Speech2TextUniASR(**speech2text_kwargs)
-
+
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
@@ -1059,19 +1025,17 @@
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="asr",
+ preprocess_args=speech2text.asr_train_args,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
fs=fs,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
- collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
)
-
+
finish_count = 0
file_count = 1
# 7 .Start for-loop
@@ -1082,14 +1046,14 @@
writer = DatadirWriter(output_path)
else:
writer = None
-
+
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
# batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
+
# N-best list of (text, token, token_int, hyp_object)
try:
results = speech2text(**batch)
@@ -1097,7 +1061,7 @@
logging.warning(f"Utterance {keys} {e}")
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
results = [[" ", ["sil"], [2], hyp]] * nbest
-
+
# Only supporting batch_size==1
key = keys[0]
logging.info(f"Utterance: {key}")
@@ -1105,12 +1069,12 @@
# Create a directory: outdir/{n}best_recog
if writer is not None:
ibest_writer = writer[f"{n}best_recog"]
-
+
# Write the result to each file
ibest_writer["token"][key] = " ".join(token)
# ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["score"][key] = str(hyp.score)
-
+
if text is not None:
text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
@@ -1120,40 +1084,40 @@
if writer is not None:
ibest_writer["text"][key] = " ".join(word_lists)
return asr_result_list
-
+
return _forward
def inference_mfcca(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- streaming: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- param_dict: dict = None,
- **kwargs,
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ param_dict: dict = None,
+ **kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
@@ -1164,20 +1128,20 @@
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
-
+
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
-
+
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
-
+
# 1. Set random-seed
set_all_random_seed(seed)
-
+
# 2. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
@@ -1201,7 +1165,7 @@
)
logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
speech2text = Speech2TextMFCCA(**speech2text_kwargs)
-
+
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
@@ -1214,20 +1178,18 @@
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="asr",
+ preprocess_args=speech2text.asr_train_args,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
fs=fs,
mc=True,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
- collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
)
-
+
finish_count = 0
file_count = 1
# 7 .Start for-loop
@@ -1238,14 +1200,14 @@
writer = DatadirWriter(output_path)
else:
writer = None
-
+
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
# batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
+
# N-best list of (text, token, token_int, hyp_object)
try:
results = speech2text(**batch)
@@ -1253,19 +1215,19 @@
logging.warning(f"Utterance {keys} {e}")
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
results = [[" ", ["<space>"], [2], hyp]] * nbest
-
+
# Only supporting batch_size==1
key = keys[0]
for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
# Create a directory: outdir/{n}best_recog
if writer is not None:
ibest_writer = writer[f"{n}best_recog"]
-
+
# Write the result to each file
ibest_writer["token"][key] = " ".join(token)
# ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["score"][key] = str(hyp.score)
-
+
if text is not None:
text_postprocessed = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
@@ -1275,42 +1237,43 @@
if writer is not None:
ibest_writer["text"][key] = text
return asr_result_list
-
+
return _forward
+
def inference_transducer(
- output_dir: str,
- batch_size: int,
- dtype: str,
- beam_size: int,
- ngpu: int,
- seed: int,
- lm_weight: float,
- nbest: int,
- num_workers: int,
- log_level: Union[int, str],
- data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str],
- beam_search_config: Optional[dict],
- lm_train_config: Optional[str],
- lm_file: Optional[str],
- model_tag: Optional[str],
- token_type: Optional[str],
- bpemodel: Optional[str],
- key_file: Optional[str],
- allow_variable_data_keys: bool,
- quantize_asr_model: Optional[bool],
- quantize_modules: Optional[List[str]],
- quantize_dtype: Optional[str],
- streaming: Optional[bool],
- simu_streaming: Optional[bool],
- chunk_size: Optional[int],
- left_context: Optional[int],
- right_context: Optional[int],
- display_partial_hypotheses: bool,
- **kwargs,
+ output_dir: str,
+ batch_size: int,
+ dtype: str,
+ beam_size: int,
+ ngpu: int,
+ seed: int,
+ lm_weight: float,
+ nbest: int,
+ num_workers: int,
+ log_level: Union[int, str],
+ data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str],
+ beam_search_config: Optional[dict],
+ lm_train_config: Optional[str],
+ lm_file: Optional[str],
+ model_tag: Optional[str],
+ token_type: Optional[str],
+ bpemodel: Optional[str],
+ key_file: Optional[str],
+ allow_variable_data_keys: bool,
+ quantize_asr_model: Optional[bool],
+ quantize_modules: Optional[List[str]],
+ quantize_dtype: Optional[str],
+ streaming: Optional[bool],
+ simu_streaming: Optional[bool],
+ chunk_size: Optional[int],
+ left_context: Optional[int],
+ right_context: Optional[int],
+ display_partial_hypotheses: bool,
+ **kwargs,
) -> None:
"""Transducer model inference.
Args:
@@ -1391,7 +1354,7 @@
model_tag=model_tag,
**speech2text_kwargs,
)
-
+
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
@@ -1400,106 +1363,99 @@
**kwargs,
):
# 3. Build data-iterator
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="asr",
+ preprocess_args=speech2text.asr_train_args,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=ASRTask.build_preprocess_fn(
- speech2text.asr_train_args, False
- ),
- collate_fn=ASRTask.build_collate_fn(
- speech2text.asr_train_args, False
- ),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
)
-
+
# 4 .Start for-loop
with DatadirWriter(output_dir) as writer:
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
-
+
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
assert len(batch.keys()) == 1
-
+
try:
if speech2text.streaming:
speech = batch["speech"]
-
+
_steps = len(speech) // speech2text._ctx
_end = 0
for i in range(_steps):
_end = (i + 1) * speech2text._ctx
-
+
speech2text.streaming_decode(
- speech[i * speech2text._ctx : _end], is_final=False
+ speech[i * speech2text._ctx: _end], is_final=False
)
-
+
final_hyps = speech2text.streaming_decode(
- speech[_end : len(speech)], is_final=True
+ speech[_end: len(speech)], is_final=True
)
elif speech2text.simu_streaming:
final_hyps = speech2text.simu_streaming_decode(**batch)
else:
final_hyps = speech2text(**batch)
-
+
results = speech2text.hypotheses_to_results(final_hyps)
except TooShortUttError as e:
logging.warning(f"Utterance {keys} {e}")
hyp = Hypothesis(score=0.0, yseq=[], dec_state=None)
results = [[" ", ["<space>"], [2], hyp]] * nbest
-
+
key = keys[0]
for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
ibest_writer = writer[f"{n}best_recog"]
-
+
ibest_writer["token"][key] = " ".join(token)
ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["score"][key] = str(hyp.score)
-
+
if text is not None:
ibest_writer["text"][key] = text
-
return _forward
def inference_sa_asr(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- streaming: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- mc: bool = False,
- param_dict: dict = None,
- **kwargs,
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ mc: bool = False,
+ param_dict: dict = None,
+ **kwargs,
):
assert check_argument_types()
if batch_size > 1:
@@ -1508,23 +1464,23 @@
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
-
+
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
-
+
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
-
+
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
-
+
# 1. Set random-seed
set_all_random_seed(seed)
-
+
# 2. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
@@ -1548,7 +1504,7 @@
)
logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
speech2text = Speech2TextSAASR(**speech2text_kwargs)
-
+
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
@@ -1561,20 +1517,18 @@
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="asr",
+ preprocess_args=speech2text.asr_train_args,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
fs=fs,
mc=mc,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
- collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
)
-
+
finish_count = 0
file_count = 1
# 7 .Start for-loop
@@ -1585,7 +1539,7 @@
writer = DatadirWriter(output_path)
else:
writer = None
-
+
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
@@ -1599,20 +1553,20 @@
logging.warning(f"Utterance {keys} {e}")
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
results = [[" ", ["sil"], [2], hyp]] * nbest
-
+
# Only supporting batch_size==1
key = keys[0]
for n, (text, text_id, token, token_int, hyp) in zip(range(1, nbest + 1), results):
# Create a directory: outdir/{n}best_recog
if writer is not None:
ibest_writer = writer[f"{n}best_recog"]
-
+
# Write the result to each file
ibest_writer["token"][key] = " ".join(token)
ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["score"][key] = str(hyp.score)
ibest_writer["text_id"][key] = text_id
-
+
if text is not None:
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
@@ -1621,12 +1575,12 @@
asr_utils.print_progress(finish_count / file_count)
if writer is not None:
ibest_writer["text"][key] = text
-
+
logging.info("uttid: {}".format(key))
logging.info("text predictions: {}".format(text))
logging.info("text_id predictions: {}\n".format(text_id))
return asr_result_list
-
+
return _forward
@@ -1664,7 +1618,7 @@
description="ASR Decoding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
-
+
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
@@ -1674,7 +1628,7 @@
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
-
+
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument(
"--ngpu",
@@ -1707,7 +1661,7 @@
default=1,
help="The number of workers used for DataLoader",
)
-
+
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
@@ -1729,7 +1683,7 @@
default=False,
help="MultiChannel input",
)
-
+
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--vad_infer_config",
@@ -1792,7 +1746,7 @@
default={},
help="The keyword arguments for transducer beam search.",
)
-
+
group = parser.add_argument_group("Beam-search related")
group.add_argument(
"--batch_size",
@@ -1839,7 +1793,7 @@
default=False,
help="Whether to display partial hypotheses during chunk-by-chunk inference.",
)
-
+
group = parser.add_argument_group("Dynamic quantization related")
group.add_argument(
"--quantize_asr_model",
@@ -1864,7 +1818,7 @@
choices=["float16", "qint8"],
help="Dtype for dynamic quantization.",
)
-
+
group = parser.add_argument_group("Text converter related")
group.add_argument(
"--token_type",
@@ -1922,7 +1876,6 @@
inference_pipeline = inference_launch(**kwargs)
return inference_pipeline(kwargs["data_path_and_name_and_type"], hotword=kwargs.get("hotword", None))
-
if __name__ == "__main__":
diff --git a/funasr/bin/diar_infer.py b/funasr/bin/diar_infer.py
index 4460e3d..7c41b60 100755
--- a/funasr/bin/diar_infer.py
+++ b/funasr/bin/diar_infer.py
@@ -1,41 +1,28 @@
-# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
-import argparse
import logging
import os
-import sys
+from collections import OrderedDict
from pathlib import Path
from typing import Any
-from typing import List
from typing import Optional
-from typing import Sequence
-from typing import Tuple
from typing import Union
-from collections import OrderedDict
import numpy as np
-import soundfile
import torch
+from scipy.ndimage import median_filter
from torch.nn import functional as F
from typeguard import check_argument_types
-from typeguard import check_return_type
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.tasks.diar import DiarTask
-from funasr.tasks.diar import EENDOLADiarTask
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from scipy.ndimage import median_filter
-from funasr.utils.misc import statistic_model_parameters
-from funasr.datasets.iterable_dataset import load_bytes
from funasr.models.frontend.wav_frontend import WavFrontendMel23
+from funasr.tasks.diar import DiarTask
+from funasr.build_utils.build_model_from_file import build_model_from_file
+from funasr.torch_utils.device_funcs import to_device
+from funasr.utils.misc import statistic_model_parameters
+
class Speech2DiarizationEEND:
"""Speech2Diarlization class
@@ -61,10 +48,12 @@
assert check_argument_types()
# 1. Build Diarization model
- diar_model, diar_train_args = EENDOLADiarTask.build_model_from_file(
+ diar_model, diar_train_args = build_model_from_file(
config_file=diar_train_config,
model_file=diar_model_file,
- device=device
+ device=device,
+ task_name="diar",
+ mode="eend-ola",
)
frontend = None
if diar_train_args.frontend is not None and diar_train_args.frontend_conf is not None:
@@ -177,10 +166,12 @@
assert check_argument_types()
# TODO: 1. Build Diarization model
- diar_model, diar_train_args = DiarTask.build_model_from_file(
+ diar_model, diar_train_args = build_model_from_file(
config_file=diar_train_config,
model_file=diar_model_file,
- device=device
+ device=device,
+ task_name="diar",
+ mode="sond",
)
logging.info("diar_model: {}".format(diar_model))
logging.info("model parameter number: {}".format(statistic_model_parameters(diar_model)))
@@ -248,7 +239,7 @@
ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio
logits_idx = F.upsample(
logits_idx.unsqueeze(1).float(),
- size=(ut, ),
+ size=(ut,),
mode="nearest",
).squeeze(1).long()
logits_idx = logits_idx[0].tolist()
@@ -268,7 +259,7 @@
if spk not in results:
results[spk] = []
if dur > self.dur_threshold:
- results[spk].append((st, st+dur))
+ results[spk].append((st, st + dur))
# sort segments in start time ascending
for spk in results:
@@ -344,7 +335,3 @@
kwargs.update(**d.download_and_unpack(model_tag))
return Speech2DiarizationSOND(**kwargs)
-
-
-
-
diff --git a/funasr/bin/diar_inference_launch.py b/funasr/bin/diar_inference_launch.py
index e0d900e..820217b 100755
--- a/funasr/bin/diar_inference_launch.py
+++ b/funasr/bin/diar_inference_launch.py
@@ -1,5 +1,5 @@
+# !/usr/bin/env python3
# -*- encoding: utf-8 -*-
-#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
@@ -8,47 +8,28 @@
import logging
import os
import sys
-from typing import Union, Dict, Any
-
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-
-import argparse
-import logging
-import os
-import sys
-from pathlib import Path
-from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
-from collections import OrderedDict
import numpy as np
import soundfile
import torch
-from torch.nn import functional as F
-from typeguard import check_argument_types
-from typeguard import check_return_type
from scipy.signal import medfilt
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.tasks.diar import DiarTask
-from funasr.tasks.diar import EENDOLADiarTask
-from funasr.torch_utils.device_funcs import to_device
+from typeguard import check_argument_types
+
+from funasr.bin.diar_infer import Speech2DiarizationSOND, Speech2DiarizationEEND
+from funasr.datasets.iterable_dataset import load_bytes
+from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
-from scipy.ndimage import median_filter
-from funasr.utils.misc import statistic_model_parameters
-from funasr.datasets.iterable_dataset import load_bytes
-from funasr.bin.diar_infer import Speech2DiarizationSOND, Speech2DiarizationEEND
+
def inference_sond(
diar_train_config: str,
@@ -94,7 +75,8 @@
set_all_random_seed(seed)
# 2a. Build speech2xvec [Optional]
- if mode == "sond_demo" and param_dict is not None and "extract_profile" in param_dict and param_dict["extract_profile"]:
+ if mode == "sond_demo" and param_dict is not None and "extract_profile" in param_dict and param_dict[
+ "extract_profile"]:
assert "sv_train_config" in param_dict, "sv_train_config must be provided param_dict."
assert "sv_model_file" in param_dict, "sv_model_file must be provided in param_dict."
sv_train_config = param_dict["sv_train_config"]
@@ -139,7 +121,7 @@
rst = []
mid = uttid.rsplit("-", 1)[0]
for key in results:
- results[key] = [(x[0]/100, x[1]/100) for x in results[key]]
+ results[key] = [(x[0] / 100, x[1] / 100) for x in results[key]]
if out_format == "vad":
for spk, segs in results.items():
rst.append("{} {}".format(spk, segs))
@@ -176,7 +158,7 @@
example = [x.numpy() if isinstance(example[0], torch.Tensor) else x
for x in example]
speech = example[0]
- logging.info("Extracting profiles for {} waveforms".format(len(example)-1))
+ logging.info("Extracting profiles for {} waveforms".format(len(example) - 1))
profile = [speech2xvector.calculate_embedding(x) for x in example[1:]]
profile = torch.cat(profile, dim=0)
yield ["test{}".format(idx)], {"speech": [speech], "profile": [profile]}
@@ -186,16 +168,15 @@
raise TypeError("raw_inputs must be a list or tuple in [speech, profile1, profile2, ...] ")
else:
# 3. Build data-iterator
- loader = DiarTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="diar",
+ preprocess_args=None,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=None,
- collate_fn=None,
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
+ use_collate_fn=False,
)
# 7. Start for-loop
@@ -234,6 +215,7 @@
return result_list
return _forward
+
def inference_eend(
diar_train_config: str,
@@ -306,16 +288,14 @@
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs[0], "speech", "sound"]
- loader = EENDOLADiarTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="diar",
+ preprocess_args=None,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=EENDOLADiarTask.build_preprocess_fn(speech2diar.diar_train_args, False),
- collate_fn=EENDOLADiarTask.build_collate_fn(speech2diar.diar_train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
)
# 3. Start for-loop
@@ -362,8 +342,6 @@
return _forward
-
-
def inference_launch(mode, **kwargs):
if mode == "sond":
return inference_sond(mode=mode, **kwargs)
@@ -386,6 +364,7 @@
logging.info("Unknown decoding mode: {}".format(mode))
return None
+
def get_parser():
parser = config_argparse.ArgumentParser(
description="Speaker Verification",
diff --git a/funasr/bin/lm_inference_launch.py b/funasr/bin/lm_inference_launch.py
index 1d99fce..c8482b8 100644
--- a/funasr/bin/lm_inference_launch.py
+++ b/funasr/bin/lm_inference_launch.py
@@ -1,5 +1,5 @@
-# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
@@ -7,40 +7,25 @@
import logging
import os
import sys
-from typing import Union, Dict, Any
-
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.utils.types import float_or_none
-import argparse
-import logging
-from pathlib import Path
-import sys
-import os
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-from typing import Dict
from typing import Any
from typing import List
+from typing import Optional
+from typing import Union
import numpy as np
import torch
from torch.nn.parallel import data_parallel
from typeguard import check_argument_types
-from funasr.tasks.lm import LMTask
+from funasr.build_utils.build_model_from_file import build_model_from_file
+from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.datasets.preprocessor import LMPreprocessor
-from funasr.utils.cli_utils import get_commandline_args
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import float_or_none
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
@@ -48,42 +33,42 @@
def inference_lm(
- batch_size: int,
- dtype: str,
- ngpu: int,
- seed: int,
- num_workers: int,
- log_level: Union[int, str],
- key_file: Optional[str],
- train_config: Optional[str],
- model_file: Optional[str],
- log_base: Optional[float] = 10,
- allow_variable_data_keys: bool = False,
- split_with_space: Optional[bool] = False,
- seg_dict_file: Optional[str] = None,
- output_dir: Optional[str] = None,
- param_dict: dict = None,
- **kwargs,
+ batch_size: int,
+ dtype: str,
+ ngpu: int,
+ seed: int,
+ num_workers: int,
+ log_level: Union[int, str],
+ key_file: Optional[str],
+ train_config: Optional[str],
+ model_file: Optional[str],
+ log_base: Optional[float] = 10,
+ allow_variable_data_keys: bool = False,
+ split_with_space: Optional[bool] = False,
+ seg_dict_file: Optional[str] = None,
+ output_dir: Optional[str] = None,
+ param_dict: dict = None,
+ **kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
-
+
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
-
+
# 1. Set random-seed
set_all_random_seed(seed)
-
+
# 2. Build Model
- model, train_args = LMTask.build_model_from_file(
- train_config, model_file, device)
+ model, train_args = build_model_from_file(
+ train_config, model_file, None, device, "lm")
wrapped_model = ForwardAdaptor(model, "nll")
wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
logging.info(f"Model:\n{model}")
-
+
preprocessor = LMPreprocessor(
train=False,
token_type=train_args.token_type,
@@ -96,12 +81,12 @@
split_with_space=split_with_space,
seg_dict_file=seg_dict_file
)
-
+
def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[List[Any], bytes, str] = None,
- output_dir_v2: Optional[str] = None,
- param_dict: dict = None,
+ data_path_and_name_and_type,
+ raw_inputs: Union[List[Any], bytes, str] = None,
+ output_dir_v2: Optional[str] = None,
+ param_dict: dict = None,
):
results = []
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
@@ -109,7 +94,7 @@
writer = DatadirWriter(output_path)
else:
writer = None
-
+
if raw_inputs != None:
line = raw_inputs.strip()
key = "lm demo"
@@ -121,7 +106,7 @@
batch['text'] = line
if preprocessor != None:
batch = preprocessor(key, batch)
-
+
# Force data-precision
for name in batch:
value = batch[name]
@@ -138,11 +123,11 @@
else:
raise NotImplementedError(f"Not supported dtype: {value.dtype}")
batch[name] = value
-
+
batch["text_lengths"] = torch.from_numpy(
np.array([len(batch["text"])], dtype='int32'))
batch["text"] = np.expand_dims(batch["text"], axis=0)
-
+
with torch.no_grad():
batch = to_device(batch, device)
if ngpu <= 1:
@@ -173,7 +158,7 @@
word_nll=round(word_nll.item(), 8)
)
pre_word = cur_word
-
+
sent_nll_mean = sent_nll.mean().cpu().numpy()
sent_nll_sum = sent_nll.sum().cpu().numpy()
if log_base is None:
@@ -189,22 +174,20 @@
if writer is not None:
writer["ppl"][key + ":\n"] = ppl_out
results.append(item)
-
+
return results
-
+
# 3. Build data-iterator
- loader = LMTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="lm",
+ preprocess_args=train_args,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=preprocessor,
- collate_fn=LMTask.build_collate_fn(train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
)
-
+
# 4. Start for-loop
total_nll = 0.0
total_ntokens = 0
@@ -214,7 +197,7 @@
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-
+
ppl_out_batch = ""
with torch.no_grad():
batch = to_device(batch, device)
@@ -247,7 +230,7 @@
word_nll=round(word_nll.item(), 8)
)
pre_word = cur_word
-
+
sent_nll_mean = sent_nll.mean().cpu().numpy()
sent_nll_sum = sent_nll.sum().cpu().numpy()
if log_base is None:
@@ -265,9 +248,9 @@
writer["ppl"][key + ":\n"] = ppl_out
writer["utt2nll"][key] = str(utt2nll)
results.append(item)
-
+
ppl_out_all += ppl_out_batch
-
+
assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
# nll: (B, L) -> (B,)
nll = nll.detach().cpu().numpy().sum(1)
@@ -275,12 +258,12 @@
lengths = lengths.detach().cpu().numpy()
total_nll += nll.sum()
total_ntokens += lengths.sum()
-
+
if log_base is None:
ppl = np.exp(total_nll / total_ntokens)
else:
ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
-
+
avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format(
total_nll=round(-total_nll.item(), 4),
total_ppl=round(ppl.item(), 4)
@@ -290,9 +273,9 @@
if writer is not None:
writer["ppl"]["AVG PPL : "] = avg_ppl
results.append(item)
-
+
return results
-
+
return _forward
@@ -302,7 +285,8 @@
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None
-
+
+
def get_parser():
parser = config_argparse.ArgumentParser(
description="Calc perplexity",
@@ -407,4 +391,3 @@
if __name__ == "__main__":
main()
-
diff --git a/funasr/bin/punc_infer.py b/funasr/bin/punc_infer.py
index 4b6cd27..ac96811 100644
--- a/funasr/bin/punc_infer.py
+++ b/funasr/bin/punc_infer.py
@@ -1,46 +1,32 @@
-# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
-import argparse
-import logging
-from pathlib import Path
-import sys
from typing import Optional
-from typing import Sequence
-from typing import Tuple
from typing import Union
-from typing import Any
-from typing import List
import numpy as np
import torch
-from typeguard import check_argument_types
+from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.tasks.punctuation import PunctuationTask
+from funasr.datasets.preprocessor import split_to_mini_sentence
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.datasets.preprocessor import split_to_mini_sentence
class Text2Punc:
def __init__(
- self,
- train_config: Optional[str],
- model_file: Optional[str],
- device: str = "cpu",
- dtype: str = "float32",
+ self,
+ train_config: Optional[str],
+ model_file: Optional[str],
+ device: str = "cpu",
+ dtype: str = "float32",
):
# Build Model
- model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
+ model, train_args = build_model_from_file(train_config, model_file, None, device, task_name="punc")
self.device = device
# Wrape model to make model.nll() data-parallel
self.wrapped_model = ForwardAdaptor(model, "inference")
@@ -144,16 +130,16 @@
class Text2PuncVADRealtime:
-
+
def __init__(
- self,
- train_config: Optional[str],
- model_file: Optional[str],
- device: str = "cpu",
- dtype: str = "float32",
+ self,
+ train_config: Optional[str],
+ model_file: Optional[str],
+ device: str = "cpu",
+ dtype: str = "float32",
):
# Build Model
- model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
+ model, train_args = build_model_from_file(train_config, model_file, None, device, task_name="punc")
self.device = device
# Wrape model to make model.nll() data-parallel
self.wrapped_model = ForwardAdaptor(model, "inference")
@@ -178,7 +164,7 @@
text_name="text",
non_linguistic_symbols=train_args.non_linguistic_symbols,
)
-
+
@torch.no_grad()
def __call__(self, text: Union[list, str], cache: list, split_size=20):
if cache is not None and len(cache) > 0:
@@ -215,7 +201,7 @@
if indices.size()[0] != 1:
punctuations = torch.squeeze(indices)
assert punctuations.size()[0] == len(mini_sentence)
-
+
# Search for the last Period/QuestionMark as cache
if mini_sentence_i < len(mini_sentences) - 1:
sentenceEnd = -1
@@ -226,7 +212,7 @@
break
if last_comma_index < 0 and self.punc_list[punctuations[i]] == "锛�":
last_comma_index = i
-
+
if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
# The sentence it too long, cut off at a comma.
sentenceEnd = last_comma_index
@@ -235,11 +221,11 @@
cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
mini_sentence = mini_sentence[0:sentenceEnd + 1]
punctuations = punctuations[0:sentenceEnd + 1]
-
+
punctuations_np = punctuations.cpu().numpy()
sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
sentence_words_list += mini_sentence
-
+
assert len(sentence_punc_list) == len(sentence_words_list)
words_with_punc = []
sentence_punc_list_out = []
@@ -256,7 +242,7 @@
if sentence_punc_list[i] != "_":
words_with_punc.append(sentence_punc_list[i])
sentence_out = "".join(words_with_punc)
-
+
sentenceEnd = -1
for i in range(len(sentence_punc_list) - 2, 1, -1):
if sentence_punc_list[i] == "銆�" or sentence_punc_list[i] == "锛�":
@@ -267,5 +253,3 @@
sentence_out = sentence_out[:-1]
sentence_punc_list_out[-1] = "_"
return sentence_out, sentence_punc_list_out, cache_out
-
-
diff --git a/funasr/bin/punc_inference_launch.py b/funasr/bin/punc_inference_launch.py
index 7f60f81..8fc15f0 100755
--- a/funasr/bin/punc_inference_launch.py
+++ b/funasr/bin/punc_inference_launch.py
@@ -1,5 +1,5 @@
-# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
@@ -7,55 +7,36 @@
import logging
import os
import sys
-from typing import Union, Dict, Any
-
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.utils.types import float_or_none
-
-import argparse
-import logging
from pathlib import Path
-import sys
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
from typing import Any
from typing import List
+from typing import Optional
+from typing import Union
-import numpy as np
import torch
from typeguard import check_argument_types
-from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.tasks.punctuation import PunctuationTask
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.forward_adaptor import ForwardAdaptor
+from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
-from funasr.datasets.preprocessor import split_to_mini_sentence
-from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime
+
def inference_punc(
- batch_size: int,
- dtype: str,
- ngpu: int,
- seed: int,
- num_workers: int,
- log_level: Union[int, str],
- key_file: Optional[str],
- train_config: Optional[str],
- model_file: Optional[str],
- output_dir: Optional[str] = None,
- param_dict: dict = None,
- **kwargs,
+ batch_size: int,
+ dtype: str,
+ ngpu: int,
+ seed: int,
+ num_workers: int,
+ log_level: Union[int, str],
+ key_file: Optional[str],
+ train_config: Optional[str],
+ model_file: Optional[str],
+ output_dir: Optional[str] = None,
+ param_dict: dict = None,
+ **kwargs,
):
assert check_argument_types()
logging.basicConfig(
@@ -73,11 +54,11 @@
text2punc = Text2Punc(train_config, model_file, device)
def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[List[Any], bytes, str] = None,
- output_dir_v2: Optional[str] = None,
- cache: List[Any] = None,
- param_dict: dict = None,
+ data_path_and_name_and_type,
+ raw_inputs: Union[List[Any], bytes, str] = None,
+ output_dir_v2: Optional[str] = None,
+ cache: List[Any] = None,
+ param_dict: dict = None,
):
results = []
split_size = 20
@@ -121,20 +102,21 @@
return _forward
+
def inference_punc_vad_realtime(
- batch_size: int,
- dtype: str,
- ngpu: int,
- seed: int,
- num_workers: int,
- log_level: Union[int, str],
- #cache: list,
- key_file: Optional[str],
- train_config: Optional[str],
- model_file: Optional[str],
- output_dir: Optional[str] = None,
- param_dict: dict = None,
- **kwargs,
+ batch_size: int,
+ dtype: str,
+ ngpu: int,
+ seed: int,
+ num_workers: int,
+ log_level: Union[int, str],
+ # cache: list,
+ key_file: Optional[str],
+ train_config: Optional[str],
+ model_file: Optional[str],
+ output_dir: Optional[str] = None,
+ param_dict: dict = None,
+ **kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
@@ -150,11 +132,11 @@
text2punc = Text2PuncVADRealtime(train_config, model_file, device)
def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[List[Any], bytes, str] = None,
- output_dir_v2: Optional[str] = None,
- cache: List[Any] = None,
- param_dict: dict = None,
+ data_path_and_name_and_type,
+ raw_inputs: Union[List[Any], bytes, str] = None,
+ output_dir_v2: Optional[str] = None,
+ cache: List[Any] = None,
+ param_dict: dict = None,
):
results = []
split_size = 10
@@ -177,7 +159,6 @@
return _forward
-
def inference_launch(mode, **kwargs):
if mode == "punc":
return inference_punc(**kwargs)
@@ -186,6 +167,7 @@
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None
+
def get_parser():
parser = config_argparse.ArgumentParser(
@@ -267,7 +249,6 @@
kwargs.pop("njob", None)
inference_pipeline = inference_launch(**kwargs)
return inference_pipeline(kwargs["data_path_and_name_and_type"])
-
if __name__ == "__main__":
diff --git a/funasr/bin/sv_infer.py b/funasr/bin/sv_infer.py
index 1517bfa..6e861da 100755
--- a/funasr/bin/sv_infer.py
+++ b/funasr/bin/sv_infer.py
@@ -1,35 +1,24 @@
-# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
-import argparse
import logging
-import os
-import sys
from pathlib import Path
from typing import Any
-from typing import List
from typing import Optional
-from typing import Sequence
from typing import Tuple
from typing import Union
import numpy as np
import torch
-from kaldiio import WriteHelper
from typeguard import check_argument_types
from typeguard import check_return_type
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.tasks.sv import SVTask
+from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
from funasr.utils.misc import statistic_model_parameters
+
class Speech2Xvector:
"""Speech2Xvector class
@@ -56,10 +45,13 @@
assert check_argument_types()
# TODO: 1. Build SV model
- sv_model, sv_train_args = SVTask.build_model_from_file(
+ sv_model, sv_train_args = build_model_from_file(
config_file=sv_train_config,
model_file=sv_model_file,
- device=device
+ cmvn_file=None,
+ device=device,
+ task_name="sv",
+ mode="sv",
)
logging.info("sv_model: {}".format(sv_model))
logging.info("model parameter number: {}".format(statistic_model_parameters(sv_model)))
@@ -157,7 +149,3 @@
kwargs.update(**d.download_and_unpack(model_tag))
return Speech2Xvector(**kwargs)
-
-
-
-
diff --git a/funasr/bin/sv_inference_launch.py b/funasr/bin/sv_inference_launch.py
index dbddd9f..d165736 100755
--- a/funasr/bin/sv_inference_launch.py
+++ b/funasr/bin/sv_inference_launch.py
@@ -1,5 +1,5 @@
-# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
@@ -7,20 +7,6 @@
import logging
import os
import sys
-from typing import Union, Dict, Any
-
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-import argparse
-import logging
-import os
-import sys
-from pathlib import Path
-from typing import Any
-from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
@@ -30,61 +16,59 @@
import torch
from kaldiio import WriteHelper
from typeguard import check_argument_types
-from typeguard import check_return_type
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.tasks.sv import SVTask
-from funasr.torch_utils.device_funcs import to_device
+from funasr.bin.sv_infer import Speech2Xvector
+from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
-from funasr.utils.misc import statistic_model_parameters
-from funasr.bin.sv_infer import Speech2Xvector
+
def inference_sv(
- output_dir: Optional[str] = None,
- batch_size: int = 1,
- dtype: str = "float32",
- ngpu: int = 1,
- seed: int = 0,
- num_workers: int = 0,
- log_level: Union[int, str] = "INFO",
- key_file: Optional[str] = None,
- sv_train_config: Optional[str] = "sv.yaml",
- sv_model_file: Optional[str] = "sv.pb",
- model_tag: Optional[str] = None,
- allow_variable_data_keys: bool = True,
- streaming: bool = False,
- embedding_node: str = "resnet1_dense",
- sv_threshold: float = 0.9465,
- param_dict: Optional[dict] = None,
- **kwargs,
+ output_dir: Optional[str] = None,
+ batch_size: int = 1,
+ dtype: str = "float32",
+ ngpu: int = 1,
+ seed: int = 0,
+ num_workers: int = 0,
+ log_level: Union[int, str] = "INFO",
+ key_file: Optional[str] = None,
+ sv_train_config: Optional[str] = "sv.yaml",
+ sv_model_file: Optional[str] = "sv.pb",
+ model_tag: Optional[str] = None,
+ allow_variable_data_keys: bool = True,
+ streaming: bool = False,
+ embedding_node: str = "resnet1_dense",
+ sv_threshold: float = 0.9465,
+ param_dict: Optional[dict] = None,
+ **kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
-
+
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
-
+
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.info("param_dict: {}".format(param_dict))
-
+
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
-
+
# 1. Set random-seed
set_all_random_seed(seed)
-
+
# 2. Build speech2xvector
speech2xvector_kwargs = dict(
sv_train_config=sv_train_config,
@@ -100,32 +84,31 @@
**speech2xvector_kwargs,
)
speech2xvector.sv_model.eval()
-
+
def _forward(
- data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- param_dict: Optional[dict] = None,
+ data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ param_dict: Optional[dict] = None,
):
logging.info("param_dict: {}".format(param_dict))
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-
+
# 3. Build data-iterator
- loader = SVTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="sv",
+ preprocess_args=None,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=None,
- collate_fn=None,
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
+ use_collate_fn=False,
)
-
+
# 7 .Start for-loop
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
embd_writer, ref_embd_writer, score_writer = None, None, None
@@ -139,7 +122,7 @@
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
+
embedding, ref_embedding, score = speech2xvector(**batch)
# Only supporting batch_size==1
key = keys[0]
@@ -161,18 +144,16 @@
score_writer = open(os.path.join(output_path, "score.txt"), "w")
ref_embd_writer(key, ref_embedding[0].cpu().numpy())
score_writer.write("{} {:.6f}\n".format(key, normalized_score))
-
+
if output_path is not None:
embd_writer.close()
if ref_embd_writer is not None:
ref_embd_writer.close()
score_writer.close()
-
+
return sv_result_list
-
+
return _forward
-
-
def inference_launch(mode, **kwargs):
@@ -182,6 +163,7 @@
logging.info("Unknown decoding mode: {}".format(mode))
return None
+
def get_parser():
parser = config_argparse.ArgumentParser(
description="Speaker Verification",
diff --git a/funasr/bin/tp_infer.py b/funasr/bin/tp_infer.py
index 4ddcba4..213c018 100644
--- a/funasr/bin/tp_infer.py
+++ b/funasr/bin/tp_infer.py
@@ -1,57 +1,35 @@
-# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
-import argparse
import logging
-from optparse import Option
-import sys
-import json
from pathlib import Path
-from typing import Any
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
from typing import Union
-from typing import Dict
import numpy as np
import torch
from typeguard import check_argument_types
-
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.datasets.preprocessor import LMPreprocessor
-from funasr.tasks.asr import ASRTaskAligner as ASRTask
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
+from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.text.token_id_converter import TokenIDConverter
-from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
-
-
+from funasr.torch_utils.device_funcs import to_device
class Speech2Timestamp:
def __init__(
- self,
- timestamp_infer_config: Union[Path, str] = None,
- timestamp_model_file: Union[Path, str] = None,
- timestamp_cmvn_file: Union[Path, str] = None,
- device: str = "cpu",
- dtype: str = "float32",
- **kwargs,
+ self,
+ timestamp_infer_config: Union[Path, str] = None,
+ timestamp_model_file: Union[Path, str] = None,
+ timestamp_cmvn_file: Union[Path, str] = None,
+ device: str = "cpu",
+ dtype: str = "float32",
+ **kwargs,
):
assert check_argument_types()
# 1. Build ASR model
- tp_model, tp_train_args = ASRTask.build_model_from_file(
- timestamp_infer_config, timestamp_model_file, device=device
+ tp_model, tp_train_args = build_model_from_file(
+ timestamp_infer_config, timestamp_model_file, cmvn_file=None, device=device, task_name="asr", mode="tp"
)
if 'cuda' in device:
tp_model = tp_model.cuda() # force model to cuda
@@ -59,13 +37,12 @@
frontend = None
if tp_train_args.frontend is not None:
frontend = WavFrontend(cmvn_file=timestamp_cmvn_file, **tp_train_args.frontend_conf)
-
+
logging.info("tp_model: {}".format(tp_model))
logging.info("tp_train_args: {}".format(tp_train_args))
tp_model.to(dtype=getattr(torch, dtype)).eval()
logging.info(f"Decoding device={device}, dtype={dtype}")
-
self.tp_model = tp_model
self.tp_train_args = tp_train_args
@@ -79,13 +56,13 @@
self.encoder_downsampling_factor = 1
if tp_train_args.encoder_conf["input_layer"] == "conv2d":
self.encoder_downsampling_factor = 4
-
+
@torch.no_grad()
def __call__(
- self,
- speech: Union[torch.Tensor, np.ndarray],
- speech_lengths: Union[torch.Tensor, np.ndarray] = None,
- text_lengths: Union[torch.Tensor, np.ndarray] = None
+ self,
+ speech: Union[torch.Tensor, np.ndarray],
+ speech_lengths: Union[torch.Tensor, np.ndarray] = None,
+ text_lengths: Union[torch.Tensor, np.ndarray] = None
):
assert check_argument_types()
@@ -113,8 +90,6 @@
enc = enc[0]
# c. Forward Predictor
- _, _, us_alphas, us_peaks = self.tp_model.calc_predictor_timestamp(enc, enc_len, text_lengths.to(self.device)+1)
+ _, _, us_alphas, us_peaks = self.tp_model.calc_predictor_timestamp(enc, enc_len,
+ text_lengths.to(self.device) + 1)
return us_alphas, us_peaks
-
-
-
diff --git a/funasr/bin/tp_inference_launch.py b/funasr/bin/tp_inference_launch.py
index a8d67ef..3f8df0c 100644
--- a/funasr/bin/tp_inference_launch.py
+++ b/funasr/bin/tp_inference_launch.py
@@ -1,5 +1,5 @@
-# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
@@ -8,87 +8,66 @@
import logging
import os
import sys
-from typing import Union, Dict, Any
-
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-
-import argparse
-import logging
-from optparse import Option
-import sys
-import json
-from pathlib import Path
-from typing import Any
-from typing import List
from typing import Optional
-from typing import Sequence
-from typing import Tuple
from typing import Union
-from typing import Dict
import numpy as np
import torch
from typeguard import check_argument_types
-from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.bin.tp_infer import Speech2Timestamp
+from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.datasets.preprocessor import LMPreprocessor
-from funasr.tasks.asr import ASRTaskAligner as ASRTask
-from funasr.torch_utils.device_funcs import to_device
+from funasr.fileio.datadir_writer import DatadirWriter
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
-from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.text.token_id_converter import TokenIDConverter
-from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
-from funasr.bin.tp_infer import Speech2Timestamp
+
def inference_tp(
- batch_size: int,
- ngpu: int,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- timestamp_infer_config: Optional[str],
- timestamp_model_file: Optional[str],
- timestamp_cmvn_file: Optional[str] = None,
- # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- key_file: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- num_workers: int = 1,
- split_with_space: bool = True,
- seg_dict_file: Optional[str] = None,
- **kwargs,
+ batch_size: int,
+ ngpu: int,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ timestamp_infer_config: Optional[str],
+ timestamp_model_file: Optional[str],
+ timestamp_cmvn_file: Optional[str] = None,
+ # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ key_file: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ num_workers: int = 1,
+ split_with_space: bool = True,
+ seg_dict_file: Optional[str] = None,
+ **kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
-
+
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
-
+
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
-
+
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# 1. Set random-seed
set_all_random_seed(seed)
-
+
# 2. Build speech2vadsegment
speechtext2timestamp_kwargs = dict(
timestamp_infer_config=timestamp_infer_config,
@@ -99,7 +78,7 @@
)
logging.info("speechtext2timestamp_kwargs: {}".format(speechtext2timestamp_kwargs))
speechtext2timestamp = Speech2Timestamp(**speechtext2timestamp_kwargs)
-
+
preprocessor = LMPreprocessor(
train=False,
token_type=speechtext2timestamp.tp_train_args.token_type,
@@ -112,21 +91,21 @@
split_with_space=split_with_space,
seg_dict_file=seg_dict_file,
)
-
+
if output_dir is not None:
writer = DatadirWriter(output_dir)
tp_writer = writer[f"timestamp_prediction"]
# ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
else:
tp_writer = None
-
+
def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- **kwargs
+ data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ **kwargs
):
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
writer = None
@@ -140,32 +119,31 @@
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
+
+ loader = build_streaming_iterator(
+ task_name="asr",
+ preprocess_args=speechtext2timestamp.tp_train_args,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=preprocessor,
- collate_fn=ASRTask.build_collate_fn(speechtext2timestamp.tp_train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
)
-
+
tp_result_list = []
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-
+
logging.info("timestamp predicting, utt_id: {}".format(keys))
_batch = {'speech': batch['speech'],
'speech_lengths': batch['speech_lengths'],
'text_lengths': batch['text_lengths']}
us_alphas, us_cif_peak = speechtext2timestamp(**_batch)
-
+
for batch_id in range(_bs):
key = keys[batch_id]
token = speechtext2timestamp.converter.ids2tokens(batch['text'][batch_id])
@@ -178,10 +156,8 @@
tp_writer["tp_time"][key + '#'] = str(ts_list)
tp_result_list.append(item)
return tp_result_list
-
+
return _forward
-
-
def inference_launch(mode, **kwargs):
@@ -190,6 +166,7 @@
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None
+
def get_parser():
parser = config_argparse.ArgumentParser(
@@ -306,7 +283,6 @@
inference_pipeline = inference_launch(**kwargs)
return inference_pipeline(kwargs["data_path_and_name_and_type"])
-
if __name__ == "__main__":
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 961bff9..1dc3fb5 100755
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -1,4 +1,6 @@
#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
import argparse
import logging
diff --git a/funasr/bin/vad_infer.py b/funasr/bin/vad_infer.py
index e1698d0..f888bb4 100644
--- a/funasr/bin/vad_infer.py
+++ b/funasr/bin/vad_infer.py
@@ -1,42 +1,23 @@
-# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
-import argparse
import logging
-import os
-import sys
-import json
+import math
from pathlib import Path
-from typing import Any
+from typing import Dict
from typing import List
-from typing import Optional
-from typing import Sequence
from typing import Tuple
from typing import Union
-from typing import Dict
-import math
import numpy as np
import torch
from typeguard import check_argument_types
-from typeguard import check_return_type
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.modules.scorers.scorer_interface import BatchScorerInterface
-from funasr.modules.subsampling import TooShortUttError
-from funasr.tasks.vad import VADTask
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.utils import asr_utils, wav_utils, postprocess_utils
+from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
-
+from funasr.torch_utils.device_funcs import to_device
class Speech2VadSegment:
@@ -64,8 +45,8 @@
assert check_argument_types()
# 1. Build vad model
- vad_model, vad_infer_args = VADTask.build_model_from_file(
- vad_infer_config, vad_model_file, device
+ vad_model, vad_infer_args = build_model_from_file(
+ vad_infer_config, vad_model_file, None, device, task_name="vad"
)
frontend = None
if vad_infer_args.frontend is not None:
@@ -128,12 +109,13 @@
"in_cache": in_cache
}
# a. To device
- #batch = to_device(batch, device=self.device)
+ # batch = to_device(batch, device=self.device)
segments_part, in_cache = self.vad_model(**batch)
if segments_part:
for batch_num in range(0, self.batch_size):
segments[batch_num] += segments_part[batch_num]
return fbanks, segments
+
class Speech2VadSegmentOnline(Speech2VadSegment):
"""Speech2VadSegmentOnline class
@@ -146,13 +128,13 @@
[[10, 230], [245, 450], ...]
"""
+
def __init__(self, **kwargs):
super(Speech2VadSegmentOnline, self).__init__(**kwargs)
vad_cmvn_file = kwargs.get('vad_cmvn_file', None)
self.frontend = None
if self.vad_infer_args.frontend is not None:
self.frontend = WavFrontendOnline(cmvn_file=vad_cmvn_file, **self.vad_infer_args.frontend_conf)
-
@torch.no_grad()
def __call__(
@@ -198,5 +180,3 @@
# in_cache.update(batch['in_cache'])
# in_cache = {key: value for key, value in batch['in_cache'].items()}
return fbanks, segments, in_cache
-
-
diff --git a/funasr/bin/vad_inference_launch.py b/funasr/bin/vad_inference_launch.py
index b17d058..829f157 100644
--- a/funasr/bin/vad_inference_launch.py
+++ b/funasr/bin/vad_inference_launch.py
@@ -1,58 +1,34 @@
-# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import torch
+
torch.set_num_threads(1)
import argparse
import logging
import os
import sys
-from typing import Union, Dict, Any
-
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-
-import argparse
-import logging
-import os
-import sys
import json
-from pathlib import Path
-from typing import Any
-from typing import List
from typing import Optional
-from typing import Sequence
-from typing import Tuple
from typing import Union
-from typing import Dict
-import math
import numpy as np
import torch
from typeguard import check_argument_types
-from typeguard import check_return_type
-
+from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.modules.scorers.scorer_interface import BatchScorerInterface
-from funasr.modules.subsampling import TooShortUttError
-from funasr.tasks.vad import VADTask
-from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
-from funasr.utils import asr_utils, wav_utils, postprocess_utils
-from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
from funasr.bin.vad_infer import Speech2VadSegment, Speech2VadSegmentOnline
+
def inference_vad(
batch_size: int,
@@ -74,7 +50,6 @@
assert check_argument_types()
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
-
logging.basicConfig(
level=log_level,
@@ -112,16 +87,14 @@
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = VADTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="vad",
+ preprocess_args=None,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
- collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
)
finish_count = 0
@@ -157,6 +130,7 @@
return _forward
+
def inference_vad_online(
batch_size: int,
ngpu: int,
@@ -175,7 +149,6 @@
**kwargs,
):
assert check_argument_types()
-
logging.basicConfig(
level=log_level,
@@ -214,16 +187,14 @@
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = VADTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="vad",
+ preprocess_args=None,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
- collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
)
finish_count = 0
@@ -273,8 +244,6 @@
return _forward
-
-
def inference_launch(mode, **kwargs):
if mode == "offline":
return inference_vad(**kwargs)
@@ -283,6 +252,7 @@
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None
+
def get_parser():
parser = config_argparse.ArgumentParser(
@@ -405,5 +375,6 @@
inference_pipeline = inference_launch(**kwargs)
return inference_pipeline(kwargs["data_path_and_name_and_type"])
+
if __name__ == "__main__":
main()
diff --git a/funasr/build_utils/build_args.py b/funasr/build_utils/build_args.py
index 517c85b..632c134 100644
--- a/funasr/build_utils/build_args.py
+++ b/funasr/build_utils/build_args.py
@@ -41,7 +41,7 @@
"--cmvn_file",
type=str_or_none,
default=None,
- help="The file path of noise scp file.",
+ help="The path of cmvn file.",
)
elif args.task_name == "pretrain":
@@ -75,12 +75,29 @@
default=None,
help="The number of input dimension of the feature",
)
+ task_parser.add_argument(
+ "--cmvn_file",
+ type=str_or_none,
+ default=None,
+ help="The path of cmvn file.",
+ )
elif args.task_name == "diar":
from funasr.build_utils.build_diar_model import class_choices_list
for class_choices in class_choices_list:
class_choices.add_arguments(task_parser)
+ elif args.task_name == "sv":
+ from funasr.build_utils.build_sv_model import class_choices_list
+ for class_choices in class_choices_list:
+ class_choices.add_arguments(task_parser)
+ task_parser.add_argument(
+ "--input_size",
+ type=int_or_none,
+ default=None,
+ help="The number of input dimension of the feature",
+ )
+
else:
raise NotImplementedError("Not supported task: {}".format(args.task_name))
diff --git a/funasr/build_utils/build_asr_model.py b/funasr/build_utils/build_asr_model.py
index 7aa8111..200395d 100644
--- a/funasr/build_utils/build_asr_model.py
+++ b/funasr/build_utils/build_asr_model.py
@@ -20,15 +20,18 @@
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
-from funasr.models.joint_net.joint_network import JointNetwork
from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
from funasr.models.e2e_asr import ASRModel
+from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_asr_mfcca import MFCCA
+
+from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
+
from funasr.models.e2e_sa_asr import SAASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
+
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_uni_asr import UniASR
-from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
@@ -42,6 +45,7 @@
from funasr.models.frontend.s3prl import S3prlFrontend
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.joint_net.joint_network import JointNetwork
from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3
from funasr.models.specaug.specaug import SpecAug
from funasr.models.specaug.specaug import SpecAugLFR
@@ -89,6 +93,7 @@
paraformer_bert=ParaformerBert,
bicif_paraformer=BiCifParaformer,
contextual_paraformer=ContextualParaformer,
+ neatcontextual_paraformer=NeatContextualParaformer,
mfcca=MFCCA,
timestamp_prediction=TimestampPredictor,
rnnt=TransducerModel,
@@ -258,17 +263,22 @@
def build_asr_model(args):
# token_list
- if args.token_list is not None:
- with open(args.token_list) as f:
+ if isinstance(args.token_list, str):
+ with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
args.token_list = list(token_list)
vocab_size = len(token_list)
logging.info(f"Vocabulary size: {vocab_size}")
+ elif isinstance(args.token_list, (tuple, list)):
+ token_list = list(args.token_list)
+ vocab_size = len(token_list)
+ logging.info(f"Vocabulary size: {vocab_size}")
else:
+ token_list = None
vocab_size = None
# frontend
- if args.input_size is None:
+ if hasattr(args, "input_size") and args.input_size is None:
frontend_class = frontend_choices.get_class(args.frontend)
if args.frontend == 'wav_frontend' or args.frontend == 'multichannelfrontend':
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
@@ -279,7 +289,7 @@
args.frontend = None
args.frontend_conf = {}
frontend = None
- input_size = args.input_size
+ input_size = args.input_size if hasattr(args, "input_size") else None
# data augmentation for spectrogram
if args.specaug is not None:
@@ -291,7 +301,10 @@
# normalization layer
if args.normalize is not None:
normalize_class = normalize_choices.get_class(args.normalize)
- normalize = normalize_class(**args.normalize_conf)
+ if args.model == "mfcca":
+ normalize = normalize_class(stats_file=args.cmvn_file, **args.normalize_conf)
+ else:
+ normalize = normalize_class(**args.normalize_conf)
else:
normalize = None
@@ -325,7 +338,8 @@
token_list=token_list,
**args.model_conf,
)
- elif args.model in ["paraformer", "paraformer_online", "paraformer_bert", "bicif_paraformer", "contextual_paraformer"]:
+ elif args.model in ["paraformer", "paraformer_online", "paraformer_bert", "bicif_paraformer",
+ "contextual_paraformer", "neatcontextual_paraformer"]:
# predictor
predictor_class = predictor_choices.get_class(args.predictor)
predictor = predictor_class(**args.predictor_conf)
diff --git a/funasr/build_utils/build_diar_model.py b/funasr/build_utils/build_diar_model.py
index 6406404..0ea3127 100644
--- a/funasr/build_utils/build_diar_model.py
+++ b/funasr/build_utils/build_diar_model.py
@@ -178,14 +178,18 @@
def build_diar_model(args):
# token_list
- if args.token_list is not None:
- with open(args.token_list) as f:
+ if isinstance(args.token_list, str):
+ with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
+
+ # Overwriting token_list to keep it as "portable".
args.token_list = list(token_list)
- vocab_size = len(token_list)
- logging.info(f"Vocabulary size: {vocab_size}")
+ elif isinstance(args.token_list, (tuple, list)):
+ token_list = list(args.token_list)
else:
- vocab_size = None
+ raise RuntimeError("token_list must be str or list")
+ vocab_size = len(token_list)
+ logging.info(f"Vocabulary size: {vocab_size}")
# frontend
if args.input_size is None:
@@ -205,7 +209,7 @@
encoder_class = encoder_choices.get_class(args.encoder)
encoder = encoder_class(input_size=input_size, **args.encoder_conf)
- if args.model_name == "sond":
+ if args.model == "sond":
# data augmentation for spectrogram
if args.specaug is not None:
specaug_class = specaug_choices.get_class(args.specaug)
@@ -243,11 +247,7 @@
# decoder
decoder_class = decoder_choices.get_class(args.decoder)
- decoder = decoder_class(
- vocab_size=vocab_size,
- encoder_output_size=encoder.output_size(),
- **args.decoder_conf,
- )
+ decoder = decoder_class(**args.decoder_conf)
# logger aggregator
if getattr(args, "label_aggregator", None) is not None:
diff --git a/funasr/build_utils/build_lm_model.py b/funasr/build_utils/build_lm_model.py
index 8f4a958..f78a20e 100644
--- a/funasr/build_utils/build_lm_model.py
+++ b/funasr/build_utils/build_lm_model.py
@@ -34,10 +34,14 @@
def build_lm_model(args):
# token_list
- if args.token_list is not None:
- with open(args.token_list) as f:
+ if isinstance(args.token_list, str):
+ with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
args.token_list = list(token_list)
+ vocab_size = len(token_list)
+ logging.info(f"Vocabulary size: {vocab_size}")
+ elif isinstance(args.token_list, (tuple, list)):
+ token_list = list(args.token_list)
vocab_size = len(token_list)
logging.info(f"Vocabulary size: {vocab_size}")
else:
@@ -47,6 +51,7 @@
lm_class = lm_choices.get_class(args.lm)
lm = lm_class(vocab_size=vocab_size, **args.lm_conf)
+ args.model = args.model if hasattr(args, "model") else "lm"
model_class = model_choices.get_class(args.model)
model = model_class(lm=lm, vocab_size=vocab_size, **args.model_conf)
diff --git a/funasr/build_utils/build_model.py b/funasr/build_utils/build_model.py
index 13a6faa..be8f910 100644
--- a/funasr/build_utils/build_model.py
+++ b/funasr/build_utils/build_model.py
@@ -1,9 +1,10 @@
from funasr.build_utils.build_asr_model import build_asr_model
+from funasr.build_utils.build_diar_model import build_diar_model
from funasr.build_utils.build_lm_model import build_lm_model
from funasr.build_utils.build_pretrain_model import build_pretrain_model
from funasr.build_utils.build_punc_model import build_punc_model
+from funasr.build_utils.build_sv_model import build_sv_model
from funasr.build_utils.build_vad_model import build_vad_model
-from funasr.build_utils.build_diar_model import build_diar_model
def build_model(args):
@@ -19,6 +20,8 @@
model = build_vad_model(args)
elif args.task_name == "diar":
model = build_diar_model(args)
+ elif args.task_name == "sv":
+ model = build_sv_model(args)
else:
raise NotImplementedError("Not supported task: {}".format(args.task_name))
diff --git a/funasr/build_utils/build_model_from_file.py b/funasr/build_utils/build_model_from_file.py
new file mode 100644
index 0000000..8fd4e46
--- /dev/null
+++ b/funasr/build_utils/build_model_from_file.py
@@ -0,0 +1,193 @@
+import argparse
+import logging
+import os
+from pathlib import Path
+from typing import Union
+
+import torch
+import yaml
+from typeguard import check_argument_types
+
+from funasr.build_utils.build_model import build_model
+from funasr.models.base_model import FunASRModel
+
+
+def build_model_from_file(
+ config_file: Union[Path, str] = None,
+ model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ device: str = "cpu",
+ task_name: str = "asr",
+ mode: str = "paraformer",
+):
+ """Build model from the files.
+
+ This method is used for inference or fine-tuning.
+
+ Args:
+ config_file: The yaml file saved when training.
+ model_file: The model file saved when training.
+ device: Device type, "cpu", "cuda", or "cuda:N".
+
+ """
+ assert check_argument_types()
+ if config_file is None:
+ assert model_file is not None, (
+ "The argument 'model_file' must be provided "
+ "if the argument 'config_file' is not specified."
+ )
+ config_file = Path(model_file).parent / "config.yaml"
+ else:
+ config_file = Path(config_file)
+
+ with config_file.open("r", encoding="utf-8") as f:
+ args = yaml.safe_load(f)
+ if cmvn_file is not None:
+ args["cmvn_file"] = cmvn_file
+ args = argparse.Namespace(**args)
+ args.task_name = task_name
+ model = build_model(args)
+ if not isinstance(model, FunASRModel):
+ raise RuntimeError(
+ f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
+ )
+ model.to(device)
+ model_dict = dict()
+ model_name_pth = None
+ if model_file is not None:
+ logging.info("model_file is {}".format(model_file))
+ if device == "cuda":
+ device = f"cuda:{torch.cuda.current_device()}"
+ model_dir = os.path.dirname(model_file)
+ model_name = os.path.basename(model_file)
+ if "model.ckpt-" in model_name or ".bin" in model_name:
+ model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
+ '.pb')) if ".bin" in model_name else os.path.join(
+ model_dir, "{}.pb".format(model_name))
+ if os.path.exists(model_name_pth):
+ logging.info("model_file is load from pth: {}".format(model_name_pth))
+ model_dict = torch.load(model_name_pth, map_location=device)
+ else:
+ model_dict = convert_tf2torch(model, model_file, mode)
+ model.load_state_dict(model_dict)
+ else:
+ model_dict = torch.load(model_file, map_location=device)
+ if task_name == "diar" and mode == "sond":
+ model_dict = fileter_model_dict(model_dict, model.state_dict())
+ if task_name == "vad":
+ model.encoder.load_state_dict(model_dict)
+ else:
+ model.load_state_dict(model_dict)
+ if model_name_pth is not None and not os.path.exists(model_name_pth):
+ torch.save(model_dict, model_name_pth)
+ logging.info("model_file is saved to pth: {}".format(model_name_pth))
+
+ return model, args
+
+
+def convert_tf2torch(
+ model,
+ ckpt,
+ mode,
+):
+ assert mode == "paraformer" or mode == "uniasr" or mode == "sond" or mode == "sv" or mode == "tp"
+ logging.info("start convert tf model to torch model")
+ from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
+ var_dict_tf = load_tf_dict(ckpt)
+ var_dict_torch = model.state_dict()
+ var_dict_torch_update = dict()
+ if mode == "uniasr":
+ # encoder
+ var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # predictor
+ var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # decoder
+ var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # encoder2
+ var_dict_torch_update_local = model.encoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # predictor2
+ var_dict_torch_update_local = model.predictor2.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # decoder2
+ var_dict_torch_update_local = model.decoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # stride_conv
+ var_dict_torch_update_local = model.stride_conv.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ elif mode == "paraformer":
+ # encoder
+ var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # predictor
+ var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # decoder
+ var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # bias_encoder
+ var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ elif "mode" == "sond":
+ if model.encoder is not None:
+ var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # speaker encoder
+ if model.speaker_encoder is not None:
+ var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # cd scorer
+ if model.cd_scorer is not None:
+ var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # ci scorer
+ if model.ci_scorer is not None:
+ var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # decoder
+ if model.decoder is not None:
+ var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ elif "mode" == "sv":
+ # speech encoder
+ var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # pooling layer
+ var_dict_torch_update_local = model.pooling_layer.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # decoder
+ var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ else:
+ # encoder
+ var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # predictor
+ var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # decoder
+ var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # bias_encoder
+ var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ return var_dict_torch_update
+
+ return var_dict_torch_update
+
+
+def fileter_model_dict(src_dict: dict, dest_dict: dict):
+ from collections import OrderedDict
+ new_dict = OrderedDict()
+ for key, value in src_dict.items():
+ if key in dest_dict:
+ new_dict[key] = value
+ else:
+ logging.info("{} is no longer needed in this model.".format(key))
+ for key, value in dest_dict.items():
+ if key not in new_dict:
+ logging.warning("{} is missed in checkpoint.".format(key))
+ return new_dict
diff --git a/funasr/build_utils/build_streaming_iterator.py b/funasr/build_utils/build_streaming_iterator.py
new file mode 100644
index 0000000..1b16cf4
--- /dev/null
+++ b/funasr/build_utils/build_streaming_iterator.py
@@ -0,0 +1,67 @@
+import numpy as np
+from torch.utils.data import DataLoader
+from typeguard import check_argument_types
+
+from funasr.datasets.iterable_dataset import IterableESPnetDataset
+from funasr.datasets.small_datasets.collate_fn import CommonCollateFn
+from funasr.datasets.small_datasets.preprocessor import build_preprocess
+
+
+def build_streaming_iterator(
+ task_name,
+ preprocess_args,
+ data_path_and_name_and_type,
+ key_file: str = None,
+ batch_size: int = 1,
+ fs: dict = None,
+ mc: bool = False,
+ dtype: str = np.float32,
+ num_workers: int = 1,
+ use_collate_fn: bool = True,
+ preprocess_fn=None,
+ ngpu: int = 0,
+ train: bool = False,
+) -> DataLoader:
+ """Build DataLoader using iterable dataset"""
+ assert check_argument_types()
+
+ # preprocess
+ if preprocess_fn is not None:
+ preprocess_fn = preprocess_fn
+ elif preprocess_args is not None:
+ preprocess_args.task_name = task_name
+ preprocess_fn = build_preprocess(preprocess_args, train)
+ else:
+ preprocess_fn = None
+
+ # collate
+ if not use_collate_fn:
+ collate_fn = None
+ elif task_name in ["punc", "lm"]:
+ collate_fn = CommonCollateFn(int_pad_value=0)
+ else:
+ collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
+ if collate_fn is not None:
+ kwargs = dict(collate_fn=collate_fn)
+ else:
+ kwargs = {}
+
+ dataset = IterableESPnetDataset(
+ data_path_and_name_and_type,
+ float_dtype=dtype,
+ fs=fs,
+ mc=mc,
+ preprocess=preprocess_fn,
+ key_file=key_file,
+ )
+ if dataset.apply_utt2category:
+ kwargs.update(batch_size=1)
+ else:
+ kwargs.update(batch_size=batch_size)
+
+ return DataLoader(
+ dataset=dataset,
+ pin_memory=ngpu > 0,
+ num_workers=num_workers,
+ **kwargs,
+ )
diff --git a/funasr/build_utils/build_sv_model.py b/funasr/build_utils/build_sv_model.py
new file mode 100644
index 0000000..c0f1ae8
--- /dev/null
+++ b/funasr/build_utils/build_sv_model.py
@@ -0,0 +1,258 @@
+import logging
+
+import torch
+from typeguard import check_return_type
+
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.layers.global_mvn import GlobalMVN
+from funasr.layers.utterance_mvn import UtteranceMVN
+from funasr.models.base_model import FunASRModel
+from funasr.models.decoder.abs_decoder import AbsDecoder
+from funasr.models.decoder.sv_decoder import DenseDecoder
+from funasr.models.e2e_sv import ESPnetSVModel
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.encoder.resnet34_encoder import ResNet34, ResNet34_SP_L2Reg
+from funasr.models.encoder.rnn_encoder import RNNEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.fused import FusedFrontends
+from funasr.models.frontend.s3prl import S3prlFrontend
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.pooling.statistic_pooling import StatisticPooling
+from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
+from funasr.models.postencoder.hugging_face_transformers_postencoder import (
+ HuggingFaceTransformersPostEncoder, # noqa: H301
+)
+from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.preencoder.linear import LinearProjection
+from funasr.models.preencoder.sinc import LightweightSincConvs
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models.specaug.specaug import SpecAug
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+
+frontend_choices = ClassChoices(
+ name="frontend",
+ classes=dict(
+ default=DefaultFrontend,
+ sliding_window=SlidingWindow,
+ s3prl=S3prlFrontend,
+ fused=FusedFrontends,
+ wav_frontend=WavFrontend,
+ ),
+ type_check=AbsFrontend,
+ default="default",
+)
+specaug_choices = ClassChoices(
+ name="specaug",
+ classes=dict(
+ specaug=SpecAug,
+ ),
+ type_check=AbsSpecAug,
+ default=None,
+ optional=True,
+)
+normalize_choices = ClassChoices(
+ "normalize",
+ classes=dict(
+ global_mvn=GlobalMVN,
+ utterance_mvn=UtteranceMVN,
+ ),
+ type_check=AbsNormalize,
+ default=None,
+ optional=True,
+)
+model_choices = ClassChoices(
+ "model",
+ classes=dict(
+ espnet=ESPnetSVModel,
+ ),
+ type_check=FunASRModel,
+ default="espnet",
+)
+preencoder_choices = ClassChoices(
+ name="preencoder",
+ classes=dict(
+ sinc=LightweightSincConvs,
+ linear=LinearProjection,
+ ),
+ type_check=AbsPreEncoder,
+ default=None,
+ optional=True,
+)
+encoder_choices = ClassChoices(
+ "encoder",
+ classes=dict(
+ resnet34=ResNet34,
+ resnet34_sp_l2reg=ResNet34_SP_L2Reg,
+ rnn=RNNEncoder,
+ ),
+ type_check=AbsEncoder,
+ default="resnet34",
+)
+postencoder_choices = ClassChoices(
+ name="postencoder",
+ classes=dict(
+ hugging_face_transformers=HuggingFaceTransformersPostEncoder,
+ ),
+ type_check=AbsPostEncoder,
+ default=None,
+ optional=True,
+)
+pooling_choices = ClassChoices(
+ name="pooling_type",
+ classes=dict(
+ statistic=StatisticPooling,
+ ),
+ type_check=torch.nn.Module,
+ default="statistic",
+)
+decoder_choices = ClassChoices(
+ "decoder",
+ classes=dict(
+ dense=DenseDecoder,
+ ),
+ type_check=AbsDecoder,
+ default="dense",
+)
+
+class_choices_list = [
+ # --frontend and --frontend_conf
+ frontend_choices,
+ # --specaug and --specaug_conf
+ specaug_choices,
+ # --normalize and --normalize_conf
+ normalize_choices,
+ # --model and --model_conf
+ model_choices,
+ # --preencoder and --preencoder_conf
+ preencoder_choices,
+ # --encoder and --encoder_conf
+ encoder_choices,
+ # --postencoder and --postencoder_conf
+ postencoder_choices,
+ # --pooling and --pooling_conf
+ pooling_choices,
+ # --decoder and --decoder_conf
+ decoder_choices,
+]
+
+
+def build_sv_model(args):
+ # token_list
+ if isinstance(args.token_list, str):
+ with open(args.token_list, encoding="utf-8") as f:
+ token_list = [line.rstrip() for line in f]
+
+ # Overwriting token_list to keep it as "portable".
+ args.token_list = list(token_list)
+ elif isinstance(args.token_list, (tuple, list)):
+ token_list = list(args.token_list)
+ else:
+ raise RuntimeError("token_list must be str or list")
+ vocab_size = len(token_list)
+ logging.info(f"Speaker number: {vocab_size}")
+
+ # 1. frontend
+ if args.input_size is None:
+ # Extract features in the model
+ frontend_class = frontend_choices.get_class(args.frontend)
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ # Give features from data-loader
+ args.frontend = None
+ args.frontend_conf = {}
+ frontend = None
+ input_size = args.input_size
+
+ # 2. Data augmentation for spectrogram
+ if args.specaug is not None:
+ specaug_class = specaug_choices.get_class(args.specaug)
+ specaug = specaug_class(**args.specaug_conf)
+ else:
+ specaug = None
+
+ # 3. Normalization layer
+ if args.normalize is not None:
+ normalize_class = normalize_choices.get_class(args.normalize)
+ normalize = normalize_class(**args.normalize_conf)
+ else:
+ normalize = None
+
+ # 4. Pre-encoder input block
+ # NOTE(kan-bayashi): Use getattr to keep the compatibility
+ if getattr(args, "preencoder", None) is not None:
+ preencoder_class = preencoder_choices.get_class(args.preencoder)
+ preencoder = preencoder_class(**args.preencoder_conf)
+ input_size = preencoder.output_size()
+ else:
+ preencoder = None
+
+ # 5. Encoder
+ encoder_class = encoder_choices.get_class(args.encoder)
+ encoder = encoder_class(input_size=input_size, **args.encoder_conf)
+
+ # 6. Post-encoder block
+ # NOTE(kan-bayashi): Use getattr to keep the compatibility
+ encoder_output_size = encoder.output_size()
+ if getattr(args, "postencoder", None) is not None:
+ postencoder_class = postencoder_choices.get_class(args.postencoder)
+ postencoder = postencoder_class(
+ input_size=encoder_output_size, **args.postencoder_conf
+ )
+ encoder_output_size = postencoder.output_size()
+ else:
+ postencoder = None
+
+ # 7. Pooling layer
+ pooling_class = pooling_choices.get_class(args.pooling_type)
+ pooling_dim = (2, 3)
+ eps = 1e-12
+ if hasattr(args, "pooling_type_conf"):
+ if "pooling_dim" in args.pooling_type_conf:
+ pooling_dim = args.pooling_type_conf["pooling_dim"]
+ if "eps" in args.pooling_type_conf:
+ eps = args.pooling_type_conf["eps"]
+ pooling_layer = pooling_class(
+ pooling_dim=pooling_dim,
+ eps=eps,
+ )
+ if args.pooling_type == "statistic":
+ encoder_output_size *= 2
+
+ # 8. Decoder
+ decoder_class = decoder_choices.get_class(args.decoder)
+ decoder = decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ **args.decoder_conf,
+ )
+
+ # 7. Build model
+ try:
+ model_class = model_choices.get_class(args.model)
+ except AttributeError:
+ model_class = model_choices.get_class("espnet")
+ model = model_class(
+ vocab_size=vocab_size,
+ token_list=token_list,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ preencoder=preencoder,
+ encoder=encoder,
+ postencoder=postencoder,
+ pooling_layer=pooling_layer,
+ decoder=decoder,
+ **args.model_conf,
+ )
+
+ # FIXME(kamo): Should be done in model?
+ # 8. Initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ assert check_return_type(model)
+ return model
diff --git a/funasr/build_utils/build_vad_model.py b/funasr/build_utils/build_vad_model.py
index 76eb09b..6a840cf 100644
--- a/funasr/build_utils/build_vad_model.py
+++ b/funasr/build_utils/build_vad_model.py
@@ -50,6 +50,10 @@
def build_vad_model(args):
# frontend
+ if not hasattr(args, "cmvn_file"):
+ args.cmvn_file = None
+ if not hasattr(args, "init"):
+ args.init = None
if args.input_size is None:
frontend_class = frontend_choices.get_class(args.frontend)
if args.frontend == 'wav_frontend':
diff --git a/funasr/models/e2e_asr_contextual_paraformer.py b/funasr/models/e2e_asr_contextual_paraformer.py
index dc820db..cfb5008 100644
--- a/funasr/models/e2e_asr_contextual_paraformer.py
+++ b/funasr/models/e2e_asr_contextual_paraformer.py
@@ -43,9 +43,7 @@
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
- postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@@ -72,6 +70,8 @@
crit_attn_weight: float = 0.0,
crit_attn_smooth: float = 0.0,
bias_encoder_dropout_rate: float = 0.0,
+ preencoder: Optional[AbsPreEncoder] = None,
+ postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
diff --git a/funasr/models/e2e_asr_mfcca.py b/funasr/models/e2e_asr_mfcca.py
index fbf0d11..3927e38 100644
--- a/funasr/models/e2e_asr_mfcca.py
+++ b/funasr/models/e2e_asr_mfcca.py
@@ -53,7 +53,7 @@
encoder: AbsEncoder,
decoder: AbsDecoder,
ctc: CTC,
- rnnt_decoder: None,
+ rnnt_decoder: None = None,
ctc_weight: float = 0.5,
ignore_id: int = -1,
lsm_weight: float = 0.0,
diff --git a/funasr/models/e2e_uni_asr.py b/funasr/models/e2e_uni_asr.py
index d08ea37..9ec3a39 100644
--- a/funasr/models/e2e_uni_asr.py
+++ b/funasr/models/e2e_uni_asr.py
@@ -50,9 +50,7 @@
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
- postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@@ -80,6 +78,8 @@
loss_weight_model1: float = 0.5,
enable_maas_finetune: bool = False,
freeze_encoder2: bool = False,
+ preencoder: Optional[AbsPreEncoder] = None,
+ postencoder: Optional[AbsPostEncoder] = None,
encoder1_encoder2_joint_training: bool = True,
):
assert check_argument_types()
diff --git a/funasr/models/e2e_vad.py b/funasr/models/e2e_vad.py
index 14d56a8..7c55b2e 100644
--- a/funasr/models/e2e_vad.py
+++ b/funasr/models/e2e_vad.py
@@ -5,6 +5,7 @@
from torch import nn
import math
from funasr.models.encoder.fsmn_encoder import FSMN
+from funasr.models.base_model import FunASRModel
class VadStateMachine(Enum):
@@ -211,7 +212,7 @@
return int(self.frame_size_ms)
-class E2EVadModel(nn.Module):
+class E2EVadModel(FunASRModel):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
diff --git a/tests/test_sv_inference_pipeline.py b/tests/test_sv_inference_pipeline.py
index 60ece2d..c4e427e 100644
--- a/tests/test_sv_inference_pipeline.py
+++ b/tests/test_sv_inference_pipeline.py
@@ -35,4 +35,4 @@
logger.info(f"Similarity {rec_result['scores']}")
if __name__ == '__main__':
- unittest.main()
+ unittest.main()
\ No newline at end of file
diff --git a/tests/test_vad_inference_pipeline.py b/tests/test_vad_inference_pipeline.py
index b6601b1..50b8db3 100644
--- a/tests/test_vad_inference_pipeline.py
+++ b/tests/test_vad_inference_pipeline.py
@@ -37,7 +37,7 @@
rec_result = inference_pipeline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav')
logger.info("vad inference result: {0}".format(rec_result))
- assert rec_result["text"] == [[80, 2340], [2620, 6200], [6480, 23670], [23950, 26250], [26780, 28990],
+ assert rec_result["text"] == [[70, 2340], [2620, 6200], [6480, 23670], [23950, 26250], [26780, 28990],
[29950, 31430], [31750, 37600], [38210, 46900], [47310, 49630], [49910, 56460],
[56740, 59540], [59820, 70450]]
--
Gitblit v1.9.1