speech_asr
2023-03-10 42fa661837904ebca79304fb7b937b0ecbf0c983
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
 
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import Optional
from typing import Tuple
 
import torch
from typeguard import check_argument_types
 
from funasr.layers.abs_normalize import AbsNormalize
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
 
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
else:
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):
        yield
 
 
class Data2VecPretrainModel(AbsESPnetModel):
    """Data2Vec Pretrain model"""
 
    def __init__(
            self,
            frontend: Optional[AbsFrontend],
            specaug: Optional[AbsSpecAug],
            normalize: Optional[AbsNormalize],
            preencoder: Optional[AbsPreEncoder],
            encoder: AbsEncoder,
    ):
        assert check_argument_types()
 
        super().__init__()
 
        self.frontend = frontend
        self.specaug = specaug
        self.normalize = normalize
        self.preencoder = preencoder
        self.encoder = encoder
        self.num_updates = 0
 
    def forward(
            self,
            speech: torch.Tensor,
            speech_lengths: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Calc loss
 
        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )
        """
        # Check that batch_size is unified
        assert (
                speech.shape[0]
                == speech_lengths.shape[0]
        ), (speech.shape, speech_lengths.shape)
 
        self.encoder.set_num_updates(self.num_updates)
 
        # 1. Encoder
        encoder_out = self.encode(speech, speech_lengths)
 
        losses = encoder_out["losses"]
        loss = sum(losses.values())
        sample_size = encoder_out["sample_size"]
        loss = loss.sum() / sample_size
 
        target_var = float(encoder_out["target_var"])
        pred_var = float(encoder_out["pred_var"])
        ema_decay = float(encoder_out["ema_decay"])
 
        stats = dict(
            loss=torch.clone(loss.detach()),
            target_var=target_var,
            pred_var=pred_var,
            ema_decay=ema_decay,
        )
 
        loss, stats, weight = force_gatherable((loss, stats, sample_size), loss.device)
        return loss, stats, weight
 
    def collect_feats(
            self,
            speech: torch.Tensor,
            speech_lengths: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        feats, feats_lengths = self._extract_feats(speech, speech_lengths)
        return {"feats": feats, "feats_lengths": feats_lengths}
 
    def encode(
            self,
            speech: torch.Tensor,
            speech_lengths: torch.Tensor,
    ):
        """Frontend + Encoder.
 
        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )
        """
        with autocast(False):
            # 1. Extract feats
            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
 
            # 2. Data augmentation
            if self.specaug is not None and self.training:
                feats, feats_lengths = self.specaug(feats, feats_lengths)
 
            # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                feats, feats_lengths = self.normalize(feats, feats_lengths)
 
        # Pre-encoder, e.g. used for raw input data
        if self.preencoder is not None:
            feats, feats_lengths = self.preencoder(feats, feats_lengths)
 
        # 4. Forward encoder
        if min(speech_lengths) == max(speech_lengths):  # for clipping, set speech_lengths as None
            speech_lengths = None
        encoder_out = self.encoder(feats, speech_lengths, mask=True, features_only=False)
 
        return encoder_out
 
    def _extract_feats(
            self, speech: torch.Tensor, speech_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        assert speech_lengths.dim() == 1, speech_lengths.shape
 
        # for data-parallel
        speech = speech[:, : speech_lengths.max()]
 
        if self.frontend is not None:
            # Frontend
            #  e.g. STFT and Feature extract
            #       data_loader may send time-domain signal in this case
            # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
            feats, feats_lengths = self.frontend(speech, speech_lengths)
        else:
            # No frontend and no feature extract
            feats, feats_lengths = speech, speech_lengths
        return feats, feats_lengths
 
    def set_num_updates(self, num_updates):
        self.num_updates = num_updates
 
    def get_num_updates(self):
        return self.num_updates