From 0c4fbea66b7c4eddeec5734d4ff43ad85e32d5fa Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 15 六月 2023 15:39:22 +0800
Subject: [PATCH] update repo

---
 funasr/bin/lm_inference_launch.py              |  127 ++++++++++++++++++------------------------
 funasr/build_utils/build_streaming_iterator.py |    1 
 2 files changed, 56 insertions(+), 72 deletions(-)

diff --git a/funasr/bin/lm_inference_launch.py b/funasr/bin/lm_inference_launch.py
index 1d99fce..c8482b8 100644
--- a/funasr/bin/lm_inference_launch.py
+++ b/funasr/bin/lm_inference_launch.py
@@ -1,5 +1,5 @@
-# -*- encoding: utf-8 -*-
 #!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
 # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 #  MIT License  (https://opensource.org/licenses/MIT)
 
@@ -7,40 +7,25 @@
 import logging
 import os
 import sys
-from typing import Union, Dict, Any
-
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.utils.types import float_or_none
-import argparse
-import logging
-from pathlib import Path
-import sys
-import os
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-from typing import Dict
 from typing import Any
 from typing import List
+from typing import Optional
+from typing import Union
 
 import numpy as np
 import torch
 from torch.nn.parallel import data_parallel
 from typeguard import check_argument_types
 
-from funasr.tasks.lm import LMTask
+from funasr.build_utils.build_model_from_file import build_model_from_file
+from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
 from funasr.datasets.preprocessor import LMPreprocessor
-from funasr.utils.cli_utils import get_commandline_args
 from funasr.fileio.datadir_writer import DatadirWriter
 from funasr.torch_utils.device_funcs import to_device
 from funasr.torch_utils.forward_adaptor import ForwardAdaptor
 from funasr.torch_utils.set_all_random_seed import set_all_random_seed
 from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
 from funasr.utils.types import float_or_none
 from funasr.utils.types import str2bool
 from funasr.utils.types import str2triple_str
@@ -48,42 +33,42 @@
 
 
 def inference_lm(
-    batch_size: int,
-    dtype: str,
-    ngpu: int,
-    seed: int,
-    num_workers: int,
-    log_level: Union[int, str],
-    key_file: Optional[str],
-    train_config: Optional[str],
-    model_file: Optional[str],
-    log_base: Optional[float] = 10,
-    allow_variable_data_keys: bool = False,
-    split_with_space: Optional[bool] = False,
-    seg_dict_file: Optional[str] = None,
-    output_dir: Optional[str] = None,
-    param_dict: dict = None,
-    **kwargs,
+        batch_size: int,
+        dtype: str,
+        ngpu: int,
+        seed: int,
+        num_workers: int,
+        log_level: Union[int, str],
+        key_file: Optional[str],
+        train_config: Optional[str],
+        model_file: Optional[str],
+        log_base: Optional[float] = 10,
+        allow_variable_data_keys: bool = False,
+        split_with_space: Optional[bool] = False,
+        seg_dict_file: Optional[str] = None,
+        output_dir: Optional[str] = None,
+        param_dict: dict = None,
+        **kwargs,
 ):
     assert check_argument_types()
     ncpu = kwargs.get("ncpu", 1)
     torch.set_num_threads(ncpu)
-    
+
     if ngpu >= 1 and torch.cuda.is_available():
         device = "cuda"
     else:
         device = "cpu"
-    
+
     # 1. Set random-seed
     set_all_random_seed(seed)
-    
+
     # 2. Build Model
-    model, train_args = LMTask.build_model_from_file(
-        train_config, model_file, device)
+    model, train_args = build_model_from_file(
+        train_config, model_file, None, device, "lm")
     wrapped_model = ForwardAdaptor(model, "nll")
     wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
     logging.info(f"Model:\n{model}")
-    
+
     preprocessor = LMPreprocessor(
         train=False,
         token_type=train_args.token_type,
@@ -96,12 +81,12 @@
         split_with_space=split_with_space,
         seg_dict_file=seg_dict_file
     )
-    
+
     def _forward(
-        data_path_and_name_and_type,
-        raw_inputs: Union[List[Any], bytes, str] = None,
-        output_dir_v2: Optional[str] = None,
-        param_dict: dict = None,
+            data_path_and_name_and_type,
+            raw_inputs: Union[List[Any], bytes, str] = None,
+            output_dir_v2: Optional[str] = None,
+            param_dict: dict = None,
     ):
         results = []
         output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
@@ -109,7 +94,7 @@
             writer = DatadirWriter(output_path)
         else:
             writer = None
