Yabin Li
2024-03-04 aba47683fd4b2984dbff7fc79b0f532fc2d9f6b7
funasr/models/llm_asr/model.py
@@ -20,8 +20,8 @@
from funasr.register import tables
@tables.register("model_classes", "LLMASRNAR")
class LLMASRNAR(nn.Module):
@tables.register("model_classes", "LLMASR")
class LLMASR(nn.Module):
    """ """
    
    def __init__(
@@ -216,8 +216,8 @@
        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
    
        audio_mask = kwargs.get("audio_mask")
        audio_token_lengths = audio_mask.sum(-1)
        audio_mask = kwargs.get("audio_mask", None)
        audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
        batch = {"speech": speech, "speech_lengths": speech_lengths}
        enc, enc_lens = self.audio_encoder.encode(**batch)
@@ -279,7 +279,7 @@
        
    
        prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(prompt)
        prompt_ids = self.tokenizer.encode(prompt_pre)
        prompt_ids = tokenizer.encode(prompt_pre)
        prompt_length = len(prompt_ids)
        prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(kwargs["device"])
@@ -294,24 +294,32 @@
        inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1)  # [prompt, audio]
        attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"])
        
        model_outputs = self.llm.generate(
            inputs_embeds=inputs_embeds,
            max_length=kwargs.get("max_length", 200),
            max_new_tokens=kwargs.get("max_new_tokens", 200),
            num_beams=kwargs.get("num_beams", 4),
            do_sample=kwargs.get("do_sample", False),
            min_length=kwargs.get("min_length", 1),
            top_p=kwargs.get("top_p", 1.0),
            repetition_penalty=kwargs.get("repetition_penalty", 1.0),
            length_penalty=kwargs.get("length_penalty", 1.0),
            temperature=kwargs.get("temperature", 1.0),
            attention_mask=attention_mask,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id
        )
        # model_outputs = self.llm.generate(
        #     inputs_embeds=inputs_embeds,
        #     max_length=kwargs.get("max_length", 200),
        #     max_new_tokens=kwargs.get("max_new_tokens", 200),
        #     num_beams=kwargs.get("num_beams", 4),
        #     do_sample=kwargs.get("do_sample", False),
        #     min_length=kwargs.get("min_length", 1),
        #     top_p=kwargs.get("top_p", 1.0),
        #     repetition_penalty=kwargs.get("repetition_penalty", 1.0),
        #     length_penalty=kwargs.get("length_penalty", 1.0),
        #     temperature=kwargs.get("temperature", 1.0),
        #     attention_mask=attention_mask,
        #     bos_token_id=tokenizer.bos_token_id,
        #     eos_token_id=tokenizer.eos_token_id,
        #     pad_token_id=tokenizer.pad_token_id
        # )
        text = tokenizer.batch_decode(model_outputs, add_special_tokens=False, skip_special_tokens=True)
        model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=None)
        preds = torch.argmax(model_outputs.logits, -1)
        text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True)
        text = text[0].split(': ')[-1]
        text = text.strip()
        # preds = torch.argmax(model_outputs.logits, -1)
        
        ibest_writer = None
        if kwargs.get("output_dir") is not None: