From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
funasr/models/sa_asr/e2e_sa_asr.py | 191 ++++++++++++++++++++++++-----------------------
1 files changed, 97 insertions(+), 94 deletions(-)
diff --git a/funasr/models/sa_asr/e2e_sa_asr.py b/funasr/models/sa_asr/e2e_sa_asr.py
index e0cb69a..f4827b2 100644
--- a/funasr/models/sa_asr/e2e_sa_asr.py
+++ b/funasr/models/sa_asr/e2e_sa_asr.py
@@ -14,19 +14,17 @@
import torch.nn.functional as F
from funasr.layers.abs_normalize import AbsNormalize
-from funasr.losses.label_smoothing_loss import (
- LabelSmoothingLoss, NllLoss # noqa: H301
-)
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, NllLoss # noqa: H301
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.frontends.abs_frontend import AbsFrontend
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.transformer.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.metrics import ErrorCalculator
-from funasr.models.transformer.utils.nets_utils import th_accuracy
+from funasr.metrics.compute_acc import th_accuracy
from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
@@ -43,28 +41,28 @@
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
- self,
- vocab_size: int,
- max_spk_num: int,
- token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[AbsFrontend],
- specaug: Optional[AbsSpecAug],
- normalize: Optional[AbsNormalize],
- asr_encoder: AbsEncoder,
- spk_encoder: torch.nn.Module,
- decoder: AbsDecoder,
- ctc: CTC,
- spk_weight: float = 0.5,
- ctc_weight: float = 0.5,
- interctc_weight: float = 0.0,
- ignore_id: int = -1,
- lsm_weight: float = 0.0,
- length_normalized_loss: bool = False,
- report_cer: bool = True,
- report_wer: bool = True,
- sym_space: str = "<space>",
- sym_blank: str = "<blank>",
- extract_feats_in_collect_stats: bool = True,
+ self,
+ vocab_size: int,
+ max_spk_num: int,
+ token_list: Union[Tuple[str, ...], List[str]],
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
+ asr_encoder: AbsEncoder,
+ spk_encoder: torch.nn.Module,
+ decoder: AbsDecoder,
+ ctc: CTC,
+ spk_weight: float = 0.5,
+ ctc_weight: float = 0.5,
+ interctc_weight: float = 0.0,
+ ignore_id: int = -1,
+ lsm_weight: float = 0.0,
+ length_normalized_loss: bool = False,
+ report_cer: bool = True,
+ report_wer: bool = True,
+ sym_space: str = "<space>",
+ sym_blank: str = "<blank>",
+ extract_feats_in_collect_stats: bool = True,
):
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
@@ -75,7 +73,7 @@
self.sos = 1
self.eos = 2
self.vocab_size = vocab_size
- self.max_spk_num=max_spk_num
+ self.max_spk_num = max_spk_num
self.ignore_id = ignore_id
self.spk_weight = spk_weight
self.ctc_weight = ctc_weight
@@ -96,7 +94,6 @@
)
self.error_calculator = None
-
# we set self.decoder = None in the CTC mode since
# self.decoder parameters were never used and PyTorch complained
@@ -133,15 +130,15 @@
self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- profile: torch.Tensor,
- profile_lengths: torch.Tensor,
- text_id: torch.Tensor,
- text_id_lengths: torch.Tensor
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ profile: torch.Tensor,
+ profile_lengths: torch.Tensor,
+ text_id: torch.Tensor,
+ text_id_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
@@ -156,10 +153,7 @@
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
- speech.shape[0]
- == speech_lengths.shape[0]
- == text.shape[0]
- == text_lengths.shape[0]
+ speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
batch_size = speech.shape[0]
@@ -183,7 +177,6 @@
asr_encoder_out, encoder_out_lens, text, text_lengths
)
-
# Intermediate CTC (optional)
loss_interctc = 0.0
if self.interctc_weight != 0.0 and intermediate_outs is not None:
@@ -204,15 +197,20 @@
loss_interctc = loss_interctc / len(intermediate_outs)
# calculate whole encoder loss
- loss_ctc = (
- 1 - self.interctc_weight
- ) * loss_ctc + self.interctc_weight * loss_interctc
-
+ loss_ctc = (1 - self.interctc_weight) * loss_ctc + self.interctc_weight * loss_interctc
# 2b. Attention decoder branch
if self.ctc_weight != 1.0:
loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att = self._calc_att_loss(
- asr_encoder_out, spk_encoder_out, encoder_out_lens, text, text_lengths, profile, profile_lengths, text_id, text_id_lengths
+ asr_encoder_out,
+ spk_encoder_out,
+ encoder_out_lens,
+ text,
+ text_lengths,
+ profile,
+ profile_lengths,
+ text_id,
+ text_id_lengths,
)
# 3. CTC-Att loss definition
@@ -227,7 +225,6 @@
loss = loss_asr
else:
loss = self.spk_weight * loss_spk + (1 - self.spk_weight) * loss_asr
-
stats = dict(
loss=loss.detach(),
@@ -247,11 +244,11 @@
return loss, stats, weight
def collect_feats(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
) -> Dict[str, torch.Tensor]:
if self.extract_feats_in_collect_stats:
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
@@ -266,7 +263,7 @@
return {"feats": feats, "feats_lengths": feats_lengths}
def encode(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
@@ -291,9 +288,7 @@
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
if self.asr_encoder.interctc_use_conditioning:
- encoder_out, encoder_out_lens, _ = self.asr_encoder(
- feats, feats_lengths, ctc=self.ctc
- )
+ encoder_out, encoder_out_lens, _ = self.asr_encoder(feats, feats_lengths, ctc=self.ctc)
else:
encoder_out, encoder_out_lens, _ = self.asr_encoder(feats, feats_lengths)
intermediate_outs = None
@@ -303,10 +298,12 @@
encoder_out_spk_ori = self.spk_encoder(feats_raw, feats_lengths)[0]
# import ipdb;ipdb.set_trace()
- if encoder_out_spk_ori.size(1)!=encoder_out.size(1):
- encoder_out_spk=F.interpolate(encoder_out_spk_ori.transpose(-2,-1), size=(encoder_out.size(1)), mode='nearest').transpose(-2,-1)
+ if encoder_out_spk_ori.size(1) != encoder_out.size(1):
+ encoder_out_spk = F.interpolate(
+ encoder_out_spk_ori.transpose(-2, -1), size=(encoder_out.size(1)), mode="nearest"
+ ).transpose(-2, -1)
else:
- encoder_out_spk=encoder_out_spk_ori
+ encoder_out_spk = encoder_out_spk_ori
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
@@ -327,7 +324,7 @@
return encoder_out, encoder_out_lens, encoder_out_spk
def _extract_feats(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
assert speech_lengths.dim() == 1, speech_lengths.shape
@@ -346,11 +343,11 @@
return feats, feats_lengths
def nll(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
) -> torch.Tensor:
"""Compute negative log likelihood(nll) from transformer-decoder
@@ -384,12 +381,12 @@
return nll
def batchify_nll(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- batch_size: int = 100,
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ batch_size: int = 100,
):
"""Compute negative log likelihood(nll) from transformer-decoder
@@ -431,28 +428,34 @@
return nll
def _calc_att_loss(
- self,
- asr_encoder_out: torch.Tensor,
- spk_encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- profile: torch.Tensor,
- profile_lens: torch.Tensor,
- text_id: torch.Tensor,
- text_id_lengths: torch.Tensor
+ self,
+ asr_encoder_out: torch.Tensor,
+ spk_encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ profile: torch.Tensor,
+ profile_lens: torch.Tensor,
+ text_id: torch.Tensor,
+ text_id_lengths: torch.Tensor,
):
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
decoder_out, weights_no_pad, _ = self.decoder(
- asr_encoder_out, spk_encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens, profile, profile_lens
+ asr_encoder_out,
+ spk_encoder_out,
+ encoder_out_lens,
+ ys_in_pad,
+ ys_in_lens,
+ profile,
+ profile_lens,
)
- spk_num_no_pad=weights_no_pad.size(-1)
- pad=(0,self.max_spk_num-spk_num_no_pad)
- weights=F.pad(weights_no_pad, pad, mode='constant', value=0)
+ spk_num_no_pad = weights_no_pad.size(-1)
+ pad = (0, self.max_spk_num - spk_num_no_pad)
+ weights = F.pad(weights_no_pad, pad, mode="constant", value=0)
# pre_id=weights.argmax(-1)
# pre_text=decoder_out.argmax(-1)
@@ -467,7 +470,7 @@
loss_att = self.criterion_att(decoder_out, ys_out_pad)
loss_spk = self.criterion_spk(torch.log(weights), text_id)
- acc_spk= th_accuracy(
+ acc_spk = th_accuracy(
weights.view(-1, self.max_spk_num),
text_id,
ignore_label=self.ignore_id,
@@ -488,11 +491,11 @@
return loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att
def _calc_ctc_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
):
# Calc CTC loss
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
--
Gitblit v1.9.1