From e8f68b44dd65ede9278e89a5277a5cf66d546375 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 01 七月 2024 11:09:01 +0800
Subject: [PATCH] v1.0.29
---
funasr/models/sense_voice/model.py | 26 +++++++++++++++++---------
1 files changed, 17 insertions(+), 9 deletions(-)
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index 97f1b19..9db6539 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -10,12 +10,13 @@
from torch import Tensor
from torch import nn
from torch.cuda.amp import autocast
-from funasr.metrics.compute_acc import compute_accuracy
+from funasr.metrics.compute_acc import compute_accuracy, th_accuracy
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.train_utils.device_funcs import force_gatherable
from . import whisper_lib as whisper
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils.datadir_writer import DatadirWriter
+from funasr.models.ctc.ctc import CTC
from funasr.register import tables
@@ -661,9 +662,11 @@
else:
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
- encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask
- )
+ with autocast(False):
+ loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
+ encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask
+ )
+
loss = loss_att
stats = {}
stats["acc"] = acc_att
@@ -1035,6 +1038,7 @@
self.length_normalized_loss = length_normalized_loss
self.beam_search = None
self.activation_checkpoint = kwargs.get("activation_checkpoint", False)
+ self.encoder_output_size = encoder_output_size
def forward(
self,
@@ -1256,7 +1260,7 @@
if isinstance(task, str):
task = [task]
task = "".join([f"<|{x}|>" for x in task])
-
+
sos = kwargs.get("model_conf").get("sos")
if isinstance(sos, str):
initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
@@ -1270,7 +1274,9 @@
language = DecodingOptions.get("language", None)
language = None if language == "auto" else language
initial_prompt = kwargs.get("initial_prompt", f"{task}")
- initial_prompt_lid = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
+ initial_prompt_lid = (
+ f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
+ )
initial_prompt_lid_int = tokenizer.encode(initial_prompt_lid, allowed_special="all")
sos_int = [sos] + initial_prompt_lid_int
eos = kwargs.get("model_conf").get("eos")
@@ -1303,9 +1309,7 @@
)
self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
- encoder_out, encoder_out_lens = self.encode(
- speech[None, :, :], speech_lengths
- )
+ encoder_out, encoder_out_lens = self.encode(speech[None, :, :], speech_lengths)
if text_token_int is not None:
i = 0
@@ -1384,3 +1388,7 @@
ibest_writer["text"][key[i]] = text
return results, meta_data
+
+
+from funasr.models.paraformer.search import Hypothesis
+from funasr.utils import postprocess_utils
--
Gitblit v1.9.1