Merge pull request #91 from alibaba-damo-academy/dev_lyb
add language model infer pipeline
| | |
| | | set_all_random_seed(seed) |
| | | |
| | | # 2. Build LM |
| | | model, train_args = LMTask.build_model_from_file(train_config, model_file, device) |
| | | model, train_args = LMTask.build_model_from_file(config_file=train_config, model_file=model_file, device=device) |
| | | # Wrape model to make model.nll() data-parallel |
| | | wrapped_model = ForwardAdaptor(model, "nll") |
| | | wrapped_model.to(dtype=getattr(torch, dtype)).eval() |
| | |
| | | utt_ppl = log_base ** (_nll / ntoken / np.log(log_base)) |
| | | |
| | | # Write PPL of each utts for debugging or analysis |
| | | writer["utt2nll"][key] = str(-_nll) |
| | | writer["utt2ppl"][key] = str(utt_ppl) |
| | | writer["utt2ntokens"][key] = str(ntoken) |
| | | |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | import argparse |
| | | import logging |
| | | from pathlib import Path |
| | | import sys |
| | | import os |
| | | from typing import Optional |
| | | from typing import Sequence |
| | | from typing import Tuple |
| | | from typing import Union |
| | | from typing import Dict |
| | | from typing import Any |
| | | from typing import List |
| | | |
| | | import numpy as np |
| | | import torch |
| | | from torch.nn.parallel import data_parallel |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.tasks.lm import LMTask |
| | | from funasr.datasets.preprocessor import LMPreprocessor |
| | | from funasr.utils.cli_utils import get_commandline_args |
| | | from funasr.fileio.datadir_writer import DatadirWriter |
| | | from funasr.torch_utils.device_funcs import to_device |
| | | from funasr.torch_utils.forward_adaptor import ForwardAdaptor |
| | | from funasr.torch_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.utils import config_argparse |
| | | from funasr.utils.types import float_or_none |
| | | from funasr.utils.types import str2bool |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | |
| | | def inference( |
| | | output_dir: str, |
| | | batch_size: int, |
| | | dtype: str, |
| | | ngpu: int, |
| | | seed: int, |
| | | num_workers: int, |
| | | log_level: Union[int, str], |
| | | train_config: Optional[str], |
| | | model_file: Optional[str], |
| | | log_base: Optional[float], |
| | | key_file: Optional[str] = None, |
| | | allow_variable_data_keys: bool = False, |
| | | split_with_space: Optional[bool] = False, |
| | | seg_dict_file: Optional[str] = None, |
| | | data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None, |
| | | raw_inputs: Union[List[Any], bytes, str] = None, |
| | | **kwargs, |
| | | ): |
| | | inference_pipeline = inference_modelscope( |
| | | output_dir=output_dir, |
| | | raw_inputs=raw_inputs, |
| | | batch_size=batch_size, |
| | | dtype=dtype, |
| | | ngpu=ngpu, |
| | | seed=seed, |
| | | num_workers=num_workers, |
| | | log_level=log_level, |
| | | key_file=key_file, |
| | | train_config=train_config, |
| | | model_file=model_file, |
| | | log_base = log_base, |
| | | allow_variable_data_keys = allow_variable_data_keys, |
| | | split_with_space=split_with_space, |
| | | seg_dict_file=seg_dict_file, |
| | | **kwargs, |
| | | ) |
| | | return inference_pipeline(data_path_and_name_and_type, raw_inputs) |
| | | |
| | | |
| | | def inference_modelscope( |
| | | batch_size: int, |
| | | dtype: str, |
| | | ngpu: int, |
| | | seed: int, |
| | | num_workers: int, |
| | | log_level: Union[int, str], |
| | | key_file: Optional[str], |
| | | train_config: Optional[str], |
| | | model_file: Optional[str], |
| | | log_base: Optional[float] = 10, |
| | | allow_variable_data_keys: bool = False, |
| | | split_with_space: Optional[bool] = False, |
| | | seg_dict_file: Optional[str] = None, |
| | | output_dir: Optional[str] = None, |
| | | param_dict: dict = None, |
| | | **kwargs, |
| | | ): |
| | | assert check_argument_types() |
| | | logging.basicConfig( |
| | | level=log_level, |
| | | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | |
| | | if ngpu >= 1 and torch.cuda.is_available(): |
| | | device = "cuda" |
| | | else: |
| | | device = "cpu" |
| | | |
| | | # 1. Set random-seed |
| | | set_all_random_seed(seed) |
| | | |
| | | # 2. Build Model |
| | | model, train_args = LMTask.build_model_from_file( |
| | | train_config, model_file, device) |
| | | wrapped_model = ForwardAdaptor(model, "nll") |
| | | wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval() |
| | | logging.info(f"Model:\n{model}") |
| | | |
| | | preprocessor = LMPreprocessor( |
| | | train=False, |
| | | token_type=train_args.token_type, |
| | | token_list=train_args.token_list, |
| | | bpemodel=train_args.bpemodel, |
| | | text_cleaner=train_args.cleaner, |
| | | g2p_type=train_args.g2p, |
| | | text_name="text", |
| | | non_linguistic_symbols=train_args.non_linguistic_symbols, |
| | | split_with_space=split_with_space, |
| | | seg_dict_file=seg_dict_file |
| | | ) |
| | | |
| | | def _forward( |
| | | data_path_and_name_and_type, |
| | | raw_inputs: Union[List[Any], bytes, str] = None, |
| | | output_dir_v2: Optional[str] = None, |
| | | param_dict: dict = None, |
| | | ): |
| | | results = [] |
| | | if output_dir_v2 is not None: |
| | | writer = DatadirWriter(output_dir_v2) |
| | | else: |
| | | writer = None |
| | | |
| | | if raw_inputs != None: |
| | | line = raw_inputs.strip() |
| | | key = "lm demo" |
| | | if line=="": |
| | | item = {'key': key, 'value': ""} |
| | | results.append(item) |
| | | return results |
| | | batch = {} |
| | | batch['text'] = line |
| | | if preprocessor != None: |
| | | batch = preprocessor(key, batch) |
| | | |
| | | # Force data-precision |
| | | for name in batch: |
| | | value = batch[name] |
| | | if not isinstance(value, np.ndarray): |
| | | raise RuntimeError( |
| | | f"All values must be converted to np.ndarray object " |
| | | f'by preprocessing, but "{name}" is still {type(value)}.' |
| | | ) |
| | | # Cast to desired type |
| | | if value.dtype.kind == "f": |
| | | value = value.astype("float32") |
| | | elif value.dtype.kind == "i": |
| | | value = value.astype("long") |
| | | else: |
| | | raise NotImplementedError(f"Not supported dtype: {value.dtype}") |
| | | batch[name] = value |
| | | |
| | | batch["text_lengths"] = torch.from_numpy( |
| | | np.array([len(batch["text"])], dtype='int32')) |
| | | batch["text"] = np.expand_dims(batch["text"], axis=0) |
| | | |
| | | with torch.no_grad(): |
| | | batch = to_device(batch, device) |
| | | if ngpu <= 1: |
| | | nll, lengths = wrapped_model(**batch) |
| | | else: |
| | | nll, lengths = data_parallel( |
| | | wrapped_model, (), range(ngpu), module_kwargs=batch |
| | | ) |
| | | ## compute ppl |
| | | ppl_out_batch = "" |
| | | ids2tokens = preprocessor.token_id_converter.ids2tokens |
| | | for sent_ids, sent_nll in zip(batch['text'], nll): |
| | | pre_word = "<s>" |
| | | cur_word = None |
| | | sent_lst = ids2tokens(sent_ids) + ['</s>'] |
| | | ppl_out = " ".join(sent_lst) + "\n" |
| | | for word, word_nll in zip(sent_lst, sent_nll): |
| | | cur_word = word |
| | | word_nll = -word_nll.cpu() |
| | | if log_base is None: |
| | | word_prob = np.exp(word_nll) |
| | | else: |
| | | word_prob = log_base ** (word_nll / np.log(log_base)) |
| | | ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format( |
| | | cur=cur_word, |
| | | pre=pre_word, |
| | | prob=round(word_prob.item(), 8), |
| | | word_nll=round(word_nll.item(), 8) |
| | | ) |
| | | pre_word = cur_word |
| | | |
| | | sent_nll_mean = sent_nll.mean().cpu().numpy() |
| | | sent_nll_sum = sent_nll.sum().cpu().numpy() |
| | | if log_base is None: |
| | | sent_ppl = np.exp(sent_nll_mean) |
| | | else: |
| | | sent_ppl = log_base ** (sent_nll_mean / np.log(log_base)) |
| | | ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format( |
| | | sent_nll=round(-sent_nll_sum.item(), 4), |
| | | sent_ppl=round(sent_ppl.item(), 4) |
| | | ) |
| | | ppl_out_batch += ppl_out |
| | | item = {'key': key, 'value': ppl_out} |
| | | if writer is not None: |
| | | writer["ppl"][key+":\n"] = ppl_out |
| | | results.append(item) |
| | | |
| | | return results |
| | | |
| | | # 3. Build data-iterator |
| | | loader = LMTask.build_streaming_iterator( |
| | | data_path_and_name_and_type, |
| | | dtype=dtype, |
| | | batch_size=batch_size, |
| | | key_file=key_file, |
| | | num_workers=num_workers, |
| | | preprocess_fn=preprocessor, |
| | | collate_fn=LMTask.build_collate_fn(train_args, False), |
| | | allow_variable_data_keys=allow_variable_data_keys, |
| | | inference=True, |
| | | ) |
| | | |
| | | # 4. Start for-loop |
| | | total_nll = 0.0 |
| | | total_ntokens = 0 |
| | | ppl_out_all = "" |
| | | for keys, batch in loader: |
| | | assert isinstance(batch, dict), type(batch) |
| | | assert all(isinstance(s, str) for s in keys), keys |
| | | _bs = len(next(iter(batch.values()))) |
| | | assert len(keys) == _bs, f"{len(keys)} != {_bs}" |
| | | |
| | | ppl_out_batch = "" |
| | | with torch.no_grad(): |
| | | batch = to_device(batch, device) |
| | | if ngpu <= 1: |
| | | # NOTE(kamo): data_parallel also should work with ngpu=1, |
| | | # but for debuggability it's better to keep this block. |
| | | nll, lengths = wrapped_model(**batch) |
| | | else: |
| | | nll, lengths = data_parallel( |
| | | wrapped_model, (), range(ngpu), module_kwargs=batch |
| | | ) |
| | | ## print ppl |
| | | ids2tokens = preprocessor.token_id_converter.ids2tokens |
| | | for key, sent_ids, sent_nll in zip(keys, batch['text'], nll): |
| | | pre_word = "<s>" |
| | | cur_word = None |
| | | sent_lst = ids2tokens(sent_ids) + ['</s>'] |
| | | ppl_out = " ".join(sent_lst) + "\n" |
| | | for word, word_nll in zip(sent_lst, sent_nll): |
| | | cur_word = word |
| | | word_nll = -word_nll.cpu() |
| | | if log_base is None: |
| | | word_prob = np.exp(word_nll) |
| | | else: |
| | | word_prob = log_base ** (word_nll / np.log(log_base)) |
| | | ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format( |
| | | cur=cur_word, |
| | | pre=pre_word, |
| | | prob=round(word_prob.item(), 8), |
| | | word_nll=round(word_nll.item(), 8) |
| | | ) |
| | | pre_word = cur_word |
| | | |
| | | sent_nll_mean = sent_nll.mean().cpu().numpy() |
| | | sent_nll_sum = sent_nll.sum().cpu().numpy() |
| | | if log_base is None: |
| | | sent_ppl = np.exp(sent_nll_mean) |
| | | else: |
| | | sent_ppl = log_base ** (sent_nll_mean / np.log(log_base)) |
| | | ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format( |
| | | sent_nll=round(-sent_nll_sum.item(), 4), |
| | | sent_ppl=round(sent_ppl.item(), 4) |
| | | ) |
| | | ppl_out_batch += ppl_out |
| | | utt2nll = round(-sent_nll_sum.item(), 5) |
| | | item = {'key': key, 'value': ppl_out} |
| | | if writer is not None: |
| | | writer["ppl"][key+":\n"] = ppl_out |
| | | writer["utt2nll"][key] = str(utt2nll) |
| | | results.append(item) |
| | | |
| | | ppl_out_all += ppl_out_batch |
| | | |
| | | assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths)) |
| | | # nll: (B, L) -> (B,) |
| | | nll = nll.detach().cpu().numpy().sum(1) |
| | | # lengths: (B,) |
| | | lengths = lengths.detach().cpu().numpy() |
| | | total_nll += nll.sum() |
| | | total_ntokens += lengths.sum() |
| | | |
| | | if log_base is None: |
| | | ppl = np.exp(total_nll / total_ntokens) |
| | | else: |
| | | ppl = log_base ** (total_nll / total_ntokens / np.log(log_base)) |
| | | |
| | | avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format( |
| | | total_nll=round(-total_nll.item(), 4), |
| | | total_ppl=round(ppl.item(), 4) |
| | | ) |
| | | item = {'key': 'AVG PPL', 'value': avg_ppl} |
| | | ppl_out_all += avg_ppl |
| | | if writer is not None: |
| | | writer["ppl"]["AVG PPL : "] = avg_ppl |
| | | results.append(item) |
| | | |
| | | return results |
| | | |
| | | return _forward |
| | | |
| | | |
| | | def get_parser(): |
| | | parser = config_argparse.ArgumentParser( |
| | | description="Calc perplexity", |
| | | formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| | | ) |
| | | |
| | | parser.add_argument( |
| | | "--log_level", |
| | | type=lambda x: x.upper(), |
| | | default="INFO", |
| | | choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), |
| | | help="The verbose level of logging", |
| | | ) |
| | | |
| | | parser.add_argument("--output_dir", type=str, required=False) |
| | | parser.add_argument( |
| | | "--ngpu", |
| | | type=int, |
| | | default=0, |
| | | help="The number of gpus. 0 indicates CPU mode", |
| | | ) |
| | | parser.add_argument("--seed", type=int, default=0, help="Random seed") |
| | | parser.add_argument( |
| | | "--dtype", |
| | | default="float32", |
| | | choices=["float16", "float32", "float64"], |
| | | help="Data type", |
| | | ) |
| | | parser.add_argument( |
| | | "--num_workers", |
| | | type=int, |
| | | default=1, |
| | | help="The number of workers used for DataLoader", |
| | | ) |
| | | parser.add_argument( |
| | | "--batch_size", |
| | | type=int, |
| | | default=1, |
| | | help="The batch size for inference", |
| | | ) |
| | | parser.add_argument( |
| | | "--log_base", |
| | | type=float_or_none, |
| | | default=10, |
| | | help="The base of logarithm for Perplexity. " |
| | | "If None, napier's constant is used.", |
| | | required=False |
| | | ) |
| | | |
| | | group = parser.add_argument_group("Input data related") |
| | | group.add_argument( |
| | | "--data_path_and_name_and_type", |
| | | type=str2triple_str, |
| | | action="append", |
| | | required=False |
| | | ) |
| | | group.add_argument( |
| | | "--raw_inputs", |
| | | type=str, |
| | | required=False |
| | | ) |
| | | group.add_argument("--key_file", type=str_or_none) |
| | | group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) |
| | | |
| | | group.add_argument("--split_with_space", type=str2bool, default=False) |
| | | group.add_argument("--seg_dict_file", type=str_or_none) |
| | | |
| | | group = parser.add_argument_group("The model configuration related") |
| | | group.add_argument("--train_config", type=str) |
| | | group.add_argument("--model_file", type=str) |
| | | |
| | | return parser |
| | | |
| | | |
| | | def main(cmd=None): |
| | | print(get_commandline_args(), file=sys.stderr) |
| | | parser = get_parser() |
| | | args = parser.parse_args(cmd) |
| | | kwargs = vars(args) |
| | | inference(**kwargs) |
| | | |
| | | if __name__ == "__main__": |
| | | main() |
| | | |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. |
| | | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | import argparse |
| | | import logging |
| | | import os |
| | | import sys |
| | | from typing import Union, Dict, Any |
| | | |
| | | from funasr.utils import config_argparse |
| | | from funasr.utils.cli_utils import get_commandline_args |
| | | from funasr.utils.types import str2bool |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.utils.types import float_or_none |
| | | |
| | | |
| | | def get_parser(): |
| | | parser = config_argparse.ArgumentParser( |
| | | description="Calc perplexity", |
| | | formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| | | ) |
| | | |
| | | parser.add_argument( |
| | | "--log_level", |
| | | type=lambda x: x.upper(), |
| | | default="INFO", |
| | | choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), |
| | | help="The verbose level of logging", |
| | | ) |
| | | parser.add_argument("--output_dir", type=str, required=True) |
| | | parser.add_argument("--gpuid_list", type=str, required=True) |
| | | parser.add_argument( |
| | | "--ngpu", |
| | | type=int, |
| | | default=0, |
| | | help="The number of gpus. 0 indicates CPU mode", |
| | | ) |
| | | parser.add_argument("--seed", type=int, default=0, help="Random seed") |
| | | parser.add_argument("--njob", type=int, default=1, help="Random seed") |
| | | parser.add_argument( |
| | | "--dtype", |
| | | default="float32", |
| | | choices=["float16", "float32", "float64"], |
| | | help="Data type", |
| | | ) |
| | | parser.add_argument( |
| | | "--num_workers", |
| | | type=int, |
| | | default=1, |
| | | help="The number of workers used for DataLoader", |
| | | ) |
| | | parser.add_argument( |
| | | "--batch_size", |
| | | type=int, |
| | | default=1, |
| | | help="The batch size for inference", |
| | | ) |
| | | parser.add_argument( |
| | | "--log_base", |
| | | type=float_or_none, |
| | | default=10, |
| | | help="The base of logarithm for Perplexity. " |
| | | "If None, napier's constant is used.", |
| | | required=False |
| | | ) |
| | | |
| | | group = parser.add_argument_group("Input data related") |
| | | group.add_argument( |
| | | "--data_path_and_name_and_type", |
| | | type=str2triple_str, |
| | | action="append", |
| | | required=False |
| | | ) |
| | | group.add_argument( |
| | | "--raw_inputs", |
| | | type=str, |
| | | required=False |
| | | ) |
| | | group.add_argument("--key_file", type=str_or_none) |
| | | group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) |
| | | |
| | | group.add_argument("--split_with_space", type=str2bool, default=False) |
| | | group.add_argument("--seg_dict_file", type=str_or_none) |
| | | |
| | | group = parser.add_argument_group("The model configuration related") |
| | | group.add_argument("--train_config", type=str) |
| | | group.add_argument("--model_file", type=str) |
| | | group.add_argument("--mode", type=str, default="lm") |
| | | return parser |
| | | |
| | | def inference_launch(mode, **kwargs): |
| | | if mode == "transformer": |
| | | from funasr.bin.lm_inference import inference_modelscope |
| | | return inference_modelscope(**kwargs) |
| | | else: |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| | | |
| | | |
| | | def main(cmd=None): |
| | | print(get_commandline_args(), file=sys.stderr) |
| | | parser = get_parser() |
| | | args = parser.parse_args(cmd) |
| | | kwargs = vars(args) |
| | | kwargs.pop("config", None) |
| | | |
| | | # set logging messages |
| | | logging.basicConfig( |
| | | level=args.log_level, |
| | | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | logging.info("Decoding args: {}".format(kwargs)) |
| | | |
| | | # gpu setting |
| | | if args.ngpu > 0: |
| | | jobid = int(args.output_dir.split(".")[-1]) |
| | | gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob] |
| | | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| | | os.environ["CUDA_VISIBLE_DEVICES"] = gpuid |
| | | |
| | | kwargs.pop("gpuid_list", None) |
| | | kwargs.pop("njob", None) |
| | | results = inference_launch(**kwargs) |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | main() |
| | | |
| | |
| | | #!/usr/bin/env python3 |
| | | |
| | | import os |
| | | |
| | | from funasr.tasks.lm import LMTask |
| | | |
| | | |
| | | def get_parser(): |
| | | # for LM Training |
| | | def parse_args(): |
| | | parser = LMTask.get_parser() |
| | | return parser |
| | | parser.add_argument( |
| | | "--gpu_id", |
| | | type=int, |
| | | default=0, |
| | | help="local gpu id.", |
| | | ) |
| | | args = parser.parse_args() |
| | | return args |
| | | |
| | | |
| | | def main(cmd=None): |
| | | """LM training. |
| | | |
| | | Example: |
| | | |
| | | % python lm_train.py asr --print_config --optim adadelta |
| | | % python lm_train.py --config conf/train_asr.yaml |
| | | """ |
| | | LMTask.main(cmd=cmd) |
| | | def main(args=None, cmd=None): |
| | | # for LM Training |
| | | LMTask.main(args=args, cmd=cmd) |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | main() |
| | | if __name__ == '__main__': |
| | | args = parse_args() |
| | | |
| | | # setup local gpu_id |
| | | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) |
| | | |
| | | # DDP settings |
| | | if args.ngpu > 1: |
| | | args.distributed = True |
| | | else: |
| | | args.distributed = False |
| | | assert args.num_worker_count == 1 |
| | | |
| | | # re-compute batch size: when dataset type is small |
| | | if args.dataset_type == "small" and args.ngpu != 0: |
| | | if args.batch_size is not None: |
| | | args.batch_size = args.batch_size * args.ngpu |
| | | if args.batch_bins is not None: |
| | | args.batch_bins = args.batch_bins * args.ngpu |
| | | |
| | | main(args=args) |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | import argparse |
| | | from collections import Counter |
| | | import logging |
| | | from pathlib import Path |
| | | import sys |
| | | from typing import List |
| | | from typing import Optional |
| | | |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.utils.cli_utils import get_commandline_args |
| | | from funasr.text.build_tokenizer import build_tokenizer |
| | | from funasr.text.cleaner import TextCleaner |
| | | from funasr.text.phoneme_tokenizer import g2p_choices |
| | | from funasr.utils.types import str2bool |
| | | from funasr.utils.types import str_or_none |
| | | |
| | | |
| | | def field2slice(field: Optional[str]) -> slice: |
| | | """Convert field string to slice |
| | | |
| | | Note that field string accepts 1-based integer. |
| | | |
| | | Examples: |
| | | >>> field2slice("1-") |
| | | slice(0, None, None) |
| | | >>> field2slice("1-3") |
| | | slice(0, 3, None) |
| | | >>> field2slice("-3") |
| | | slice(None, 3, None) |
| | | """ |
| | | field = field.strip() |
| | | try: |
| | | if "-" in field: |
| | | # e.g. "2-" or "2-5" or "-7" |
| | | s1, s2 = field.split("-", maxsplit=1) |
| | | if s1.strip() == "": |
| | | s1 = None |
| | | else: |
| | | s1 = int(s1) |
| | | if s1 == 0: |
| | | raise ValueError("1-based string") |
| | | if s2.strip() == "": |
| | | s2 = None |
| | | else: |
| | | s2 = int(s2) |
| | | else: |
| | | # e.g. "2" |
| | | s1 = int(field) |
| | | s2 = s1 + 1 |
| | | if s1 == 0: |
| | | raise ValueError("must be 1 or more value") |
| | | except ValueError: |
| | | raise RuntimeError(f"Format error: e.g. '2-', '2-5', or '-5': {field}") |
| | | |
| | | if s1 is None: |
| | | slic = slice(None, s2) |
| | | else: |
| | | # -1 because of 1-based integer following "cut" command |
| | | # e.g "1-3" -> slice(0, 3) |
| | | slic = slice(s1 - 1, s2) |
| | | return slic |
| | | |
| | | |
| | | def tokenize( |
| | | input: str, |
| | | output: str, |
| | | field: Optional[str], |
| | | delimiter: Optional[str], |
| | | token_type: str, |
| | | space_symbol: str, |
| | | non_linguistic_symbols: Optional[str], |
| | | bpemodel: Optional[str], |
| | | log_level: str, |
| | | write_vocabulary: bool, |
| | | vocabulary_size: int, |
| | | remove_non_linguistic_symbols: bool, |
| | | cutoff: int, |
| | | add_symbol: List[str], |
| | | cleaner: Optional[str], |
| | | g2p: Optional[str], |
| | | ): |
| | | assert check_argument_types() |
| | | |
| | | logging.basicConfig( |
| | | level=log_level, |
| | | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | if input == "-": |
| | | fin = sys.stdin |
| | | else: |
| | | fin = Path(input).open("r", encoding="utf-8") |
| | | if output == "-": |
| | | fout = sys.stdout |
| | | else: |
| | | p = Path(output) |
| | | p.parent.mkdir(parents=True, exist_ok=True) |
| | | fout = p.open("w", encoding="utf-8") |
| | | |
| | | cleaner = TextCleaner(cleaner) |
| | | tokenizer = build_tokenizer( |
| | | token_type=token_type, |
| | | bpemodel=bpemodel, |
| | | delimiter=delimiter, |
| | | space_symbol=space_symbol, |
| | | non_linguistic_symbols=non_linguistic_symbols, |
| | | remove_non_linguistic_symbols=remove_non_linguistic_symbols, |
| | | g2p_type=g2p, |
| | | ) |
| | | |
| | | counter = Counter() |
| | | if field is not None: |
| | | field = field2slice(field) |
| | | |
| | | for line in fin: |
| | | line = line.rstrip() |
| | | if field is not None: |
| | | # e.g. field="2-" |
| | | # uttidA hello world!! -> hello world!! |
| | | tokens = line.split(delimiter) |
| | | tokens = tokens[field] |
| | | if delimiter is None: |
| | | line = " ".join(tokens) |
| | | else: |
| | | line = delimiter.join(tokens) |
| | | |
| | | line = cleaner(line) |
| | | tokens = tokenizer.text2tokens(line) |
| | | if not write_vocabulary: |
| | | fout.write(" ".join(tokens) + "\n") |
| | | else: |
| | | for t in tokens: |
| | | counter[t] += 1 |
| | | |
| | | if not write_vocabulary: |
| | | return |
| | | |
| | | ## FIXME |
| | | ## del duplicate add_symbols in counter |
| | | for symbol_and_id in add_symbol: |
| | | # e.g symbol="<blank>:0" |
| | | try: |
| | | symbol, idx = symbol_and_id.split(":") |
| | | except ValueError: |
| | | raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}") |
| | | symbol = symbol.strip() |
| | | if symbol in counter: |
| | | del counter[symbol] |
| | | |
| | | # ======= write_vocabulary mode from here ======= |
| | | # Sort by the number of occurrences in descending order |
| | | # and filter lower frequency words than cutoff value |
| | | words_and_counts = list( |
| | | filter(lambda x: x[1] > cutoff, sorted(counter.items(), key=lambda x: -x[1])) |
| | | ) |
| | | # Restrict the vocabulary size |
| | | if vocabulary_size > 0: |
| | | if vocabulary_size < len(add_symbol): |
| | | raise RuntimeError(f"vocabulary_size is too small: {vocabulary_size}") |
| | | words_and_counts = words_and_counts[: vocabulary_size - len(add_symbol)] |
| | | |
| | | # Parse the values of --add_symbol |
| | | for symbol_and_id in add_symbol: |
| | | # e.g symbol="<blank>:0" |
| | | try: |
| | | symbol, idx = symbol_and_id.split(":") |
| | | idx = int(idx) |
| | | except ValueError: |
| | | raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}") |
| | | symbol = symbol.strip() |
| | | |
| | | # e.g. idx=0 -> append as the first symbol |
| | | # e.g. idx=-1 -> append as the last symbol |
| | | if idx < 0: |
| | | idx = len(words_and_counts) + 1 + idx |
| | | words_and_counts.insert(idx, (symbol, None)) |
| | | |
| | | # Write words |
| | | for w, c in words_and_counts: |
| | | fout.write(w + "\n") |
| | | |
| | | # Logging |
| | | total_count = sum(counter.values()) |
| | | invocab_count = sum(c for w, c in words_and_counts if c is not None) |
| | | logging.info(f"OOV rate = {(total_count - invocab_count) / total_count * 100} %") |
| | | |
| | | |
| | | def get_parser() -> argparse.ArgumentParser: |
| | | parser = argparse.ArgumentParser( |
| | | description="Tokenize texts", |
| | | formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| | | ) |
| | | parser.add_argument( |
| | | "--log_level", |
| | | type=lambda x: x.upper(), |
| | | default="INFO", |
| | | choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), |
| | | help="The verbose level of logging", |
| | | ) |
| | | |
| | | parser.add_argument( |
| | | "--input", "-i", required=True, help="Input text. - indicates sys.stdin" |
| | | ) |
| | | parser.add_argument( |
| | | "--output", "-o", required=True, help="Output text. - indicates sys.stdout" |
| | | ) |
| | | parser.add_argument( |
| | | "--field", |
| | | "-f", |
| | | help="The target columns of the input text as 1-based integer. e.g 2-", |
| | | ) |
| | | parser.add_argument( |
| | | "--token_type", |
| | | "-t", |
| | | default="char", |
| | | choices=["char", "bpe", "word", "phn"], |
| | | help="Token type", |
| | | ) |
| | | parser.add_argument("--delimiter", "-d", default=None, help="The delimiter") |
| | | parser.add_argument("--space_symbol", default="<space>", help="The space symbol") |
| | | parser.add_argument("--bpemodel", default=None, help="The bpemodel file path") |
| | | parser.add_argument( |
| | | "--non_linguistic_symbols", |
| | | type=str_or_none, |
| | | help="non_linguistic_symbols file path", |
| | | ) |
| | | parser.add_argument( |
| | | "--remove_non_linguistic_symbols", |
| | | type=str2bool, |
| | | default=False, |
| | | help="Remove non-language-symbols from tokens", |
| | | ) |
| | | parser.add_argument( |
| | | "--cleaner", |
| | | type=str_or_none, |
| | | choices=[None, "tacotron", "jaconv", "vietnamese", "korean_cleaner"], |
| | | default=None, |
| | | help="Apply text cleaning", |
| | | ) |
| | | parser.add_argument( |
| | | "--g2p", |
| | | type=str_or_none, |
| | | choices=g2p_choices, |
| | | default=None, |
| | | help="Specify g2p method if --token_type=phn", |
| | | ) |
| | | |
| | | group = parser.add_argument_group("write_vocabulary mode related") |
| | | group.add_argument( |
| | | "--write_vocabulary", |
| | | type=str2bool, |
| | | default=False, |
| | | help="Write tokens list instead of tokenized text per line", |
| | | ) |
| | | group.add_argument("--vocabulary_size", type=int, default=0, help="Vocabulary size") |
| | | group.add_argument( |
| | | "--cutoff", |
| | | default=0, |
| | | type=int, |
| | | help="cut-off frequency used for write-vocabulary mode", |
| | | ) |
| | | group.add_argument( |
| | | "--add_symbol", |
| | | type=str, |
| | | default=[], |
| | | action="append", |
| | | help="Append symbol e.g. --add_symbol '<blank>:0' --add_symbol '<unk>:1'", |
| | | ) |
| | | |
| | | return parser |
| | | |
| | | |
| | | def main(cmd=None): |
| | | print(get_commandline_args(), file=sys.stderr) |
| | | parser = get_parser() |
| | | args = parser.parse_args(cmd) |
| | | kwargs = vars(args) |
| | | tokenize(**kwargs) |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | main() |
| | |
| | | continue |
| | | return out_txt.strip().split() |
| | | |
| | | def seg_tokenize_wo_pattern(txt, seg_dict): |
| | | out_txt = "" |
| | | for word in txt: |
| | | if word in seg_dict: |
| | | out_txt += seg_dict[word] + " " |
| | | else: |
| | | out_txt += "<unk>" + " " |
| | | return out_txt.strip().split() |
| | | |
| | | |
| | | def framing( |
| | | x, |
| | |
| | | data = self._text_process(data) |
| | | return data |
| | | |
| | | ## FIXME |
| | | class LMPreprocessor(CommonPreprocessor): |
| | | def __init__( |
| | | self, |
| | | train: bool, |
| | | token_type: str = None, |
| | | token_list: Union[Path, str, Iterable[str]] = None, |
| | | bpemodel: Union[Path, str, Iterable[str]] = None, |
| | | text_cleaner: Collection[str] = None, |
| | | g2p_type: str = None, |
| | | unk_symbol: str = "<unk>", |
| | | space_symbol: str = "<space>", |
| | | non_linguistic_symbols: Union[Path, str, Iterable[str]] = None, |
| | | delimiter: str = None, |
| | | rir_scp: str = None, |
| | | rir_apply_prob: float = 1.0, |
| | | noise_scp: str = None, |
| | | noise_apply_prob: float = 1.0, |
| | | noise_db_range: str = "3_10", |
| | | speech_volume_normalize: float = None, |
| | | speech_name: str = "speech", |
| | | text_name: str = "text", |
| | | split_with_space: bool = False, |
| | | seg_dict_file: str = None, |
| | | ): |
| | | super().__init__(train, |
| | | token_type, |
| | | token_list, |
| | | bpemodel, |
| | | text_cleaner, |
| | | g2p_type, |
| | | unk_symbol, |
| | | space_symbol, |
| | | non_linguistic_symbols, |
| | | delimiter, |
| | | rir_scp, |
| | | rir_apply_prob, |
| | | noise_scp, |
| | | noise_apply_prob, |
| | | noise_db_range, |
| | | speech_volume_normalize, |
| | | speech_name, |
| | | text_name, |
| | | split_with_space, |
| | | seg_dict_file, |
| | | ) |
| | | |
| | | def _text_process( |
| | | self, data: Dict[str, Union[str, np.ndarray]] |
| | | ) -> Dict[str, np.ndarray]: |
| | | if self.text_name in data and self.tokenizer is not None: |
| | | text = data[self.text_name] |
| | | text = self.text_cleaner(text) |
| | | if self.split_with_space: |
| | | tokens = text.strip().split(" ") |
| | | if self.seg_dict is not None: |
| | | tokens = seg_tokenize_wo_pattern(tokens, self.seg_dict) |
| | | else: |
| | | tokens = self.tokenizer.text2tokens(text) |
| | | text_ints = self.token_id_converter.tokens2ids(tokens) |
| | | data[self.text_name] = np.array(text_ints, dtype=np.int64) |
| | | assert check_return_type(data) |
| | | return data |
| | | |
| | | |
| | | class CommonPreprocessor_multi(AbsPreprocessor): |
| | | def __init__( |
| | |
| | | |
| | | # 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>' |
| | | # text: (Batch, Length) -> x, y: (Batch, Length + 1) |
| | | x = F.pad(text, [1, 0], "constant", self.eos) |
| | | x = F.pad(text, [1, 0], "constant", self.sos) |
| | | t = F.pad(text, [0, 1], "constant", self.ignore_id) |
| | | for i, l in enumerate(text_lengths): |
| | | t[i, l] = self.sos |
| | | t[i, l] = self.eos |
| | | x_lengths = text_lengths + 1 |
| | | |
| | | # 2. Forward Language model |
| | |
| | | from funasr.iterators.chunk_iter_factory import ChunkIterFactory |
| | | from funasr.iterators.multiple_iter_factory import MultipleIterFactory |
| | | from funasr.iterators.sequence_iter_factory import SequenceIterFactory |
| | | from funasr.main_funcs.collect_stats import collect_stats |
| | | from funasr.optimizers.sgd import SGD |
| | | from funasr.optimizers.fairseq_adam import FairseqAdam |
| | | from funasr.samplers.build_batch_sampler import BATCH_TYPES |
| | |
| | | |
| | | if args.dry_run: |
| | | pass |
| | | elif args.collect_stats: |
| | | # Perform on collect_stats mode. This mode has two roles |
| | | # - Derive the length and dimension of all input data |
| | | # - Accumulate feats, square values, and the length for whitening |
| | | |
| | | if args.valid_batch_size is None: |
| | | args.valid_batch_size = args.batch_size |
| | | |
| | | if len(args.train_shape_file) != 0: |
| | | train_key_file = args.train_shape_file[0] |
| | | else: |
| | | train_key_file = None |
| | | if len(args.valid_shape_file) != 0: |
| | | valid_key_file = args.valid_shape_file[0] |
| | | else: |
| | | valid_key_file = None |
| | | |
| | | collect_stats( |
| | | model=model, |
| | | train_iter=cls.build_streaming_iterator( |
| | | data_path_and_name_and_type=args.train_data_path_and_name_and_type, |
| | | key_file=train_key_file, |
| | | batch_size=args.batch_size, |
| | | dtype=args.train_dtype, |
| | | num_workers=args.num_workers, |
| | | allow_variable_data_keys=args.allow_variable_data_keys, |
| | | ngpu=args.ngpu, |
| | | preprocess_fn=cls.build_preprocess_fn(args, train=False), |
| | | collate_fn=cls.build_collate_fn(args, train=False), |
| | | ), |
| | | valid_iter=cls.build_streaming_iterator( |
| | | data_path_and_name_and_type=args.valid_data_path_and_name_and_type, |
| | | key_file=valid_key_file, |
| | | batch_size=args.valid_batch_size, |
| | | dtype=args.train_dtype, |
| | | num_workers=args.num_workers, |
| | | allow_variable_data_keys=args.allow_variable_data_keys, |
| | | ngpu=args.ngpu, |
| | | preprocess_fn=cls.build_preprocess_fn(args, train=False), |
| | | collate_fn=cls.build_collate_fn(args, train=False), |
| | | ), |
| | | output_dir=output_dir, |
| | | ngpu=args.ngpu, |
| | | log_interval=args.log_interval, |
| | | write_collected_feats=args.write_collected_feats, |
| | | ) |
| | | else: |
| | | logging.info("Training args: {}".format(args)) |
| | | # 6. Loads pre-trained model |
| | |
| | | # NOTE(kamo): add_arguments(..., required=True) can't be used |
| | | # to provide --print_config mode. Instead of it, do as |
| | | required = parser.get_default("required") |
| | | required += ["token_list"] |
| | | # required += ["token_list"] |
| | | |
| | | group.add_argument( |
| | | "--token_list", |