From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/bin/lm_inference_launch.py |  290 +++++++++++++++++++++++++++++++++++++++++++++++++++++++--
 1 files changed, 276 insertions(+), 14 deletions(-)

diff --git a/funasr/bin/lm_inference_launch.py b/funasr/bin/lm_inference_launch.py
index dc6414f..f12f50a 100644
--- a/funasr/bin/lm_inference_launch.py
+++ b/funasr/bin/lm_inference_launch.py
@@ -1,19 +1,289 @@
 #!/usr/bin/env python3
-
-
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
 
 import argparse
 import logging
 import os
 import sys
-from typing import Union, Dict, Any
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import Union
 
+import numpy as np
+import torch
+from torch.nn.parallel import data_parallel
+
+from funasr.build_utils.build_model_from_file import build_model_from_file
+from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
+from funasr.datasets.preprocessor import LMPreprocessor
+from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.forward_adaptor import ForwardAdaptor
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
 from funasr.utils import config_argparse
 from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import float_or_none
 from funasr.utils.types import str2bool
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_none
-from funasr.utils.types import float_or_none
+
+
+def inference_lm(
+        batch_size: int,
+        dtype: str,
+        ngpu: int,
+        seed: int,
+        num_workers: int,
+        log_level: Union[int, str],
+        key_file: Optional[str],
+        train_config: Optional[str],
+        model_file: Optional[str],
+        log_base: Optional[float] = 10,
+        allow_variable_data_keys: bool = False,
+        split_with_space: Optional[bool] = False,
+        seg_dict_file: Optional[str] = None,
+        output_dir: Optional[str] = None,
+        param_dict: dict = None,
+        **kwargs,
+):
+    ncpu = kwargs.get("ncpu", 1)
+    torch.set_num_threads(ncpu)
+
+    if ngpu >= 1 and torch.cuda.is_available():
+        device = "cuda"
+    else:
+        device = "cpu"
+
+    # 1. Set random-seed
+    set_all_random_seed(seed)
+
+    # 2. Build Model
+    model, train_args = build_model_from_file(
+        train_config, model_file, None, device, "lm")
+    wrapped_model = ForwardAdaptor(model, "nll")
+    wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
+    logging.info(f"Model:\n{model}")
+
+    preprocessor = LMPreprocessor(
+        train=False,
+        token_type=train_args.token_type,
+        token_list=train_args.token_list,
+        bpemodel=train_args.bpemodel,
+        text_cleaner=train_args.cleaner,
+        g2p_type=train_args.g2p,
+        text_name="text",
+        non_linguistic_symbols=train_args.non_linguistic_symbols,
+        split_with_space=split_with_space,
+        seg_dict_file=seg_dict_file
+    )
+
+    def _forward(
+            data_path_and_name_and_type,
+            raw_inputs: Union[List[Any], bytes, str] = None,
+            output_dir_v2: Optional[str] = None,
+            param_dict: dict = None,
+    ):
+        results = []
+        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+        if output_path is not None:
+            writer = DatadirWriter(output_path)
+        else:
+            writer = None
+
+        if raw_inputs != None:
+            line = raw_inputs.strip()
+            key = "lm demo"
+            if line == "":
+                item = {'key': key, 'value': ""}
+                results.append(item)
+                return results
+            batch = {}
+            batch['text'] = line
+            if preprocessor != None:
+                batch = preprocessor(key, batch)
+
+            #  Force data-precision
+            for name in batch:
+                value = batch[name]
+                if not isinstance(value, np.ndarray):
+                    raise RuntimeError(
+                        f"All values must be converted to np.ndarray object "
+                        f'by preprocessing, but "{name}" is still {type(value)}.'
+                    )
+                # Cast to desired type
+                if value.dtype.kind == "f":
+                    value = value.astype("float32")
+                elif value.dtype.kind == "i":
+                    value = value.astype("long")
+                else:
+                    raise NotImplementedError(f"Not supported dtype: {value.dtype}")
+                batch[name] = value
+
+            batch["text_lengths"] = torch.from_numpy(
+                np.array([len(batch["text"])], dtype='int32'))
+            batch["text"] = np.expand_dims(batch["text"], axis=0)
+
+            with torch.no_grad():
+                batch = to_device(batch, device)
+                if ngpu <= 1:
+                    nll, lengths = wrapped_model(**batch)
+                else:
+                    nll, lengths = data_parallel(
+                        wrapped_model, (), range(ngpu), module_kwargs=batch
+                    )
+                ## compute ppl
+                ppl_out_batch = ""
+                ids2tokens = preprocessor.token_id_converter.ids2tokens
+                for sent_ids, sent_nll in zip(batch['text'], nll):
+                    pre_word = "<s>"
+                    cur_word = None
+                    sent_lst = ids2tokens(sent_ids) + ['</s>']
+                    ppl_out = " ".join(sent_lst) + "\n"
+                    for word, word_nll in zip(sent_lst, sent_nll):
+                        cur_word = word
+                        word_nll = -word_nll.cpu()
+                        if log_base is None:
+                            word_prob = np.exp(word_nll)
+                        else:
+                            word_prob = log_base ** (word_nll / np.log(log_base))
+                        ppl_out += '    p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
+                            cur=cur_word,
+                            pre=pre_word,
+                            prob=round(word_prob.item(), 8),
+                            word_nll=round(word_nll.item(), 8)
+                        )
+                        pre_word = cur_word
+
+                    sent_nll_mean = sent_nll.mean().cpu().numpy()
+                    sent_nll_sum = sent_nll.sum().cpu().numpy()
+                    if log_base is None:
+                        sent_ppl = np.exp(sent_nll_mean)
+                    else:
+                        sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
+                    ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
+                        sent_nll=round(-sent_nll_sum.item(), 4),
+                        sent_ppl=round(sent_ppl.item(), 4)
+                    )
+                    ppl_out_batch += ppl_out
+                    item = {'key': key, 'value': ppl_out}
+                    if writer is not None:
+                        writer["ppl"][key + ":\n"] = ppl_out
+                    results.append(item)
+
+            return results
+
+        # 3. Build data-iterator
+        loader = build_streaming_iterator(
+            task_name="lm",
+            preprocess_args=train_args,
+            data_path_and_name_and_type=data_path_and_name_and_type,
+            dtype=dtype,
+            batch_size=batch_size,
+            key_file=key_file,
+            preprocess_fn=preprocessor,
+            num_workers=num_workers,
+        )
+
+        # 4. Start for-loop
+        total_nll = 0.0
+        total_ntokens = 0
+        ppl_out_all = ""
+        for keys, batch in loader:
+            assert isinstance(batch, dict), type(batch)
+            assert all(isinstance(s, str) for s in keys), keys
+            _bs = len(next(iter(batch.values())))
+            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+
+            ppl_out_batch = ""
+            with torch.no_grad():
+                batch = to_device(batch, device)
+                if ngpu <= 1:
+                    # NOTE(kamo): data_parallel also should work with ngpu=1,
+                    # but for debuggability it's better to keep this block.
+                    nll, lengths = wrapped_model(**batch)
+                else:
+                    nll, lengths = data_parallel(
+                        wrapped_model, (), range(ngpu), module_kwargs=batch
+                    )
+                ## print ppl
+                ids2tokens = preprocessor.token_id_converter.ids2tokens
+                for key, sent_ids, sent_nll in zip(keys, batch['text'], nll):
+                    pre_word = "<s>"
+                    cur_word = None
+                    sent_lst = ids2tokens(sent_ids) + ['</s>']
+                    ppl_out = " ".join(sent_lst) + "\n"
+                    for word, word_nll in zip(sent_lst, sent_nll):
+                        cur_word = word
+                        word_nll = -word_nll.cpu()
+                        if log_base is None:
+                            word_prob = np.exp(word_nll)
+                        else:
+                            word_prob = log_base ** (word_nll / np.log(log_base))
+                        ppl_out += '    p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
+                            cur=cur_word,
+                            pre=pre_word,
+                            prob=round(word_prob.item(), 8),
+                            word_nll=round(word_nll.item(), 8)
+                        )
+                        pre_word = cur_word
+
+                    sent_nll_mean = sent_nll.mean().cpu().numpy()
+                    sent_nll_sum = sent_nll.sum().cpu().numpy()
+                    if log_base is None:
+                        sent_ppl = np.exp(sent_nll_mean)
+                    else:
+                        sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
+                    ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
+                        sent_nll=round(-sent_nll_sum.item(), 4),
+                        sent_ppl=round(sent_ppl.item(), 4)
+                    )
+                    ppl_out_batch += ppl_out
+                    utt2nll = round(-sent_nll_sum.item(), 5)
+                    item = {'key': key, 'value': ppl_out}
+                    if writer is not None:
+                        writer["ppl"][key + ":\n"] = ppl_out
+                        writer["utt2nll"][key] = str(utt2nll)
+                    results.append(item)
+
+            ppl_out_all += ppl_out_batch
+
+            assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
+            # nll: (B, L) -> (B,)
+            nll = nll.detach().cpu().numpy().sum(1)
+            # lengths: (B,)
+            lengths = lengths.detach().cpu().numpy()
+            total_nll += nll.sum()
+            total_ntokens += lengths.sum()
+
+        if log_base is None:
+            ppl = np.exp(total_nll / total_ntokens)
+        else:
+            ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
+
+        avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format(
+            total_nll=round(-total_nll.item(), 4),
+            total_ppl=round(ppl.item(), 4)
+        )
+        item = {'key': 'AVG PPL', 'value': avg_ppl}
+        ppl_out_all += avg_ppl
+        if writer is not None:
+            writer["ppl"]["AVG PPL : "] = avg_ppl
+        results.append(item)
+
+        return results
+
+    return _forward
+
+
+def inference_launch(mode, **kwargs):
+    if mode == "transformer":
+        return inference_lm(**kwargs)
+    else:
+        logging.info("Unknown decoding mode: {}".format(mode))
+        return None
 
 
 def get_parser():
@@ -90,14 +360,6 @@
     group.add_argument("--mode", type=str, default="lm")
     return parser
 
-def inference_launch(mode, **kwargs):
-    if mode == "transformer":
-        from funasr.bin.lm_inference import inference_modelscope
-        return inference_modelscope(**kwargs)
-    else:
-        logging.info("Unknown decoding mode: {}".format(mode))
-        return None
-
 
 def main(cmd=None):
     print(get_commandline_args(), file=sys.stderr)
@@ -122,9 +384,9 @@
 
     kwargs.pop("gpuid_list", None)
     kwargs.pop("njob", None)
-    results = inference_launch(**kwargs)
+    inference_pipeline = inference_launch(**kwargs)
+    return inference_pipeline(kwargs["data_path_and_name_and_type"])
 
 
 if __name__ == "__main__":
     main()
-

--
Gitblit v1.9.1