From 8c7b7e5feb68fda1fc4ddd627bad0f915358149e Mon Sep 17 00:00:00 2001
From: Zhanzhao (Deo) Liang <liangzhanzhao1985@gmail.com>
Date: 星期三, 25 十二月 2024 16:40:29 +0800
Subject: [PATCH] fix export_meta import of sense voice (#2334)

---
 funasr/models/sense_voice/model.py |   95 ++++++++++++++++++++++++++++++++++++++---------
 1 files changed, 76 insertions(+), 19 deletions(-)

diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index 25e9faf..0e3ef5f 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -19,6 +19,7 @@
 
 
 from funasr.models.paraformer.search import Hypothesis
+from .utils.ctc_alignment import ctc_forced_align
 
 
 class SinusoidalPositionEncoder(torch.nn.Module):
@@ -94,7 +95,7 @@
         n_feat,
         dropout_rate,
         kernel_size,
-        sanm_shfit=0,
+        sanm_shift=0,
         lora_list=None,
         lora_rank=8,
         lora_alpha=16,
@@ -120,17 +121,17 @@
         )
         # padding
         left_padding = (kernel_size - 1) // 2
-        if sanm_shfit > 0:
-            left_padding = left_padding + sanm_shfit
+        if sanm_shift > 0:
+            left_padding = left_padding + sanm_shift
         right_padding = kernel_size - 1 - left_padding
         self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
 
-    def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
+    def forward_fsmn(self, inputs, mask, mask_shift_chunk=None):
         b, t, d = inputs.size()
         if mask is not None:
             mask = torch.reshape(mask, (b, -1, 1))
-            if mask_shfit_chunk is not None:
-                mask = mask * mask_shfit_chunk
+            if mask_shift_chunk is not None:
+                mask = mask * mask_shift_chunk
             inputs = inputs * mask
 
         x = inputs.transpose(1, 2)
@@ -210,7 +211,7 @@
 
         return self.linear_out(x)  # (batch, time1, d_model)
 
-    def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+    def forward(self, x, mask, mask_shift_chunk=None, mask_att_chunk_encoder=None):
         """Compute scaled dot product attention.
 
         Args:
@@ -225,7 +226,7 @@
 
         """
         q_h, k_h, v_h, v = self.forward_qkv(x)
-        fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
+        fsmn_memory = self.forward_fsmn(v, mask, mask_shift_chunk)
         q_h = q_h * self.d_k ** (-0.5)
         scores = torch.matmul(q_h, k_h.transpose(-2, -1))
         att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
@@ -325,7 +326,7 @@
         self.stochastic_depth_rate = stochastic_depth_rate
         self.dropout_rate = dropout_rate
 
-    def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+    def forward(self, x, mask, cache=None, mask_shift_chunk=None, mask_att_chunk_encoder=None):
         """Compute encoded features.
 
         Args:
@@ -362,7 +363,7 @@
                     self.self_attn(
                         x,
                         mask,
-                        mask_shfit_chunk=mask_shfit_chunk,
+                        mask_shift_chunk=mask_shift_chunk,
                         mask_att_chunk_encoder=mask_att_chunk_encoder,
                     ),
                 ),
@@ -378,7 +379,7 @@
                     self.self_attn(
                         x,
                         mask,
-                        mask_shfit_chunk=mask_shfit_chunk,
+                        mask_shift_chunk=mask_shift_chunk,
                         mask_att_chunk_encoder=mask_att_chunk_encoder,
                     )
                 )
@@ -387,7 +388,7 @@
                     self.self_attn(
                         x,
                         mask,
-                        mask_shfit_chunk=mask_shfit_chunk,
+                        mask_shift_chunk=mask_shift_chunk,
                         mask_att_chunk_encoder=mask_att_chunk_encoder,
                     )
                 )
@@ -401,7 +402,7 @@
         if not self.normalize_before:
             x = self.norm2(x)
 
-        return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
+        return x, mask, cache, mask_shift_chunk, mask_att_chunk_encoder
 
     def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
         """Compute encoded features.
@@ -468,7 +469,7 @@
         positionwise_conv_kernel_size: int = 1,
         padding_idx: int = -1,
         kernel_size: int = 11,
-        sanm_shfit: int = 0,
+        sanm_shift: int = 0,
         selfattention_layer_type: str = "sanm",
         **kwargs,
     ):
