From d19f48e17478be273584853568ac101c994c37e5 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 08 四月 2024 18:51:53 +0800
Subject: [PATCH] Dev gzf exp (#1593)

---
 funasr/models/sense_voice/model.py                 |    6 ++
 funasr/version.txt                                 |    2 
 funasr/models/llm_asr_nar/model.py                 |   59 +++++++++++++++++++++++++----
 funasr/models/sense_voice/whisper_lib/tokenizer.py |   11 ++++-
 funasr/models/sense_voice/whisper_lib/decoding.py  |   12 ++++-
 funasr/frontends/whisper_frontend.py               |   15 ++++++-
 funasr/models/sense_voice/whisper_lib/audio.py     |    6 +-
 7 files changed, 89 insertions(+), 22 deletions(-)

diff --git a/funasr/frontends/whisper_frontend.py b/funasr/frontends/whisper_frontend.py
index dd61f8e..acc99af 100644
--- a/funasr/frontends/whisper_frontend.py
+++ b/funasr/frontends/whisper_frontend.py
@@ -38,7 +38,13 @@
         if whisper_model == "large-v3" or whisper_model == "large":
             self.n_mels = 128
 
-        self.mel_filters = whisper.audio.mel_filters
+        filters_path = kwargs.get("filters_path", None)
+        self.filters_path = filters_path
+        if filters_path is not None:
+            from funasr.models.sense_voice.whisper_lib.audio import mel_filters
+            self.mel_filters = mel_filters
+        else:
+            self.mel_filters = whisper.audio.mel_filters
         self.do_pad_trim = do_pad_trim
         if do_pad_trim:
             self.pad_or_trim = whisper.pad_or_trim
@@ -61,8 +67,10 @@
 
         # whisper deletes the last frame by default (Shih-Lun)
         magnitudes = stft[..., :-1].abs() ** 2
-
-        filters = self.mel_filters(audio.device, self.n_mels)
+        if self.filters_path is not None:
+            filters = self.mel_filters(audio.device, self.n_mels, self.filters_path)
+        else:
+            filters = self.mel_filters(audio.device, self.n_mels)
         mel_spec = filters @ magnitudes
 
         log_spec = torch.clamp(mel_spec, min=1e-10).log10()
@@ -86,6 +94,7 @@
         batch_size = input.size(0)
         feats = []
         feats_lens = []
+        input = input.to(torch.float32)
         for i in range(batch_size):
             if self.do_pad_trim:
                 feat = self.pad_or_trim(input[i], self.pad_samples)
diff --git a/funasr/models/llm_asr_nar/model.py b/funasr/models/llm_asr_nar/model.py
index 30537cf..994259a 100644
--- a/funasr/models/llm_asr_nar/model.py
+++ b/funasr/models/llm_asr_nar/model.py
@@ -366,7 +366,7 @@
         decoder_conf: dict = None,
         ctc: str = None,
         ctc_conf: dict = None,
-        ctc_weight: float = 0.5,
+        ctc_weight: float = 0.0,
         llm: str = None,
         llm_conf: dict = None,
         adaptor: str = None,
@@ -473,6 +473,15 @@
         
         self.length_normalized_loss = length_normalized_loss
         self.beam_search = None
+        if ctc_weight > 0.0:
+            if ctc_conf is None:
+                ctc_conf = {}
+    
+            ctc = CTC(
+                odim=vocab_size, encoder_output_size=adaptor_conf["encoder_dim"], **ctc_conf
+            )
+        self.ctc_weight = ctc_weight
+        self.ctc = ctc
     
     def forward(
         self,
@@ -502,9 +511,23 @@
             speech_lengths = speech_lengths[:, 0]
         
         batch_size = speech.shape[0]
-        
+
+        stats = {}
         # audio encoder
-        encoder_out, encoder_out_lens, loss_pre = self.encode(speech, speech_lengths, audio_mask=audio_mask)
+        outs = self.encode(speech, speech_lengths, audio_mask=audio_mask)
+        enc, enc_lens = outs[0], outs[1]
+        encoder_out, encoder_out_lens, loss_pre = outs[2], outs[3], outs[4]
+        
+
+        # decoder: CTC branch
+        
+        if self.ctc_weight != 0.0:
+            loss_ctc, cer_ctc = self._calc_ctc_loss(
+                enc, enc_lens, text, text_lengths
+            )
+    
+            # Collect CTC branch stats
+            stats["loss_ctc"] = torch.clone(loss_ctc.detach()) if loss_ctc is not None else None
         
         # adaptor
         encoder_out = self.adaptor(encoder_out)
@@ -536,17 +559,19 @@
         # labels_ids[1:] ->  [prompt, input, target, eos] -> [-1, input, target, eos];
         model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids)
         loss_llm = model_outputs.loss
+        stats["loss_llm"] = torch.clone(loss_llm.detach())
+        if self.ctc_weight > 0.0:
+            loss_llm = self.ctc_weight * loss_ctc + loss_llm
         loss = loss_llm + loss_pre * self.predictor_weight
-        stats = {}
+        
         with torch.no_grad():
             preds = torch.argmax(model_outputs.logits, -1)
             acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
             stats["acc"] = acc_att
         
-        
         stats["loss_pre"] = torch.clone(loss_pre.detach())
-        stats["loss_llm"] = torch.clone(loss_llm.detach())
         stats["loss"] = torch.clone(loss.detach())
+        stats["batch_size"] = batch_size
         
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         if self.length_normalized_loss:
@@ -576,7 +601,24 @@
             if audio_token_lengths is not None:
                 loss_pre = self.criterion_pre(audio_token_lengths.type_as(pre_token_length), pre_token_length)
         
-        return pre_acoustic_embeds, pre_token_length, loss_pre
+        return enc, enc_lens, pre_acoustic_embeds, pre_token_length, loss_pre
+
+    def _calc_ctc_loss(
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
+    ):
+        # Calc CTC loss
+        loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+    
+        # Calc CER using CTC
+        cer_ctc = None
+        if not self.training and self.error_calculator is not None:
+            ys_hat = self.ctc.argmax(encoder_out).data
+            cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+        return loss_ctc, cer_ctc
     
     def inference(self,
                   data_in,
@@ -648,7 +690,8 @@
         else:
             inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)
         
-        inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out, pad[None, :, :]), dim=1)  # [prompt, audio]
+        # inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out, pad[None, :, :]), dim=1)  # [prompt, audio, pad]
+        inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1)  # [prompt, audio]
         attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"])
         
         # model_outputs = self.llm.generate(
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index d6552a6..521dec8 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -91,7 +91,11 @@
         # decode the audio
         
         # initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>")
