From aba47683fd4b2984dbff7fc79b0f532fc2d9f6b7 Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期一, 04 三月 2024 16:44:49 +0800
Subject: [PATCH] Update SDK_advanced_guide_offline_zh.md

---
 funasr/models/llm_asr/model.py |  198 +++++++++++++++++++++++--------------------------
 1 files changed, 93 insertions(+), 105 deletions(-)

diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index fcb301d..4139d8c 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -7,6 +7,7 @@
 import torch.nn.functional as F
 from torch.cuda.amp import autocast
 
+from funasr.models.scama.utils import sequence_mask
 from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
 from funasr.models.ctc.ctc import CTC
 from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
@@ -72,15 +73,13 @@
         hub = encoder_conf.get("hub", None)
         if hub == "funasr":
             from funasr import AutoModel
-            from funasr.models.scama.utils import sequence_mask
-            init_param_path = encoder_conf.get("hub", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
+            init_param_path = encoder_conf.get("init_param_path", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
             model = AutoModel(model=init_param_path, model_revision="v2.0.4")
-            frontend = model.kwargs.get("frontend")
+            # frontend = model.kwargs.get("frontend")
             model.model.decoder = None
             
-            self.model = model.model
-            self.frontend = frontend
-            self.mask_fn = sequence_mask
+            self.audio_encoder = model.model
+            # self.frontend = frontend
             
         elif hub == "hf":
             pass
@@ -102,8 +101,8 @@
                 device_map=None,
                 use_cache=None,
             )
-            freeze_llm = llm_conf.get("freeze_llm", True)
-            if freeze_llm:
+            freeze = llm_conf.get("freeze", True)
+            if freeze:
                 for name, param in model.named_parameters():
                     param.requires_grad = False
                 model.eval()
@@ -151,9 +150,9 @@
         text_lengths: torch.Tensor,
         input_ids: torch.Tensor,
         attention_mask:torch.Tensor,
-        labels_ids:torch.Tensor,
+        labels_ids: torch.Tensor,
         label_mask: torch.Tensor,
-        audio_mask:torch.Tensor,
+        audio_mask: torch.Tensor,
         **kwargs,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Encoder + Decoder + Calc loss
@@ -173,13 +172,14 @@
         batch_size = speech.shape[0]
         
         # audio encoder
-        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, audio_mask)
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, audio_mask=audio_mask)
         
         # adaptor
         encoder_out = self.adaptor(encoder_out)
 
         if input_ids is not None:
             input_ids[input_ids == -1] = 0
+            input_ids[input_ids == -100] = 0
             if hasattr(self.llm.model, "embed_tokens"):
                 inputs_embeds = self.llm.model.embed_tokens(input_ids)
             elif hasattr(self.llm.model.model, "embed_tokens"):
@@ -191,21 +191,20 @@
                 batch_size, token_num, dims = inputs_embeds.shape
                 _, l, _ = encoder_out.shape
                 encoder_outs_pad = F.pad(encoder_out, (0, 0, token_num-l-1, 1, 0, 0), value=0.0)
-                inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (~audio_mask[:, :, None])
+                inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (1.0-audio_mask[:, :, None])
                 inputs_embeds = F.pad(inputs_embeds[:, 1:, :], (0, 0, 0, 1, 0, 0), value=0.0)
 
-        model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
+        model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids)
         loss = model_outputs.loss
 
-        acc_att = -1
-        if self.metric:
-            with torch.no_grad():
-                preds = torch.argmax(model_outputs.logits, -1)
-                acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
 
         stats = {}
-        # Collect Attn branch stats
-        stats["acc"] = acc_att.detach()
+        with torch.no_grad():
+            preds = torch.argmax(model_outputs.logits, -1)
+            acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
+            stats["acc"] = acc_att
+
+        stats["loss"] = torch.clone(loss.detach())
 
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         if self.length_normalized_loss:
@@ -217,51 +216,20 @@
         self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
     
-        audio_mask = kwargs.get("audio_mask")
-        audio_token_lengths = audio_mask.sum(-1)
+        audio_mask = kwargs.get("audio_mask", None)
+        audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
 
         batch = {"speech": speech, "speech_lengths": speech_lengths}
-        enc, enc_lens = self.model.encode(**batch)
-        enc_mask = self.mask_fn(enc_lens, enc.size(1), device=enc.device)[:, None, :]
-        pre_acoustic_embeds, pre_token_length, _, _ = self.model.predictor(enc,
-                                                                           mask=enc_mask,
-                                                                           target_label_length=audio_token_lengths,
-                                                                           )
+        enc, enc_lens = self.audio_encoder.encode(**batch)
+        with autocast(False):
+            enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
+            pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(enc,
+                                                                               mask=enc_mask,
+                                                                               target_label_length=audio_token_lengths,
+                                                                               )
 
         return pre_acoustic_embeds, pre_token_length
