| | |
| | | import logging |
| | | from dataclasses import dataclass |
| | | from typing import Dict |
| | | from typing import Iterable, Optional |
| | | import types |
| | | import time |
| | |
| | | from torch import Tensor |
| | | from torch import nn |
| | | from torch.cuda.amp import autocast |
| | | from funasr.metrics.compute_acc import compute_accuracy |
| | | from funasr.metrics.compute_acc import compute_accuracy, th_accuracy |
| | | from funasr.losses.label_smoothing_loss import LabelSmoothingLoss |
| | | from funasr.train_utils.device_funcs import force_gatherable |
| | | from . import whisper_lib as whisper |
| | | |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | from funasr.utils.datadir_writer import DatadirWriter |
| | | from funasr.models.ctc.ctc import CTC |
| | | |
| | | from funasr.register import tables |
| | | |
| | | |
| | | @tables.register("model_classes", "SenseVoice") |
| | | class SenseVoice(nn.Module): |
| | | from funasr.models.paraformer.search import Hypothesis |
| | | from funasr.models.sense_voice.utils.ctc_alignment import ctc_forced_align |
| | | |
| | | |
| | | class SinusoidalPositionEncoder(torch.nn.Module): |
| | | """ """ |
| | | |
| | | def __int__(self, d_model=80, dropout_rate=0.1): |
| | | pass |
| | | |
| | | def encode( |
| | | self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32 |
| | | ): |
| | | batch_size = positions.size(0) |
| | | positions = positions.type(dtype) |
| | | device = positions.device |
| | | log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype, device=device)) / ( |
| | | depth / 2 - 1 |
| | | ) |
| | | inv_timescales = torch.exp( |
| | | torch.arange(depth / 2, device=device).type(dtype) * (-log_timescale_increment) |
| | | ) |
| | | inv_timescales = torch.reshape(inv_timescales, [batch_size, -1]) |
| | | scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape( |
| | | inv_timescales, [1, 1, -1] |
| | | ) |
| | | encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) |
| | | return encoding.type(dtype) |
| | | |
| | | def forward(self, x): |
| | | batch_size, timesteps, input_dim = x.size() |
| | | positions = torch.arange(1, timesteps + 1, device=x.device)[None, :] |
| | | position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device) |
| | | |
| | | return x + position_encoding |
| | | |
| | | |
| | | class PositionwiseFeedForward(torch.nn.Module): |
| | | """Positionwise feed forward layer. |
| | | |
| | | Args: |
| | | idim (int): Input dimenstion. |
| | | hidden_units (int): The number of hidden units. |
| | | dropout_rate (float): Dropout rate. |
| | | |
| | | """ |
| | | |
| | | def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()): |
| | | """Construct an PositionwiseFeedForward object.""" |
| | | super(PositionwiseFeedForward, self).__init__() |
| | | self.w_1 = torch.nn.Linear(idim, hidden_units) |
| | | self.w_2 = torch.nn.Linear(hidden_units, idim) |
| | | self.dropout = torch.nn.Dropout(dropout_rate) |
| | | self.activation = activation |
| | | |
| | | def forward(self, x): |
| | | """Forward function.""" |
| | | return self.w_2(self.dropout(self.activation(self.w_1(x)))) |
| | | |
| | | |
| | | class MultiHeadedAttentionSANM(nn.Module): |
| | | """Multi-Head Attention layer. |
| | | |
| | | Args: |
| | | n_head (int): The number of heads. |
| | | n_feat (int): The number of features. |
| | | dropout_rate (float): Dropout rate. |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | n_head, |
| | | in_feat, |
| | | n_feat, |
| | | dropout_rate, |
| | | kernel_size, |
| | | sanm_shfit=0, |
| | | lora_list=None, |
| | | lora_rank=8, |
| | | lora_alpha=16, |
| | | lora_dropout=0.1, |
| | | ): |
| | | """Construct an MultiHeadedAttention object.""" |
| | | super().__init__() |
| | | assert n_feat % n_head == 0 |
| | | # We assume d_v always equals d_k |
| | | self.d_k = n_feat // n_head |
| | | self.h = n_head |
| | | # self.linear_q = nn.Linear(n_feat, n_feat) |
| | | # self.linear_k = nn.Linear(n_feat, n_feat) |
| | | # self.linear_v = nn.Linear(n_feat, n_feat) |
| | | |
| | | self.linear_out = nn.Linear(n_feat, n_feat) |
| | | self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3) |
| | | self.attn = None |
| | | self.dropout = nn.Dropout(p=dropout_rate) |
| | | |
| | | self.fsmn_block = nn.Conv1d( |
| | | n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False |
| | | ) |
| | | # padding |
| | | left_padding = (kernel_size - 1) // 2 |
| | | if sanm_shfit > 0: |
| | | left_padding = left_padding + sanm_shfit |
| | | right_padding = kernel_size - 1 - left_padding |
| | | self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0) |
| | | |
| | | def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None): |
| | | b, t, d = inputs.size() |
| | | if mask is not None: |
| | | mask = torch.reshape(mask, (b, -1, 1)) |
| | | if mask_shfit_chunk is not None: |
| | | mask = mask * mask_shfit_chunk |
| | | inputs = inputs * mask |
| | | |
| | | x = inputs.transpose(1, 2) |
| | | x = self.pad_fn(x) |
| | | x = self.fsmn_block(x) |
| | | x = x.transpose(1, 2) |
| | | x += inputs |
| | | x = self.dropout(x) |
| | | if mask is not None: |
| | | x = x * mask |
| | | return x |
| | | |
| | | def forward_qkv(self, x): |
| | | """Transform query, key and value. |
| | | |
| | | Args: |
| | | query (torch.Tensor): Query tensor (#batch, time1, size). |
| | | key (torch.Tensor): Key tensor (#batch, time2, size). |
| | | value (torch.Tensor): Value tensor (#batch, time2, size). |
| | | |
| | | Returns: |
| | | torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). |
| | | torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). |
| | | torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). |
| | | |
| | | """ |
| | | b, t, d = x.size() |
| | | q_k_v = self.linear_q_k_v(x) |
| | | q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1) |
| | | q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose( |
| | | 1, 2 |
| | | ) # (batch, head, time1, d_k) |
| | | k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose( |
| | | 1, 2 |
| | | ) # (batch, head, time2, d_k) |
| | | v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose( |
| | | 1, 2 |
| | | ) # (batch, head, time2, d_k) |
| | | |
| | | return q_h, k_h, v_h, v |
| | | |
| | | def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None): |
| | | """Compute attention context vector. |
| | | |
| | | Args: |
| | | value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). |
| | | scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). |
| | | mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). |
| | | |
| | | Returns: |
| | | torch.Tensor: Transformed value (#batch, time1, d_model) |
| | | weighted by the attention score (#batch, time1, time2). |
| | | |
| | | """ |
| | | n_batch = value.size(0) |
| | | if mask is not None: |
| | | if mask_att_chunk_encoder is not None: |
| | | mask = mask * mask_att_chunk_encoder |
| | | |
| | | mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) |
| | | |
| | | min_value = -float( |
| | | "inf" |
| | | ) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min) |
| | | scores = scores.masked_fill(mask, min_value) |
| | | attn = torch.softmax(scores, dim=-1).masked_fill( |
| | | mask, 0.0 |
| | | ) # (batch, head, time1, time2) |
| | | else: |
| | | attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) |
| | | |
| | | p_attn = self.dropout(attn) |
| | | x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) |
| | | x = ( |
| | | x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) |
| | | ) # (batch, time1, d_model) |
| | | |
| | | return self.linear_out(x) # (batch, time1, d_model) |
| | | |
| | | def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None): |
| | | """Compute scaled dot product attention. |
| | | |
| | | Args: |
| | | query (torch.Tensor): Query tensor (#batch, time1, size). |
| | | key (torch.Tensor): Key tensor (#batch, time2, size). |
| | | value (torch.Tensor): Value tensor (#batch, time2, size). |
| | | mask (torch.Tensor): Mask tensor (#batch, 1, time2) or |
| | | (#batch, time1, time2). |
| | | |
| | | Returns: |
| | | torch.Tensor: Output tensor (#batch, time1, d_model). |
| | | |
| | | """ |
| | | q_h, k_h, v_h, v = self.forward_qkv(x) |
| | | fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk) |
| | | q_h = q_h * self.d_k ** (-0.5) |
| | | scores = torch.matmul(q_h, k_h.transpose(-2, -1)) |
| | | att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder) |
| | | return att_outs + fsmn_memory |
| | | |
| | | def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0): |
| | | """Compute scaled dot product attention. |
| | | |
| | | Args: |
| | | query (torch.Tensor): Query tensor (#batch, time1, size). |
| | | key (torch.Tensor): Key tensor (#batch, time2, size). |
| | | value (torch.Tensor): Value tensor (#batch, time2, size). |
| | | mask (torch.Tensor): Mask tensor (#batch, 1, time2) or |
| | | (#batch, time1, time2). |
| | | |
| | | Returns: |
| | | torch.Tensor: Output tensor (#batch, time1, d_model). |
| | | |
| | | """ |
| | | q_h, k_h, v_h, v = self.forward_qkv(x) |
| | | if chunk_size is not None and look_back > 0 or look_back == -1: |
| | | if cache is not None: |
| | | k_h_stride = k_h[:, :, : -(chunk_size[2]), :] |
| | | v_h_stride = v_h[:, :, : -(chunk_size[2]), :] |
| | | k_h = torch.cat((cache["k"], k_h), dim=2) |
| | | v_h = torch.cat((cache["v"], v_h), dim=2) |
| | | |
| | | cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2) |
| | | cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2) |
| | | if look_back != -1: |
| | | cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]) :, :] |
| | | cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]) :, :] |
| | | else: |
| | | cache_tmp = { |
| | | "k": k_h[:, :, : -(chunk_size[2]), :], |
| | | "v": v_h[:, :, : -(chunk_size[2]), :], |
| | | } |
| | | cache = cache_tmp |
| | | fsmn_memory = self.forward_fsmn(v, None) |
| | | q_h = q_h * self.d_k ** (-0.5) |
| | | scores = torch.matmul(q_h, k_h.transpose(-2, -1)) |
| | | att_outs = self.forward_attention(v_h, scores, None) |
| | | return att_outs + fsmn_memory, cache |
| | | |
| | | |
| | | class LayerNorm(nn.LayerNorm): |
| | | def __init__(self, *args, **kwargs): |
| | | super().__init__(*args, **kwargs) |
| | | |
| | | def forward(self, input): |
| | | output = F.layer_norm( |
| | | input.float(), |
| | | self.normalized_shape, |
| | | self.weight.float() if self.weight is not None else None, |
| | | self.bias.float() if self.bias is not None else None, |
| | | self.eps, |
| | | ) |
| | | return output.type_as(input) |
| | | |
| | | |
| | | def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None): |
| | | if maxlen is None: |
| | | maxlen = lengths.max() |
| | | row_vector = torch.arange(0, maxlen, 1).to(lengths.device) |
| | | matrix = torch.unsqueeze(lengths, dim=-1) |
| | | mask = row_vector < matrix |
| | | mask = mask.detach() |
| | | |
| | | return mask.type(dtype).to(device) if device is not None else mask.type(dtype) |
| | | |
| | | |
| | | class EncoderLayerSANM(nn.Module): |
| | | def __init__( |
| | | self, |
| | | in_size, |
| | | size, |
| | | self_attn, |
| | | feed_forward, |
| | | dropout_rate, |
| | | normalize_before=True, |
| | | concat_after=False, |
| | | stochastic_depth_rate=0.0, |
| | | ): |
| | | """Construct an EncoderLayer object.""" |
| | | super(EncoderLayerSANM, self).__init__() |
| | | self.self_attn = self_attn |
| | | self.feed_forward = feed_forward |
| | | self.norm1 = LayerNorm(in_size) |
| | | self.norm2 = LayerNorm(size) |
| | | self.dropout = nn.Dropout(dropout_rate) |
| | | self.in_size = in_size |
| | | self.size = size |
| | | self.normalize_before = normalize_before |
| | | self.concat_after = concat_after |
| | | if self.concat_after: |
| | | self.concat_linear = nn.Linear(size + size, size) |
| | | self.stochastic_depth_rate = stochastic_depth_rate |
| | | self.dropout_rate = dropout_rate |
| | | |
| | | def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None): |
| | | """Compute encoded features. |
| | | |
| | | Args: |
| | | x_input (torch.Tensor): Input tensor (#batch, time, size). |
| | | mask (torch.Tensor): Mask tensor for the input (#batch, time). |
| | | cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). |
| | | |
| | | Returns: |
| | | torch.Tensor: Output tensor (#batch, time, size). |
| | | torch.Tensor: Mask tensor (#batch, time). |
| | | |
| | | """ |
| | | skip_layer = False |
| | | # with stochastic depth, residual connection `x + f(x)` becomes |
| | | # `x <- x + 1 / (1 - p) * f(x)` at training time. |
| | | stoch_layer_coeff = 1.0 |
| | | if self.training and self.stochastic_depth_rate > 0: |
| | | skip_layer = torch.rand(1).item() < self.stochastic_depth_rate |
| | | stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) |
| | | |
| | | if skip_layer: |
| | | if cache is not None: |
| | | x = torch.cat([cache, x], dim=1) |
| | | return x, mask |
| | | |
| | | residual = x |
| | | if self.normalize_before: |
| | | x = self.norm1(x) |
| | | |
| | | if self.concat_after: |
| | | x_concat = torch.cat( |
| | | ( |
| | | x, |
| | | self.self_attn( |
| | | x, |
| | | mask, |
| | | mask_shfit_chunk=mask_shfit_chunk, |
| | | mask_att_chunk_encoder=mask_att_chunk_encoder, |
| | | ), |
| | | ), |
| | | dim=-1, |
| | | ) |
| | | if self.in_size == self.size: |
| | | x = residual + stoch_layer_coeff * self.concat_linear(x_concat) |
| | | else: |
| | | x = stoch_layer_coeff * self.concat_linear(x_concat) |
| | | else: |
| | | if self.in_size == self.size: |
| | | x = residual + stoch_layer_coeff * self.dropout( |
| | | self.self_attn( |
| | | x, |
| | | mask, |
| | | mask_shfit_chunk=mask_shfit_chunk, |
| | | mask_att_chunk_encoder=mask_att_chunk_encoder, |
| | | ) |
| | | ) |
| | | else: |
| | | x = stoch_layer_coeff * self.dropout( |
| | | self.self_attn( |
| | | x, |
| | | mask, |
| | | mask_shfit_chunk=mask_shfit_chunk, |
| | | mask_att_chunk_encoder=mask_att_chunk_encoder, |
| | | ) |
| | | ) |
| | | if not self.normalize_before: |
| | | x = self.norm1(x) |
| | | |
| | | residual = x |
| | | if self.normalize_before: |
| | | x = self.norm2(x) |
| | | x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) |
| | | if not self.normalize_before: |
| | | x = self.norm2(x) |
| | | |
| | | return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder |
| | | |
| | | def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0): |
| | | """Compute encoded features. |
| | | |
| | | Args: |
| | | x_input (torch.Tensor): Input tensor (#batch, time, size). |
| | | mask (torch.Tensor): Mask tensor for the input (#batch, time). |
| | | cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). |
| | | |
| | | Returns: |
| | | torch.Tensor: Output tensor (#batch, time, size). |
| | | torch.Tensor: Mask tensor (#batch, time). |
| | | |
| | | """ |
| | | |
| | | residual = x |
| | | if self.normalize_before: |
| | | x = self.norm1(x) |
| | | |
| | | if self.in_size == self.size: |
| | | attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back) |
| | | x = residual + attn |
| | | else: |
| | | x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back) |
| | | |
| | | if not self.normalize_before: |
| | | x = self.norm1(x) |
| | | |
| | | residual = x |
| | | if self.normalize_before: |
| | | x = self.norm2(x) |
| | | x = residual + self.feed_forward(x) |
| | | if not self.normalize_before: |
| | | x = self.norm2(x) |
| | | |
| | | return x, cache |
| | | |
| | | |
| | | @tables.register("encoder_classes", "SenseVoiceEncoderSmall") |
| | | class SenseVoiceEncoderSmall(nn.Module): |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition |
| | | https://arxiv.org/abs/2006.01713 |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | input_size: int, |
| | | output_size: int = 256, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | tp_blocks: int = 0, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | attention_dropout_rate: float = 0.0, |
| | | stochastic_depth_rate: float = 0.0, |
| | | input_layer: Optional[str] = "conv2d", |
| | | pos_enc_class=SinusoidalPositionEncoder, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | positionwise_layer_type: str = "linear", |
| | | positionwise_conv_kernel_size: int = 1, |
| | | padding_idx: int = -1, |
| | | kernel_size: int = 11, |
| | | sanm_shfit: int = 0, |
| | | selfattention_layer_type: str = "sanm", |
| | | **kwargs, |
| | | ): |
| | | super().__init__() |
| | | self._output_size = output_size |
| | | |
| | | self.embed = SinusoidalPositionEncoder() |
| | | |
| | | self.normalize_before = normalize_before |
| | | |
| | | positionwise_layer = PositionwiseFeedForward |
| | | positionwise_layer_args = ( |
| | | output_size, |
| | | linear_units, |
| | | dropout_rate, |
| | | ) |
| | | |
| | | encoder_selfattn_layer = MultiHeadedAttentionSANM |
| | | encoder_selfattn_layer_args0 = ( |
| | | attention_heads, |
| | | input_size, |
| | | output_size, |
| | | attention_dropout_rate, |
| | | kernel_size, |
| | | sanm_shfit, |
| | | ) |
| | | encoder_selfattn_layer_args = ( |
| | | attention_heads, |
| | | output_size, |
| | | output_size, |
| | | attention_dropout_rate, |
| | | kernel_size, |
| | | sanm_shfit, |
| | | ) |
| | | |
| | | self.encoders0 = nn.ModuleList( |
| | | [ |
| | | EncoderLayerSANM( |
| | | input_size, |
| | | output_size, |
| | | encoder_selfattn_layer(*encoder_selfattn_layer_args0), |
| | | positionwise_layer(*positionwise_layer_args), |
| | | dropout_rate, |
| | | ) |
| | | for i in range(1) |
| | | ] |
| | | ) |
| | | self.encoders = nn.ModuleList( |
| | | [ |
| | | EncoderLayerSANM( |
| | | output_size, |
| | | output_size, |
| | | encoder_selfattn_layer(*encoder_selfattn_layer_args), |
| | | positionwise_layer(*positionwise_layer_args), |
| | | dropout_rate, |
| | | ) |
| | | for i in range(num_blocks - 1) |
| | | ] |
| | | ) |
| | | |
| | | self.tp_encoders = nn.ModuleList( |
| | | [ |
| | | EncoderLayerSANM( |
| | | output_size, |
| | | output_size, |
| | | encoder_selfattn_layer(*encoder_selfattn_layer_args), |
| | | positionwise_layer(*positionwise_layer_args), |
| | | dropout_rate, |
| | | ) |
| | | for i in range(tp_blocks) |
| | | ] |
| | | ) |
| | | |
| | | self.after_norm = LayerNorm(output_size) |
| | | |
| | | self.tp_norm = LayerNorm(output_size) |
| | | |
| | | def output_size(self) -> int: |
| | | return self._output_size |
| | | |
| | | def forward( |
| | | self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | ): |
| | | """Embed positions in tensor.""" |
| | | maxlen = xs_pad.shape[1] |
| | | masks = sequence_mask(ilens, maxlen = maxlen, device=ilens.device)[:, None, :] |
| | | |
| | | xs_pad *= self.output_size() ** 0.5 |
| | | |
| | | xs_pad = self.embed(xs_pad) |
| | | |
| | | # forward encoder1 |
| | | for layer_idx, encoder_layer in enumerate(self.encoders0): |
| | | encoder_outs = encoder_layer(xs_pad, masks) |
| | | xs_pad, masks = encoder_outs[0], encoder_outs[1] |
| | | |
| | | for layer_idx, encoder_layer in enumerate(self.encoders): |
| | | encoder_outs = encoder_layer(xs_pad, masks) |
| | | xs_pad, masks = encoder_outs[0], encoder_outs[1] |
| | | |
| | | xs_pad = self.after_norm(xs_pad) |
| | | |
| | | # forward encoder2 |
| | | olens = masks.squeeze(1).sum(1).int() |
| | | |
| | | for layer_idx, encoder_layer in enumerate(self.tp_encoders): |
| | | encoder_outs = encoder_layer(xs_pad, masks) |
| | | xs_pad, masks = encoder_outs[0], encoder_outs[1] |
| | | |
| | | xs_pad = self.tp_norm(xs_pad) |
| | | return xs_pad, olens |
| | | |
| | | |
| | | @tables.register("model_classes", "SenseVoiceSmall") |
| | | class SenseVoiceSmall(nn.Module): |
| | | """CTC-attention hybrid Encoder-Decoder model""" |
| | | |
| | | def __init__( |
| | | self, |
| | | specaug: str = None, |
| | | specaug_conf: dict = None, |
| | | normalize: str = None, |
| | | normalize_conf: dict = None, |
| | | encoder: str = None, |
| | | encoder_conf: dict = None, |
| | | ctc_conf: dict = None, |
| | | input_size: int = 80, |
| | | vocab_size: int = -1, |
| | | ignore_id: int = -1, |
| | | blank_id: int = 0, |
| | | sos: int = 1, |
| | | eos: int = 2, |
| | | length_normalized_loss: bool = False, |
| | | **kwargs, |
| | | ): |
| | | |
| | | super().__init__() |
| | | |
| | | dims = kwargs.get("dims", {}) |
| | | dims = whisper.model.ModelDimensions(**dims) |
| | | model = whisper.model.Whisper(dims=dims) |
| | | if specaug is not None: |
| | | specaug_class = tables.specaug_classes.get(specaug) |
| | | specaug = specaug_class(**specaug_conf) |
| | | if normalize is not None: |
| | | normalize_class = tables.normalize_classes.get(normalize) |
| | | normalize = normalize_class(**normalize_conf) |
| | | encoder_class = tables.encoder_classes.get(encoder) |
| | | encoder = encoder_class(input_size=input_size, **encoder_conf) |
| | | encoder_output_size = encoder.output_size() |
| | | |
| | | # encoder |
| | | model.encoder.downsample_rate = kwargs.get("downsample_rate", 4) |
| | | model.encoder.use_padmask = kwargs.get("use_padmask", True) |
| | | from .encoder import sense_voice_encode_forward |
| | | if ctc_conf is None: |
| | | ctc_conf = {} |
| | | ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf) |
| | | |
| | | model.encoder.forward = types.MethodType(sense_voice_encode_forward, model.encoder) |
| | | self.blank_id = blank_id |
| | | self.sos = sos if sos is not None else vocab_size - 1 |
| | | self.eos = eos if eos is not None else vocab_size - 1 |
| | | self.vocab_size = vocab_size |
| | | self.ignore_id = ignore_id |
| | | self.specaug = specaug |
| | | self.normalize = normalize |
| | | self.encoder = encoder |
| | | self.error_calculator = None |
| | | |
| | | # decoder |
| | | model.decoder.use_padmask = kwargs.get("use_padmask", True) |
| | | from .decoder import sense_voice_decode_forward |
| | | self.ctc = ctc |
| | | |
| | | model.decoder.forward = types.MethodType(sense_voice_decode_forward, model.decoder) |
| | | self.length_normalized_loss = length_normalized_loss |
| | | self.encoder_output_size = encoder_output_size |
| | | |
| | | self.model = model |
| | | self.lid_dict = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13} |
| | | self.lid_int_dict = {24884: 3, 24885: 4, 24888: 7, 24892: 11, 24896: 12, 24992: 13} |
| | | self.textnorm_dict = {"withitn": 14, "woitn": 15} |
| | | self.textnorm_int_dict = {25016: 14, 25017: 15} |
| | | self.embed = torch.nn.Embedding( |
| | | 7 + len(self.lid_dict) + len(self.textnorm_dict), input_size |
| | | ) |
| | | self.emo_dict = { |
| | | "unk": 25009, |
| | | "happy": 25001, |
| | | "sad": 25002, |
| | | "angry": 25003, |
| | | "neutral": 25004, |
| | | } |
| | | |
| | | self.encoder_output_size = self.model.dims.n_audio_state |
| | | |
| | | self.activation_checkpoint = kwargs.get("activation_checkpoint", False) |
| | | self.ignore_id = kwargs.get("ignore_id", -1) |
| | | self.vocab_size = kwargs.get("vocab_size", -1) |
| | | self.length_normalized_loss = kwargs.get("length_normalized_loss", True) |
| | | self.criterion_att = LabelSmoothingLoss( |
| | | size=self.vocab_size, |
| | | padding_idx=self.ignore_id, |
| | |
| | | normalize_length=self.length_normalized_loss, |
| | | ) |
| | | |
| | | specaug = kwargs.get("specaug", None) |
| | | if specaug is not None: |
| | | specaug_class = tables.specaug_classes.get(specaug) |
| | | specaug = specaug_class(**kwargs.get("specaug_conf", {})) |
| | | self.specaug = specaug |
| | | @staticmethod |
| | | def from_pretrained(model: str = None, **kwargs): |
| | | from funasr import AutoModel |
| | | |
| | | model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs) |
| | | |
| | | return model, kwargs |
| | | |
| | | def forward( |
| | | self, |
| | |
| | | text_lengths: torch.Tensor, |
| | | **kwargs, |
| | | ): |
| | | target_mask = kwargs.get("target_mask", None) |
| | | |
| | | """Encoder + Decoder + Calc loss |
| | | Args: |
| | | speech: (Batch, Length, ...) |
| | | speech_lengths: (Batch, ) |
| | | text: (Batch, Length) |
| | | text_lengths: (Batch,) |
| | | """ |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | if len(text_lengths.size()) > 1: |
| | |
| | | |
| | | batch_size = speech.shape[0] |
| | | |
| | | if self.activation_checkpoint: |
| | | from torch.utils.checkpoint import checkpoint |
| | | # 1. Encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text) |
| | | |
| | | encoder_out, encoder_out_lens = checkpoint( |
| | | self.encode, speech, speech_lengths, use_reentrant=False |
| | | ) |
| | | else: |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | loss_ctc, cer_ctc = None, None |
| | | loss_rich, acc_rich = None, None |
| | | stats = dict() |
| | | |
| | | loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( |
| | | encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask |
| | | loss_ctc, cer_ctc = self._calc_ctc_loss( |
| | | encoder_out[:, 4:, :], encoder_out_lens - 4, text[:, 4:], text_lengths - 4 |
| | | ) |
| | | loss = loss_att |
| | | stats = {} |
| | | stats["acc"] = acc_att |
| | | stats["loss"] = torch.clone(loss.detach()) |
| | | stats["batch_size"] = batch_size |
| | | |
| | | loss_rich, acc_rich = self._calc_rich_ce_loss(encoder_out[:, :4, :], text[:, :4]) |
| | | |
| | | loss = loss_ctc + loss_rich |
| | | # Collect total loss stats |
| | | stats["loss_ctc"] = torch.clone(loss_ctc.detach()) if loss_ctc is not None else None |
| | | stats["loss_rich"] = torch.clone(loss_rich.detach()) if loss_rich is not None else None |
| | | stats["loss"] = torch.clone(loss.detach()) if loss is not None else None |
| | | stats["acc_rich"] = acc_rich |
| | | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | if self.length_normalized_loss: |
| | |
| | | return loss, stats, weight |
| | | |
| | | def encode( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | **kwargs, |
| | | ): |
| | | """Encoder. Note that this method is used by asr_inference.py |
| | | Args: |
| | | speech: (Batch, Length, ...) |
| | | speech_lengths: (Batch, ) |
| | | ind: int |
| | | """ |
| | | with autocast(False): |
| | | |
| | | # Data augmentation |
| | | if self.specaug is not None and self.training: |
| | | speech, speech_lengths = self.specaug(speech, speech_lengths) |
| | | |
| | | # Forward encoder |
| | | encoder_out, encoder_out_lens = self.model.encoder(speech.permute(0, 2, 1), speech_lengths) |
| | | |
| | | return encoder_out, encoder_out_lens |
| | | |
| | | def _calc_att_loss( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ys_pad_lens: torch.Tensor, |
| | | **kwargs, |
| | | ): |
| | | target_mask = kwargs.get("target_mask", None) |
| | | stats = {} |
| | | |
| | | # 1. Forward decoder |
| | | decoder_out = self.model.decoder( |
| | | x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens |
| | | ) |
| | | |
| | | # 2. Compute attention loss |
| | | mask = torch.ones_like(ys_pad) * (-1) |
| | | ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64) |
| | | ys_pad_mask[ys_pad_mask == 0] = -1 |
| | | loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:]) |
| | | |
| | | with torch.no_grad(): |
| | | preds = torch.argmax(decoder_out, -1) |
| | | acc_att = compute_accuracy( |
| | | preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id |
| | | ) |
| | | |
| | | return loss_att, acc_att, None, None |
| | | |
| | | def inference( |
| | | self, |
| | | data_in, |
| | | data_lengths=None, |
| | | key: list = None, |
| | | tokenizer=None, |
| | | frontend=None, |
| | | **kwargs, |
| | | ): |
| | | if kwargs.get("batch_size", 1) > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | |
| | | if frontend is None and not hasattr(self, "frontend"): |
| | | frontend_class = tables.frontend_classes.get("WhisperFrontend") |
| | | frontend = frontend_class( |
| | | n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True) |
| | | ) |
| | | self.frontend = frontend |
| | | else: |
| | | frontend = frontend if frontend is not None else self.frontend |
| | | |
| | | meta_data = {} |
| | | if ( |
| | | isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank" |
| | | ): # fbank |
| | | speech, speech_lengths = data_in, data_lengths |
| | | if len(speech.shape) < 3: |
| | | speech = speech[None, :, :] |
| | | if speech_lengths is None: |
| | | speech_lengths = speech.shape[1] |
| | | else: |
| | | # extract fbank feats |
| | | time1 = time.perf_counter() |
| | | audio_sample_list = load_audio_text_image_video( |
| | | data_in, |
| | | fs=frontend.fs if hasattr(frontend, "fs") else 16000, |
| | | audio_fs=kwargs.get("fs", 16000), |
| | | data_type=kwargs.get("data_type", "sound"), |
| | | tokenizer=tokenizer, |
| | | ) |
| | | time2 = time.perf_counter() |
| | | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| | | speech, speech_lengths = extract_fbank( |
| | | audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend |
| | | ) |
| | | time3 = time.perf_counter() |
| | | meta_data["extract_feat"] = f"{time3 - time2:0.3f}" |
| | | frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10 |
| | | lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1 |
| | | meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000 |
| | | |
| | | speech = speech.to(device=kwargs["device"])[0, :, :] |
| | | speech_lengths = speech_lengths.to(device=kwargs["device"]) |
| | | |
| | | DecodingOptions = kwargs.get("DecodingOptions", {}) |
| | | task = DecodingOptions.get("task", "ASR") |
| | | if isinstance(task, str): |
| | | task = [task] |
| | | task = "".join([f"<|{x}|>" for x in task]) |
| | | initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}") |
| | | DecodingOptions["initial_prompt"] = initial_prompt |
| | | |
| | | language = DecodingOptions.get("language", None) |
| | | language = None if language == "auto" else language |
| | | DecodingOptions["language"] = language |
| | | |
| | | DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None) |
| | | |
| | | if "without_timestamps" not in DecodingOptions: |
| | | DecodingOptions["without_timestamps"] = True |
| | | |
| | | options = whisper.DecodingOptions(**DecodingOptions) |
| | | |
| | | result = whisper.decode(self.model, speech, options) |
| | | text = f"{result.text}" |
| | | results = [] |
| | | result_i = {"key": key[0], "text": text} |
| | | |
| | | results.append(result_i) |
| | | |
| | | return results, meta_data |
| | | |
| | | |
| | | @tables.register("model_classes", "SenseVoiceRWKV") |
| | | class SenseVoiceRWKV(nn.Module): |
| | | def __init__(self, *args, **kwargs): |
| | | super().__init__() |
| | | |
| | | dims = kwargs.get("dims", {}) |
| | | dims = whisper.model.ModelDimensions(**dims) |
| | | model = whisper.model.Whisper(dims=dims) |
| | | |
| | | # encoder |
| | | model.encoder.downsample_rate = kwargs.get("downsample_rate", 4) |
| | | model.encoder.use_padmask = kwargs.get("use_padmask", True) |
| | | from .encoder import sense_voice_encode_forward |
| | | |
| | | model.encoder.forward = types.MethodType(sense_voice_encode_forward, model.encoder) |
| | | |
| | | # decoder |
| | | del model.decoder |
| | | decoder = kwargs.get("decoder", "SenseVoiceDecoder") |
| | | decoder_class = tables.decoder_classes.get(decoder) |
| | | decoder = decoder_class( |
| | | n_vocab=dims.n_vocab, |
| | | n_ctx=dims.n_text_ctx, |
| | | n_state=dims.n_text_state, |
| | | n_head=dims.n_text_head, |
| | | n_layer=dims.n_text_layer, |
| | | **kwargs.get("decoder_conf"), |
| | | ) |
| | | model.decoder = decoder |
| | | |
| | | self.model = model |
| | | |
| | | self.encoder_output_size = self.model.dims.n_audio_state |
| | | |
| | | self.activation_checkpoint = kwargs.get("activation_checkpoint", False) |
| | | self.ignore_id = kwargs.get("ignore_id", -1) |
| | | self.vocab_size = kwargs.get("vocab_size", -1) |
| | | self.length_normalized_loss = kwargs.get("length_normalized_loss", True) |
| | | self.criterion_att = LabelSmoothingLoss( |
| | | size=self.vocab_size, |
| | | padding_idx=self.ignore_id, |
| | | smoothing=kwargs.get("lsm_weight", 0.0), |
| | | normalize_length=self.length_normalized_loss, |
| | | ) |
| | | |
| | | specaug = kwargs.get("specaug", None) |
| | | if specaug is not None: |
| | | specaug_class = tables.specaug_classes.get(specaug) |
| | | specaug = specaug_class(**kwargs.get("specaug_conf", {})) |
| | | self.specaug = specaug |
| | | |
| | | def forward( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | **kwargs, |
| | | ): |
| | | target_mask = kwargs.get("target_mask", None) |
| | | |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | if len(text_lengths.size()) > 1: |
| | | text_lengths = text_lengths[:, 0] |
| | | if len(speech_lengths.size()) > 1: |
| | | speech_lengths = speech_lengths[:, 0] |
| | | |
| | | batch_size, frames, _ = speech.shape |
| | | |
| | | if self.activation_checkpoint: |
| | | from torch.utils.checkpoint import checkpoint |
| | | |
| | | encoder_out, encoder_out_lens = checkpoint( |
| | | self.encode, speech, speech_lengths, use_reentrant=False |
| | | ) |
| | | else: |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | |
| | | loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( |
| | | encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask |
| | | ) |
| | | loss = loss_att |
| | | stats = {} |
| | | stats["acc"] = acc_att |
| | | stats["loss"] = torch.clone(loss.detach()) |
| | | stats["batch_size"] = batch_size |
| | | stats["batch_size_x_frames"] = frames * batch_size |
| | | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | if self.length_normalized_loss: |
| | | batch_size = int((text_lengths + 1).sum()) |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | return loss, stats, weight |
| | | |
| | | def encode( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | **kwargs, |
| | | ): |
| | | """Encoder. Note that this method is used by asr_inference.py |
| | | """Frontend + Encoder. Note that this method is used by asr_inference.py |
| | | Args: |
| | | speech: (Batch, Length, ...) |
| | | speech_lengths: (Batch, ) |
| | | ind: int |
| | | """ |
| | | with autocast(False): |
| | | # Data augmentation |
| | | if self.specaug is not None and self.training: |
| | | speech, speech_lengths = self.specaug(speech, speech_lengths) |
| | | |
| | | # Forward encoder |
| | | encoder_out, encoder_out_lens = self.model.encoder(speech.permute(0, 2, 1), speech_lengths) |
| | | # Data augmentation |
| | | if self.specaug is not None and self.training: |
| | | speech, speech_lengths = self.specaug(speech, speech_lengths) |
| | | |
| | | # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN |
| | | if self.normalize is not None: |
| | | speech, speech_lengths = self.normalize(speech, speech_lengths) |
| | | |
| | | lids = torch.LongTensor( |
| | | [ |
| | | [ |
| | | ( |
| | | self.lid_int_dict[int(lid)] |
| | | if torch.rand(1) > 0.2 and int(lid) in self.lid_int_dict |
| | | else 0 |
| | | ) |
| | | ] |
| | | for lid in text[:, 0] |
| | | ] |
| | | ).to(speech.device) |
| | | language_query = self.embed(lids) |
| | | |
| | | styles = torch.LongTensor( |
| | | [[self.textnorm_int_dict[int(style)]] for style in text[:, 3]] |
| | | ).to(speech.device) |
| | | style_query = self.embed(styles) |
| | | speech = torch.cat((style_query, speech), dim=1) |
| | | speech_lengths += 1 |
| | | |
| | | event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat( |
| | | speech.size(0), 1, 1 |
| | | ) |
| | | input_query = torch.cat((language_query, event_emo_query), dim=1) |
| | | speech = torch.cat((input_query, speech), dim=1) |
| | | speech_lengths += 3 |
| | | |
| | | encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths) |
| | | |
| | | return encoder_out, encoder_out_lens |
| | | |
| | | def _calc_att_loss( |
| | | def _calc_ctc_loss( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ys_pad_lens: torch.Tensor, |
| | | **kwargs, |
| | | ): |
| | | target_mask = kwargs.get("target_mask", None) |
| | | stats = {} |
| | | # Calc CTC loss |
| | | loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) |
| | | |
| | | # 1. Forward decoder |
| | | decoder_out = self.model.decoder( |
| | | x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens |
| | | # Calc CER using CTC |
| | | cer_ctc = None |
| | | if not self.training and self.error_calculator is not None: |
| | | ys_hat = self.ctc.argmax(encoder_out).data |
| | | cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) |
| | | return loss_ctc, cer_ctc |
| | | |
| | | def _calc_rich_ce_loss( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ): |
| | | decoder_out = self.ctc.ctc_lo(encoder_out) |
| | | # 2. Compute attention loss |
| | | loss_rich = self.criterion_att(decoder_out, ys_pad.contiguous()) |
| | | acc_rich = th_accuracy( |
| | | decoder_out.view(-1, self.vocab_size), |
| | | ys_pad.contiguous(), |
| | | ignore_label=self.ignore_id, |
| | | ) |
| | | |
| | | # 2. Compute attention loss |
| | | mask = torch.ones_like(ys_pad) * (-1) |
| | | ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64) |
| | | ys_pad_mask[ys_pad_mask == 0] = -1 |
| | | loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:]) |
| | | |
| | | with torch.no_grad(): |
| | | preds = torch.argmax(decoder_out, -1) |
| | | acc_att = compute_accuracy( |
| | | preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id |
| | | ) |
| | | |
| | | return loss_att, acc_att, None, None |
| | | return loss_rich, acc_rich |
| | | |
| | | def inference( |
| | | self, |
| | | data_in, |
| | | data_lengths=None, |
| | | key: list = None, |
| | | key: list = ["wav_file_tmp_name"], |
| | | tokenizer=None, |
| | | frontend=None, |
| | | **kwargs, |
| | | ): |
| | | if kwargs.get("batch_size", 1) > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | |
| | | if frontend is None and not hasattr(self, "frontend"): |
| | | frontend_class = tables.frontend_classes.get("WhisperFrontend") |
| | | frontend = frontend_class( |
| | | n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True) |
| | | ) |
| | | self.frontend = frontend |
| | | else: |
| | | frontend = frontend if frontend is not None else self.frontend |
| | | |
| | | meta_data = {} |
| | | if ( |
| | |
| | | time1 = time.perf_counter() |
| | | audio_sample_list = load_audio_text_image_video( |
| | | data_in, |
| | | fs=frontend.fs if hasattr(frontend, "fs") else 16000, |
| | | fs=frontend.fs, |
| | | audio_fs=kwargs.get("fs", 16000), |
| | | data_type=kwargs.get("data_type", "sound"), |
| | | tokenizer=tokenizer, |
| | |
| | | ) |
| | | time3 = time.perf_counter() |
| | | meta_data["extract_feat"] = f"{time3 - time2:0.3f}" |
| | | frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10 |
| | | lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1 |
| | | meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000 |
| | | meta_data["batch_data_time"] = ( |
| | | speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 |
| | | ) |
| | | |
| | | speech = speech.to(device=kwargs["device"])[0, :, :] |
| | | speech = speech.to(device=kwargs["device"]) |
| | | speech_lengths = speech_lengths.to(device=kwargs["device"]) |
| | | |
| | | DecodingOptions = kwargs.get("DecodingOptions", {}) |
| | | task = DecodingOptions.get("task", "ASR") |
| | | if isinstance(task, str): |
| | | task = [task] |
| | | task = "".join([f"<|{x}|>" for x in task]) |
| | | initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}") |
| | | DecodingOptions["initial_prompt"] = initial_prompt |
| | | language = kwargs.get("language", "auto") |
| | | language_query = self.embed( |
| | | torch.LongTensor([[self.lid_dict[language] if language in self.lid_dict else 0]]).to( |
| | | speech.device |
| | | ) |
| | | ).repeat(speech.size(0), 1, 1) |
| | | |
| | | language = DecodingOptions.get("language", None) |
| | | language = None if language == "auto" else language |
| | | DecodingOptions["language"] = language |
| | | use_itn = kwargs.get("use_itn", False) |
| | | textnorm = kwargs.get("text_norm", None) |
| | | output_timestamp = kwargs.get("output_timestamp", False) |
| | | |
| | | DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None) |
| | | if textnorm is None: |
| | | textnorm = "withitn" if use_itn else "woitn" |
| | | textnorm_query = self.embed( |
| | | torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device) |
| | | ).repeat(speech.size(0), 1, 1) |
| | | speech = torch.cat((textnorm_query, speech), dim=1) |
| | | speech_lengths += 1 |
| | | |
| | | if "without_timestamps" not in DecodingOptions: |
| | | DecodingOptions["without_timestamps"] = True |
| | | event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat( |
| | | speech.size(0), 1, 1 |
| | | ) |
| | | input_query = torch.cat((language_query, event_emo_query), dim=1) |
| | | speech = torch.cat((input_query, speech), dim=1) |
| | | speech_lengths += 3 |
| | | |
| | | options = whisper.DecodingOptions(**DecodingOptions) |
| | | # Encoder |
| | | encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths) |
| | | if isinstance(encoder_out, tuple): |
| | | encoder_out = encoder_out[0] |
| | | |
| | | result = whisper.decode(self.model, speech, options) |
| | | text = f"{result.text}" |
| | | # c. Passed the encoder result and the beam search |
| | | ctc_logits = self.ctc.log_softmax(encoder_out) |
| | | if kwargs.get("ban_emo_unk", False): |
| | | ctc_logits[:, :, self.emo_dict["unk"]] = -float("inf") |
| | | |
| | | results = [] |
| | | result_i = {"key": key[0], "text": text} |
| | | b, n, d = encoder_out.size() |
| | | if isinstance(key[0], (list, tuple)): |
| | | key = key[0] |
| | | if len(key) < b: |
| | | key = key * b |
| | | for i in range(b): |
| | | x = ctc_logits[i, : encoder_out_lens[i].item(), :] |
| | | yseq = x.argmax(dim=-1) |
| | | yseq = torch.unique_consecutive(yseq, dim=-1) |
| | | |
| | | results.append(result_i) |
| | | ibest_writer = None |
| | | if kwargs.get("output_dir") is not None: |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = self.writer[f"1best_recog"] |
| | | |
| | | mask = yseq != self.blank_id |
| | | token_int = yseq[mask].tolist() |
| | | |
| | | # Change integer-ids to tokens |
| | | text = tokenizer.decode(token_int) |
| | | |
| | | # result_i = {"key": key[i], "text": text} |
| | | # results.append(result_i) |
| | | |
| | | if ibest_writer is not None: |
| | | ibest_writer["text"][key[i]] = text |
| | | |
| | | if output_timestamp: |
| | | from itertools import groupby |
| | | timestamp = [] |
| | | tokens = tokenizer.text2tokens(text)[4:] |
| | | logits_speech = self.ctc.softmax(encoder_out)[i, 4:encoder_out_lens[i].item(), :] |
| | | pred = logits_speech.argmax(-1).cpu() |
| | | logits_speech[pred==self.blank_id, self.blank_id] = 0 |
| | | align = ctc_forced_align( |
| | | logits_speech.unsqueeze(0).float(), |
| | | torch.Tensor(token_int[4:]).unsqueeze(0).long().to(logits_speech.device), |
| | | (encoder_out_lens-4).long(), |
| | | torch.tensor(len(token_int)-4).unsqueeze(0).long().to(logits_speech.device), |
| | | ignore_id=self.ignore_id, |
| | | ) |
| | | pred = groupby(align[0, :encoder_out_lens[0]]) |
| | | _start = 0 |
| | | token_id = 0 |
| | | ts_max = encoder_out_lens[i] - 4 |
| | | for pred_token, pred_frame in pred: |
| | | _end = _start + len(list(pred_frame)) |
| | | if pred_token != 0: |
| | | ts_left = max((_start*60-30)/1000, 0) |
| | | ts_right = min((_end*60-30)/1000, (ts_max*60-30)/1000) |
| | | timestamp.append([tokens[token_id], ts_left, ts_right]) |
| | | token_id += 1 |
| | | _start = _end |
| | | timestamp = self.post(timestamp) |
| | | result_i = {"key": key[i], "text": text, "timestamp": timestamp} |
| | | results.append(result_i) |
| | | else: |
| | | result_i = {"key": key[i], "text": text} |
| | | results.append(result_i) |
| | | return results, meta_data |
| | | |
| | | def post(self, timestamp): |
| | | timestamp_new = [] |
| | | for i, t in enumerate(timestamp): |
| | | word, start, end = t |
| | | if word == '▁': |
| | | continue |
| | | if i == 0: |
| | | # timestamp_new.append([word, start, end]) |
| | | timestamp_new.append([int(start*1000), int(end*1000)]) |
| | | elif word.startswith("▁") or len(word) == 1 or not word[1].isalpha(): |
| | | word = word[1:] |
| | | # timestamp_new.append([word, start, end]) |
| | | timestamp_new.append([int(start*1000), int(end*1000)]) |
| | | else: |
| | | # timestamp_new[-1][0] += word |
| | | timestamp_new[-1][1] = int(end*1000) |
| | | return timestamp_new |
| | | def export(self, **kwargs): |
| | | from export_meta import export_rebuild_model |
| | | |
| | | if "max_seq_len" not in kwargs: |
| | | kwargs["max_seq_len"] = 512 |
| | | models = export_rebuild_model(model=self, **kwargs) |
| | | return models |
| | | |
| | | return results, meta_data |
| | | |