From 44a6b59468c552e5e554d1e7234efb5dcab0e0b4 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 04 三月 2024 19:09:21 +0800
Subject: [PATCH] Dev gzf (#1421)
---
funasr/models/qwen_audio/model.py | 85 ++++++++++++++++++++++++++++
funasr/models/qwen_audio/template.yaml | 46 +++++++++++++++
funasr/models/qwen_audio/__init__.py | 0
3 files changed, 131 insertions(+), 0 deletions(-)
diff --git a/funasr/models/qwen_audio/__init__.py b/funasr/models/qwen_audio/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/qwen_audio/__init__.py
diff --git a/funasr/models/qwen_audio/model.py b/funasr/models/qwen_audio/model.py
new file mode 100644
index 0000000..f09405a
--- /dev/null
+++ b/funasr/models/qwen_audio/model.py
@@ -0,0 +1,85 @@
+from dataclasses import dataclass
+from typing import Dict
+from typing import Iterable, Optional
+import time
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+from torch import nn
+import whisper
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+
+from funasr.register import tables
+
+
+
+@tables.register("model_classes", "WhisperWarp")
+class WhisperWarp(nn.Module):
+ def __init__(self, whisper_dims: dict, **kwargs):
+ super().__init__()
+ hub = kwargs.get("hub", "funasr")
+ if hub == "openai":
+ init_param_path = kwargs.get("init_param_path", "large-v3")
+ model = whisper.load_model(init_param_path)
+ else:
+ dims = whisper.model.ModelDimensions(**whisper_dims)
+ model = whisper.model.Whisper(dims=dims)
+
+ self.model = model
+
+ def forward(self, ):
+ pass
+
+ def inference(self,
+ data_in,
+ data_lengths=None,
+ key: list = None,
+ tokenizer=None,
+ frontend=None,
+ **kwargs,
+ ):
+ if kwargs.get("batch_size", 1) > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+
+ meta_data = {}
+ if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
+ speech, speech_lengths = data_in, data_lengths
+ if len(speech.shape) < 3:
+ speech = speech[None, :, :]
+ if speech_lengths is None:
+ speech_lengths = speech.shape[1]
+ else:
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
+ data_type=kwargs.get("data_type", "sound"),
+ tokenizer=tokenizer)
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+ frontend=frontend)
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
+ lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
+ meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
+
+ speech = speech.to(device=kwargs["device"])[0, :, :]
+ speech_lengths = speech_lengths.to(device=kwargs["device"])
+
+ # detect the spoken language
+ _, probs = self.model.detect_language(speech)
+ print(f"Detected language: {max(probs, key=probs.get)}")
+
+ # decode the audio
+ options = whisper.DecodingOptions(language=kwargs.get("language", None), fp16=False)
+ result = whisper.decode(self.model, speech, options)
+
+ results = []
+ result_i = {"key": key[0], "text": result.text}
+
+ results.append(result_i)
+
+ return results, meta_data
+
\ No newline at end of file
diff --git a/funasr/models/qwen_audio/template.yaml b/funasr/models/qwen_audio/template.yaml
new file mode 100644
index 0000000..40b902c
--- /dev/null
+++ b/funasr/models/qwen_audio/template.yaml
@@ -0,0 +1,46 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.register import tables
+# tables.print()
+
+# network architecture
+model: WhisperWarp
+model_conf:
+ lsm_weight: 0.1
+ length_normalized_loss: true
+ hub: funasr # openai
+ init_param_path: null # large-v2 or large-v3 if hub == "openai"
+
+
+
+# only use for hub == funasr,
+# if hub == openai, whisper_dims is automaticall download
+whisper_dims:
+ 'n_mels': 80
+ 'n_vocab': 51865
+ 'n_audio_ctx': 1500
+ 'n_audio_state': 1280
+ 'n_audio_head': 20
+ 'n_audio_layer': 32
+ 'n_text_ctx': 448
+ 'n_text_state': 1280
+ 'n_text_head': 20
+ 'n_text_layer': 32
+
+# frontend related
+frontend: WhisperFrontend
+frontend_conf:
+ fs: 16000
+ n_mels: 80
+ do_pad_trim: true
+
+tokenizer: WhisperTokenizer
+tokenizer_conf:
+ language: null
+ task: transcribe
+ is_multilingual: true
+ num_languages: 99
+
+scope_map: ['none', "model."]
--
Gitblit v1.9.1