-        options = whisper.DecodingOptions(language=language, fp16=False, without_timestamps=True, initial_prompt=initial_prompt)
+        
+        vocab_path = kwargs.get("vocab_path", None)
+        options = whisper.DecodingOptions(language=language, fp16=False, without_timestamps=True, initial_prompt=initial_prompt, vocab_path=vocab_path)
+
+        
         result = whisper.decode(self.model, speech, options)
 
         results = []
diff --git a/funasr/models/sense_voice/whisper_lib/audio.py b/funasr/models/sense_voice/whisper_lib/audio.py
index cf6c66a..52da32c 100644
--- a/funasr/models/sense_voice/whisper_lib/audio.py
+++ b/funasr/models/sense_voice/whisper_lib/audio.py
@@ -89,7 +89,7 @@
 
 
 @lru_cache(maxsize=None)
-def mel_filters(device, n_mels: int) -> torch.Tensor:
+def mel_filters(device, n_mels: int, filters_path: str=None) -> torch.Tensor:
     """
     load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
     Allows decoupling librosa dependency; saved using:
@@ -101,8 +101,8 @@
         )
     """
     assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
-
-    filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
+    if filters_path is None:
+        filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
     with np.load(filters_path, allow_pickle=False) as f:
         return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
 
diff --git a/funasr/models/sense_voice/whisper_lib/decoding.py b/funasr/models/sense_voice/whisper_lib/decoding.py
index 73b0262..caca114 100644
--- a/funasr/models/sense_voice/whisper_lib/decoding.py
+++ b/funasr/models/sense_voice/whisper_lib/decoding.py
@@ -119,6 +119,7 @@
 
     # FIX(funasr): sense vocie
     initial_prompt: str = None