@@ -493,7 +494,7 @@
             output_size,
             attention_dropout_rate,
             kernel_size,
-            sanm_shfit,
+            sanm_shift,
         )
         encoder_selfattn_layer_args = (
             attention_heads,
@@ -501,7 +502,7 @@
             output_size,
             attention_dropout_rate,
             kernel_size,
-            sanm_shfit,
+            sanm_shift,
         )
 
         self.encoders0 = nn.ModuleList(
@@ -555,7 +556,8 @@
         ilens: torch.Tensor,
     ):
         """Embed positions in tensor."""
-        masks = sequence_mask(ilens, device=ilens.device)[:, None, :]
+        maxlen = xs_pad.shape[1]
+        masks = sequence_mask(ilens, maxlen=maxlen, device=ilens.device)[:, None, :]
 
         xs_pad *= self.output_size() ** 0.5
 
@@ -856,6 +858,8 @@
 
         use_itn = kwargs.get("use_itn", False)
         textnorm = kwargs.get("text_norm", None)
+        output_timestamp = kwargs.get("output_timestamp", False)
+
         if textnorm is None:
             textnorm = "withitn" if use_itn else "woitn"
         textnorm_query = self.embed(
@@ -904,13 +908,64 @@
             # Change integer-ids to tokens
             text = tokenizer.decode(token_int)
 
-            result_i = {"key": key[i], "text": text}
-            results.append(result_i)
+            # result_i = {"key": key[i], "text": text}
+            # results.append(result_i)
 
             if ibest_writer is not None:
                 ibest_writer["text"][key[i]] = text
 
+            if output_timestamp:
+                from itertools import groupby
+
+                timestamp = []
+                tokens = tokenizer.text2tokens(text)[4:]
+                logits_speech = self.ctc.softmax(encoder_out)[i, 4 : encoder_out_lens[i].item(), :]
+                pred = logits_speech.argmax(-1).cpu()
+                logits_speech[pred == self.blank_id, self.blank_id] = 0
+                align = ctc_forced_align(
+                    logits_speech.unsqueeze(0).float(),
+                    torch.Tensor(token_int[4:]).unsqueeze(0).long().to(logits_speech.device),
+                    (encoder_out_lens - 4).long(),
+                    torch.tensor(len(token_int) - 4).unsqueeze(0).long().to(logits_speech.device),
+                    ignore_id=self.ignore_id,
+                )
+                pred = groupby(align[0, : encoder_out_lens[0]])
+                _start = 0
+                token_id = 0
+                ts_max = encoder_out_lens[i] - 4
+                for pred_token, pred_frame in pred:
+                    _end = _start + len(list(pred_frame))
+                    if pred_token != 0:
+                        ts_left = max((_start * 60 - 30) / 1000, 0)
+                        ts_right = min((_end * 60 - 30) / 1000, (ts_max * 60 - 30) / 1000)
+                        timestamp.append([tokens[token_id], ts_left, ts_right])
+                        token_id += 1
+                    _start = _end
+                timestamp = self.post(timestamp)
+                result_i = {"key": key[i], "text": text, "timestamp": timestamp}
+                results.append(result_i)
+            else:
+                result_i = {"key": key[i], "text": text}
+                results.append(result_i)
         return results, meta_data
+
+    def post(self, timestamp):
+        timestamp_new = []
+        for i, t in enumerate(timestamp):
+            word, start, end = t
+            if word == "鈻�":
+                continue
+            if i == 0:
+                # timestamp_new.append([word, start, end])
+                timestamp_new.append([int(start * 1000), int(end * 1000)])
+            elif word.startswith("鈻�") or len(word) == 1 or not word[1].isalpha():
+                word = word[1:]
+                # timestamp_new.append([word, start, end])
+                timestamp_new.append([int(start * 1000), int(end * 1000)])
+            else:
+                # timestamp_new[-1][0] += word
+                timestamp_new[-1][1] = int(end * 1000)
+        return timestamp_new
 
     def export(self, **kwargs):
         from .export_meta import export_rebuild_model
@@ -919,3 +974,5 @@
             kwargs["max_seq_len"] = 512
         models = export_rebuild_model(model=self, **kwargs)
         return models
+
+        return results, meta_data

--
Gitblit v1.9.1