aky15
2023-04-12 7d1efe158eda74dc847c397db906f6cb77ac0f84
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""Stateless decoder definition for Transducer models."""
 
from typing import List, Optional, Tuple
 
import torch
from typeguard import check_argument_types
 
from funasr.modules.beam_search.beam_search_transducer import Hypothesis
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
from funasr.models.specaug.specaug import SpecAug
 
class StatelessDecoder(AbsDecoder):
    """Stateless Transducer decoder module.
 
    Args:
        vocab_size: Output size.
        embed_size: Embedding size.
        embed_dropout_rate: Dropout rate for embedding layer.
        embed_pad: Embed/Blank symbol ID.
 
    """
 
    def __init__(
        self,
        vocab_size: int,
        embed_size: int = 256,
        embed_dropout_rate: float = 0.0,
        embed_pad: int = 0,
    ) -> None:
        """Construct a StatelessDecoder object."""
        super().__init__()
 
        assert check_argument_types()
 
        self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad)
        self.embed_dropout_rate = torch.nn.Dropout(p=embed_dropout_rate)
 
        self.output_size = embed_size
        self.vocab_size = vocab_size
 
        self.device = next(self.parameters()).device
        self.score_cache = {}
 
 
 
    def forward(
        self,
        labels: torch.Tensor,
        label_lens: torch.Tensor,
        states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
    ) -> torch.Tensor:
        """Encode source label sequences.
 
        Args:
            labels: Label ID sequences. (B, L)
            states: Decoder hidden states. None
 
        Returns:
            dec_embed: Decoder output sequences. (B, U, D_emb)
 
        """
        dec_embed = self.embed_dropout_rate(self.embed(labels))
        return dec_embed
 
    def score(
        self,
        label: torch.Tensor,
        label_sequence: List[int],
        state: None,
    ) -> Tuple[torch.Tensor, None]:
        """One-step forward hypothesis.
 
        Args:
            label: Previous label. (1, 1)
            label_sequence: Current label sequence.
            state: Previous decoder hidden states. None
 
        Returns:
            dec_out: Decoder output sequence. (1, D_emb)
            state: Decoder hidden states. None
 
        """
        str_labels = "_".join(map(str, label_sequence))
 
        if str_labels in self.score_cache:
            dec_embed = self.score_cache[str_labels]
        else:
            dec_embed = self.embed(label)
 
            self.score_cache[str_labels] = dec_embed
 
        return dec_embed[0], None
 
    def batch_score(
        self,
        hyps: List[Hypothesis],
    ) -> Tuple[torch.Tensor, None]:
        """One-step forward hypotheses.
 
        Args:
            hyps: Hypotheses.
 
        Returns:
            dec_out: Decoder output sequences. (B, D_dec)
            states: Decoder hidden states. None
 
        """
        labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device)
        dec_embed = self.embed(labels)
 
        return dec_embed.squeeze(1), None
 
    def set_device(self, device: torch.device) -> None:
        """Set GPU device to use.
 
        Args:
            device: Device ID.
 
        """
        self.device = device
 
    def init_state(self, batch_size: int) -> None:
        """Initialize decoder states.
 
        Args:
            batch_size: Batch size.
 
        Returns:
            : Initial decoder hidden states. None
 
        """
        return None
 
    def select_state(self, states: Optional[torch.Tensor], idx: int) -> None:
        """Get specified ID state from decoder hidden states.
 
        Args:
            states: Decoder hidden states. None
            idx: State ID to extract.
 
        Returns:
            : Decoder hidden state for given ID. None
 
        """
        return None