From 4cd79db451786548d8a100f25c3b03da0eb30f4b Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 16 五月 2023 14:08:57 +0800
Subject: [PATCH] inference
---
/dev/null | 50 --------------------------------------------------
1 files changed, 0 insertions(+), 50 deletions(-)
diff --git a/funasr/bin/asr_train_paraformer.py b/funasr/bin/asr_train_paraformer.py
deleted file mode 100755
index 223be14..0000000
--- a/funasr/bin/asr_train_paraformer.py
+++ /dev/null
@@ -1,55 +0,0 @@
-# -*- encoding: utf-8 -*-
-#!/usr/bin/env python3
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-import os
-
-from funasr.tasks.asr import ASRTaskParaformer as ASRTask
-
-
-# for ASR Training
-def parse_args():
- parser = ASRTask.get_parser()
- parser.add_argument(
- "--mode",
- type=str,
- default="asr",
- help="mode",
- )
- parser.add_argument(
- "--gpu_id",
- type=int,
- default=0,
- help="local gpu id.",
- )
- args = parser.parse_args()
- return args
-
-
-def main(args=None, cmd=None):
- # for ASR Training
- ASRTask.main(args=args, cmd=cmd)
-
-
-if __name__ == '__main__':
- args = parse_args()
-
- # setup local gpu_id
- os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
-
- # DDP settings
- if args.ngpu > 1:
- args.distributed = True
- else:
- args.distributed = False
- assert args.num_worker_count == 1
-
- # re-compute batch size: when dataset type is small
- if args.dataset_type == "small":
- if args.batch_size is not None:
- args.batch_size = args.batch_size * args.ngpu
- if args.batch_bins is not None:
- args.batch_bins = args.batch_bins * args.ngpu
-
- main(args=args)
diff --git a/funasr/bin/asr_train_transducer.py b/funasr/bin/asr_train_transducer.py
deleted file mode 100755
index fe418db..0000000
--- a/funasr/bin/asr_train_transducer.py
+++ /dev/null
@@ -1,46 +0,0 @@
-#!/usr/bin/env python3
-
-import os
-
-from funasr.tasks.asr import ASRTransducerTask
-
-
-# for ASR Training
-def parse_args():
- parser = ASRTransducerTask.get_parser()
- parser.add_argument(
- "--gpu_id",
- type=int,
- default=0,
- help="local gpu id.",
- )
- args = parser.parse_args()
- return args
-
-
-def main(args=None, cmd=None):
- # for ASR Training
- ASRTransducerTask.main(args=args, cmd=cmd)
-
-
-if __name__ == '__main__':
- args = parse_args()
-
- # setup local gpu_id
- os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
-
- # DDP settings
- if args.ngpu > 1:
- args.distributed = True
- else:
- args.distributed = False
- assert args.num_worker_count == 1
-
- # re-compute batch size: when dataset type is small
- if args.dataset_type == "small":
- if args.batch_size is not None:
- args.batch_size = args.batch_size * args.ngpu
- if args.batch_bins is not None:
- args.batch_bins = args.batch_bins * args.ngpu
-
- main(args=args)
diff --git a/funasr/bin/asr_train_uniasr.py b/funasr/bin/asr_train_uniasr.py
deleted file mode 100755
index a40b503..0000000
--- a/funasr/bin/asr_train_uniasr.py
+++ /dev/null
@@ -1,46 +0,0 @@
-#!/usr/bin/env python3
-
-import os
-
-from funasr.tasks.asr import ASRTaskUniASR
-
-
-# for ASR Training
-def parse_args():
- parser = ASRTaskUniASR.get_parser()
- parser.add_argument(
- "--gpu_id",
- type=int,
- default=0,
- help="local gpu id.",
- )
- args = parser.parse_args()
- return args
-
-
-def main(args=None, cmd=None):
- # for ASR Training
- ASRTaskUniASR.main(args=args, cmd=cmd)
-
-
-if __name__ == '__main__':
- args = parse_args()
-
- # setup local gpu_id
- os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
-
- # DDP settings
- if args.ngpu > 1:
- args.distributed = True
- else:
- args.distributed = False
- assert args.num_worker_count == 1
-
- # re-compute batch size: when dataset type is small
- if args.dataset_type == "small":
- if args.batch_size is not None:
- args.batch_size = args.batch_size * args.ngpu
- if args.batch_bins is not None:
- args.batch_bins = args.batch_bins * args.ngpu
-
- main(args=args)
diff --git a/funasr/bin/sa_asr_inference.py b/funasr/bin/sa_asr_inference.py
deleted file mode 100644
index 7a5ba83..0000000
--- a/funasr/bin/sa_asr_inference.py
+++ /dev/null
@@ -1,692 +0,0 @@
-# -*- encoding: utf-8 -*-
-#!/usr/bin/env python3
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-import argparse
-import logging
-import sys
-from pathlib import Path
-from typing import Any
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-from typing import Dict
-
-import numpy as np
-import torch
-from typeguard import check_argument_types
-from typeguard import check_return_type
-
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.modules.beam_search.batch_beam_search_online_sim import BatchBeamSearchOnlineSim
-from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch
-from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis
-from funasr.modules.scorers.ctc import CTCPrefixScorer
-from funasr.modules.scorers.length_bonus import LengthBonus
-from funasr.modules.scorers.scorer_interface import BatchScorerInterface
-from funasr.modules.subsampling import TooShortUttError
-from funasr.tasks.sa_asr import ASRTask
-from funasr.tasks.lm import LMTask
-from funasr.text.build_tokenizer import build_tokenizer
-from funasr.text.token_id_converter import TokenIDConverter
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.utils import asr_utils, wav_utils, postprocess_utils
-from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.tasks.asr import frontend_choices
-
-
-header_colors = '\033[95m'
-end_colors = '\033[0m'
-
-
-class Speech2Text:
- """Speech2Text class
-
- Examples:
- >>> import soundfile
- >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
- >>> audio, rate = soundfile.read("speech.wav")
- >>> speech2text(audio)
- [(text, token, token_int, hypothesis object), ...]
-
- """
-
- def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- maxlenratio: float = 0.0,
- minlenratio: float = 0.0,
- batch_size: int = 1,
- dtype: str = "float32",
- beam_size: int = 20,
- ctc_weight: float = 0.5,
- lm_weight: float = 1.0,
- ngram_weight: float = 0.9,
- penalty: float = 0.0,
- nbest: int = 1,
- streaming: bool = False,
- frontend_conf: dict = None,
- **kwargs,
- ):
- assert check_argument_types()
-
- # 1. Build ASR model
- scorers = {}
- asr_model, asr_train_args = ASRTask.build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, device
- )
- frontend = None
- if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
- if asr_train_args.frontend=='wav_frontend':
- frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
- else:
- frontend_class=frontend_choices.get_class(asr_train_args.frontend)
- frontend = frontend_class(**asr_train_args.frontend_conf).eval()
-
- logging.info("asr_model: {}".format(asr_model))
- logging.info("asr_train_args: {}".format(asr_train_args))
- asr_model.to(dtype=getattr(torch, dtype)).eval()
-
- decoder = asr_model.decoder
-
- ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
- token_list = asr_model.token_list
- scorers.update(
- decoder=decoder,
- ctc=ctc,
- length_bonus=LengthBonus(len(token_list)),
- )
-
- # 2. Build Language model
- if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, None, device
- )
- scorers["lm"] = lm.lm
-
- # 3. Build ngram model
- # ngram is not supported now
- ngram = None
- scorers["ngram"] = ngram
-
- # 4. Build BeamSearch object
- # transducer is not supported now
- beam_search_transducer = None
-
- weights = dict(
- decoder=1.0 - ctc_weight,
- ctc=ctc_weight,
- lm=lm_weight,
- ngram=ngram_weight,
- length_bonus=penalty,
- )
- beam_search = BeamSearch(
- beam_size=beam_size,
- weights=weights,
- scorers=scorers,
- sos=asr_model.sos,
- eos=asr_model.eos,
- vocab_size=len(token_list),
- token_list=token_list,
- pre_beam_score_key=None if ctc_weight == 1.0 else "full",
- )
-
- # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
- if token_type is None:
- token_type = asr_train_args.token_type
- if bpemodel is None:
- bpemodel = asr_train_args.bpemodel
-
- if token_type is None:
- tokenizer = None
- elif token_type == "bpe":
- if bpemodel is not None:
- tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
- else:
- tokenizer = None
- else:
- tokenizer = build_tokenizer(token_type=token_type)
- converter = TokenIDConverter(token_list=token_list)
- logging.info(f"Text tokenizer: {tokenizer}")
-
- self.asr_model = asr_model
- self.asr_train_args = asr_train_args
- self.converter = converter
- self.tokenizer = tokenizer
- self.beam_search = beam_search
- self.beam_search_transducer = beam_search_transducer
- self.maxlenratio = maxlenratio
- self.minlenratio = minlenratio
- self.device = device
- self.dtype = dtype
- self.nbest = nbest
- self.frontend = frontend
-
- @torch.no_grad()
- def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray], profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray]
- ) -> List[
- Tuple[
- Optional[str],
- Optional[str],
- List[str],
- List[int],
- Union[Hypothesis],
- ]
- ]:
- """Inference
-
- Args:
- speech: Input speech data
- Returns:
- text, text_id, token, token_int, hyp
-
- """
- assert check_argument_types()
-
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
-
- if isinstance(profile, np.ndarray):
- profile = torch.tensor(profile)
-
- if self.frontend is not None:
- feats, feats_len = self.frontend.forward(speech, speech_lengths)
- feats = to_device(feats, device=self.device)
- feats_len = feats_len.int()
- self.asr_model.frontend = None
- else:
- feats = speech
- feats_len = speech_lengths
- lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
- batch = {"speech": feats, "speech_lengths": feats_len}
-
- # a. To device
- batch = to_device(batch, device=self.device)
-
- # b. Forward Encoder
- asr_enc, _, spk_enc = self.asr_model.encode(**batch)
- if isinstance(asr_enc, tuple):
- asr_enc = asr_enc[0]
- if isinstance(spk_enc, tuple):
- spk_enc = spk_enc[0]
- assert len(asr_enc) == 1, len(asr_enc)
- assert len(spk_enc) == 1, len(spk_enc)
-
- # c. Passed the encoder result and the beam search
- nbest_hyps = self.beam_search(
- asr_enc[0], spk_enc[0], profile[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
-
- results = []
- for hyp in nbest_hyps:
- assert isinstance(hyp, (Hypothesis)), type(hyp)
-
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1: last_pos]
- else:
- token_int = hyp.yseq[1: last_pos].tolist()
-
- spk_weigths=torch.stack(hyp.spk_weigths, dim=0)
-
- token_ori = self.converter.ids2tokens(token_int)
- text_ori = self.tokenizer.tokens2text(token_ori)
-
- text_ori_spklist = text_ori.split('$')
- cur_index = 0
- spk_choose = []
- for i in range(len(text_ori_spklist)):
- text_ori_split = text_ori_spklist[i]
- n = len(text_ori_split)
- spk_weights_local = spk_weigths[cur_index: cur_index + n]
- cur_index = cur_index + n + 1
- spk_weights_local = spk_weights_local.mean(dim=0)
- spk_choose_local = spk_weights_local.argmax(-1)
- spk_choose.append(spk_choose_local.item() + 1)
-
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x != 0, token_int))
-
- # Change integer-ids to tokens
- token = self.converter.ids2tokens(token_int)
-
- if self.tokenizer is not None:
- text = self.tokenizer.tokens2text(token)
- else:
- text = None
-
- text_spklist = text.split('$')
- assert len(spk_choose) == len(text_spklist)
-
- spk_list=[]
- for i in range(len(text_spklist)):
- text_split = text_spklist[i]
- n = len(text_split)
- spk_list.append(str(spk_choose[i]) * n)
-
- text_id = '$'.join(spk_list)
-
- assert len(text) == len(text_id)
-
- results.append((text, text_id, token, token_int, hyp))
-
- assert check_return_type(results)
- return results
-
-def inference(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- streaming: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- mc: bool = False,
- **kwargs,
-):
- inference_pipeline = inference_modelscope(
- maxlenratio=maxlenratio,
- minlenratio=minlenratio,
- batch_size=batch_size,
- beam_size=beam_size,
- ngpu=ngpu,
- ctc_weight=ctc_weight,
- lm_weight=lm_weight,
- penalty=penalty,
- log_level=log_level,
- asr_train_config=asr_train_config,
- asr_model_file=asr_model_file,
- cmvn_file=cmvn_file,
- raw_inputs=raw_inputs,
- lm_train_config=lm_train_config,
- lm_file=lm_file,
- token_type=token_type,
- key_file=key_file,
- word_lm_train_config=word_lm_train_config,
- bpemodel=bpemodel,
- allow_variable_data_keys=allow_variable_data_keys,
- streaming=streaming,
- output_dir=output_dir,
- dtype=dtype,
- seed=seed,
- ngram_weight=ngram_weight,
- nbest=nbest,
- num_workers=num_workers,
- mc=mc,
- **kwargs,
- )
- return inference_pipeline(data_path_and_name_and_type, raw_inputs)
-
-def inference_modelscope(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- streaming: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- mc: bool = False,
- param_dict: dict = None,
- **kwargs,
-):
- assert check_argument_types()
- if batch_size > 1:
- raise NotImplementedError("batch decoding is not implemented")
- if word_lm_train_config is not None:
- raise NotImplementedError("Word LM is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
-
- for handler in logging.root.handlers[:]:
- logging.root.removeHandler(handler)
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2text
- speech2text_kwargs = dict(
- asr_train_config=asr_train_config,
- asr_model_file=asr_model_file,
- cmvn_file=cmvn_file,
- lm_train_config=lm_train_config,
- lm_file=lm_file,
- token_type=token_type,
- bpemodel=bpemodel,
- device=device,
- maxlenratio=maxlenratio,
- minlenratio=minlenratio,
- dtype=dtype,
- beam_size=beam_size,
- ctc_weight=ctc_weight,
- lm_weight=lm_weight,
- ngram_weight=ngram_weight,
- penalty=penalty,
- nbest=nbest,
- streaming=streaming,
- )
- logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
- speech2text = Speech2Text(**speech2text_kwargs)
-
- def _forward(data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- **kwargs,
- ):
- # 3. Build data-iterator
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
- dtype=dtype,
- fs=fs,
- mc=mc,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
- collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
- )
-
- finish_count = 0
- file_count = 1
- # 7 .Start for-loop
- # FIXME(kamo): The output format should be discussed about
- asr_result_list = []
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path is not None:
- writer = DatadirWriter(output_path)
- else:
- writer = None
-
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
- # N-best list of (text, token, token_int, hyp_object)
- try:
- results = speech2text(**batch)
- except TooShortUttError as e:
- logging.warning(f"Utterance {keys} {e}")
- hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
- results = [[" ", ["sil"], [2], hyp]] * nbest
-
- # Only supporting batch_size==1
- key = keys[0]
- for n, (text, text_id, token, token_int, hyp) in zip(range(1, nbest + 1), results):
- # Create a directory: outdir/{n}best_recog
- if writer is not None:
- ibest_writer = writer[f"{n}best_recog"]
-
- # Write the result to each file
- ibest_writer["token"][key] = " ".join(token)
- ibest_writer["token_int"][key] = " ".join(map(str, token_int))
- ibest_writer["score"][key] = str(hyp.score)
- ibest_writer["text_id"][key] = text_id
-
- if text is not None:
- text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
- item = {'key': key, 'value': text_postprocessed}
- asr_result_list.append(item)
- finish_count += 1
- asr_utils.print_progress(finish_count / file_count)
- if writer is not None:
- ibest_writer["text"][key] = text
-
- logging.info("uttid: {}".format(key))
- logging.info("text predictions: {}".format(text))
- logging.info("text_id predictions: {}\n".format(text_id))
- return asr_result_list
-
- return _forward
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="ASR Decoding",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- # Note(kamo): Use '_' instead of '-' as separator.
- # '-' is confusing if written in yaml.
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
-
- parser.add_argument("--output_dir", type=str, required=True)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument(
- "--gpuid_list",
- type=str,
- default="",
- help="The visible gpus",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
-
- group = parser.add_argument_group("Input data related")
- group.add_argument(
- "--data_path_and_name_and_type",
- type=str2triple_str,
- required=False,
- action="append",
- )
- group.add_argument("--raw_inputs", type=list, default=None)
- # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
- group.add_argument("--key_file", type=str_or_none)
- group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument(
- "--asr_train_config",
- type=str,
- help="ASR training configuration",
- )
- group.add_argument(
- "--asr_model_file",
- type=str,
- help="ASR model parameter file",
- )
- group.add_argument(
- "--cmvn_file",
- type=str,
- help="Global cmvn file",
- )
- group.add_argument(
- "--lm_train_config",
- type=str,
- help="LM training configuration",
- )
- group.add_argument(
- "--lm_file",
- type=str,
- help="LM parameter file",
- )
- group.add_argument(
- "--word_lm_train_config",
- type=str,
- help="Word LM training configuration",
- )
- group.add_argument(
- "--word_lm_file",
- type=str,
- help="Word LM parameter file",
- )
- group.add_argument(
- "--ngram_file",
- type=str,
- help="N-gram parameter file",
- )
- group.add_argument(
- "--model_tag",
- type=str,
- help="Pretrained model tag. If specify this option, *_train_config and "
- "*_file will be overwritten",
- )
-
- group = parser.add_argument_group("Beam-search related")
- group.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
- group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
- group.add_argument("--beam_size", type=int, default=20, help="Beam size")
- group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
- group.add_argument(
- "--maxlenratio",
- type=float,
- default=0.0,
- help="Input length ratio to obtain max output length. "
- "If maxlenratio=0.0 (default), it uses a end-detect "
- "function "
- "to automatically find maximum hypothesis lengths."
- "If maxlenratio<0.0, its absolute value is interpreted"
- "as a constant max output length",
- )
- group.add_argument(
- "--minlenratio",
- type=float,
- default=0.0,
- help="Input length ratio to obtain min output length",
- )
- group.add_argument(
- "--ctc_weight",
- type=float,
- default=0.5,
- help="CTC weight in joint decoding",
- )
- group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
- group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
- group.add_argument("--streaming", type=str2bool, default=False)
-
- group = parser.add_argument_group("Text converter related")
- group.add_argument(
- "--token_type",
- type=str_or_none,
- default=None,
- choices=["char", "bpe", None],
- help="The token type for ASR model. "
- "If not given, refers from the training args",
- )
- group.add_argument(
- "--bpemodel",
- type=str_or_none,
- default=None,
- help="The model path of sentencepiece. "
- "If not given, refers from the training args",
- )
-
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- parser = get_parser()
- args = parser.parse_args(cmd)
- kwargs = vars(args)
- kwargs.pop("config", None)
- inference(**kwargs)
-
-
-if __name__ == "__main__":
- main()
diff --git a/funasr/bin/sa_asr_train.py b/funasr/bin/sa_asr_train.py
deleted file mode 100755
index 67106cf..0000000
--- a/funasr/bin/sa_asr_train.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# -*- encoding: utf-8 -*-
-#!/usr/bin/env python3
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-import os
-
-from funasr.tasks.sa_asr import ASRTask
-
-
-# for ASR Training
-def parse_args():
- parser = ASRTask.get_parser()
- parser.add_argument(
- "--gpu_id",
- type=int,
- default=0,
- help="local gpu id.",
- )
- args = parser.parse_args()
- return args
-
-
-def main(args=None, cmd=None):
- # for ASR Training
- ASRTask.main(args=args, cmd=cmd)
-
-
-if __name__ == '__main__':
- args = parse_args()
-
- # setup local gpu_id
- if args.ngpu > 0:
- os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
-
- # DDP settings
- if args.ngpu > 1:
- args.distributed = True
- else:
- args.distributed = False
- assert args.num_worker_count == 1
-
- # re-compute batch size: when dataset type is small
- if args.dataset_type == "small":
- if args.batch_size is not None and args.ngpu > 0:
- args.batch_size = args.batch_size * args.ngpu
- if args.batch_bins is not None and args.ngpu > 0:
- args.batch_bins = args.batch_bins * args.ngpu
-
- main(args=args)
--
Gitblit v1.9.1