-        
+
         if raw_inputs != None:
             line = raw_inputs.strip()
             key = "lm demo"
@@ -121,7 +106,7 @@
             batch['text'] = line
             if preprocessor != None:
                 batch = preprocessor(key, batch)
-            
+
             #  Force data-precision
             for name in batch:
                 value = batch[name]
@@ -138,11 +123,11 @@
                 else:
                     raise NotImplementedError(f"Not supported dtype: {value.dtype}")
                 batch[name] = value
-            
+
             batch["text_lengths"] = torch.from_numpy(
                 np.array([len(batch["text"])], dtype='int32'))
             batch["text"] = np.expand_dims(batch["text"], axis=0)
-            
+
             with torch.no_grad():
                 batch = to_device(batch, device)
                 if ngpu <= 1:
@@ -173,7 +158,7 @@
                             word_nll=round(word_nll.item(), 8)
                         )
                         pre_word = cur_word
-                    
+
                     sent_nll_mean = sent_nll.mean().cpu().numpy()
                     sent_nll_sum = sent_nll.sum().cpu().numpy()
                     if log_base is None:
@@ -189,22 +174,20 @@
                     if writer is not None:
                         writer["ppl"][key + ":\n"] = ppl_out
                     results.append(item)
-            
+
             return results
-        
+
         # 3. Build data-iterator
-        loader = LMTask.build_streaming_iterator(
-            data_path_and_name_and_type,
+        loader = build_streaming_iterator(
+            task_name="lm",
+            preprocess_args=train_args,
+            data_path_and_name_and_type=data_path_and_name_and_type,
             dtype=dtype,
             batch_size=batch_size,
             key_file=key_file,
             num_workers=num_workers,
-            preprocess_fn=preprocessor,
-            collate_fn=LMTask.build_collate_fn(train_args, False),
-            allow_variable_data_keys=allow_variable_data_keys,
-            inference=True,
         )
-        
+
         # 4. Start for-loop
         total_nll = 0.0
         total_ntokens = 0
@@ -214,7 +197,7 @@
             assert all(isinstance(s, str) for s in keys), keys
             _bs = len(next(iter(batch.values())))
             assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-            
+
             ppl_out_batch = ""
             with torch.no_grad():
                 batch = to_device(batch, device)
@@ -247,7 +230,7 @@
                             word_nll=round(word_nll.item(), 8)
                         )
                         pre_word = cur_word
-                    
+
                     sent_nll_mean = sent_nll.mean().cpu().numpy()
                     sent_nll_sum = sent_nll.sum().cpu().numpy()
                     if log_base is None:
@@ -265,9 +248,9 @@
                         writer["ppl"][key + ":\n"] = ppl_out
                         writer["utt2nll"][key] = str(utt2nll)
                     results.append(item)
-            
+
             ppl_out_all += ppl_out_batch
-            
+
             assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
             # nll: (B, L) -> (B,)
             nll = nll.detach().cpu().numpy().sum(1)
@@ -275,12 +258,12 @@
             lengths = lengths.detach().cpu().numpy()
             total_nll += nll.sum()
             total_ntokens += lengths.sum()
-        
+
         if log_base is None:
             ppl = np.exp(total_nll / total_ntokens)
         else:
             ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
-        
+
         avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format(
             total_nll=round(-total_nll.item(), 4),
             total_ppl=round(ppl.item(), 4)
@@ -290,9 +273,9 @@
         if writer is not None:
             writer["ppl"]["AVG PPL : "] = avg_ppl
         results.append(item)
-        
+
         return results
-    
+
     return _forward
 
 
@@ -302,7 +285,8 @@
     else:
         logging.info("Unknown decoding mode: {}".format(mode))
         return None
-    
+
+
 def get_parser():
     parser = config_argparse.ArgumentParser(
         description="Calc perplexity",
@@ -407,4 +391,3 @@
 
 if __name__ == "__main__":
     main()
-
diff --git a/funasr/build_utils/build_streaming_iterator.py b/funasr/build_utils/build_streaming_iterator.py
index da42929..ad36b4e 100644
--- a/funasr/build_utils/build_streaming_iterator.py
+++ b/funasr/build_utils/build_streaming_iterator.py
@@ -26,6 +26,7 @@
 
     # preprocess
     if preprocess_args is not None:
+        preprocess_args.task_name = task_name
         preprocess_fn = build_preprocess(preprocess_args, train)
     else:
         preprocess_fn = None

--
Gitblit v1.9.1