From eaf9dda9e4d970af3d09db695e9e10c83ef94e25 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 17 四月 2024 15:05:37 +0800
Subject: [PATCH] Dev gzf exp (#1624)

---
 funasr/models/sense_voice/model.py |  131 ++++++++++++++++++++++++++++++++++++++++++-
 1 files changed, 127 insertions(+), 4 deletions(-)

diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index 4ee2fa5..b5272a1 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -1,35 +1,158 @@
 from dataclasses import dataclass
 from typing import Dict
 from typing import Iterable, Optional
+import types
 import time
 import numpy as np
 import torch
 import torch.nn.functional as F
 from torch import Tensor
 from torch import nn
+from torch.cuda.amp import autocast
+from funasr.metrics.compute_acc import compute_accuracy
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.train_utils.device_funcs import force_gatherable
 from . import whisper_lib as whisper
 from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
 
 from funasr.register import tables
 
 
+
+
 @tables.register("model_classes", "SenseVoice")
 class SenseVoice(nn.Module):
     def __init__(self, *args, **kwargs):
         super().__init__()
-        hub = kwargs.get("hub", "funasr")
-
+        
         dims = kwargs.get("dims", {})
         dims = whisper.model.ModelDimensions(**dims)
         model = whisper.model.Whisper(dims=dims)
+        
+        # encoder
+        model.encoder.downsample_rate = kwargs.get("downsample_rate", 4)
+        model.encoder.use_padmask = kwargs.get("use_padmask", True)
+        from .encoder import sense_voice_encode_forward
+        model.encoder.forward = types.MethodType(sense_voice_encode_forward, model.encoder)
+        
+        # decoder
+        model.decoder.use_padmask = kwargs.get("use_padmask", True)
+        from .decoder import sense_voice_decode_forward
+        model.decoder.forward = types.MethodType(sense_voice_decode_forward, model.decoder)
         
         self.model = model
         
         self.encoder_output_size = self.model.dims.n_audio_state
         
-    def forward(self, ):
-        pass
+        self.activation_checkpoint = kwargs.get("activation_checkpoint", False)
+        self.ignore_id = kwargs.get("ignore_id", -1)
+        self.vocab_size = kwargs.get("vocab_size", -1)
+        self.length_normalized_loss = kwargs.get("length_normalized_loss", True)
+        self.criterion_att = LabelSmoothingLoss(
+            size=self.vocab_size,
+            padding_idx=self.ignore_id,
+            smoothing=kwargs.get("lsm_weight", 0.0),
+            normalize_length=self.length_normalized_loss,
+        )
+        
+        specaug = kwargs.get("specaug", None)
+        if specaug is not None:
+            specaug_class = tables.specaug_classes.get(specaug)
+            specaug = specaug_class(**kwargs.get("specaug_conf", {}))
+        self.specaug = specaug
+
+ 
+    def forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        **kwargs,
+    ):
+        target_mask = kwargs.get("target_mask", None)
     
+        # import pdb;
+        # pdb.set_trace()
+        if len(text_lengths.size()) > 1:
+            text_lengths = text_lengths[:, 0]
+        if len(speech_lengths.size()) > 1:
+            speech_lengths = speech_lengths[:, 0]
+    
+        batch_size = speech.shape[0]
+
+        if self.activation_checkpoint:
+            from torch.utils.checkpoint import checkpoint
+            encoder_out, encoder_out_lens = checkpoint(self.encode, speech, speech_lengths, use_reentrant=False)
+        else:
+            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+        loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
+            encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask
+        )
+        loss = loss_att
+        stats = {}
+        stats["acc"] = acc_att
+        stats["loss"] = torch.clone(loss.detach())
+        stats["batch_size"] = batch_size
+        
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        if self.length_normalized_loss:
+            batch_size = int((text_lengths + 1).sum())
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+        return loss, stats, weight
+
+    def encode(
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+    ) :
+        """Encoder. Note that this method is used by asr_inference.py
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                ind: int
+        """
+        with autocast(False):
+
+            # Data augmentation
+            if self.specaug is not None and self.training:
+                speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+
+        # Forward encoder
+        encoder_out, encoder_out_lens = self.model.encoder(speech.permute(0, 2, 1), speech_lengths)
+    
+        return encoder_out, encoder_out_lens
+
+
+    def _calc_att_loss(
+            self,
+            encoder_out: torch.Tensor,
+            encoder_out_lens: torch.Tensor,
+            ys_pad: torch.Tensor,
+            ys_pad_lens: torch.Tensor,
+            **kwargs,
+    ):
+        target_mask = kwargs.get("target_mask", None)
+        stats = {}
+        
+        # 1. Forward decoder
+        decoder_out = self.model.decoder(
+            x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
+        )
+        
+        # 2. Compute attention loss
+        mask = torch.ones_like(ys_pad) * (-1)
+        ys_pad_mask = (ys_pad * target_mask + mask * (1-target_mask)).to(torch.int64)
+        ys_pad_mask[ys_pad_mask == 0] = -1
+        loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:])
+
+        with torch.no_grad():
+            preds = torch.argmax(decoder_out, -1)
+            acc_att = compute_accuracy(preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id)
+
+        return loss_att, acc_att, None, None
+
+
     def inference(self,
                   data_in,
                   data_lengths=None,

--
Gitblit v1.9.1