From 0c4fbea66b7c4eddeec5734d4ff43ad85e32d5fa Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 15 六月 2023 15:39:22 +0800
Subject: [PATCH] update repo
---
funasr/bin/lm_inference_launch.py | 127 ++++++++++++++++++------------------------
funasr/build_utils/build_streaming_iterator.py | 1
2 files changed, 56 insertions(+), 72 deletions(-)
diff --git a/funasr/bin/lm_inference_launch.py b/funasr/bin/lm_inference_launch.py
index 1d99fce..c8482b8 100644
--- a/funasr/bin/lm_inference_launch.py
+++ b/funasr/bin/lm_inference_launch.py
@@ -1,5 +1,5 @@
-# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
@@ -7,40 +7,25 @@
import logging
import os
import sys
-from typing import Union, Dict, Any
-
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.utils.types import float_or_none
-import argparse
-import logging
-from pathlib import Path
-import sys
-import os
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-from typing import Dict
from typing import Any
from typing import List
+from typing import Optional
+from typing import Union
import numpy as np
import torch
from torch.nn.parallel import data_parallel
from typeguard import check_argument_types
-from funasr.tasks.lm import LMTask
+from funasr.build_utils.build_model_from_file import build_model_from_file
+from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.datasets.preprocessor import LMPreprocessor
-from funasr.utils.cli_utils import get_commandline_args
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import float_or_none
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
@@ -48,42 +33,42 @@
def inference_lm(
- batch_size: int,
- dtype: str,
- ngpu: int,
- seed: int,
- num_workers: int,
- log_level: Union[int, str],
- key_file: Optional[str],
- train_config: Optional[str],
- model_file: Optional[str],
- log_base: Optional[float] = 10,
- allow_variable_data_keys: bool = False,
- split_with_space: Optional[bool] = False,
- seg_dict_file: Optional[str] = None,
- output_dir: Optional[str] = None,
- param_dict: dict = None,
- **kwargs,
+ batch_size: int,
+ dtype: str,
+ ngpu: int,
+ seed: int,
+ num_workers: int,
+ log_level: Union[int, str],
+ key_file: Optional[str],
+ train_config: Optional[str],
+ model_file: Optional[str],
+ log_base: Optional[float] = 10,
+ allow_variable_data_keys: bool = False,
+ split_with_space: Optional[bool] = False,
+ seg_dict_file: Optional[str] = None,
+ output_dir: Optional[str] = None,
+ param_dict: dict = None,
+ **kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
-
+
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
-
+
# 1. Set random-seed
set_all_random_seed(seed)
-
+
# 2. Build Model
- model, train_args = LMTask.build_model_from_file(
- train_config, model_file, device)
+ model, train_args = build_model_from_file(
+ train_config, model_file, None, device, "lm")
wrapped_model = ForwardAdaptor(model, "nll")
wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
logging.info(f"Model:\n{model}")
-
+
preprocessor = LMPreprocessor(
train=False,
token_type=train_args.token_type,
@@ -96,12 +81,12 @@
split_with_space=split_with_space,
seg_dict_file=seg_dict_file
)
-
+
def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[List[Any], bytes, str] = None,
- output_dir_v2: Optional[str] = None,
- param_dict: dict = None,
+ data_path_and_name_and_type,
+ raw_inputs: Union[List[Any], bytes, str] = None,
+ output_dir_v2: Optional[str] = None,
+ param_dict: dict = None,
):
results = []
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
@@ -109,7 +94,7 @@
writer = DatadirWriter(output_path)
else:
writer = None
-
+
if raw_inputs != None:
line = raw_inputs.strip()
key = "lm demo"
@@ -121,7 +106,7 @@
batch['text'] = line
if preprocessor != None:
batch = preprocessor(key, batch)
-
+
# Force data-precision
for name in batch:
value = batch[name]
@@ -138,11 +123,11 @@
else:
raise NotImplementedError(f"Not supported dtype: {value.dtype}")
batch[name] = value
-
+
batch["text_lengths"] = torch.from_numpy(
np.array([len(batch["text"])], dtype='int32'))
batch["text"] = np.expand_dims(batch["text"], axis=0)
-
+
with torch.no_grad():
batch = to_device(batch, device)
if ngpu <= 1:
@@ -173,7 +158,7 @@
word_nll=round(word_nll.item(), 8)
)
pre_word = cur_word
-
+
sent_nll_mean = sent_nll.mean().cpu().numpy()
sent_nll_sum = sent_nll.sum().cpu().numpy()
if log_base is None:
@@ -189,22 +174,20 @@
if writer is not None:
writer["ppl"][key + ":\n"] = ppl_out
results.append(item)
-
+
return results
-
+
# 3. Build data-iterator
- loader = LMTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="lm",
+ preprocess_args=train_args,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=preprocessor,
- collate_fn=LMTask.build_collate_fn(train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
)
-
+
# 4. Start for-loop
total_nll = 0.0
total_ntokens = 0
@@ -214,7 +197,7 @@
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-
+
ppl_out_batch = ""
with torch.no_grad():
batch = to_device(batch, device)
@@ -247,7 +230,7 @@
word_nll=round(word_nll.item(), 8)
)
pre_word = cur_word
-
+
sent_nll_mean = sent_nll.mean().cpu().numpy()
sent_nll_sum = sent_nll.sum().cpu().numpy()
if log_base is None:
@@ -265,9 +248,9 @@
writer["ppl"][key + ":\n"] = ppl_out
writer["utt2nll"][key] = str(utt2nll)
results.append(item)
-
+
ppl_out_all += ppl_out_batch
-
+
assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
# nll: (B, L) -> (B,)
nll = nll.detach().cpu().numpy().sum(1)
@@ -275,12 +258,12 @@
lengths = lengths.detach().cpu().numpy()
total_nll += nll.sum()
total_ntokens += lengths.sum()
-
+
if log_base is None:
ppl = np.exp(total_nll / total_ntokens)
else:
ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
-
+
avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format(
total_nll=round(-total_nll.item(), 4),
total_ppl=round(ppl.item(), 4)
@@ -290,9 +273,9 @@
if writer is not None:
writer["ppl"]["AVG PPL : "] = avg_ppl
results.append(item)
-
+
return results
-
+
return _forward
@@ -302,7 +285,8 @@
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None
-
+
+
def get_parser():
parser = config_argparse.ArgumentParser(
description="Calc perplexity",
@@ -407,4 +391,3 @@
if __name__ == "__main__":
main()
-
diff --git a/funasr/build_utils/build_streaming_iterator.py b/funasr/build_utils/build_streaming_iterator.py
index da42929..ad36b4e 100644
--- a/funasr/build_utils/build_streaming_iterator.py
+++ b/funasr/build_utils/build_streaming_iterator.py
@@ -26,6 +26,7 @@
# preprocess
if preprocess_args is not None:
+ preprocess_args.task_name = task_name
preprocess_fn = build_preprocess(preprocess_args, train)
else:
preprocess_fn = None
--
Gitblit v1.9.1