From 607073619cedf2c114e1589aa6d5953d171f33bf Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 27 四月 2023 19:27:49 +0800
Subject: [PATCH] update
---
funasr/models/e2e_diar_sond.py | 28 +
funasr/models/e2e_tp.py | 26
funasr/models/e2e_asr.py | 21
funasr/models/e2e_asr_paraformer.py | 511 +++++++++++++++++++++++++--------
funasr/models/e2e_asr_mfcca.py | 148 +++++----
funasr/models/e2e_vad.py | 45 ++
funasr/models/e2e_sv.py | 30 +
funasr/models/e2e_uni_asr.py | 22
funasr/models/data2vec.py | 14
funasr/models/e2e_diar_eend_ola.py | 3
10 files changed, 579 insertions(+), 269 deletions(-)
diff --git a/funasr/models/data2vec.py b/funasr/models/data2vec.py
index 380c137..e5bd640 100644
--- a/funasr/models/data2vec.py
+++ b/funasr/models/data2vec.py
@@ -12,7 +12,11 @@
import torch
from typeguard import check_argument_types
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
@@ -30,11 +34,11 @@
def __init__(
self,
- frontend: Optional[torch.nn.Module],
- specaug: Optional[torch.nn.Module],
- normalize: Optional[torch.nn.Module],
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
- encoder: torch.nn.Module,
+ encoder: AbsEncoder,
):
assert check_argument_types()
@@ -53,7 +57,6 @@
speech_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -102,7 +105,6 @@
speech_lengths: torch.Tensor,
):
"""Frontend + Encoder.
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
diff --git a/funasr/models/e2e_asr.py b/funasr/models/e2e_asr.py
index 779d703..8410ede 100644
--- a/funasr/models/e2e_asr.py
+++ b/funasr/models/e2e_asr.py
@@ -13,18 +13,22 @@
import torch
from typeguard import check_argument_types
+from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
)
from funasr.models.ctc import CTC
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.decoder.abs_decoder import AbsDecoder
-from funasr.models.base_model import FunASRModel
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
+from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.modules.add_sos_eos import add_sos_eos
from funasr.modules.e2e_asr_common import ErrorCalculator
from funasr.modules.nets_utils import th_accuracy
from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -43,9 +47,11 @@
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[AbsFrontend],
- specaug: Optional[torch.nn.Module],
- normalize: Optional[torch.nn.Module],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
+ preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
+ postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@@ -127,7 +133,6 @@
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -243,7 +248,6 @@
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -325,9 +329,7 @@
ys_pad_lens: torch.Tensor,
) -> torch.Tensor:
"""Compute negative log likelihood(nll) from transformer-decoder
-
Normally, this function is called in batchify_nll.
-
Args:
encoder_out: (Batch, Length, Dim)
encoder_out_lens: (Batch,)
@@ -364,7 +366,6 @@
batch_size: int = 100,
):
"""Compute negative log likelihood(nll) from transformer-decoder
-
To avoid OOM, this fuction seperate the input into batches.
Then call nll for each batch and combine and return results.
Args:
diff --git a/funasr/models/e2e_asr_mfcca.py b/funasr/models/e2e_asr_mfcca.py
index efdd90d..44679ef 100644
--- a/funasr/models/e2e_asr_mfcca.py
+++ b/funasr/models/e2e_asr_mfcca.py
@@ -17,10 +17,13 @@
)
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-from funasr.models.base_model import FunASRModel
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
-
+from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -32,30 +35,36 @@
import pdb
import random
import math
+
+
class MFCCA(FunASRModel):
- """CTC-attention hybrid Encoder-Decoder model"""
+ """
+ Author: Audio, Speech and Language Processing Group (ASLP@NPU), Northwestern Polytechnical University
+ MFCCA:Multi-Frame Cross-Channel attention for multi-speaker ASR in Multi-party meeting scenario
+ https://arxiv.org/abs/2210.05265
+ """
def __init__(
- self,
- vocab_size: int,
- token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[torch.nn.Module],
- specaug: Optional[torch.nn.Module],
- normalize: Optional[torch.nn.Module],
- preencoder: Optional[AbsPreEncoder],
- encoder: torch.nn.Module,
- decoder: AbsDecoder,
- ctc: CTC,
- rnnt_decoder: None,
- ctc_weight: float = 0.5,
- ignore_id: int = -1,
- lsm_weight: float = 0.0,
- mask_ratio: float = 0.0,
- length_normalized_loss: bool = False,
- report_cer: bool = True,
- report_wer: bool = True,
- sym_space: str = "<space>",
- sym_blank: str = "<blank>",
+ self,
+ vocab_size: int,
+ token_list: Union[Tuple[str, ...], List[str]],
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
+ preencoder: Optional[AbsPreEncoder],
+ encoder: AbsEncoder,
+ decoder: AbsDecoder,
+ ctc: CTC,
+ rnnt_decoder: None,
+ ctc_weight: float = 0.5,
+ ignore_id: int = -1,
+ lsm_weight: float = 0.0,
+ mask_ratio: float = 0.0,
+ length_normalized_loss: bool = False,
+ report_cer: bool = True,
+ report_wer: bool = True,
+ sym_space: str = "<space>",
+ sym_blank: str = "<blank>",
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -69,10 +78,9 @@
self.ignore_id = ignore_id
self.ctc_weight = ctc_weight
self.token_list = token_list.copy()
-
+
self.mask_ratio = mask_ratio
-
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
@@ -106,14 +114,13 @@
self.error_calculator = None
def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -123,22 +130,22 @@
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
- speech.shape[0]
- == speech_lengths.shape[0]
- == text.shape[0]
- == text_lengths.shape[0]
+ speech.shape[0]
+ == speech_lengths.shape[0]
+ == text.shape[0]
+ == text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
- #pdb.set_trace()
- if(speech.dim()==3 and speech.size(2)==8 and self.mask_ratio !=0):
+ # pdb.set_trace()
+ if (speech.dim() == 3 and speech.size(2) == 8 and self.mask_ratio != 0):
rate_num = random.random()
- #rate_num = 0.1
- if(rate_num<=self.mask_ratio):
- retain_channel = math.ceil(random.random() *8)
- if(retain_channel>1):
- speech = speech[:,:,torch.randperm(8)[0:retain_channel].sort().values]
+ # rate_num = 0.1
+ if (rate_num <= self.mask_ratio):
+ retain_channel = math.ceil(random.random() * 8)
+ if (retain_channel > 1):
+ speech = speech[:, :, torch.randperm(8)[0:retain_channel].sort().values]
else:
- speech = speech[:,:,torch.randperm(8)[0]]
- #pdb.set_trace()
+ speech = speech[:, :, torch.randperm(8)[0]]
+ # pdb.set_trace()
batch_size = speech.shape[0]
# for data-parallel
text = text[:, : text_lengths.max()]
@@ -188,20 +195,19 @@
return loss, stats, weight
def collect_feats(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
) -> Dict[str, torch.Tensor]:
feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
return {"feats": feats, "feats_lengths": feats_lengths}
def encode(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -220,14 +226,14 @@
# Pre-encoder, e.g. used for raw input data
if self.preencoder is not None:
feats, feats_lengths = self.preencoder(feats, feats_lengths)
- #pdb.set_trace()
+ # pdb.set_trace()
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, channel_size)
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
speech.size(0),
)
- if(encoder_out.dim()==4):
+ if (encoder_out.dim() == 4):
assert encoder_out.size(2) <= encoder_out_lens.max(), (
encoder_out.size(),
encoder_out_lens.max(),
@@ -241,7 +247,7 @@
return encoder_out, encoder_out_lens
def _extract_feats(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
assert speech_lengths.dim() == 1, speech_lengths.shape
# for data-parallel
@@ -259,11 +265,11 @@
return feats, feats_lengths, channel_size
def _calc_att_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
):
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_in_lens = ys_pad_lens + 1
@@ -291,14 +297,14 @@
return loss_att, acc_att, cer_att, wer_att
def _calc_ctc_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
):
# Calc CTC loss
- if(encoder_out.dim()==4):
+ if (encoder_out.dim() == 4):
encoder_out = encoder_out.mean(1)
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
@@ -310,10 +316,10 @@
return loss_ctc, cer_ctc
def _calc_rnnt_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
):
- raise NotImplementedError
+ raise NotImplementedError
\ No newline at end of file
diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index f414e4f..9d4f106 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -12,22 +12,25 @@
import numpy as np
from typeguard import check_argument_types
+from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
)
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.e2e_asr_common import ErrorCalculator
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.predictor.cif import mae_loss
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-from funasr.models.base_model import FunASRModel
+from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.modules.add_sos_eos import add_sos_eos
from funasr.modules.nets_utils import make_pad_mask, pad_list
from funasr.modules.nets_utils import th_accuracy
from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.models.base_model import FunASRModel
from funasr.models.predictor.cif import CifPredictorV3
-
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -40,7 +43,7 @@
class Paraformer(FunASRModel):
"""
- Author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
@@ -49,10 +52,12 @@
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[torch.nn.Module],
- specaug: Optional[torch.nn.Module],
- normalize: Optional[torch.nn.Module],
- encoder: torch.nn.Module,
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
+ preencoder: Optional[AbsPreEncoder],
+ encoder: AbsEncoder,
+ postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@@ -92,7 +97,16 @@
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
+ self.preencoder = preencoder
+ self.postencoder = postencoder
self.encoder = encoder
+
+ if not hasattr(self.encoder, "interctc_use_conditioning"):
+ self.encoder.interctc_use_conditioning = False
+ if self.encoder.interctc_use_conditioning:
+ self.encoder.conditioning_layer = torch.nn.Linear(
+ vocab_size, self.encoder.output_size()
+ )
self.error_calculator = None
@@ -138,7 +152,6 @@
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -161,7 +174,9 @@
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ intermediate_outs = None
if isinstance(encoder_out, tuple):
+ intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
loss_att, acc_att, cer_att, wer_att = None, None, None, None
@@ -178,6 +193,30 @@
# Collect CTC branch stats
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
stats["cer_ctc"] = cer_ctc
+
+ # Intermediate CTC (optional)
+ loss_interctc = 0.0
+ if self.interctc_weight != 0.0 and intermediate_outs is not None:
+ for layer_idx, intermediate_out in intermediate_outs:
+ # we assume intermediate_out has the same length & padding
+ # as those of encoder_out
+ loss_ic, cer_ic = self._calc_ctc_loss(
+ intermediate_out, encoder_out_lens, text, text_lengths
+ )
+ loss_interctc = loss_interctc + loss_ic
+
+ # Collect Intermedaite CTC stats
+ stats["loss_interctc_layer{}".format(layer_idx)] = (
+ loss_ic.detach() if loss_ic is not None else None
+ )
+ stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
+
+ loss_interctc = loss_interctc / len(intermediate_outs)
+
+ # calculate whole encoder loss
+ loss_ctc = (
+ 1 - self.interctc_weight
+ ) * loss_ctc + self.interctc_weight * loss_interctc
# 2b. Attention decoder branch
if self.ctc_weight != 1.0:
@@ -229,7 +268,6 @@
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -246,8 +284,29 @@
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
+ # Pre-encoder, e.g. used for raw input data
+ if self.preencoder is not None:
+ feats, feats_lengths = self.preencoder(feats, feats_lengths)
+
# 4. Forward encoder
- encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
+ # feats: (Batch, Length, Dim)
+ # -> encoder_out: (Batch, Length2, Dim2)
+ if self.encoder.interctc_use_conditioning:
+ encoder_out, encoder_out_lens, _ = self.encoder(
+ feats, feats_lengths, ctc=self.ctc
+ )
+ else:
+ encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
+ intermediate_outs = None
+ if isinstance(encoder_out, tuple):
+ intermediate_outs = encoder_out[1]
+ encoder_out = encoder_out[0]
+
+ # Post-encoder, e.g. NLU
+ if self.postencoder is not None:
+ encoder_out, encoder_out_lens = self.postencoder(
+ encoder_out, encoder_out_lens
+ )
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
@@ -258,45 +317,18 @@
encoder_out_lens.max(),
)
+ if intermediate_outs is not None:
+ return (encoder_out, intermediate_outs), encoder_out_lens
+
return encoder_out, encoder_out_lens
-
- def encode_chunk(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Frontend + Encoder. Note that this method is used by asr_inference.py
-
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- """
- with autocast(False):
- # 1. Extract feats
- feats, feats_lengths = self._extract_feats(speech, speech_lengths)
-
- # 2. Data augmentation
- if self.specaug is not None and self.training:
- feats, feats_lengths = self.specaug(feats, feats_lengths)
-
- # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
- if self.normalize is not None:
- feats, feats_lengths = self.normalize(feats, feats_lengths)
-
- # 4. Forward encoder
- encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"])
-
- return encoder_out, torch.tensor([encoder_out.size(1)])
def calc_predictor(self, encoder_out, encoder_out_lens):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask,
- ignore_id=self.ignore_id)
- return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
-
- def calc_predictor_chunk(self, encoder_out, cache=None):
-
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor.forward_chunk(encoder_out, cache["encoder"])
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None,
+ encoder_out_mask,
+ ignore_id=self.ignore_id)
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
@@ -307,14 +339,6 @@
decoder_out = decoder_outs[0]
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out, ys_pad_lens
-
- def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
- decoder_outs = self.decoder.forward_chunk(
- encoder_out, sematic_embeds, cache["decoder"]
- )
- decoder_out = decoder_outs
- decoder_out = torch.log_softmax(decoder_out, dim=-1)
- return decoder_out
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
@@ -342,9 +366,7 @@
ys_pad_lens: torch.Tensor,
) -> torch.Tensor:
"""Compute negative log likelihood(nll) from transformer-decoder
-
Normally, this function is called in batchify_nll.
-
Args:
encoder_out: (Batch, Length, Dim)
encoder_out_lens: (Batch,)
@@ -381,7 +403,6 @@
batch_size: int = 100,
):
"""Compute negative log likelihood(nll) from transformer-decoder
-
To avoid OOM, this fuction seperate the input into batches.
Then call nll for each batch and combine and return results.
Args:
@@ -521,9 +542,186 @@
return loss_ctc, cer_ctc
-class ParaformerBert(Paraformer):
+class ParaformerOnline(Paraformer):
"""
Author: Speech Lab, Alibaba Group, China
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2206.08317
+ """
+
+ def __init__(
+ self, *args, **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Frontend + Encoder + Decoder + Calc loss
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ assert text_lengths.dim() == 1, text_lengths.shape
+ # Check that batch_size is unified
+ assert (
+ speech.shape[0]
+ == speech_lengths.shape[0]
+ == text.shape[0]
+ == text_lengths.shape[0]
+ ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+ batch_size = speech.shape[0]
+ self.step_cur += 1
+ # for data-parallel
+ text = text[:, : text_lengths.max()]
+ speech = speech[:, :speech_lengths.max()]
+
+ # 1. Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ intermediate_outs = None
+ if isinstance(encoder_out, tuple):
+ intermediate_outs = encoder_out[1]
+ encoder_out = encoder_out[0]
+
+ loss_att, acc_att, cer_att, wer_att = None, None, None, None
+ loss_ctc, cer_ctc = None, None
+ loss_pre = None
+ stats = dict()
+
+ # 1. CTC branch
+ if self.ctc_weight != 0.0:
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # Collect CTC branch stats
+ stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+ stats["cer_ctc"] = cer_ctc
+
+ # Intermediate CTC (optional)
+ loss_interctc = 0.0
+ if self.interctc_weight != 0.0 and intermediate_outs is not None:
+ for layer_idx, intermediate_out in intermediate_outs:
+ # we assume intermediate_out has the same length & padding
+ # as those of encoder_out
+ loss_ic, cer_ic = self._calc_ctc_loss(
+ intermediate_out, encoder_out_lens, text, text_lengths
+ )
+ loss_interctc = loss_interctc + loss_ic
+
+ # Collect Intermedaite CTC stats
+ stats["loss_interctc_layer{}".format(layer_idx)] = (
+ loss_ic.detach() if loss_ic is not None else None
+ )
+ stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
+
+ loss_interctc = loss_interctc / len(intermediate_outs)
+
+ # calculate whole encoder loss
+ loss_ctc = (
+ 1 - self.interctc_weight
+ ) * loss_ctc + self.interctc_weight * loss_interctc
+
+ # 2b. Attention decoder branch
+ if self.ctc_weight != 1.0:
+ loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # 3. CTC-Att loss definition
+ if self.ctc_weight == 0.0:
+ loss = loss_att + loss_pre * self.predictor_weight
+ elif self.ctc_weight == 1.0:
+ loss = loss_ctc
+ else:
+ loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
+
+ # Collect Attn branch stats
+ stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+ stats["acc"] = acc_att
+ stats["cer"] = cer_att
+ stats["wer"] = wer_att
+ stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+
+ stats["loss"] = torch.clone(loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def encode_chunk(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Frontend + Encoder. Note that this method is used by asr_inference.py
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ """
+ with autocast(False):
+ # 1. Extract feats
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+ # 2. Data augmentation
+ if self.specaug is not None and self.training:
+ feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+ # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+ # Pre-encoder, e.g. used for raw input data
+ if self.preencoder is not None:
+ feats, feats_lengths = self.preencoder(feats, feats_lengths)
+
+ # 4. Forward encoder
+ # feats: (Batch, Length, Dim)
+ # -> encoder_out: (Batch, Length2, Dim2)
+ if self.encoder.interctc_use_conditioning:
+ encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(
+ feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc
+ )
+ else:
+ encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"])
+ intermediate_outs = None
+ if isinstance(encoder_out, tuple):
+ intermediate_outs = encoder_out[1]
+ encoder_out = encoder_out[0]
+
+ # Post-encoder, e.g. NLU
+ if self.postencoder is not None:
+ encoder_out, encoder_out_lens = self.postencoder(
+ encoder_out, encoder_out_lens
+ )
+
+ if intermediate_outs is not None:
+ return (encoder_out, intermediate_outs), encoder_out_lens
+
+ return encoder_out, torch.tensor([encoder_out.size(1)])
+
+ def calc_predictor_chunk(self, encoder_out, cache=None):
+
+ pre_acoustic_embeds, pre_token_length = \
+ self.predictor.forward_chunk(encoder_out, cache["encoder"])
+ return pre_acoustic_embeds, pre_token_length
+
+ def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
+ decoder_outs = self.decoder.forward_chunk(
+ encoder_out, sematic_embeds, cache["decoder"]
+ )
+ decoder_out = decoder_outs
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+ return decoder_out
+
+
+class ParaformerBert(Paraformer):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer2: advanced paraformer with LFMMI and bert for non-autoregressive end-to-end speech recognition
"""
@@ -531,11 +729,11 @@
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[torch.nn.Module],
- specaug: Optional[torch.nn.Module],
- normalize: Optional[torch.nn.Module],
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
- encoder: torch.nn.Module,
+ encoder: AbsEncoder,
postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
@@ -690,7 +888,6 @@
embed_lengths: torch.Tensor = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -799,74 +996,73 @@
class BiCifParaformer(Paraformer):
-
"""
Paraformer model with an extra cif predictor
to conduct accurate timestamp prediction
"""
def __init__(
- self,
- vocab_size: int,
- token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[torch.nn.Module],
- specaug: Optional[torch.nn.Module],
- normalize: Optional[torch.nn.Module],
- preencoder: Optional[AbsPreEncoder],
- encoder: torch.nn.Module,
- postencoder: Optional[AbsPostEncoder],
- decoder: AbsDecoder,
- ctc: CTC,
- ctc_weight: float = 0.5,
- interctc_weight: float = 0.0,
- ignore_id: int = -1,
- blank_id: int = 0,
- sos: int = 1,
- eos: int = 2,
- lsm_weight: float = 0.0,
- length_normalized_loss: bool = False,
- report_cer: bool = True,
- report_wer: bool = True,
- sym_space: str = "<space>",
- sym_blank: str = "<blank>",
- extract_feats_in_collect_stats: bool = True,
- predictor = None,
- predictor_weight: float = 0.0,
- predictor_bias: int = 0,
- sampling_ratio: float = 0.2,
+ self,
+ vocab_size: int,
+ token_list: Union[Tuple[str, ...], List[str]],
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
+ preencoder: Optional[AbsPreEncoder],
+ encoder: AbsEncoder,
+ postencoder: Optional[AbsPostEncoder],
+ decoder: AbsDecoder,
+ ctc: CTC,
+ ctc_weight: float = 0.5,
+ interctc_weight: float = 0.0,
+ ignore_id: int = -1,
+ blank_id: int = 0,
+ sos: int = 1,
+ eos: int = 2,
+ lsm_weight: float = 0.0,
+ length_normalized_loss: bool = False,
+ report_cer: bool = True,
+ report_wer: bool = True,
+ sym_space: str = "<space>",
+ sym_blank: str = "<blank>",
+ extract_feats_in_collect_stats: bool = True,
+ predictor=None,
+ predictor_weight: float = 0.0,
+ predictor_bias: int = 0,
+ sampling_ratio: float = 0.2,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
super().__init__(
- vocab_size=vocab_size,
- token_list=token_list,
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- preencoder=preencoder,
- encoder=encoder,
- postencoder=postencoder,
- decoder=decoder,
- ctc=ctc,
- ctc_weight=ctc_weight,
- interctc_weight=interctc_weight,
- ignore_id=ignore_id,
- blank_id=blank_id,
- sos=sos,
- eos=eos,
- lsm_weight=lsm_weight,
- length_normalized_loss=length_normalized_loss,
- report_cer=report_cer,
- report_wer=report_wer,
- sym_space=sym_space,
- sym_blank=sym_blank,
- extract_feats_in_collect_stats=extract_feats_in_collect_stats,
- predictor=predictor,
- predictor_weight=predictor_weight,
- predictor_bias=predictor_bias,
- sampling_ratio=sampling_ratio,
+ vocab_size=vocab_size,
+ token_list=token_list,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ preencoder=preencoder,
+ encoder=encoder,
+ postencoder=postencoder,
+ decoder=decoder,
+ ctc=ctc,
+ ctc_weight=ctc_weight,
+ interctc_weight=interctc_weight,
+ ignore_id=ignore_id,
+ blank_id=blank_id,
+ sos=sos,
+ eos=eos,
+ lsm_weight=lsm_weight,
+ length_normalized_loss=length_normalized_loss,
+ report_cer=report_cer,
+ report_wer=report_wer,
+ sym_space=sym_space,
+ sym_blank=sym_blank,
+ extract_feats_in_collect_stats=extract_feats_in_collect_stats,
+ predictor=predictor,
+ predictor_weight=predictor_weight,
+ predictor_bias=predictor_bias,
+ sampling_ratio=sampling_ratio,
)
assert isinstance(self.predictor, CifPredictorV3), "BiCifParaformer should use CIFPredictorV3"
@@ -888,21 +1084,77 @@
loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2)
return loss_pre2
-
+
+ def _calc_att_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ if self.predictor_bias == 1:
+ _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_pad_lens = ys_pad_lens + self.predictor_bias
+ pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad,
+ encoder_out_mask,
+ ignore_id=self.ignore_id)
+
+ # 0. sampler
+ decoder_out_1st = None
+ if self.sampling_ratio > 0.0:
+ if self.step_cur < 2:
+ logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+ sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
+ pre_acoustic_embeds)
+ else:
+ if self.step_cur < 2:
+ logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+ sematic_embeds = pre_acoustic_embeds
+
+ # 1. Forward decoder
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+
+ if decoder_out_1st is None:
+ decoder_out_1st = decoder_out
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_pad)
+ acc_att = th_accuracy(
+ decoder_out_1st.view(-1, self.vocab_size),
+ ys_pad,
+ ignore_label=self.ignore_id,
+ )
+ loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+
+ # Compute cer/wer using attention-decoder
+ if self.training or self.error_calculator is None:
+ cer_att, wer_att = None, None
+ else:
+ ys_hat = decoder_out_1st.argmax(dim=-1)
+ cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+ return loss_att, acc_att, cer_att, wer_att, loss_pre
+
def calc_predictor(self, encoder_out, encoder_out_lens):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out, None, encoder_out_mask,
- ignore_id=self.ignore_id)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out,
+ None,
+ encoder_out_mask,
+ ignore_id=self.ignore_id)
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
-
+
def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
- encoder_out_mask,
- token_num)
+ encoder_out_mask,
+ token_num)
return ds_alphas, ds_cif_peak, us_alphas, us_peaks
def forward(
@@ -913,7 +1165,6 @@
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -996,7 +1247,8 @@
elif self.ctc_weight == 1.0:
loss = loss_ctc
else:
- loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
+ loss = self.ctc_weight * loss_ctc + (
+ 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
# Collect Attn branch stats
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
@@ -1022,11 +1274,11 @@
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[torch.nn.Module],
- specaug: Optional[torch.nn.Module],
- normalize: Optional[torch.nn.Module],
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
- encoder: torch.nn.Module,
+ encoder: AbsEncoder,
postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
@@ -1120,7 +1372,6 @@
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -1504,4 +1755,4 @@
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
var_dict_tf[name_tf].shape))
- return var_dict_torch_update
+ return var_dict_torch_update
\ No newline at end of file
diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py
index b4a3fa2..da7c674 100644
--- a/funasr/models/e2e_diar_eend_ola.py
+++ b/funasr/models/e2e_diar_eend_ola.py
@@ -15,8 +15,8 @@
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
from funasr.modules.eend_ola.utils.power import generate_mapping_dict
-from funasr.models.base_model import FunASRModel
from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
pass
@@ -91,7 +91,6 @@
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
diff --git a/funasr/models/e2e_diar_sond.py b/funasr/models/e2e_diar_sond.py
index dc7135f..9c3fb92 100644
--- a/funasr/models/e2e_diar_sond.py
+++ b/funasr/models/e2e_diar_sond.py
@@ -14,9 +14,15 @@
from torch.nn import functional as F
from typeguard import check_argument_types
+from funasr.modules.nets_utils import to_device
from funasr.modules.nets_utils import make_pad_mask
-from funasr.models.base_model import FunASRModel
+from funasr.models.decoder.abs_decoder import AbsDecoder
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.models.base_model import FunASRModel
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy
from funasr.utils.misc import int2vec
@@ -30,16 +36,20 @@
class DiarSondModel(FunASRModel):
- """Speaker overlap-aware neural diarization model
- reference: https://arxiv.org/abs/2211.10243
+ """
+ Author: Speech Lab, Alibaba Group, China
+ SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
+ https://arxiv.org/abs/2211.10243
+ TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
+ https://arxiv.org/abs/2303.05397
"""
def __init__(
self,
vocab_size: int,
- frontend: Optional[torch.nn.Module],
- specaug: Optional[torch.nn.Module],
- normalize: Optional[torch.nn.Module],
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
encoder: torch.nn.Module,
speaker_encoder: Optional[torch.nn.Module],
ci_scorer: torch.nn.Module,
@@ -105,7 +115,6 @@
binary_labels_lengths: torch.Tensor = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss
-
Args:
speech: (Batch, samples) or (Batch, frames, input_size)
speech_lengths: (Batch,) default None for chunk interator,
@@ -342,7 +351,7 @@
cd_simi = torch.reshape(cd_simi, [bb, self.max_spk_num, tt, 1])
cd_simi = cd_simi.squeeze(dim=3).permute([0, 2, 1])
- if isinstance(self.ci_scorer, torch.nn.Module):
+ if isinstance(self.ci_scorer, AbsEncoder):
ci_simi = self.ci_scorer(ge_in, ge_len)[0]
ci_simi = torch.reshape(ci_simi, [bb, self.max_spk_num, tt]).permute([0, 2, 1])
else:
@@ -381,7 +390,6 @@
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch,)
@@ -481,4 +489,4 @@
speaker_miss,
speaker_falarm,
speaker_error,
- )
+ )
\ No newline at end of file
diff --git a/funasr/models/e2e_sv.py b/funasr/models/e2e_sv.py
index 582c25d..bd82c7c 100644
--- a/funasr/models/e2e_sv.py
+++ b/funasr/models/e2e_sv.py
@@ -1,3 +1,8 @@
+
+"""
+Author: Speech Lab, Alibaba Group, China
+"""
+
import logging
from contextlib import contextmanager
from distutils.version import LooseVersion
@@ -10,11 +15,22 @@
import torch
from typeguard import check_argument_types
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.losses.label_smoothing_loss import (
+ LabelSmoothingLoss, # noqa: H301
+)
+from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-from funasr.models.base_model import FunASRModel
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.modules.add_sos_eos import add_sos_eos
+from funasr.modules.e2e_asr_common import ErrorCalculator
+from funasr.modules.nets_utils import th_accuracy
from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -32,11 +48,11 @@
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[torch.nn.Module],
- specaug: Optional[torch.nn.Module],
- normalize: Optional[torch.nn.Module],
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
- encoder: torch.nn.Module,
+ encoder: AbsEncoder,
postencoder: Optional[AbsPostEncoder],
pooling_layer: torch.nn.Module,
decoder: AbsDecoder,
@@ -65,7 +81,6 @@
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -206,7 +221,6 @@
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -256,4 +270,4 @@
else:
# No frontend and no feature extract
feats, feats_lengths = speech, speech_lengths
- return feats, feats_lengths
+ return feats, feats_lengths
\ No newline at end of file
diff --git a/funasr/models/e2e_tp.py b/funasr/models/e2e_tp.py
index c5dc63c..39419c8 100644
--- a/funasr/models/e2e_tp.py
+++ b/funasr/models/e2e_tp.py
@@ -2,19 +2,23 @@
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
+from typing import List
from typing import Optional
from typing import Tuple
+from typing import Union
import torch
+import numpy as np
from typeguard import check_argument_types
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.predictor.cif import mae_loss
-from funasr.models.base_model import FunASRModel
from funasr.modules.add_sos_eos import add_sos_eos
from funasr.modules.nets_utils import make_pad_mask, pad_list
from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.models.base_model import FunASRModel
from funasr.models.predictor.cif import CifPredictorV3
-
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -25,15 +29,15 @@
yield
-class TimestampPredictor(FunASRModel):
+class TimestampPredictor(AbsESPnetModel):
"""
- Author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
"""
def __init__(
self,
- frontend: Optional[torch.nn.Module],
- encoder: torch.nn.Module,
+ frontend: Optional[AbsFrontend],
+ encoder: AbsEncoder,
predictor: CifPredictorV3,
predictor_bias: int = 0,
token_list=None,
@@ -51,7 +55,7 @@
self.predictor_bias = predictor_bias
self.criterion_pre = mae_loss()
self.token_list = token_list
-
+
def forward(
self,
speech: torch.Tensor,
@@ -60,7 +64,6 @@
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -108,7 +111,6 @@
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -123,7 +125,7 @@
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
return encoder_out, encoder_out_lens
-
+
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -146,8 +148,8 @@
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
- encoder_out_mask,
- token_num)
+ encoder_out_mask,
+ token_num)
return ds_alphas, ds_cif_peak, us_alphas, us_peaks
def collect_feats(
diff --git a/funasr/models/e2e_uni_asr.py b/funasr/models/e2e_uni_asr.py
index 0c53389..d08ea37 100644
--- a/funasr/models/e2e_uni_asr.py
+++ b/funasr/models/e2e_uni_asr.py
@@ -17,10 +17,13 @@
LabelSmoothingLoss, # noqa: H301
)
from funasr.models.ctc import CTC
-from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.decoder.abs_decoder import AbsDecoder
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
from funasr.modules.streaming_utils.chunk_utilis import sequence_mask
@@ -37,18 +40,18 @@
class UniASR(FunASRModel):
"""
- Author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
"""
def __init__(
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[torch.nn.Module],
- specaug: Optional[torch.nn.Module],
- normalize: Optional[torch.nn.Module],
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
- encoder: torch.nn.Module,
+ encoder: AbsEncoder,
postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
@@ -176,7 +179,6 @@
decoding_ind: int = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -466,7 +468,6 @@
self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -530,7 +531,6 @@
ind: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -624,9 +624,7 @@
ys_pad_lens: torch.Tensor,
) -> torch.Tensor:
"""Compute negative log likelihood(nll) from transformer-decoder
-
Normally, this function is called in batchify_nll.
-
Args:
encoder_out: (Batch, Length, Dim)
encoder_out_lens: (Batch,)
@@ -663,7 +661,6 @@
batch_size: int = 100,
):
"""Compute negative log likelihood(nll) from transformer-decoder
-
To avoid OOM, this fuction seperate the input into batches.
Then call nll for each batch and combine and return results.
Args:
@@ -1069,4 +1066,3 @@
ys_hat = self.ctc2.argmax(encoder_out).data
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
return loss_ctc, cer_ctc
-
diff --git a/funasr/models/e2e_vad.py b/funasr/models/e2e_vad.py
index ff37429..e477750 100644
--- a/funasr/models/e2e_vad.py
+++ b/funasr/models/e2e_vad.py
@@ -35,6 +35,12 @@
class VADXOptions:
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
+
def __init__(
self,
sample_rate: int = 16000,
@@ -99,6 +105,12 @@
class E2EVadSpeechBufWithDoa(object):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
+
def __init__(self):
self.start_ms = 0
self.end_ms = 0
@@ -117,6 +129,12 @@
class E2EVadFrameProb(object):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
+
def __init__(self):
self.noise_prob = 0.0
self.speech_prob = 0.0
@@ -126,6 +144,12 @@
class WindowDetector(object):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
+
def __init__(self, window_size_ms: int, sil_to_speech_time: int,
speech_to_sil_time: int, frame_size_ms: int):
self.window_size_ms = window_size_ms
@@ -192,6 +216,12 @@
class E2EVadModel(nn.Module):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
+
def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], frontend=None):
super(E2EVadModel, self).__init__()
self.vad_opts = VADXOptions(**vad_post_args)
@@ -286,7 +316,7 @@
0.000001))
def ComputeScores(self, feats: torch.Tensor, in_cache: Dict[str, torch.Tensor]) -> None:
- scores = self.encoder(feats, in_cache) # return B * T * D
+ scores = self.encoder(feats, in_cache).to('cpu') # return B * T * D
assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
self.vad_opts.nn_eval_block_size = scores.shape[1]
self.frm_cnt += scores.shape[1] # count total frames
@@ -444,7 +474,7 @@
- 1)) / self.vad_opts.noise_frame_num_used_for_snr
return frame_state
-
+
def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
is_final: bool = False
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
@@ -460,8 +490,9 @@
segment_batch = []
if len(self.output_data_buf) > 0:
for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
- if not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
- i].contain_seg_end_point:
+ if not is_final and (
+ not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
+ i].contain_seg_end_point):
continue
segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
segment_batch.append(segment)
@@ -474,11 +505,11 @@
return segments, in_cache
def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
- is_final: bool = False, max_end_sil: int = 800
- ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
+ is_final: bool = False, max_end_sil: int = 800
+ ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
self.waveform = waveform # compute decibel for each frame
-
+
self.ComputeScores(feats, in_cache)
self.ComputeDecibel()
if not is_final:
--
Gitblit v1.9.1