游雁
2023-11-23 dc682db808eb5f425f0dbed4c5e7feb0a334955f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
 
import logging
from pathlib import Path
from typing import Union
 
import numpy as np
import torch
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tokenizer.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
 
 
class Speech2Timestamp:
    def __init__(
            self,
            timestamp_infer_config: Union[Path, str] = None,
            timestamp_model_file: Union[Path, str] = None,
            timestamp_cmvn_file: Union[Path, str] = None,
            device: str = "cpu",
            dtype: str = "float32",
            **kwargs,
    ):
        # 1. Build ASR model
        tp_model, tp_train_args = build_model_from_file(
            timestamp_infer_config, timestamp_model_file, cmvn_file=None, device=device, task_name="asr", mode="tp"
        )
        if 'cuda' in device:
            tp_model = tp_model.cuda()  # force model to cuda
 
        frontend = None
        if tp_train_args.frontend is not None:
            frontend = WavFrontend(cmvn_file=timestamp_cmvn_file, **tp_train_args.frontend_conf)
 
        logging.info("tp_model: {}".format(tp_model))
        logging.info("tp_train_args: {}".format(tp_train_args))
        tp_model.to(dtype=getattr(torch, dtype)).eval()
 
        logging.info(f"Decoding device={device}, dtype={dtype}")
 
        self.tp_model = tp_model
        self.tp_train_args = tp_train_args
 
        token_list = self.tp_model.token_list
        self.converter = TokenIDConverter(token_list=token_list)
 
        self.device = device
        self.dtype = dtype
        self.frontend = frontend
        self.encoder_downsampling_factor = 1
        if tp_train_args.encoder_conf["input_layer"] == "conv2d":
            self.encoder_downsampling_factor = 4
 
    @torch.no_grad()
    def __call__(
            self,
            speech: Union[torch.Tensor, np.ndarray],
            speech_lengths: Union[torch.Tensor, np.ndarray] = None,
            text_lengths: Union[torch.Tensor, np.ndarray] = None
    ):
 
        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        if self.frontend is not None:
            feats, feats_len = self.frontend.forward(speech, speech_lengths)
            feats = to_device(feats, device=self.device)
            feats_len = feats_len.int()
            self.tp_model.frontend = None
        else:
            feats = speech
            feats_len = speech_lengths
 
        # lfr_factor = max(1, (feats.size()[-1]//80)-1)
        batch = {"speech": feats, "speech_lengths": feats_len}
 
        # a. To device
        batch = to_device(batch, device=self.device)
 
        # b. Forward Encoder
        enc, enc_len = self.tp_model.encode(**batch)
        if isinstance(enc, tuple):
            enc = enc[0]
 
        # c. Forward Predictor
        _, _, us_alphas, us_peaks = self.tp_model.calc_predictor_timestamp(enc, enc_len,
                                                                           text_lengths.to(self.device) + 1)
        return us_alphas, us_peaks