From e772c7eb9e5439aaff2f599e79f0b3c8fdca22c2 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 21 二月 2024 16:55:02 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR merge
---
funasr/models/bat/model.py | 706 +++++++++++++++++++++++++++++++---------------------------
1 files changed, 373 insertions(+), 333 deletions(-)
diff --git a/funasr/models/bat/model.py b/funasr/models/bat/model.py
index 3fed9aa..8e76b45 100644
--- a/funasr/models/bat/model.py
+++ b/funasr/models/bat/model.py
@@ -3,137 +3,145 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
-
+import time
import torch
import logging
-import torch.nn as nn
+from contextlib import contextmanager
+from typing import Dict, Optional, Tuple
+from distutils.version import LooseVersion
-from typing import Dict, List, Optional, Tuple, Union
-
-
-from torch.cuda.amp import autocast
-from funasr.losses.label_smoothing_loss import (
- LabelSmoothingLoss, # noqa: H301
-)
-
-from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.register import tables
+from funasr.utils import postprocess_utils
+from funasr.utils.datadir_writer import DatadirWriter
from funasr.train_utils.device_funcs import force_gatherable
+from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.models.transformer.scorers.length_bonus import LengthBonus
+from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+from funasr.models.transducer.beam_search_transducer import BeamSearchTransducer
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
-
-class BATModel(nn.Module):
- """BATModel module definition.
-
- Args:
- vocab_size: Size of complete vocabulary (w/ EOS and blank included).
- token_list: List of token
- frontend: Frontend module.
- specaug: SpecAugment module.
- normalize: Normalization module.
- encoder: Encoder module.
- decoder: Decoder module.
- joint_network: Joint Network module.
- transducer_weight: Weight of the Transducer loss.
- fastemit_lambda: FastEmit lambda value.
- auxiliary_ctc_weight: Weight of auxiliary CTC loss.
- auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs.
- auxiliary_lm_loss_weight: Weight of auxiliary LM loss.
- auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing.
- ignore_id: Initial padding ID.
- sym_space: Space symbol.
- sym_blank: Blank Symbol
- report_cer: Whether to report Character Error Rate during validation.
- report_wer: Whether to report Word Error Rate during validation.
- extract_feats_in_collect_stats: Whether to use extract_feats stats collection.
-
- """
-
+@tables.register("model_classes", "BAT") # TODO: BAT training
+class BAT(torch.nn.Module):
def __init__(
self,
-
- cif_weight: float = 1.0,
+ frontend: Optional[str] = None,
+ frontend_conf: Optional[Dict] = None,
+ specaug: Optional[str] = None,
+ specaug_conf: Optional[Dict] = None,
+ normalize: str = None,
+ normalize_conf: Optional[Dict] = None,
+ encoder: str = None,
+ encoder_conf: Optional[Dict] = None,
+ decoder: str = None,
+ decoder_conf: Optional[Dict] = None,
+ joint_network: str = None,
+ joint_network_conf: Optional[Dict] = None,
+ transducer_weight: float = 1.0,
fastemit_lambda: float = 0.0,
auxiliary_ctc_weight: float = 0.0,
auxiliary_ctc_dropout_rate: float = 0.0,
auxiliary_lm_loss_weight: float = 0.0,
auxiliary_lm_loss_smoothing: float = 0.0,
+ input_size: int = 80,
+ vocab_size: int = -1,
ignore_id: int = -1,
- sym_space: str = "<space>",
- sym_blank: str = "<blank>",
- report_cer: bool = True,
- report_wer: bool = True,
- extract_feats_in_collect_stats: bool = True,
+ blank_id: int = 0,
+ sos: int = 1,
+ eos: int = 2,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
- r_d: int = 5,
- r_u: int = 5,
+ # report_cer: bool = True,
+ # report_wer: bool = True,
+ # sym_space: str = "<space>",
+ # sym_blank: str = "<blank>",
+ # extract_feats_in_collect_stats: bool = True,
+ share_embedding: bool = False,
+ # preencoder: Optional[AbsPreEncoder] = None,
+ # postencoder: Optional[AbsPostEncoder] = None,
**kwargs,
- ) -> None:
- """Construct an BATModel object."""
+ ):
+
super().__init__()
- # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
- self.blank_id = 0
+ if specaug is not None:
+ specaug_class = tables.specaug_classes.get(specaug)
+ specaug = specaug_class(**specaug_conf)
+ if normalize is not None:
+ normalize_class = tables.normalize_classes.get(normalize)
+ normalize = normalize_class(**normalize_conf)
+ encoder_class = tables.encoder_classes.get(encoder)
+ encoder = encoder_class(input_size=input_size, **encoder_conf)
+ encoder_output_size = encoder.output_size()
+
+ decoder_class = tables.decoder_classes.get(decoder)
+ decoder = decoder_class(
+ vocab_size=vocab_size,
+ **decoder_conf,
+ )
+ decoder_output_size = decoder.output_size
+
+ joint_network_class = tables.joint_network_classes.get(joint_network)
+ joint_network = joint_network_class(
+ vocab_size,
+ encoder_output_size,
+ decoder_output_size,
+ **joint_network_conf,
+ )
+
+ self.criterion_transducer = None
+ self.error_calculator = None
+
+ self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
+ self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
+
+ if self.use_auxiliary_ctc:
+ self.ctc_lin = torch.nn.Linear(encoder.output_size(), vocab_size)
+ self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
+
+ if self.use_auxiliary_lm_loss:
+ self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
+ self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
+
+ self.transducer_weight = transducer_weight
+ self.fastemit_lambda = fastemit_lambda
+
+ self.auxiliary_ctc_weight = auxiliary_ctc_weight
+ self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
+ self.blank_id = blank_id
+ self.sos = sos if sos is not None else vocab_size - 1
+ self.eos = eos if eos is not None else vocab_size - 1
self.vocab_size = vocab_size
self.ignore_id = ignore_id
- self.token_list = token_list.copy()
-
- self.sym_space = sym_space
- self.sym_blank = sym_blank
-
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
-
self.encoder = encoder
self.decoder = decoder
self.joint_network = joint_network
- self.criterion_transducer = None
- self.error_calculator = None
+ self.criterion_att = LabelSmoothingLoss(
+ size=vocab_size,
+ padding_idx=ignore_id,
+ smoothing=lsm_weight,
+ normalize_length=length_normalized_loss,
+ )
- self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
- self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
-
- if self.use_auxiliary_ctc:
- self.ctc_lin = torch.nn.Linear(encoder.output_size(), vocab_size)
- self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
-
- if self.use_auxiliary_lm_loss:
- self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
- self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
-
- self.transducer_weight = transducer_weight
- self.fastemit_lambda = fastemit_lambda
-
- self.auxiliary_ctc_weight = auxiliary_ctc_weight
- self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
-
- self.report_cer = report_cer
- self.report_wer = report_wer
-
- self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
-
- self.criterion_pre = torch.nn.L1Loss()
- self.predictor_weight = predictor_weight
- self.predictor = predictor
-
- self.cif_weight = cif_weight
- if self.cif_weight > 0:
- self.cif_output_layer = torch.nn.Linear(encoder.output_size(), vocab_size)
- self.criterion_cif = LabelSmoothingLoss(
- size=vocab_size,
- padding_idx=ignore_id,
- smoothing=lsm_weight,
- normalize_length=length_normalized_loss,
- )
- self.r_d = r_d
- self.r_u = r_u
-
+ self.length_normalized_loss = length_normalized_loss
+ self.beam_search = None
+ self.ctc = None
+ self.ctc_weight = 0.0
+
def forward(
self,
speech: torch.Tensor,
@@ -142,111 +150,167 @@
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Forward architecture and compute loss(es).
-
+ """Encoder + Decoder + Calc loss
Args:
- speech: Speech sequences. (B, S)
- speech_lengths: Speech sequences lengths. (B,)
- text: Label ID sequences. (B, L)
- text_lengths: Label ID sequences lengths. (B,)
- kwargs: Contains "utts_id".
-
- Return:
- loss: Main loss value.
- stats: Task statistics.
- weight: Task weights.
-
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
"""
- assert text_lengths.dim() == 1, text_lengths.shape
- assert (
- speech.shape[0]
- == speech_lengths.shape[0]
- == text.shape[0]
- == text_lengths.shape[0]
- ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
-
+ if len(text_lengths.size()) > 1:
+ text_lengths = text_lengths[:, 0]
+ if len(speech_lengths.size()) > 1:
+ speech_lengths = speech_lengths[:, 0]
+
batch_size = speech.shape[0]
- text = text[:, : text_lengths.max()]
-
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
if hasattr(self.encoder, 'overlap_chunk_cls') and self.encoder.overlap_chunk_cls is not None:
encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
chunk_outs=None)
-
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(encoder_out.device)
# 2. Transducer-related I/O preparation
decoder_in, target, t_len, u_len = get_transducer_task_io(
text,
encoder_out_lens,
ignore_id=self.ignore_id,
)
-
+
# 3. Decoder
self.decoder.set_device(encoder_out.device)
decoder_out = self.decoder(decoder_in, u_len)
-
- pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=self.ignore_id)
- loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length), pre_token_length)
-
- if self.cif_weight > 0.0:
- cif_predict = self.cif_output_layer(pre_acoustic_embeds)
- loss_cif = self.criterion_cif(cif_predict, text)
- else:
- loss_cif = 0.0
-
- # 5. Losses
- boundary = torch.zeros((encoder_out.size(0), 4), dtype=torch.int64, device=encoder_out.device)
- boundary[:, 2] = u_len.long().detach()
- boundary[:, 3] = t_len.long().detach()
-
- pre_peak_index = torch.floor(pre_peak_index).long()
- s_begin = pre_peak_index - self.r_d
-
- T = encoder_out.size(1)
- B = encoder_out.size(0)
- U = decoder_out.size(1)
-
- mask = torch.arange(0, T, device=encoder_out.device).reshape(1, T).expand(B, T)
- mask = mask <= boundary[:, 3].reshape(B, 1) - 1
-
- s_begin_padding = boundary[:, 2].reshape(B, 1) - (self.r_u+self.r_d) + 1
- # handle the cases where `len(symbols) < s_range`
- s_begin_padding = torch.clamp(s_begin_padding, min=0)
-
- s_begin = torch.where(mask, s_begin, s_begin_padding)
- mask2 = s_begin < boundary[:, 2].reshape(B, 1) - (self.r_u+self.r_d) + 1
-
- s_begin = torch.where(mask2, s_begin, boundary[:, 2].reshape(B, 1) - (self.r_u+self.r_d) + 1)
-
- s_begin = torch.clamp(s_begin, min=0)
-
- ranges = s_begin.reshape((B, T, 1)).expand((B, T, min(self.r_u+self.r_d, min(u_len)))) + torch.arange(min(self.r_d+self.r_u, min(u_len)), device=encoder_out.device)
-
- import fast_rnnt
- am_pruned, lm_pruned = fast_rnnt.do_rnnt_pruning(
- am=self.joint_network.lin_enc(encoder_out),
- lm=self.joint_network.lin_dec(decoder_out),
- ranges=ranges,
+ # 4. Joint Network
+ joint_out = self.joint_network(
+ encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
)
-
- logits = self.joint_network(am_pruned, lm_pruned, project_input=False)
-
- with torch.cuda.amp.autocast(enabled=False):
- loss_trans = fast_rnnt.rnnt_loss_pruned(
- logits=logits.float(),
- symbols=target.long(),
- ranges=ranges,
- termination_symbol=self.blank_id,
- boundary=boundary,
- reduction="sum",
+
+ # 5. Losses
+ loss_trans, cer_trans, wer_trans = self._calc_transducer_loss(
+ encoder_out,
+ joint_out,
+ target,
+ t_len,
+ u_len,
+ )
+
+ loss_ctc, loss_lm = 0.0, 0.0
+
+ if self.use_auxiliary_ctc:
+ loss_ctc = self._calc_ctc_loss(
+ encoder_out,
+ target,
+ t_len,
+ u_len,
)
+
+ if self.use_auxiliary_lm_loss:
+ loss_lm = self._calc_lm_loss(decoder_out, target)
+
+ loss = (
+ self.transducer_weight * loss_trans
+ + self.auxiliary_ctc_weight * loss_ctc
+ + self.auxiliary_lm_loss_weight * loss_lm
+ )
+
+ stats = dict(
+ loss=loss.detach(),
+ loss_transducer=loss_trans.detach(),
+ aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
+ aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
+ cer_transducer=cer_trans,
+ wer_transducer=wer_trans,
+ )
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+
+ return loss, stats, weight
- cer_trans, wer_trans = None, None
+ def encode(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Frontend + Encoder. Note that this method is used by asr_inference.py
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ ind: int
+ """
+ with autocast(False):
+
+ # Data augmentation
+ if self.specaug is not None and self.training:
+ speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+ # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ speech, speech_lengths = self.normalize(speech, speech_lengths)
+
+ # Forward encoder
+ # feats: (Batch, Length, Dim)
+ # -> encoder_out: (Batch, Length2, Dim2)
+ encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
+ intermediate_outs = None
+ if isinstance(encoder_out, tuple):
+ intermediate_outs = encoder_out[1]
+ encoder_out = encoder_out[0]
+
+ if intermediate_outs is not None:
+ return (encoder_out, intermediate_outs), encoder_out_lens
+
+ return encoder_out, encoder_out_lens
+
+ def _calc_transducer_loss(
+ self,
+ encoder_out: torch.Tensor,
+ joint_out: torch.Tensor,
+ target: torch.Tensor,
+ t_len: torch.Tensor,
+ u_len: torch.Tensor,
+ ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
+ """Compute Transducer loss.
+
+ Args:
+ encoder_out: Encoder output sequences. (B, T, D_enc)
+ joint_out: Joint Network output sequences (B, T, U, D_joint)
+ target: Target label ID sequences. (B, L)
+ t_len: Encoder output sequences lengths. (B,)
+ u_len: Target label ID sequences lengths. (B,)
+
+ Return:
+ loss_transducer: Transducer loss value.
+ cer_transducer: Character error rate for Transducer.
+ wer_transducer: Word Error Rate for Transducer.
+
+ """
+ if self.criterion_transducer is None:
+ try:
+ from warp_rnnt import rnnt_loss as RNNTLoss
+ self.criterion_transducer = RNNTLoss
+
+ except ImportError:
+ logging.error(
+ "warp-rnnt was not installed."
+ "Please consult the installation documentation."
+ )
+ exit(1)
+
+ log_probs = torch.log_softmax(joint_out, dim=-1)
+
+ loss_transducer = self.criterion_transducer(
+ log_probs,
+ target,
+ t_len,
+ u_len,
+ reduction="mean",
+ blank=self.blank_id,
+ fastemit_lambda=self.fastemit_lambda,
+ gather=True,
+ )
+
if not self.training and (self.report_cer or self.report_wer):
if self.error_calculator is None:
from funasr.metrics import ErrorCalculatorTransducer as ErrorCalculator
+
self.error_calculator = ErrorCalculator(
self.decoder,
self.joint_network,
@@ -256,149 +320,13 @@
report_cer=self.report_cer,
report_wer=self.report_wer,
)
- cer_trans, wer_trans = self.error_calculator(encoder_out, target, t_len)
-
- loss_ctc, loss_lm = 0.0, 0.0
-
- if self.use_auxiliary_ctc:
- loss_ctc = self._calc_ctc_loss(
- encoder_out,
- target,
- t_len,
- u_len,
- )
-
- if self.use_auxiliary_lm_loss:
- loss_lm = self._calc_lm_loss(decoder_out, target)
-
- loss = (
- self.transducer_weight * loss_trans
- + self.auxiliary_ctc_weight * loss_ctc
- + self.auxiliary_lm_loss_weight * loss_lm
- + self.predictor_weight * loss_pre
- + self.cif_weight * loss_cif
- )
-
- stats = dict(
- loss=loss.detach(),
- loss_transducer=loss_trans.detach(),
- loss_pre=loss_pre.detach(),
- loss_cif=loss_cif.detach() if loss_cif > 0.0 else None,
- aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
- aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
- cer_transducer=cer_trans,
- wer_transducer=wer_trans,
- )
-
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
-
- return loss, stats, weight
-
- def collect_feats(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- **kwargs,
- ) -> Dict[str, torch.Tensor]:
- """Collect features sequences and features lengths sequences.
-
- Args:
- speech: Speech sequences. (B, S)
- speech_lengths: Speech sequences lengths. (B,)
- text: Label ID sequences. (B, L)
- text_lengths: Label ID sequences lengths. (B,)
- kwargs: Contains "utts_id".
-
- Return:
- {}: "feats": Features sequences. (B, T, D_feats),
- "feats_lengths": Features sequences lengths. (B,)
-
- """
- if self.extract_feats_in_collect_stats:
- feats, feats_lengths = self._extract_feats(speech, speech_lengths)
- else:
- # Generate dummy stats if extract_feats_in_collect_stats is False
- logging.warning(
- "Generating dummy stats for feats and feats_lengths, "
- "because encoder_conf.extract_feats_in_collect_stats is "
- f"{self.extract_feats_in_collect_stats}"
- )
-
- feats, feats_lengths = speech, speech_lengths
-
- return {"feats": feats, "feats_lengths": feats_lengths}
-
- def encode(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Encoder speech sequences.
-
- Args:
- speech: Speech sequences. (B, S)
- speech_lengths: Speech sequences lengths. (B,)
-
- Return:
- encoder_out: Encoder outputs. (B, T, D_enc)
- encoder_out_lens: Encoder outputs lengths. (B,)
-
- """
- with autocast(False):
- # 1. Extract feats
- feats, feats_lengths = self._extract_feats(speech, speech_lengths)
-
- # 2. Data augmentation
- if self.specaug is not None and self.training:
- feats, feats_lengths = self.specaug(feats, feats_lengths)
-
- # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
- if self.normalize is not None:
- feats, feats_lengths = self.normalize(feats, feats_lengths)
-
- # 4. Forward encoder
- encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
-
- assert encoder_out.size(0) == speech.size(0), (
- encoder_out.size(),
- speech.size(0),
- )
- assert encoder_out.size(1) <= encoder_out_lens.max(), (
- encoder_out.size(),
- encoder_out_lens.max(),
- )
-
- return encoder_out, encoder_out_lens
-
- def _extract_feats(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Extract features sequences and features sequences lengths.
-
- Args:
- speech: Speech sequences. (B, S)
- speech_lengths: Speech sequences lengths. (B,)
-
- Return:
- feats: Features sequences. (B, T, D_feats)
- feats_lengths: Features sequences lengths. (B,)
-
- """
- assert speech_lengths.dim() == 1, speech_lengths.shape
-
- # for data-parallel
- speech = speech[:, : speech_lengths.max()]
-
- if self.frontend is not None:
- feats, feats_lengths = self.frontend(speech, speech_lengths)
- else:
- feats, feats_lengths = speech, speech_lengths
-
- return feats, feats_lengths
-
+
+ cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len)
+
+ return loss_transducer, cer_transducer, wer_transducer
+
+ return loss_transducer, None, None
+
def _calc_ctc_loss(
self,
encoder_out: torch.Tensor,
@@ -422,10 +350,10 @@
torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
)
ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
-
+
target_mask = target != 0
ctc_target = target[target_mask].cpu()
-
+
with torch.backends.cudnn.flags(deterministic=True):
loss_ctc = torch.nn.functional.ctc_loss(
ctc_in,
@@ -436,9 +364,9 @@
reduction="sum",
)
loss_ctc /= target.size(0)
-
+
return loss_ctc
-
+
def _calc_lm_loss(
self,
decoder_out: torch.Tensor,
@@ -456,17 +384,17 @@
"""
lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
lm_target = target.view(-1).type(torch.int64)
-
+
with torch.no_grad():
true_dist = lm_loss_in.clone()
true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
-
+
# Ignore blank ID (0)
ignore = lm_target == 0
lm_target = lm_target.masked_fill(ignore, 0)
-
+
true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
-
+
loss_lm = torch.nn.functional.kl_div(
torch.log_softmax(lm_loss_in, dim=1),
true_dist,
@@ -475,5 +403,117 @@
loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
0
)
-
+
return loss_lm
+
+ def init_beam_search(self,
+ **kwargs,
+ ):
+
+ # 1. Build ASR model
+ scorers = {}
+
+ if self.ctc != None:
+ ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
+ scorers.update(
+ ctc=ctc
+ )
+ token_list = kwargs.get("token_list")
+ scorers.update(
+ length_bonus=LengthBonus(len(token_list)),
+ )
+
+ # 3. Build ngram model
+ # ngram is not supported now
+ ngram = None
+ scorers["ngram"] = ngram
+
+ beam_search = BeamSearchTransducer(
+ self.decoder,
+ self.joint_network,
+ kwargs.get("beam_size", 2),
+ nbest=1,
+ )
+ # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
+ # for scorer in scorers.values():
+ # if isinstance(scorer, torch.nn.Module):
+ # scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
+ self.beam_search = beam_search
+
+ def inference(self,
+ data_in: list,
+ data_lengths: list=None,
+ key: list=None,
+ tokenizer=None,
+ **kwargs,
+ ):
+
+ if kwargs.get("batch_size", 1) > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+
+ # init beamsearch
+ is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
+ is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+ # if self.beam_search is None and (is_use_lm or is_use_ctc):
+ logging.info("enable beam_search")
+ self.init_beam_search(**kwargs)
+ self.nbest = kwargs.get("nbest", 1)
+
+ meta_data = {}
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio_text_image_video(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend)
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
+
+ speech = speech.to(device=kwargs["device"])
+ speech_lengths = speech_lengths.to(device=kwargs["device"])
+
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ # c. Passed the encoder result and the beam search
+ nbest_hyps = self.beam_search(encoder_out[0], is_final=True)
+ nbest_hyps = nbest_hyps[: self.nbest]
+
+ results = []
+ b, n, d = encoder_out.size()
+ for i in range(b):
+
+ for nbest_idx, hyp in enumerate(nbest_hyps):
+ ibest_writer = None
+ if kwargs.get("output_dir") is not None:
+ if not hasattr(self, "writer"):
+ self.writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq#[1:last_pos]
+ else:
+ token_int = hyp.yseq#[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
+
+ # Change integer-ids to tokens
+ token = tokenizer.ids2tokens(token_int)
+ text = tokenizer.tokens2text(token)
+
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
+ results.append(result_i)
+
+ if ibest_writer is not None:
+ ibest_writer["token"][key[i]] = " ".join(token)
+ ibest_writer["text"][key[i]] = text
+ ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
+
+ return results, meta_data
+
--
Gitblit v1.9.1