+    vocab_path: str = None
 
 
 @dataclass(frozen=True)
@@ -527,6 +528,7 @@
             num_languages=model.num_languages,
             language=language,
             task=options.task,
+            vocab_path=options.vocab_path
         )
         self.tokenizer: Tokenizer = tokenizer
         self.options: DecodingOptions = self._verify_options(options)
@@ -616,10 +618,13 @@
                 + prompt_tokens[-(self.n_ctx // 2 - 1) :]
                 + tokens
             )
-        #FIX(gzf): sense vocie
+        #FIX(funasr): sense vocie
         if initial_prompt := self.options.initial_prompt:
-            tokens = self.tokenizer.encode(initial_prompt, allowed_special="all")
-            if self.options.language is None:
+            if self.options.language is not None:
+                initial_prompt = f"{initial_prompt}<|{self.options.language}|>"
+                tokens = self.tokenizer.encode(initial_prompt, allowed_special="all")
+            else:
+                tokens = self.tokenizer.encode(initial_prompt, allowed_special="all")
                 tokens += [0]
 
 
@@ -691,6 +696,7 @@
             if self.options.language is None:
                 # tokens[:, self.sot_index + 1] = lang_tokens  # write language tokens
                 languages = "".join([f"<|{language}|>" for language in languages])
+
                 n_audio = audio_features.shape[0]
                 lang_tokens = torch.tensor([self.tokenizer.encode(languages, allowed_special="all")] * n_audio).to(
                     audio_features.device)  # [n_audio, 1]
diff --git a/funasr/models/sense_voice/whisper_lib/tokenizer.py b/funasr/models/sense_voice/whisper_lib/tokenizer.py
index e941fb2..463ce83 100644
--- a/funasr/models/sense_voice/whisper_lib/tokenizer.py
+++ b/funasr/models/sense_voice/whisper_lib/tokenizer.py
@@ -363,8 +363,10 @@
 
 
 @lru_cache(maxsize=None)
-def get_encoding(name: str = "gpt2", num_languages: int = 99):
-    vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
+def get_encoding(name: str = "gpt2", num_languages: int = 99, vocab_path:str=None):
+    if vocab_path is None:
+        vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
+
     ranks = {
         base64.b64decode(token): int(rank)
         for token, rank in (line.split() for line in open(vocab_path) if line)
@@ -423,6 +425,7 @@
     language: Optional[str] = None,
     task: Optional[str] = None,  # Literal["transcribe", "translate", None]
     encoding_path: Optional[str] = None,
+    vocab_path: Optional[str] = None,
 ) -> Tokenizer:
     if language is not None:
         language = language.lower()
@@ -443,7 +446,9 @@
     if encoding_path is not None:
         encoding_name = encoding_path
 
-    encoding = get_encoding(name=encoding_name, num_languages=num_languages)
+
+    encoding = get_encoding(name=encoding_name, num_languages=num_languages, vocab_path=vocab_path)
+
 
     return Tokenizer(
         encoding=encoding, num_languages=num_languages, language=language, task=task
diff --git a/funasr/version.txt b/funasr/version.txt
index c2320f5..2fa3901 100644
--- a/funasr/version.txt
+++ b/funasr/version.txt
@@ -1 +1 @@
-1.0.20
+1.0.22
\ No newline at end of file

--
Gitblit v1.9.1