From cc2c1d1d53dea5d2c45f858d1baa5bd279f47987 Mon Sep 17 00:00:00 2001
From: nichongjia-2007 <nichongjia@gmail.com>
Date: 星期三, 31 五月 2023 14:39:25 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR
---
funasr/bin/tp_infer.py | 120 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 120 insertions(+), 0 deletions(-)
diff --git a/funasr/bin/tp_infer.py b/funasr/bin/tp_infer.py
new file mode 100644
index 0000000..4ddcba4
--- /dev/null
+++ b/funasr/bin/tp_infer.py
@@ -0,0 +1,120 @@
+# -*- encoding: utf-8 -*-
+#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import argparse
+import logging
+from optparse import Option
+import sys
+import json
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import Dict
+
+import numpy as np
+import torch
+from typeguard import check_argument_types
+
+from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.datasets.preprocessor import LMPreprocessor
+from funasr.tasks.asr import ASRTaskAligner as ASRTask
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.text.token_id_converter import TokenIDConverter
+from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
+
+
+
+
+class Speech2Timestamp:
+ def __init__(
+ self,
+ timestamp_infer_config: Union[Path, str] = None,
+ timestamp_model_file: Union[Path, str] = None,
+ timestamp_cmvn_file: Union[Path, str] = None,
+ device: str = "cpu",
+ dtype: str = "float32",
+ **kwargs,
+ ):
+ assert check_argument_types()
+ # 1. Build ASR model
+ tp_model, tp_train_args = ASRTask.build_model_from_file(
+ timestamp_infer_config, timestamp_model_file, device=device
+ )
+ if 'cuda' in device:
+ tp_model = tp_model.cuda() # force model to cuda
+
+ frontend = None
+ if tp_train_args.frontend is not None:
+ frontend = WavFrontend(cmvn_file=timestamp_cmvn_file, **tp_train_args.frontend_conf)
+
+ logging.info("tp_model: {}".format(tp_model))
+ logging.info("tp_train_args: {}".format(tp_train_args))
+ tp_model.to(dtype=getattr(torch, dtype)).eval()
+
+ logging.info(f"Decoding device={device}, dtype={dtype}")
+
+
+ self.tp_model = tp_model
+ self.tp_train_args = tp_train_args
+
+ token_list = self.tp_model.token_list
+ self.converter = TokenIDConverter(token_list=token_list)
+
+ self.device = device
+ self.dtype = dtype
+ self.frontend = frontend
+ self.encoder_downsampling_factor = 1
+ if tp_train_args.encoder_conf["input_layer"] == "conv2d":
+ self.encoder_downsampling_factor = 4
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ speech: Union[torch.Tensor, np.ndarray],
+ speech_lengths: Union[torch.Tensor, np.ndarray] = None,
+ text_lengths: Union[torch.Tensor, np.ndarray] = None
+ ):
+ assert check_argument_types()
+
+ # Input as audio signal
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+ if self.frontend is not None:
+ feats, feats_len = self.frontend.forward(speech, speech_lengths)
+ feats = to_device(feats, device=self.device)
+ feats_len = feats_len.int()
+ self.tp_model.frontend = None
+ else:
+ feats = speech
+ feats_len = speech_lengths
+
+ # lfr_factor = max(1, (feats.size()[-1]//80)-1)
+ batch = {"speech": feats, "speech_lengths": feats_len}
+
+ # a. To device
+ batch = to_device(batch, device=self.device)
+
+ # b. Forward Encoder
+ enc, enc_len = self.tp_model.encode(**batch)
+ if isinstance(enc, tuple):
+ enc = enc[0]
+
+ # c. Forward Predictor
+ _, _, us_alphas, us_peaks = self.tp_model.calc_predictor_timestamp(enc, enc_len, text_lengths.to(self.device)+1)
+ return us_alphas, us_peaks
+
+
+
--
Gitblit v1.9.1