-    
-    def _calc_att_loss(
-        self,
-        encoder_out: torch.Tensor,
-        encoder_out_lens: torch.Tensor,
-        ys_pad: torch.Tensor,
-        ys_pad_lens: torch.Tensor,
-    ):
-        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
-        ys_in_lens = ys_pad_lens + 1
-        
-        # 1. Forward decoder
-        decoder_out, _ = self.decoder(
-            encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
-        )
-        
-        # 2. Compute attention loss
-        loss_att = self.criterion_att(decoder_out, ys_out_pad)
-        acc_att = th_accuracy(
-            decoder_out.view(-1, self.vocab_size),
-            ys_out_pad,
-            ignore_label=self.ignore_id,
-        )
-        
-        # Compute cer/wer using attention-decoder
-        if self.training or self.error_calculator is None:
-            cer_att, wer_att = None, None
-        else:
-            ys_hat = decoder_out.argmax(dim=-1)
-            cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
-        
-        return loss_att, acc_att, cer_att, wer_att
-    
+
 
     def inference(self,
                   data_in,
@@ -272,14 +240,12 @@
                   **kwargs,
                   ):
         
+        prompt = kwargs.get("prompt", "Transcribe speech to text.")
+        
         if kwargs.get("batch_size", 1) > 1:
             raise NotImplementedError("batch decoding is not implemented")
-        
-        # init beamsearch
-        if self.beam_search is None:
-            logging.info("enable beam_search")
-            self.init_beam_search(**kwargs)
-            self.nbest = kwargs.get("nbest", 1)
+
+
         
         meta_data = {}
         if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
@@ -304,50 +270,72 @@
         
         speech = speech.to(device=kwargs["device"])
         speech_lengths = speech_lengths.to(device=kwargs["device"])
+        
         # Encoder
         encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-        if isinstance(encoder_out, tuple):
-            encoder_out = encoder_out[0]
+
+        # adaptor
+        encoder_out = self.adaptor(encoder_out)
         
-        # c. Passed the encoder result and the beam search
-        nbest_hyps = self.beam_search(
-            x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
-        )
+    
+        prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(prompt)
+        prompt_ids = tokenizer.encode(prompt_pre)
+        prompt_length = len(prompt_ids)
+        prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(kwargs["device"])
+
+
+        if hasattr(self.llm.model, "embed_tokens"):
+            inputs_embeds = self.llm.model.embed_tokens(prompt_ids)
+        elif hasattr(self.llm.model.model, "embed_tokens"):
+            inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids)
+        else:
+            inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)
+
+        inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1)  # [prompt, audio]
+        attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"])
         
-        nbest_hyps = nbest_hyps[: self.nbest]
+        # model_outputs = self.llm.generate(
+        #     inputs_embeds=inputs_embeds,
+        #     max_length=kwargs.get("max_length", 200),
+        #     max_new_tokens=kwargs.get("max_new_tokens", 200),
+        #     num_beams=kwargs.get("num_beams", 4),
+        #     do_sample=kwargs.get("do_sample", False),
+        #     min_length=kwargs.get("min_length", 1),
+        #     top_p=kwargs.get("top_p", 1.0),
+        #     repetition_penalty=kwargs.get("repetition_penalty", 1.0),
+        #     length_penalty=kwargs.get("length_penalty", 1.0),
+        #     temperature=kwargs.get("temperature", 1.0),
+        #     attention_mask=attention_mask,
+        #     bos_token_id=tokenizer.bos_token_id,
+        #     eos_token_id=tokenizer.eos_token_id,
+        #     pad_token_id=tokenizer.pad_token_id
+        # )
+
+
+        model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=None)
+        preds = torch.argmax(model_outputs.logits, -1)
+        text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True)
+
+        text = text[0].split(': ')[-1]
+        text = text.strip()
         
+        # preds = torch.argmax(model_outputs.logits, -1)
+        
+        ibest_writer = None
+        if kwargs.get("output_dir") is not None:
+            if not hasattr(self, "writer"):
+                self.writer = DatadirWriter(kwargs.get("output_dir"))
+            ibest_writer = self.writer[f"{0 + 1}best_recog"]
+
         results = []
-        b, n, d = encoder_out.size()
-        for i in range(b):
-            
-            for nbest_idx, hyp in enumerate(nbest_hyps):
-                ibest_writer = None
-                if kwargs.get("output_dir") is not None:
-                    if not hasattr(self, "writer"):
-                        self.writer = DatadirWriter(kwargs.get("output_dir"))
-                    ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
-                
-                # remove sos/eos and get results
-                last_pos = -1
-                if isinstance(hyp.yseq, list):
-                    token_int = hyp.yseq[1:last_pos]
-                else:
-                    token_int = hyp.yseq[1:last_pos].tolist()
-                
-                # remove blank symbol id, which is assumed to be 0
-                token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
-                
-                # Change integer-ids to tokens
-                token = tokenizer.ids2tokens(token_int)
-                text = tokenizer.tokens2text(token)
-                
-                text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
-                result_i = {"key": key[i], "token": token, "text": text_postprocessed}
-                results.append(result_i)
-                
-                if ibest_writer is not None:
-                    ibest_writer["token"][key[i]] = " ".join(token)
-                    ibest_writer["text"][key[i]] = text_postprocessed
+        result_i = {"key": key[0], "text": text}
+        results.append(result_i)
+
+        if ibest_writer is not None:
+            ibest_writer["text"][key[0]] = text
+        
+        
+        
         
         return results, meta_data
 

--
Gitblit v1.9.1