R1ckShi
2024-05-30 d097d0ca45472965d4411357d52adda5657691a2
update
5个文件已修改
33 ■■■■ 已修改文件
examples/industrial_data_pretraining/paraformer/export.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/contextual_paraformer/export_meta.py 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sanm/attention.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/whisper/model.py 9 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/export_utils.py 5 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/paraformer/export.py
@@ -10,7 +10,7 @@
from funasr import AutoModel
model = AutoModel(
    model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
    model="iic/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404",
)
res = model.export(type="torchscript", quantize=False)
funasr/models/contextual_paraformer/export_meta.py
@@ -16,6 +16,21 @@
        self.embedding = model.bias_embed
        model.bias_encoder.batch_first = False
        self.bias_encoder = model.bias_encoder
    def export_dummy_inputs(self):
        hotword = torch.tensor(
            [
                [10, 11, 12, 13, 14, 10, 11, 12, 13, 14],
                [100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
                [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [10, 11, 12, 13, 14, 10, 11, 12, 13, 14],
                [100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
                [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            ],
            dtype=torch.int32,
        )
        # hotword_length = torch.tensor([10, 2, 1], dtype=torch.int32)
        return (hotword)
def export_rebuild_model(model, **kwargs):
funasr/models/sanm/attention.py
@@ -780,7 +780,7 @@
        return q, k, v
    def forward_attention(self, value, scores, mask, ret_attn):
        scores = scores + mask
        scores = scores + mask.to(scores.device)
        self.attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
funasr/models/whisper/model.py
@@ -7,7 +7,10 @@
import torch.nn.functional as F
from torch import Tensor
from torch import nn
import whisper
# import whisper_timestamped as whisper
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.register import tables
@@ -108,8 +111,12 @@
        # decode the audio
        options = whisper.DecodingOptions(**kwargs.get("DecodingOptions", {}))
        result = whisper.decode(self.model, speech, options)
        result = whisper.decode(self.model, speech, language='english')
        # result = whisper.transcribe(self.model, speech)
        import pdb; pdb.set_trace()
        results = []
        result_i = {"key": key[0], "text": result.text}
funasr/utils/export_utils.py
@@ -83,7 +83,10 @@
    if device == 'cuda':
        model = model.cuda()
        dummy_input = tuple([i.cuda() for i in dummy_input])
        if isinstance(dummy_input, torch.Tensor):
            dummy_input = dummy_input.cuda()
        else:
            dummy_input = tuple([i.cuda() for i in dummy_input])
    # model_script = torch.jit.script(model)
    model_script = torch.jit.trace(model, dummy_input)