From 7d1efe158eda74dc847c397db906f6cb77ac0f84 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 12 四月 2023 16:49:56 +0800
Subject: [PATCH] rnnt reorg
---
funasr/modules/nets_utils.py | 195 ++++++++++++++
funasr/bin/asr_inference_rnnt.py | 58 ----
funasr/models/rnnt_decoder/__init__.py | 0
funasr/models/encoder/chunk_encoder_utils/validation.py | 2
funasr/models/encoder/chunk_encoder_blocks/conv1d.py | 0
egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml | 32 +-
funasr/models/encoder/chunk_encoder_blocks/conformer.py | 0
funasr/modules/activation.py | 0
funasr/models/e2e_transducer.py | 10
funasr/models/joint_network.py | 2
funasr/modules/e2e_asr_common.py | 151 +++++++++++
funasr/models/encoder/chunk_encoder_blocks/linear_input.py | 0
funasr/models/encoder/chunk_encoder_modules/positional_encoding.py | 0
funasr/models/encoder/chunk_encoder_blocks/conv_input.py | 2
funasr/models/rnnt_decoder/abs_decoder.py | 0
funasr/models/encoder/chunk_encoder.py | 26 -
funasr/models/encoder/chunk_encoder_modules/__init__.py | 0
funasr/models/e2e_transducer_unified.py | 13
funasr/models/encoder/chunk_encoder_modules/attention.py | 0
funasr/models/encoder/chunk_encoder_modules/normalization.py | 0
funasr/models/rnnt_decoder/rnn_decoder.py | 4
funasr/models/encoder/chunk_encoder_modules/multi_blocks.py | 0
funasr/models/encoder/chunk_encoder_blocks/branchformer.py | 0
funasr/models/encoder/chunk_encoder_blocks/__init__.py | 0
funasr/models/encoder/chunk_encoder_modules/convolution.py | 0
/dev/null | 200 ---------------
funasr/models/rnnt_decoder/stateless_decoder.py | 16 -
funasr/modules/beam_search/beam_search_transducer.py | 4
funasr/tasks/asr_transducer.py | 41 --
funasr/models/encoder/chunk_encoder_utils/building.py | 22
30 files changed, 418 insertions(+), 360 deletions(-)
diff --git a/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
index ef37b97..60f796c 100644
--- a/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
+++ b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
@@ -1,13 +1,13 @@
encoder_conf:
main_conf:
pos_wise_act_type: swish
- pos_enc_dropout_rate: 0.3
+ pos_enc_dropout_rate: 0.5
conv_mod_act_type: swish
time_reduction_factor: 2
unified_model_training: true
default_chunk_size: 16
jitter_range: 4
- left_chunk_size: 1
+ left_chunk_size: 0
input_conf:
block_type: conv2d
conv_size: 512
@@ -18,9 +18,9 @@
linear_size: 2048
hidden_size: 512
heads: 8
- dropout_rate: 0.3
- pos_wise_dropout_rate: 0.3
- att_dropout_rate: 0.3
+ dropout_rate: 0.5
+ pos_wise_dropout_rate: 0.5
+ att_dropout_rate: 0.5
conv_mod_kernel_size: 15
num_blocks: 12
@@ -29,8 +29,8 @@
decoder_conf:
embed_size: 512
hidden_size: 512
- embed_dropout_rate: 0.2
- dropout_rate: 0.1
+ embed_dropout_rate: 0.5
+ dropout_rate: 0.5
joint_network_conf:
joint_space_size: 512
@@ -41,14 +41,14 @@
# minibatch related
use_amp: true
-batch_type: numel
-batch_bins: 1600000
+batch_type: unsorted
+batch_size: 16
num_workers: 16
# optimization related
accum_grad: 1
grad_clip: 5
-max_epoch: 80
+max_epoch: 200
val_scheduler_criterion:
- valid
- loss
@@ -56,11 +56,11 @@
- - valid
- cer_transducer_chunk
- min
-keep_nbest_models: 5
+keep_nbest_models: 10
optim: adam
optim_conf:
- lr: 0.0003
+ lr: 0.001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
@@ -75,10 +75,12 @@
apply_freq_mask: true
freq_mask_width_range:
- 0
- - 30
+ - 40
num_freq_mask: 2
apply_time_mask: true
time_mask_width_range:
- 0
- - 40
- num_time_mask: 2
+ - 50
+ num_time_mask: 5
+
+log_interval: 50
diff --git a/funasr/bin/asr_inference_rnnt.py b/funasr/bin/asr_inference_rnnt.py
index 768bf72..465f882 100644
--- a/funasr/bin/asr_inference_rnnt.py
+++ b/funasr/bin/asr_inference_rnnt.py
@@ -16,11 +16,11 @@
from packaging.version import parse as V
from typeguard import check_argument_types, check_return_type
-from funasr.models_transducer.beam_search_transducer import (
+from funasr.modules.beam_search.beam_search_transducer import (
BeamSearchTransducer,
Hypothesis,
)
-from funasr.models_transducer.utils import TooShortUttError
+from funasr.modules.nets_utils import TooShortUttError
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.tasks.asr_transducer import ASRTransducerTask
from funasr.tasks.lm import LMTask
@@ -500,7 +500,6 @@
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-<<<<<<< HEAD
batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
assert len(batch.keys()) == 1
@@ -541,59 +540,6 @@
if text is not None:
ibest_writer["text"][key] = text
-=======
- # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
-
- logging.info("decoding, utt_id: {}".format(keys))
- # N-best list of (text, token, token_int, hyp_object)
-
- time_beg = time.time()
- results = speech2text(cache=cache, **batch)
- if len(results) < 1:
- hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
- results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
- time_end = time.time()
- forward_time = time_end - time_beg
- lfr_factor = results[0][-1]
- length = results[0][-2]
- forward_time_total += forward_time
- length_total += length
- rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time, 100 * forward_time / (length * lfr_factor))
- logging.info(rtf_cur)
-
- for batch_id in range(_bs):
- result = [results[batch_id][:-2]]
-
- key = keys[batch_id]
- for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), result):
- # Create a directory: outdir/{n}best_recog
- if writer is not None:
- ibest_writer = writer[f"{n}best_recog"]
-
- # Write the result to each file
- ibest_writer["token"][key] = " ".join(token)
- # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
- ibest_writer["score"][key] = str(hyp.score)
- ibest_writer["rtf"][key] = rtf_cur
-
- if text is not None:
- text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
- item = {'key': key, 'value': text_postprocessed}
- asr_result_list.append(item)
- finish_count += 1
- # asr_utils.print_progress(finish_count / file_count)
- if writer is not None:
- ibest_writer["text"][key] = " ".join(word_lists)
-
- logging.info("decoding, utt: {}, predictions: {}".format(key, text))
- rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor))
- logging.info(rtf_avg)
- if writer is not None:
- ibest_writer["rtf"]["rtf_avf"] = rtf_avg
- return asr_result_list
-
- return _forward
->>>>>>> main
def get_parser():
diff --git a/funasr/models_transducer/espnet_transducer_model.py b/funasr/models/e2e_transducer.py
similarity index 97%
rename from funasr/models_transducer/espnet_transducer_model.py
rename to funasr/models/e2e_transducer.py
index e32f6e3..b669c9d 100644
--- a/funasr/models_transducer/espnet_transducer_model.py
+++ b/funasr/models/e2e_transducer.py
@@ -10,11 +10,11 @@
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
+from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
-from funasr.models_transducer.encoder.encoder import Encoder
-from funasr.models_transducer.joint_network import JointNetwork
-from funasr.models_transducer.utils import get_transducer_task_io
+from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
+from funasr.models.joint_network import JointNetwork
+from funasr.modules.nets_utils import get_transducer_task_io
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
@@ -28,7 +28,7 @@
yield
-class ESPnetASRTransducerModel(AbsESPnetModel):
+class TransducerModel(AbsESPnetModel):
"""ESPnet2ASRTransducerModel module definition.
Args:
diff --git a/funasr/models_transducer/espnet_transducer_model_unified.py b/funasr/models/e2e_transducer_unified.py
similarity index 97%
rename from funasr/models_transducer/espnet_transducer_model_unified.py
rename to funasr/models/e2e_transducer_unified.py
index be61e83..6003542 100644
--- a/funasr/models_transducer/espnet_transducer_model_unified.py
+++ b/funasr/models/e2e_transducer_unified.py
@@ -10,10 +10,10 @@
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
-from funasr.models_transducer.encoder.encoder import Encoder
-from funasr.models_transducer.joint_network import JointNetwork
-from funasr.models_transducer.utils import get_transducer_task_io
+from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
+from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
+from funasr.models.joint_network import JointNetwork
+from funasr.modules.nets_utils import get_transducer_task_io
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
@@ -23,7 +23,7 @@
from funasr.losses.label_smoothing_loss import ( # noqa: H301
LabelSmoothingLoss,
)
-from funasr.models_transducer.error_calculator import ErrorCalculator
+from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator
if V(torch.__version__) >= V("1.6.0"):
from torch.cuda.amp import autocast
else:
@@ -33,7 +33,7 @@
yield
-class ESPnetASRUnifiedTransducerModel(AbsESPnetModel):
+class UnifiedTransducerModel(AbsESPnetModel):
"""ESPnet2ASRTransducerModel module definition.
Args:
@@ -289,7 +289,6 @@
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
-
return loss, stats, weight
def collect_feats(
diff --git a/funasr/models_transducer/encoder/encoder.py b/funasr/models/encoder/chunk_encoder.py
similarity index 95%
rename from funasr/models_transducer/encoder/encoder.py
rename to funasr/models/encoder/chunk_encoder.py
index b486a11..c6fc292 100644
--- a/funasr/models_transducer/encoder/encoder.py
+++ b/funasr/models/encoder/chunk_encoder.py
@@ -1,26 +1,23 @@
-"""Encoder for Transducer model."""
-
from typing import Any, Dict, List, Tuple
import torch
from typeguard import check_argument_types
-from funasr.models_transducer.encoder.building import (
+from funasr.models.encoder.chunk_encoder_utils.building import (
build_body_blocks,
build_input_block,
build_main_parameters,
build_positional_encoding,
)
-from funasr.models_transducer.encoder.validation import validate_architecture
-from funasr.models_transducer.utils import (
+from funasr.models.encoder.chunk_encoder_utils.validation import validate_architecture
+from funasr.modules.nets_utils import (
TooShortUttError,
check_short_utt,
make_chunk_mask,
make_source_mask,
)
-
-class Encoder(torch.nn.Module):
+class ChunkEncoder(torch.nn.Module):
"""Encoder module definition.
Args:
@@ -61,10 +58,9 @@
self.unified_model_training = main_params["unified_model_training"]
self.default_chunk_size = main_params["default_chunk_size"]
- self.jitter_range = main_params["jitter_range"]
+ self.jitter_range = main_params["jitter_range"]
- self.time_reduction_factor = main_params["time_reduction_factor"]
-
+ self.time_reduction_factor = main_params["time_reduction_factor"]
def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
"""Return the corresponding number of sample for a given chunk size, in frames.
@@ -79,7 +75,7 @@
"""
return self.embed.get_size_before_subsampling(size) * hop_length
-
+
def get_encoder_input_size(self, size: int) -> int:
"""Return the corresponding number of sample for a given chunk size, in frames.
@@ -157,7 +153,7 @@
mask,
chunk_mask=chunk_mask,
)
-
+
olens = mask.eq(0).sum(1)
if self.time_reduction_factor > 1:
x_utt = x_utt[:,::self.time_reduction_factor,:]
@@ -194,14 +190,14 @@
mask,
chunk_mask=chunk_mask,
)
-
+
olens = mask.eq(0).sum(1)
if self.time_reduction_factor > 1:
x = x[:,::self.time_reduction_factor,:]
olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
return x, olens
-
+
def simu_chunk_forward(
self,
x: torch.Tensor,
@@ -290,7 +286,7 @@
if right_context > 0:
x = x[:, 0:-right_context, :]
-
+
if self.time_reduction_factor > 1:
x = x[:,::self.time_reduction_factor,:]
return x
diff --git a/funasr/models_transducer/encoder/blocks/__init__.py b/funasr/models/encoder/chunk_encoder_blocks/__init__.py
similarity index 100%
rename from funasr/models_transducer/encoder/blocks/__init__.py
rename to funasr/models/encoder/chunk_encoder_blocks/__init__.py
diff --git a/funasr/models_transducer/encoder/blocks/branchformer.py b/funasr/models/encoder/chunk_encoder_blocks/branchformer.py
similarity index 100%
rename from funasr/models_transducer/encoder/blocks/branchformer.py
rename to funasr/models/encoder/chunk_encoder_blocks/branchformer.py
diff --git a/funasr/models_transducer/encoder/blocks/conformer.py b/funasr/models/encoder/chunk_encoder_blocks/conformer.py
similarity index 100%
rename from funasr/models_transducer/encoder/blocks/conformer.py
rename to funasr/models/encoder/chunk_encoder_blocks/conformer.py
diff --git a/funasr/models_transducer/encoder/blocks/conv1d.py b/funasr/models/encoder/chunk_encoder_blocks/conv1d.py
similarity index 100%
rename from funasr/models_transducer/encoder/blocks/conv1d.py
rename to funasr/models/encoder/chunk_encoder_blocks/conv1d.py
diff --git a/funasr/models_transducer/encoder/blocks/conv_input.py b/funasr/models/encoder/chunk_encoder_blocks/conv_input.py
similarity index 98%
rename from funasr/models_transducer/encoder/blocks/conv_input.py
rename to funasr/models/encoder/chunk_encoder_blocks/conv_input.py
index ffec93e..b9bd2fd 100644
--- a/funasr/models_transducer/encoder/blocks/conv_input.py
+++ b/funasr/models/encoder/chunk_encoder_blocks/conv_input.py
@@ -5,7 +5,7 @@
import torch
import math
-from funasr.models_transducer.utils import sub_factor_to_params, pad_to_len
+from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len
class ConvInput(torch.nn.Module):
diff --git a/funasr/models_transducer/encoder/blocks/linear_input.py b/funasr/models/encoder/chunk_encoder_blocks/linear_input.py
similarity index 100%
rename from funasr/models_transducer/encoder/blocks/linear_input.py
rename to funasr/models/encoder/chunk_encoder_blocks/linear_input.py
diff --git a/funasr/models_transducer/encoder/modules/__init__.py b/funasr/models/encoder/chunk_encoder_modules/__init__.py
similarity index 100%
rename from funasr/models_transducer/encoder/modules/__init__.py
rename to funasr/models/encoder/chunk_encoder_modules/__init__.py
diff --git a/funasr/models_transducer/encoder/modules/attention.py b/funasr/models/encoder/chunk_encoder_modules/attention.py
similarity index 100%
rename from funasr/models_transducer/encoder/modules/attention.py
rename to funasr/models/encoder/chunk_encoder_modules/attention.py
diff --git a/funasr/models_transducer/encoder/modules/convolution.py b/funasr/models/encoder/chunk_encoder_modules/convolution.py
similarity index 100%
rename from funasr/models_transducer/encoder/modules/convolution.py
rename to funasr/models/encoder/chunk_encoder_modules/convolution.py
diff --git a/funasr/models_transducer/encoder/modules/multi_blocks.py b/funasr/models/encoder/chunk_encoder_modules/multi_blocks.py
similarity index 100%
rename from funasr/models_transducer/encoder/modules/multi_blocks.py
rename to funasr/models/encoder/chunk_encoder_modules/multi_blocks.py
diff --git a/funasr/models_transducer/encoder/modules/normalization.py b/funasr/models/encoder/chunk_encoder_modules/normalization.py
similarity index 100%
rename from funasr/models_transducer/encoder/modules/normalization.py
rename to funasr/models/encoder/chunk_encoder_modules/normalization.py
diff --git a/funasr/models_transducer/encoder/modules/positional_encoding.py b/funasr/models/encoder/chunk_encoder_modules/positional_encoding.py
similarity index 100%
rename from funasr/models_transducer/encoder/modules/positional_encoding.py
rename to funasr/models/encoder/chunk_encoder_modules/positional_encoding.py
diff --git a/funasr/models_transducer/encoder/building.py b/funasr/models/encoder/chunk_encoder_utils/building.py
similarity index 91%
rename from funasr/models_transducer/encoder/building.py
rename to funasr/models/encoder/chunk_encoder_utils/building.py
index a19943b..21611aa 100644
--- a/funasr/models_transducer/encoder/building.py
+++ b/funasr/models/encoder/chunk_encoder_utils/building.py
@@ -2,22 +2,22 @@
from typing import Any, Dict, List, Optional, Union
-from funasr.models_transducer.activation import get_activation
-from funasr.models_transducer.encoder.blocks.branchformer import Branchformer
-from funasr.models_transducer.encoder.blocks.conformer import Conformer
-from funasr.models_transducer.encoder.blocks.conv1d import Conv1d
-from funasr.models_transducer.encoder.blocks.conv_input import ConvInput
-from funasr.models_transducer.encoder.blocks.linear_input import LinearInput
-from funasr.models_transducer.encoder.modules.attention import ( # noqa: H301
+from funasr.modules.activation import get_activation
+from funasr.models.encoder.chunk_encoder_blocks.branchformer import Branchformer
+from funasr.models.encoder.chunk_encoder_blocks.conformer import Conformer
+from funasr.models.encoder.chunk_encoder_blocks.conv1d import Conv1d
+from funasr.models.encoder.chunk_encoder_blocks.conv_input import ConvInput
+from funasr.models.encoder.chunk_encoder_blocks.linear_input import LinearInput
+from funasr.models.encoder.chunk_encoder_modules.attention import ( # noqa: H301
RelPositionMultiHeadedAttention,
)
-from funasr.models_transducer.encoder.modules.convolution import ( # noqa: H301
+from funasr.models.encoder.chunk_encoder_modules.convolution import ( # noqa: H301
ConformerConvolution,
ConvolutionalSpatialGatingUnit,
)
-from funasr.models_transducer.encoder.modules.multi_blocks import MultiBlocks
-from funasr.models_transducer.encoder.modules.normalization import get_normalization
-from funasr.models_transducer.encoder.modules.positional_encoding import ( # noqa: H301
+from funasr.models.encoder.chunk_encoder_modules.multi_blocks import MultiBlocks
+from funasr.models.encoder.chunk_encoder_modules.normalization import get_normalization
+from funasr.models.encoder.chunk_encoder_modules.positional_encoding import ( # noqa: H301
RelPositionalEncoding,
)
from funasr.modules.positionwise_feed_forward import (
diff --git a/funasr/models_transducer/encoder/validation.py b/funasr/models/encoder/chunk_encoder_utils/validation.py
similarity index 98%
rename from funasr/models_transducer/encoder/validation.py
rename to funasr/models/encoder/chunk_encoder_utils/validation.py
index 0003536..1103cb9 100644
--- a/funasr/models_transducer/encoder/validation.py
+++ b/funasr/models/encoder/chunk_encoder_utils/validation.py
@@ -2,7 +2,7 @@
from typing import Any, Dict, List, Tuple
-from funasr.models_transducer.utils import sub_factor_to_params
+from funasr.modules.nets_utils import sub_factor_to_params
def validate_block_arguments(
diff --git a/funasr/models_transducer/joint_network.py b/funasr/models/joint_network.py
similarity index 96%
rename from funasr/models_transducer/joint_network.py
rename to funasr/models/joint_network.py
index 119dd84..5cabdb4 100644
--- a/funasr/models_transducer/joint_network.py
+++ b/funasr/models/joint_network.py
@@ -2,7 +2,7 @@
import torch
-from funasr.models_transducer.activation import get_activation
+from funasr.modules.activation import get_activation
class JointNetwork(torch.nn.Module):
diff --git a/funasr/models_transducer/decoder/__init__.py b/funasr/models/rnnt_decoder/__init__.py
similarity index 100%
rename from funasr/models_transducer/decoder/__init__.py
rename to funasr/models/rnnt_decoder/__init__.py
diff --git a/funasr/models_transducer/decoder/abs_decoder.py b/funasr/models/rnnt_decoder/abs_decoder.py
similarity index 100%
rename from funasr/models_transducer/decoder/abs_decoder.py
rename to funasr/models/rnnt_decoder/abs_decoder.py
diff --git a/funasr/models_transducer/decoder/rnn_decoder.py b/funasr/models/rnnt_decoder/rnn_decoder.py
similarity index 97%
rename from funasr/models_transducer/decoder/rnn_decoder.py
rename to funasr/models/rnnt_decoder/rnn_decoder.py
index 04c3228..c4e7951 100644
--- a/funasr/models_transducer/decoder/rnn_decoder.py
+++ b/funasr/models/rnnt_decoder/rnn_decoder.py
@@ -5,8 +5,8 @@
import torch
from typeguard import check_argument_types
-from funasr.models_transducer.beam_search_transducer import Hypothesis
-from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
+from funasr.modules.beam_search.beam_search_transducer import Hypothesis
+from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
from funasr.models.specaug.specaug import SpecAug
class RNNDecoder(AbsDecoder):
diff --git a/funasr/models_transducer/decoder/stateless_decoder.py b/funasr/models/rnnt_decoder/stateless_decoder.py
similarity index 86%
rename from funasr/models_transducer/decoder/stateless_decoder.py
rename to funasr/models/rnnt_decoder/stateless_decoder.py
index 07c8f51..a2e1fc1 100644
--- a/funasr/models_transducer/decoder/stateless_decoder.py
+++ b/funasr/models/rnnt_decoder/stateless_decoder.py
@@ -5,8 +5,8 @@
import torch
from typeguard import check_argument_types
-from funasr.models_transducer.beam_search_transducer import Hypothesis
-from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
+from funasr.modules.beam_search.beam_search_transducer import Hypothesis
+from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
from funasr.models.specaug.specaug import SpecAug
class StatelessDecoder(AbsDecoder):
@@ -26,7 +26,6 @@
embed_size: int = 256,
embed_dropout_rate: float = 0.0,
embed_pad: int = 0,
- use_embed_mask: bool = False,
) -> None:
"""Construct a StatelessDecoder object."""
super().__init__()
@@ -42,14 +41,6 @@
self.device = next(self.parameters()).device
self.score_cache = {}
- self.use_embed_mask = use_embed_mask
- if self.use_embed_mask:
- self._embed_mask = SpecAug(
- time_mask_width_range=3,
- num_time_mask=1,
- apply_freq_mask=False,
- apply_time_warp=False
- )
def forward(
@@ -69,9 +60,6 @@
"""
dec_embed = self.embed_dropout_rate(self.embed(labels))
- if self.use_embed_mask and self.training:
- dec_embed = self._embed_mask(dec_embed, label_lens)[0]
-
return dec_embed
def score(
diff --git a/funasr/models_transducer/__init__.py b/funasr/models_transducer/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/models_transducer/__init__.py
+++ /dev/null
diff --git a/funasr/models_transducer/encoder/__init__.py b/funasr/models_transducer/encoder/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/models_transducer/encoder/__init__.py
+++ /dev/null
diff --git a/funasr/models_transducer/encoder/sanm_encoder.py b/funasr/models_transducer/encoder/sanm_encoder.py
deleted file mode 100644
index 9e74bdf..0000000
--- a/funasr/models_transducer/encoder/sanm_encoder.py
+++ /dev/null
@@ -1,835 +0,0 @@
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-import logging
-import torch
-import torch.nn as nn
-from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
-from typeguard import check_argument_types
-import numpy as np
-from funasr.modules.nets_utils import make_pad_mask
-from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
-from funasr.modules.embedding import SinusoidalPositionEncoder
-from funasr.modules.layer_norm import LayerNorm
-from funasr.modules.multi_layer_conv import Conv1dLinear
-from funasr.modules.multi_layer_conv import MultiLayeredConv1d
-from funasr.modules.positionwise_feed_forward import (
- PositionwiseFeedForward, # noqa: H301
-)
-from funasr.modules.repeat import repeat
-from funasr.modules.subsampling import Conv2dSubsampling
-from funasr.modules.subsampling import Conv2dSubsampling2
-from funasr.modules.subsampling import Conv2dSubsampling6
-from funasr.modules.subsampling import Conv2dSubsampling8
-from funasr.modules.subsampling import TooShortUttError
-from funasr.modules.subsampling import check_short_utt
-from funasr.models.ctc import CTC
-from funasr.models.encoder.abs_encoder import AbsEncoder
-
-
-class EncoderLayerSANM(nn.Module):
- def __init__(
- self,
- in_size,
- size,
- self_attn,
- feed_forward,
- dropout_rate,
- normalize_before=True,
- concat_after=False,
- stochastic_depth_rate=0.0,
- ):
- """Construct an EncoderLayer object."""
- super(EncoderLayerSANM, self).__init__()
- self.self_attn = self_attn
- self.feed_forward = feed_forward
- self.norm1 = LayerNorm(in_size)
- self.norm2 = LayerNorm(size)
- self.dropout = nn.Dropout(dropout_rate)
- self.in_size = in_size
- self.size = size
- self.normalize_before = normalize_before
- self.concat_after = concat_after
- if self.concat_after:
- self.concat_linear = nn.Linear(size + size, size)
- self.stochastic_depth_rate = stochastic_depth_rate
- self.dropout_rate = dropout_rate
-
- def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
- """Compute encoded features.
- Args:
- x_input (torch.Tensor): Input tensor (#batch, time, size).
- mask (torch.Tensor): Mask tensor for the input (#batch, time).
- cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
- Returns:
- torch.Tensor: Output tensor (#batch, time, size).
- torch.Tensor: Mask tensor (#batch, time).
- """
- skip_layer = False
- # with stochastic depth, residual connection `x + f(x)` becomes
- # `x <- x + 1 / (1 - p) * f(x)` at training time.
- stoch_layer_coeff = 1.0
- if self.training and self.stochastic_depth_rate > 0:
- skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
- stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
-
- if skip_layer:
- if cache is not None:
- x = torch.cat([cache, x], dim=1)
- return x, mask
-
- residual = x
- if self.normalize_before:
- x = self.norm1(x)
-
- if self.concat_after:
- x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
- if self.in_size == self.size:
- x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
- else:
- x = stoch_layer_coeff * self.concat_linear(x_concat)
- else:
- if self.in_size == self.size:
- x = residual + stoch_layer_coeff * self.dropout(
- self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
- )
- else:
- x = stoch_layer_coeff * self.dropout(
- self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
- )
- if not self.normalize_before:
- x = self.norm1(x)
-
- residual = x
- if self.normalize_before:
- x = self.norm2(x)
- x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
- if not self.normalize_before:
- x = self.norm2(x)
-
-
- return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
-
-class SANMEncoder(AbsEncoder):
- """
- author: Speech Lab, Alibaba Group, China
- San-m: Memory equipped self-attention for end-to-end speech recognition
- https://arxiv.org/abs/2006.01713
- """
-
- def __init__(
- self,
- input_size: int,
- output_size: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.0,
- pos_enc_class=SinusoidalPositionEncoder,
- normalize_before: bool = True,
- concat_after: bool = False,
- positionwise_layer_type: str = "linear",
- positionwise_conv_kernel_size: int = 1,
- padding_idx: int = -1,
- interctc_layer_idx: List[int] = [],
- interctc_use_conditioning: bool = False,
- kernel_size : int = 11,
- sanm_shfit : int = 0,
- tf2torch_tensor_name_prefix_torch: str = "encoder",
- tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
- ):
- assert check_argument_types()
- super().__init__()
-
- self.embed = SinusoidalPositionEncoder()
- self.normalize_before = normalize_before
- if positionwise_layer_type == "linear":
- positionwise_layer = PositionwiseFeedForward
- positionwise_layer_args = (
- output_size,
- linear_units,
- dropout_rate,
- )
- elif positionwise_layer_type == "conv1d":
- positionwise_layer = MultiLayeredConv1d
- positionwise_layer_args = (
- output_size,
- linear_units,
- positionwise_conv_kernel_size,
- dropout_rate,
- )
- elif positionwise_layer_type == "conv1d-linear":
- positionwise_layer = Conv1dLinear
- positionwise_layer_args = (
- output_size,
- linear_units,
- positionwise_conv_kernel_size,
- dropout_rate,
- )
- else:
- raise NotImplementedError("Support only linear or conv1d.")
-
- encoder_selfattn_layer = MultiHeadedAttentionSANM
- encoder_selfattn_layer_args0 = (
- attention_heads,
- input_size,
- output_size,
- attention_dropout_rate,
- kernel_size,
- sanm_shfit,
- )
-
- encoder_selfattn_layer_args = (
- attention_heads,
- output_size,
- output_size,
- attention_dropout_rate,
- kernel_size,
- sanm_shfit,
- )
- self.encoders0 = repeat(
- 1,
- lambda lnum: EncoderLayerSANM(
- input_size,
- output_size,
- encoder_selfattn_layer(*encoder_selfattn_layer_args0),
- positionwise_layer(*positionwise_layer_args),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
-
- self.encoders = repeat(
- num_blocks-1,
- lambda lnum: EncoderLayerSANM(
- output_size,
- output_size,
- encoder_selfattn_layer(*encoder_selfattn_layer_args),
- positionwise_layer(*positionwise_layer_args),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
- if self.normalize_before:
- self.after_norm = LayerNorm(output_size)
-
- self.interctc_layer_idx = interctc_layer_idx
- if len(interctc_layer_idx) > 0:
- assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
- self.interctc_use_conditioning = interctc_use_conditioning
- self.conditioning_layer = None
- self.dropout = nn.Dropout(dropout_rate)
- self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
- self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
-
- def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- prev_states: torch.Tensor = None,
- ctc: CTC = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
- """Embed positions in tensor.
- Args:
- xs_pad: input tensor (B, L, D)
- ilens: input length (B)
- prev_states: Not to be used now.
- Returns:
- position embedded tensor and mask
- """
- masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
- xs_pad = xs_pad * self.output_size**0.5
- if self.embed is None:
- xs_pad = xs_pad
- elif (
- isinstance(self.embed, Conv2dSubsampling)
- or isinstance(self.embed, Conv2dSubsampling2)
- or isinstance(self.embed, Conv2dSubsampling6)
- or isinstance(self.embed, Conv2dSubsampling8)
- ):
- short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
- if short_status:
- raise TooShortUttError(
- f"has {xs_pad.size(1)} frames and is too short for subsampling "
- + f"(it needs more than {limit_size} frames), return empty results",
- xs_pad.size(1),
- limit_size,
- )
- xs_pad, masks = self.embed(xs_pad, masks)
- else:
- xs_pad = self.embed(xs_pad)
-
- # xs_pad = self.dropout(xs_pad)
- encoder_outs = self.encoders0(xs_pad, masks)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
- intermediate_outs = []
- if len(self.interctc_layer_idx) == 0:
- encoder_outs = self.encoders(xs_pad, masks)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
- else:
- for layer_idx, encoder_layer in enumerate(self.encoders):
- encoder_outs = encoder_layer(xs_pad, masks)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
-
- if layer_idx + 1 in self.interctc_layer_idx:
- encoder_out = xs_pad
-
- # intermediate outputs are also normalized
- if self.normalize_before:
- encoder_out = self.after_norm(encoder_out)
-
- intermediate_outs.append((layer_idx + 1, encoder_out))
-
- if self.interctc_use_conditioning:
- ctc_out = ctc.softmax(encoder_out)
- xs_pad = xs_pad + self.conditioning_layer(ctc_out)
-
- if self.normalize_before:
- xs_pad = self.after_norm(xs_pad)
-
- olens = masks.squeeze(1).sum(1)
- if len(intermediate_outs) > 0:
- return (xs_pad, intermediate_outs), olens, None
- return xs_pad, olens
-
- def gen_tf2torch_map_dict(self):
- tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
- tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
- map_dict_local = {
- ## encoder
- # cicd
- "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (768,256),(1,256,768)
- "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (768,),(768,)
- "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 2, 0),
- }, # (256,1,31),(1,31,256,1)
- "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,256),(1,256,256)
- "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- # ffn
- "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,1024),(1,1024,256)
- "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- # out norm
- "{}.after_norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.after_norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
-
- }
-
- return map_dict_local
-
- def convert_tf2torch(self,
- var_dict_tf,
- var_dict_torch,
- ):
-
- map_dict = self.gen_tf2torch_map_dict()
-
- var_dict_torch_update = dict()
- for name in sorted(var_dict_torch.keys(), reverse=False):
- names = name.split('.')
- if names[0] == self.tf2torch_tensor_name_prefix_torch:
- if names[1] == "encoders0":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-
- name_q = name_q.replace("encoders0", "encoders")
- layeridx_bias = 0
- layeridx += layeridx_bias
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
- elif names[1] == "encoders":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
- layeridx_bias = 1
- layeridx += layeridx_bias
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "after_norm":
- name_tf = map_dict[name]["name"]
- data_tf = var_dict_tf[name_tf]
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
- var_dict_tf[name_tf].shape))
-
- return var_dict_torch_update
-
-
-class SANMEncoderChunkOpt(AbsEncoder):
- """
- author: Speech Lab, Alibaba Group, China
- SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
- https://arxiv.org/abs/2006.01713
- """
-
- def __init__(
- self,
- input_size: int,
- output_size: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.0,
- pos_enc_class=SinusoidalPositionEncoder,
- normalize_before: bool = True,
- concat_after: bool = False,
- positionwise_layer_type: str = "linear",
- positionwise_conv_kernel_size: int = 1,
- padding_idx: int = -1,
- interctc_layer_idx: List[int] = [],
- interctc_use_conditioning: bool = False,
- kernel_size: int = 11,
- sanm_shfit: int = 0,
- chunk_size: Union[int, Sequence[int]] = (16,),
- stride: Union[int, Sequence[int]] = (10,),
- pad_left: Union[int, Sequence[int]] = (0,),
- time_reduction_factor: int = 1,
- encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
- decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
- tf2torch_tensor_name_prefix_torch: str = "encoder",
- tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
- ):
- assert check_argument_types()
- super().__init__()
- self.output_size = output_size
-
- self.embed = SinusoidalPositionEncoder()
-
- self.normalize_before = normalize_before
- if positionwise_layer_type == "linear":
- positionwise_layer = PositionwiseFeedForward
- positionwise_layer_args = (
- output_size,
- linear_units,
- dropout_rate,
- )
- elif positionwise_layer_type == "conv1d":
- positionwise_layer = MultiLayeredConv1d
- positionwise_layer_args = (
- output_size,
- linear_units,
- positionwise_conv_kernel_size,
- dropout_rate,
- )
- elif positionwise_layer_type == "conv1d-linear":
- positionwise_layer = Conv1dLinear
- positionwise_layer_args = (
- output_size,
- linear_units,
- positionwise_conv_kernel_size,
- dropout_rate,
- )
- else:
- raise NotImplementedError("Support only linear or conv1d.")
-
- encoder_selfattn_layer = MultiHeadedAttentionSANM
- encoder_selfattn_layer_args0 = (
- attention_heads,
- input_size,
- output_size,
- attention_dropout_rate,
- kernel_size,
- sanm_shfit,
- )
-
- encoder_selfattn_layer_args = (
- attention_heads,
- output_size,
- output_size,
- attention_dropout_rate,
- kernel_size,
- sanm_shfit,
- )
- self.encoders0 = repeat(
- 1,
- lambda lnum: EncoderLayerSANM(
- input_size,
- output_size,
- encoder_selfattn_layer(*encoder_selfattn_layer_args0),
- positionwise_layer(*positionwise_layer_args),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
-
- self.encoders = repeat(
- num_blocks - 1,
- lambda lnum: EncoderLayerSANM(
- output_size,
- output_size,
- encoder_selfattn_layer(*encoder_selfattn_layer_args),
- positionwise_layer(*positionwise_layer_args),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
- if self.normalize_before:
- self.after_norm = LayerNorm(output_size)
-
- self.interctc_layer_idx = interctc_layer_idx
- if len(interctc_layer_idx) > 0:
- assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
- self.interctc_use_conditioning = interctc_use_conditioning
- self.conditioning_layer = None
- shfit_fsmn = (kernel_size - 1) // 2
- self.overlap_chunk_cls = overlap_chunk(
- chunk_size=chunk_size,
- stride=stride,
- pad_left=pad_left,
- shfit_fsmn=shfit_fsmn,
- encoder_att_look_back_factor=encoder_att_look_back_factor,
- decoder_att_look_back_factor=decoder_att_look_back_factor,
- )
- self.time_reduction_factor = time_reduction_factor
- self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
- self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
-
- def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- prev_states: torch.Tensor = None,
- ctc: CTC = None,
- ind: int = 0,
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
- """Embed positions in tensor.
- Args:
- xs_pad: input tensor (B, L, D)
- ilens: input length (B)
- prev_states: Not to be used now.
- Returns:
- position embedded tensor and mask
- """
- masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
- xs_pad *= self.output_size ** 0.5
- if self.embed is None:
- xs_pad = xs_pad
- elif (
- isinstance(self.embed, Conv2dSubsampling)
- or isinstance(self.embed, Conv2dSubsampling2)
- or isinstance(self.embed, Conv2dSubsampling6)
- or isinstance(self.embed, Conv2dSubsampling8)
- ):
- short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
- if short_status:
- raise TooShortUttError(
- f"has {xs_pad.size(1)} frames and is too short for subsampling "
- + f"(it needs more than {limit_size} frames), return empty results",
- xs_pad.size(1),
- limit_size,
- )
- xs_pad, masks = self.embed(xs_pad, masks)
- else:
- xs_pad = self.embed(xs_pad)
-
- mask_shfit_chunk, mask_att_chunk_encoder = None, None
- if self.overlap_chunk_cls is not None:
- ilens = masks.squeeze(1).sum(1)
- chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
- xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs)
- masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
- mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0),
- dtype=xs_pad.dtype)
- mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device,
- xs_pad.size(0),
- dtype=xs_pad.dtype)
-
- encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
- intermediate_outs = []
- if len(self.interctc_layer_idx) == 0:
- encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
- else:
- for layer_idx, encoder_layer in enumerate(self.encoders):
- encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
- if layer_idx + 1 in self.interctc_layer_idx:
- encoder_out = xs_pad
-
- # intermediate outputs are also normalized
- if self.normalize_before:
- encoder_out = self.after_norm(encoder_out)
-
- intermediate_outs.append((layer_idx + 1, encoder_out))
-
- if self.interctc_use_conditioning:
- ctc_out = ctc.softmax(encoder_out)
- xs_pad = xs_pad + self.conditioning_layer(ctc_out)
-
- if self.normalize_before:
- xs_pad = self.after_norm(xs_pad)
-
- olens = masks.squeeze(1).sum(1)
-
- xs_pad, olens = self.overlap_chunk_cls.remove_chunk(xs_pad, olens, chunk_outs=None)
-
- if self.time_reduction_factor > 1:
- xs_pad = xs_pad[:,::self.time_reduction_factor,:]
- olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
-
- if len(intermediate_outs) > 0:
- return (xs_pad, intermediate_outs), olens, None
- return xs_pad, olens
-
- def gen_tf2torch_map_dict(self):
- tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
- tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
- map_dict_local = {
- ## encoder
- # cicd
- "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (768,256),(1,256,768)
- "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (768,),(768,)
- "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 2, 0),
- }, # (256,1,31),(1,31,256,1)
- "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,256),(1,256,256)
- "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- # ffn
- "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,1024),(1,1024,256)
- "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- # out norm
- "{}.after_norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.after_norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
-
- }
-
- return map_dict_local
-
- def convert_tf2torch(self,
- var_dict_tf,
- var_dict_torch,
- ):
-
- map_dict = self.gen_tf2torch_map_dict()
-
- var_dict_torch_update = dict()
- for name in sorted(var_dict_torch.keys(), reverse=False):
- names = name.split('.')
- if names[0] == self.tf2torch_tensor_name_prefix_torch:
- if names[1] == "encoders0":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-
- name_q = name_q.replace("encoders0", "encoders")
- layeridx_bias = 0
- layeridx += layeridx_bias
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
- elif names[1] == "encoders":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
- layeridx_bias = 1
- layeridx += layeridx_bias
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "after_norm":
- name_tf = map_dict[name]["name"]
- data_tf = var_dict_tf[name_tf]
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
- var_dict_tf[name_tf].shape))
-
- return var_dict_torch_update
diff --git a/funasr/models_transducer/error_calculator.py b/funasr/models_transducer/error_calculator.py
deleted file mode 100644
index 34b1dc7..0000000
--- a/funasr/models_transducer/error_calculator.py
+++ /dev/null
@@ -1,169 +0,0 @@
-"""Error Calculator module for Transducer."""
-
-from typing import List, Optional, Tuple
-
-import torch
-
-from funasr.models_transducer.beam_search_transducer import BeamSearchTransducer
-from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
-from funasr.models_transducer.joint_network import JointNetwork
-
-
-class ErrorCalculator:
- """Calculate CER and WER for transducer models.
-
- Args:
- decoder: Decoder module.
- joint_network: Joint Network module.
- token_list: List of token units.
- sym_space: Space symbol.
- sym_blank: Blank symbol.
- report_cer: Whether to compute CER.
- report_wer: Whether to compute WER.
-
- """
-
- def __init__(
- self,
- decoder: AbsDecoder,
- joint_network: JointNetwork,
- token_list: List[int],
- sym_space: str,
- sym_blank: str,
- report_cer: bool = False,
- report_wer: bool = False,
- ) -> None:
- """Construct an ErrorCalculatorTransducer object."""
- super().__init__()
-
- self.beam_search = BeamSearchTransducer(
- decoder=decoder,
- joint_network=joint_network,
- beam_size=1,
- search_type="default",
- score_norm=False,
- )
-
- self.decoder = decoder
-
- self.token_list = token_list
- self.space = sym_space
- self.blank = sym_blank
-
- self.report_cer = report_cer
- self.report_wer = report_wer
-
- def __call__(
- self, encoder_out: torch.Tensor, target: torch.Tensor
- ) -> Tuple[Optional[float], Optional[float]]:
- """Calculate sentence-level WER or/and CER score for Transducer model.
-
- Args:
- encoder_out: Encoder output sequences. (B, T, D_enc)
- target: Target label ID sequences. (B, L)
-
- Returns:
- : Sentence-level CER score.
- : Sentence-level WER score.
-
- """
- cer, wer = None, None
-
- batchsize = int(encoder_out.size(0))
-
- encoder_out = encoder_out.to(next(self.decoder.parameters()).device)
-
- batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)]
- pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest]
-
- char_pred, char_target = self.convert_to_char(pred, target)
-
- if self.report_cer:
- cer = self.calculate_cer(char_pred, char_target)
-
- if self.report_wer:
- wer = self.calculate_wer(char_pred, char_target)
-
- return cer, wer
-
- def convert_to_char(
- self, pred: torch.Tensor, target: torch.Tensor
- ) -> Tuple[List, List]:
- """Convert label ID sequences to character sequences.
-
- Args:
- pred: Prediction label ID sequences. (B, U)
- target: Target label ID sequences. (B, L)
-
- Returns:
- char_pred: Prediction character sequences. (B, ?)
- char_target: Target character sequences. (B, ?)
-
- """
- char_pred, char_target = [], []
-
- for i, pred_i in enumerate(pred):
- char_pred_i = [self.token_list[int(h)] for h in pred_i]
- char_target_i = [self.token_list[int(r)] for r in target[i]]
-
- char_pred_i = "".join(char_pred_i).replace(self.space, " ")
- char_pred_i = char_pred_i.replace(self.blank, "")
-
- char_target_i = "".join(char_target_i).replace(self.space, " ")
- char_target_i = char_target_i.replace(self.blank, "")
-
- char_pred.append(char_pred_i)
- char_target.append(char_target_i)
-
- return char_pred, char_target
-
- def calculate_cer(
- self, char_pred: torch.Tensor, char_target: torch.Tensor
- ) -> float:
- """Calculate sentence-level CER score.
-
- Args:
- char_pred: Prediction character sequences. (B, ?)
- char_target: Target character sequences. (B, ?)
-
- Returns:
- : Average sentence-level CER score.
-
- """
- import editdistance
-
- distances, lens = [], []
-
- for i, char_pred_i in enumerate(char_pred):
- pred = char_pred_i.replace(" ", "")
- target = char_target[i].replace(" ", "")
- distances.append(editdistance.eval(pred, target))
- lens.append(len(target))
-
- return float(sum(distances)) / sum(lens)
-
- def calculate_wer(
- self, char_pred: torch.Tensor, char_target: torch.Tensor
- ) -> float:
- """Calculate sentence-level WER score.
-
- Args:
- char_pred: Prediction character sequences. (B, ?)
- char_target: Target character sequences. (B, ?)
-
- Returns:
- : Average sentence-level WER score
-
- """
- import editdistance
-
- distances, lens = [], []
-
- for i, char_pred_i in enumerate(char_pred):
- pred = char_pred_i.replace("鈻�", " ").split()
- target = char_target[i].replace("鈻�", " ").split()
-
- distances.append(editdistance.eval(pred, target))
- lens.append(len(target))
-
- return float(sum(distances)) / sum(lens)
diff --git a/funasr/models_transducer/espnet_transducer_model_uni_asr.py b/funasr/models_transducer/espnet_transducer_model_uni_asr.py
deleted file mode 100644
index 2add3fa..0000000
--- a/funasr/models_transducer/espnet_transducer_model_uni_asr.py
+++ /dev/null
@@ -1,485 +0,0 @@
-"""ESPnet2 ASR Transducer model."""
-
-import logging
-from contextlib import contextmanager
-from typing import Dict, List, Optional, Tuple, Union
-
-import torch
-from packaging.version import parse as V
-from typeguard import check_argument_types
-
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
-from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
-from funasr.models_transducer.encoder.encoder import Encoder
-from funasr.models_transducer.joint_network import JointNetwork
-from funasr.models_transducer.utils import get_transducer_task_io
-from funasr.layers.abs_normalize import AbsNormalize
-from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
-
-if V(torch.__version__) >= V("1.6.0"):
- from torch.cuda.amp import autocast
-else:
-
- @contextmanager
- def autocast(enabled=True):
- yield
-
-
-class UniASRTransducerModel(AbsESPnetModel):
- """ESPnet2ASRTransducerModel module definition.
-
- Args:
- vocab_size: Size of complete vocabulary (w/ EOS and blank included).
- token_list: List of token
- frontend: Frontend module.
- specaug: SpecAugment module.
- normalize: Normalization module.
- encoder: Encoder module.
- decoder: Decoder module.
- joint_network: Joint Network module.
- transducer_weight: Weight of the Transducer loss.
- fastemit_lambda: FastEmit lambda value.
- auxiliary_ctc_weight: Weight of auxiliary CTC loss.
- auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs.
- auxiliary_lm_loss_weight: Weight of auxiliary LM loss.
- auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing.
- ignore_id: Initial padding ID.
- sym_space: Space symbol.
- sym_blank: Blank Symbol
- report_cer: Whether to report Character Error Rate during validation.
- report_wer: Whether to report Word Error Rate during validation.
- extract_feats_in_collect_stats: Whether to use extract_feats stats collection.
-
- """
-
- def __init__(
- self,
- vocab_size: int,
- token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[AbsFrontend],
- specaug: Optional[AbsSpecAug],
- normalize: Optional[AbsNormalize],
- encoder,
- decoder: AbsDecoder,
- att_decoder: Optional[AbsAttDecoder],
- joint_network: JointNetwork,
- transducer_weight: float = 1.0,
- fastemit_lambda: float = 0.0,
- auxiliary_ctc_weight: float = 0.0,
- auxiliary_ctc_dropout_rate: float = 0.0,
- auxiliary_lm_loss_weight: float = 0.0,
- auxiliary_lm_loss_smoothing: float = 0.0,
- ignore_id: int = -1,
- sym_space: str = "<space>",
- sym_blank: str = "<blank>",
- report_cer: bool = True,
- report_wer: bool = True,
- extract_feats_in_collect_stats: bool = True,
- ) -> None:
- """Construct an ESPnetASRTransducerModel object."""
- super().__init__()
-
- assert check_argument_types()
-
- # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
- self.blank_id = 0
- self.vocab_size = vocab_size
- self.ignore_id = ignore_id
- self.token_list = token_list.copy()
-
- self.sym_space = sym_space
- self.sym_blank = sym_blank
-
- self.frontend = frontend
- self.specaug = specaug
- self.normalize = normalize
-
- self.encoder = encoder
- self.decoder = decoder
- self.joint_network = joint_network
-
- self.criterion_transducer = None
- self.error_calculator = None
-
- self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
- self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
-
- if self.use_auxiliary_ctc:
- self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size)
- self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
-
- if self.use_auxiliary_lm_loss:
- self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
- self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
-
- self.transducer_weight = transducer_weight
- self.fastemit_lambda = fastemit_lambda
-
- self.auxiliary_ctc_weight = auxiliary_ctc_weight
- self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
-
- self.report_cer = report_cer
- self.report_wer = report_wer
-
- self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
-
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- decoding_ind: int = None,
- **kwargs,
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Forward architecture and compute loss(es).
-
- Args:
- speech: Speech sequences. (B, S)
- speech_lengths: Speech sequences lengths. (B,)
- text: Label ID sequences. (B, L)
- text_lengths: Label ID sequences lengths. (B,)
- kwargs: Contains "utts_id".
-
- Return:
- loss: Main loss value.
- stats: Task statistics.
- weight: Task weights.
-
- """
- assert text_lengths.dim() == 1, text_lengths.shape
- assert (
- speech.shape[0]
- == speech_lengths.shape[0]
- == text.shape[0]
- == text_lengths.shape[0]
- ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
-
- batch_size = speech.shape[0]
- text = text[:, : text_lengths.max()]
-
- # 1. Encoder
- ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
- # 2. Transducer-related I/O preparation
- decoder_in, target, t_len, u_len = get_transducer_task_io(
- text,
- encoder_out_lens,
- ignore_id=self.ignore_id,
- )
-
- # 3. Decoder
- self.decoder.set_device(encoder_out.device)
- decoder_out = self.decoder(decoder_in, u_len)
-
- # 4. Joint Network
- joint_out = self.joint_network(
- encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
- )
-
- # 5. Losses
- loss_trans, cer_trans, wer_trans = self._calc_transducer_loss(
- encoder_out,
- joint_out,
- target,
- t_len,
- u_len,
- )
-
- loss_ctc, loss_lm = 0.0, 0.0
-
- if self.use_auxiliary_ctc:
- loss_ctc = self._calc_ctc_loss(
- encoder_out,
- target,
- t_len,
- u_len,
- )
-
- if self.use_auxiliary_lm_loss:
- loss_lm = self._calc_lm_loss(decoder_out, target)
-
- loss = (
- self.transducer_weight * loss_trans
- + self.auxiliary_ctc_weight * loss_ctc
- + self.auxiliary_lm_loss_weight * loss_lm
- )
-
- stats = dict(
- loss=loss.detach(),
- loss_transducer=loss_trans.detach(),
- aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
- aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
- cer_transducer=cer_trans,
- wer_transducer=wer_trans,
- )
-
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
-
- return loss, stats, weight
-
- def collect_feats(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- **kwargs,
- ) -> Dict[str, torch.Tensor]:
- """Collect features sequences and features lengths sequences.
-
- Args:
- speech: Speech sequences. (B, S)
- speech_lengths: Speech sequences lengths. (B,)
- text: Label ID sequences. (B, L)
- text_lengths: Label ID sequences lengths. (B,)
- kwargs: Contains "utts_id".
-
- Return:
- {}: "feats": Features sequences. (B, T, D_feats),
- "feats_lengths": Features sequences lengths. (B,)
-
- """
- if self.extract_feats_in_collect_stats:
- feats, feats_lengths = self._extract_feats(speech, speech_lengths)
- else:
- # Generate dummy stats if extract_feats_in_collect_stats is False
- logging.warning(
- "Generating dummy stats for feats and feats_lengths, "
- "because encoder_conf.extract_feats_in_collect_stats is "
- f"{self.extract_feats_in_collect_stats}"
- )
-
- feats, feats_lengths = speech, speech_lengths
-
- return {"feats": feats, "feats_lengths": feats_lengths}
-
- def encode(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- ind: int,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Encoder speech sequences.
-
- Args:
- speech: Speech sequences. (B, S)
- speech_lengths: Speech sequences lengths. (B,)
-
- Return:
- encoder_out: Encoder outputs. (B, T, D_enc)
- encoder_out_lens: Encoder outputs lengths. (B,)
-
- """
- with autocast(False):
- # 1. Extract feats
- feats, feats_lengths = self._extract_feats(speech, speech_lengths)
-
- # 2. Data augmentation
- if self.specaug is not None and self.training:
- feats, feats_lengths = self.specaug(feats, feats_lengths)
-
- # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
- if self.normalize is not None:
- feats, feats_lengths = self.normalize(feats, feats_lengths)
-
- # 4. Forward encoder
- encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths, ind=ind)
-
- assert encoder_out.size(0) == speech.size(0), (
- encoder_out.size(),
- speech.size(0),
- )
- assert encoder_out.size(1) <= encoder_out_lens.max(), (
- encoder_out.size(),
- encoder_out_lens.max(),
- )
-
- return encoder_out, encoder_out_lens
-
- def _extract_feats(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Extract features sequences and features sequences lengths.
-
- Args:
- speech: Speech sequences. (B, S)
- speech_lengths: Speech sequences lengths. (B,)
-
- Return:
- feats: Features sequences. (B, T, D_feats)
- feats_lengths: Features sequences lengths. (B,)
-
- """
- assert speech_lengths.dim() == 1, speech_lengths.shape
-
- # for data-parallel
- speech = speech[:, : speech_lengths.max()]
-
- if self.frontend is not None:
- feats, feats_lengths = self.frontend(speech, speech_lengths)
- else:
- feats, feats_lengths = speech, speech_lengths
-
- return feats, feats_lengths
-
- def _calc_transducer_loss(
- self,
- encoder_out: torch.Tensor,
- joint_out: torch.Tensor,
- target: torch.Tensor,
- t_len: torch.Tensor,
- u_len: torch.Tensor,
- ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
- """Compute Transducer loss.
-
- Args:
- encoder_out: Encoder output sequences. (B, T, D_enc)
- joint_out: Joint Network output sequences (B, T, U, D_joint)
- target: Target label ID sequences. (B, L)
- t_len: Encoder output sequences lengths. (B,)
- u_len: Target label ID sequences lengths. (B,)
-
- Return:
- loss_transducer: Transducer loss value.
- cer_transducer: Character error rate for Transducer.
- wer_transducer: Word Error Rate for Transducer.
-
- """
- if self.criterion_transducer is None:
- try:
- # from warprnnt_pytorch import RNNTLoss
- # self.criterion_transducer = RNNTLoss(
- # reduction="mean",
- # fastemit_lambda=self.fastemit_lambda,
- # )
- from warp_rnnt import rnnt_loss as RNNTLoss
- self.criterion_transducer = RNNTLoss
-
- except ImportError:
- logging.error(
- "warp-rnnt was not installed."
- "Please consult the installation documentation."
- )
- exit(1)
-
- # loss_transducer = self.criterion_transducer(
- # joint_out,
- # target,
- # t_len,
- # u_len,
- # )
- log_probs = torch.log_softmax(joint_out, dim=-1)
-
- loss_transducer = self.criterion_transducer(
- log_probs,
- target,
- t_len,
- u_len,
- reduction="mean",
- blank=self.blank_id,
- gather=True,
- )
-
- if not self.training and (self.report_cer or self.report_wer):
- if self.error_calculator is None:
- from espnet2.asr_transducer.error_calculator import ErrorCalculator
-
- self.error_calculator = ErrorCalculator(
- self.decoder,
- self.joint_network,
- self.token_list,
- self.sym_space,
- self.sym_blank,
- report_cer=self.report_cer,
- report_wer=self.report_wer,
- )
-
- cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
-
- return loss_transducer, cer_transducer, wer_transducer
-
- return loss_transducer, None, None
-
- def _calc_ctc_loss(
- self,
- encoder_out: torch.Tensor,
- target: torch.Tensor,
- t_len: torch.Tensor,
- u_len: torch.Tensor,
- ) -> torch.Tensor:
- """Compute CTC loss.
-
- Args:
- encoder_out: Encoder output sequences. (B, T, D_enc)
- target: Target label ID sequences. (B, L)
- t_len: Encoder output sequences lengths. (B,)
- u_len: Target label ID sequences lengths. (B,)
-
- Return:
- loss_ctc: CTC loss value.
-
- """
- ctc_in = self.ctc_lin(
- torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
- )
- ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
-
- target_mask = target != 0
- ctc_target = target[target_mask].cpu()
-
- with torch.backends.cudnn.flags(deterministic=True):
- loss_ctc = torch.nn.functional.ctc_loss(
- ctc_in,
- ctc_target,
- t_len,
- u_len,
- zero_infinity=True,
- reduction="sum",
- )
- loss_ctc /= target.size(0)
-
- return loss_ctc
-
- def _calc_lm_loss(
- self,
- decoder_out: torch.Tensor,
- target: torch.Tensor,
- ) -> torch.Tensor:
- """Compute LM loss.
-
- Args:
- decoder_out: Decoder output sequences. (B, U, D_dec)
- target: Target label ID sequences. (B, L)
-
- Return:
- loss_lm: LM loss value.
-
- """
- lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
- lm_target = target.view(-1).type(torch.int64)
-
- with torch.no_grad():
- true_dist = lm_loss_in.clone()
- true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
-
- # Ignore blank ID (0)
- ignore = lm_target == 0
- lm_target = lm_target.masked_fill(ignore, 0)
-
- true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
-
- loss_lm = torch.nn.functional.kl_div(
- torch.log_softmax(lm_loss_in, dim=1),
- true_dist,
- reduction="none",
- )
- loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
- 0
- )
-
- return loss_lm
diff --git a/funasr/models_transducer/utils.py b/funasr/models_transducer/utils.py
deleted file mode 100644
index fd3c531..0000000
--- a/funasr/models_transducer/utils.py
+++ /dev/null
@@ -1,200 +0,0 @@
-"""Utility functions for Transducer models."""
-
-from typing import List, Tuple
-
-import torch
-
-
-class TooShortUttError(Exception):
- """Raised when the utt is too short for subsampling.
-
- Args:
- message: Error message to display.
- actual_size: The size that cannot pass the subsampling.
- limit: The size limit for subsampling.
-
- """
-
- def __init__(self, message: str, actual_size: int, limit: int) -> None:
- """Construct a TooShortUttError module."""
- super().__init__(message)
-
- self.actual_size = actual_size
- self.limit = limit
-
-
-def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]:
- """Check if the input is too short for subsampling.
-
- Args:
- sub_factor: Subsampling factor for Conv2DSubsampling.
- size: Input size.
-
- Returns:
- : Whether an error should be sent.
- : Size limit for specified subsampling factor.
-
- """
- if sub_factor == 2 and size < 3:
- return True, 7
- elif sub_factor == 4 and size < 7:
- return True, 7
- elif sub_factor == 6 and size < 11:
- return True, 11
-
- return False, -1
-
-
-def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]:
- """Get conv2D second layer parameters for given subsampling factor.
-
- Args:
- sub_factor: Subsampling factor (1/X).
- input_size: Input size.
-
- Returns:
- : Kernel size for second convolution.
- : Stride for second convolution.
- : Conv2DSubsampling output size.
-
- """
- if sub_factor == 2:
- return 3, 1, (((input_size - 1) // 2 - 2))
- elif sub_factor == 4:
- return 3, 2, (((input_size - 1) // 2 - 1) // 2)
- elif sub_factor == 6:
- return 5, 3, (((input_size - 1) // 2 - 2) // 3)
- else:
- raise ValueError(
- "subsampling_factor parameter should be set to either 2, 4 or 6."
- )
-
-
-def make_chunk_mask(
- size: int,
- chunk_size: int,
- left_chunk_size: int = 0,
- device: torch.device = None,
-) -> torch.Tensor:
- """Create chunk mask for the subsequent steps (size, size).
-
- Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
-
- Args:
- size: Size of the source mask.
- chunk_size: Number of frames in chunk.
- left_chunk_size: Size of the left context in chunks (0 means full context).
- device: Device for the mask tensor.
-
- Returns:
- mask: Chunk mask. (size, size)
-
- """
- mask = torch.zeros(size, size, device=device, dtype=torch.bool)
-
- for i in range(size):
- if left_chunk_size <= 0:
- start = 0
- else:
- start = max((i // chunk_size - left_chunk_size) * chunk_size, 0)
-
- end = min((i // chunk_size + 1) * chunk_size, size)
- mask[i, start:end] = True
-
- return ~mask
-
-
-def make_source_mask(lengths: torch.Tensor) -> torch.Tensor:
- """Create source mask for given lengths.
-
- Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
-
- Args:
- lengths: Sequence lengths. (B,)
-
- Returns:
- : Mask for the sequence lengths. (B, max_len)
-
- """
- max_len = lengths.max()
- batch_size = lengths.size(0)
-
- expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths)
-
- return expanded_lengths >= lengths.unsqueeze(1)
-
-
-def get_transducer_task_io(
- labels: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ignore_id: int = -1,
- blank_id: int = 0,
-) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- """Get Transducer loss I/O.
-
- Args:
- labels: Label ID sequences. (B, L)
- encoder_out_lens: Encoder output lengths. (B,)
- ignore_id: Padding symbol ID.
- blank_id: Blank symbol ID.
-
- Returns:
- decoder_in: Decoder inputs. (B, U)
- target: Target label ID sequences. (B, U)
- t_len: Time lengths. (B,)
- u_len: Label lengths. (B,)
-
- """
-
- def pad_list(labels: List[torch.Tensor], padding_value: int = 0):
- """Create padded batch of labels from a list of labels sequences.
-
- Args:
- labels: Labels sequences. [B x (?)]
- padding_value: Padding value.
-
- Returns:
- labels: Batch of padded labels sequences. (B,)
-
- """
- batch_size = len(labels)
-
- padded = (
- labels[0]
- .new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:])
- .fill_(padding_value)
- )
-
- for i in range(batch_size):
- padded[i, : labels[i].size(0)] = labels[i]
-
- return padded
-
- device = labels.device
-
- labels_unpad = [y[y != ignore_id] for y in labels]
- blank = labels[0].new([blank_id])
-
- decoder_in = pad_list(
- [torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id
- ).to(device)
-
- target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device)
-
- encoder_out_lens = list(map(int, encoder_out_lens))
- t_len = torch.IntTensor(encoder_out_lens).to(device)
-
- u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device)
-
- return decoder_in, target, t_len, u_len
-
-def pad_to_len(t: torch.Tensor, pad_len: int, dim: int):
- """Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros."""
- if t.size(dim) == pad_len:
- return t
- else:
- pad_size = list(t.shape)
- pad_size[dim] = pad_len - t.size(dim)
- return torch.cat(
- [t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim
- )
diff --git a/funasr/models_transducer/activation.py b/funasr/modules/activation.py
similarity index 100%
rename from funasr/models_transducer/activation.py
rename to funasr/modules/activation.py
diff --git a/funasr/models_transducer/beam_search_transducer.py b/funasr/modules/beam_search/beam_search_transducer.py
similarity index 99%
rename from funasr/models_transducer/beam_search_transducer.py
rename to funasr/modules/beam_search/beam_search_transducer.py
index 8e234e4..eaf5627 100644
--- a/funasr/models_transducer/beam_search_transducer.py
+++ b/funasr/modules/beam_search/beam_search_transducer.py
@@ -6,8 +6,8 @@
import numpy as np
import torch
-from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
-from funasr.models_transducer.joint_network import JointNetwork
+from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
+from funasr.models.joint_network import JointNetwork
@dataclass
diff --git a/funasr/modules/e2e_asr_common.py b/funasr/modules/e2e_asr_common.py
index 92f9079..9b5039c 100644
--- a/funasr/modules/e2e_asr_common.py
+++ b/funasr/modules/e2e_asr_common.py
@@ -6,6 +6,8 @@
"""Common functions for ASR."""
+from typing import List, Optional, Tuple
+
import json
import logging
import sys
@@ -13,7 +15,11 @@
from itertools import groupby
import numpy as np
import six
+import torch
+from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
+from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
+from funasr.models.joint_network import JointNetwork
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
"""End detection.
@@ -247,3 +253,148 @@
word_eds.append(editdistance.eval(hyp_words, ref_words))
word_ref_lens.append(len(ref_words))
return float(sum(word_eds)) / sum(word_ref_lens)
+
+class ErrorCalculatorTransducer:
+ """Calculate CER and WER for transducer models.
+ Args:
+ decoder: Decoder module.
+ joint_network: Joint Network module.
+ token_list: List of token units.
+ sym_space: Space symbol.
+ sym_blank: Blank symbol.
+ report_cer: Whether to compute CER.
+ report_wer: Whether to compute WER.
+ """
+
+ def __init__(
+ self,
+ decoder: AbsDecoder,
+ joint_network: JointNetwork,
+ token_list: List[int],
+ sym_space: str,
+ sym_blank: str,
+ report_cer: bool = False,
+ report_wer: bool = False,
+ ) -> None:
+ """Construct an ErrorCalculatorTransducer object."""
+ super().__init__()
+
+ self.beam_search = BeamSearchTransducer(
+ decoder=decoder,
+ joint_network=joint_network,
+ beam_size=1,
+ search_type="default",
+ score_norm=False,
+ )
+
+ self.decoder = decoder
+
+ self.token_list = token_list
+ self.space = sym_space
+ self.blank = sym_blank
+
+ self.report_cer = report_cer
+ self.report_wer = report_wer
+
+ def __call__(
+ self, encoder_out: torch.Tensor, target: torch.Tensor
+ ) -> Tuple[Optional[float], Optional[float]]:
+ """Calculate sentence-level WER or/and CER score for Transducer model.
+ Args:
+ encoder_out: Encoder output sequences. (B, T, D_enc)
+ target: Target label ID sequences. (B, L)
+ Returns:
+ : Sentence-level CER score.
+ : Sentence-level WER score.
+ """
+ cer, wer = None, None
+
+ batchsize = int(encoder_out.size(0))
+
+ encoder_out = encoder_out.to(next(self.decoder.parameters()).device)
+
+ batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)]
+ pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest]
+
+ char_pred, char_target = self.convert_to_char(pred, target)
+
+ if self.report_cer:
+ cer = self.calculate_cer(char_pred, char_target)
+
+ if self.report_wer:
+ wer = self.calculate_wer(char_pred, char_target)
+
+ return cer, wer
+
+ def convert_to_char(
+ self, pred: torch.Tensor, target: torch.Tensor
+ ) -> Tuple[List, List]:
+ """Convert label ID sequences to character sequences.
+ Args:
+ pred: Prediction label ID sequences. (B, U)
+ target: Target label ID sequences. (B, L)
+ Returns:
+ char_pred: Prediction character sequences. (B, ?)
+ char_target: Target character sequences. (B, ?)
+ """
+ char_pred, char_target = [], []
+
+ for i, pred_i in enumerate(pred):
+ char_pred_i = [self.token_list[int(h)] for h in pred_i]
+ char_target_i = [self.token_list[int(r)] for r in target[i]]
+
+ char_pred_i = "".join(char_pred_i).replace(self.space, " ")
+ char_pred_i = char_pred_i.replace(self.blank, "")
+
+ char_target_i = "".join(char_target_i).replace(self.space, " ")
+ char_target_i = char_target_i.replace(self.blank, "")
+
+ char_pred.append(char_pred_i)
+ char_target.append(char_target_i)
+
+ return char_pred, char_target
+
+ def calculate_cer(
+ self, char_pred: torch.Tensor, char_target: torch.Tensor
+ ) -> float:
+ """Calculate sentence-level CER score.
+ Args:
+ char_pred: Prediction character sequences. (B, ?)
+ char_target: Target character sequences. (B, ?)
+ Returns:
+ : Average sentence-level CER score.
+ """
+ import editdistance
+
+ distances, lens = [], []
+
+ for i, char_pred_i in enumerate(char_pred):
+ pred = char_pred_i.replace(" ", "")
+ target = char_target[i].replace(" ", "")
+ distances.append(editdistance.eval(pred, target))
+ lens.append(len(target))
+
+ return float(sum(distances)) / sum(lens)
+
+ def calculate_wer(
+ self, char_pred: torch.Tensor, char_target: torch.Tensor
+ ) -> float:
+ """Calculate sentence-level WER score.
+ Args:
+ char_pred: Prediction character sequences. (B, ?)
+ char_target: Target character sequences. (B, ?)
+ Returns:
+ : Average sentence-level WER score
+ """
+ import editdistance
+
+ distances, lens = [], []
+
+ for i, char_pred_i in enumerate(char_pred):
+ pred = char_pred_i.replace("鈻�", " ").split()
+ target = char_target[i].replace("鈻�", " ").split()
+
+ distances.append(editdistance.eval(pred, target))
+ lens.append(len(target))
+
+ return float(sum(distances)) / sum(lens)
diff --git a/funasr/modules/nets_utils.py b/funasr/modules/nets_utils.py
index 6d77d69..5d4fe1c 100644
--- a/funasr/modules/nets_utils.py
+++ b/funasr/modules/nets_utils.py
@@ -3,7 +3,7 @@
"""Network related utility tools."""
import logging
-from typing import Dict
+from typing import Dict, List, Tuple
import numpy as np
import torch
@@ -506,3 +506,196 @@
}
return activation_funcs[act]()
+
+class TooShortUttError(Exception):
+ """Raised when the utt is too short for subsampling.
+
+ Args:
+ message: Error message to display.
+ actual_size: The size that cannot pass the subsampling.
+ limit: The size limit for subsampling.
+
+ """
+
+ def __init__(self, message: str, actual_size: int, limit: int) -> None:
+ """Construct a TooShortUttError module."""
+ super().__init__(message)
+
+ self.actual_size = actual_size
+ self.limit = limit
+
+
+def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]:
+ """Check if the input is too short for subsampling.
+
+ Args:
+ sub_factor: Subsampling factor for Conv2DSubsampling.
+ size: Input size.
+
+ Returns:
+ : Whether an error should be sent.
+ : Size limit for specified subsampling factor.
+
+ """
+ if sub_factor == 2 and size < 3:
+ return True, 7
+ elif sub_factor == 4 and size < 7:
+ return True, 7
+ elif sub_factor == 6 and size < 11:
+ return True, 11
+
+ return False, -1
+
+
+def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]:
+ """Get conv2D second layer parameters for given subsampling factor.
+
+ Args:
+ sub_factor: Subsampling factor (1/X).
+ input_size: Input size.
+
+ Returns:
+ : Kernel size for second convolution.
+ : Stride for second convolution.
+ : Conv2DSubsampling output size.
+
+ """
+ if sub_factor == 2:
+ return 3, 1, (((input_size - 1) // 2 - 2))
+ elif sub_factor == 4:
+ return 3, 2, (((input_size - 1) // 2 - 1) // 2)
+ elif sub_factor == 6:
+ return 5, 3, (((input_size - 1) // 2 - 2) // 3)
+ else:
+ raise ValueError(
+ "subsampling_factor parameter should be set to either 2, 4 or 6."
+ )
+
+
+def make_chunk_mask(
+ size: int,
+ chunk_size: int,
+ left_chunk_size: int = 0,
+ device: torch.device = None,
+) -> torch.Tensor:
+ """Create chunk mask for the subsequent steps (size, size).
+
+ Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
+
+ Args:
+ size: Size of the source mask.
+ chunk_size: Number of frames in chunk.
+ left_chunk_size: Size of the left context in chunks (0 means full context).
+ device: Device for the mask tensor.
+
+ Returns:
+ mask: Chunk mask. (size, size)
+
+ """
+ mask = torch.zeros(size, size, device=device, dtype=torch.bool)
+
+ for i in range(size):
+ if left_chunk_size <= 0:
+ start = 0
+ else:
+ start = max((i // chunk_size - left_chunk_size) * chunk_size, 0)
+
+ end = min((i // chunk_size + 1) * chunk_size, size)
+ mask[i, start:end] = True
+
+ return ~mask
+
+def make_source_mask(lengths: torch.Tensor) -> torch.Tensor:
+ """Create source mask for given lengths.
+
+ Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
+
+ Args:
+ lengths: Sequence lengths. (B,)
+
+ Returns:
+ : Mask for the sequence lengths. (B, max_len)
+
+ """
+ max_len = lengths.max()
+ batch_size = lengths.size(0)
+
+ expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths)
+
+ return expanded_lengths >= lengths.unsqueeze(1)
+
+
+def get_transducer_task_io(
+ labels: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ignore_id: int = -1,
+ blank_id: int = 0,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Get Transducer loss I/O.
+
+ Args:
+ labels: Label ID sequences. (B, L)
+ encoder_out_lens: Encoder output lengths. (B,)
+ ignore_id: Padding symbol ID.
+ blank_id: Blank symbol ID.
+
+ Returns:
+ decoder_in: Decoder inputs. (B, U)
+ target: Target label ID sequences. (B, U)
+ t_len: Time lengths. (B,)
+ u_len: Label lengths. (B,)
+
+ """
+
+ def pad_list(labels: List[torch.Tensor], padding_value: int = 0):
+ """Create padded batch of labels from a list of labels sequences.
+
+ Args:
+ labels: Labels sequences. [B x (?)]
+ padding_value: Padding value.
+
+ Returns:
+ labels: Batch of padded labels sequences. (B,)
+
+ """
+ batch_size = len(labels)
+
+ padded = (
+ labels[0]
+ .new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:])
+ .fill_(padding_value)
+ )
+
+ for i in range(batch_size):
+ padded[i, : labels[i].size(0)] = labels[i]
+
+ return padded
+
+ device = labels.device
+
+ labels_unpad = [y[y != ignore_id] for y in labels]
+ blank = labels[0].new([blank_id])
+
+ decoder_in = pad_list(
+ [torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id
+ ).to(device)
+
+ target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device)
+
+ encoder_out_lens = list(map(int, encoder_out_lens))
+ t_len = torch.IntTensor(encoder_out_lens).to(device)
+
+ u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device)
+
+ return decoder_in, target, t_len, u_len
+
+def pad_to_len(t: torch.Tensor, pad_len: int, dim: int):
+ """Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros."""
+ if t.size(dim) == pad_len:
+ return t
+ else:
+ pad_size = list(t.shape)
+ pad_size[dim] = pad_len - t.size(dim)
+ return torch.cat(
+ [t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim
+ )
diff --git a/funasr/tasks/asr_transducer.py b/funasr/tasks/asr_transducer.py
index be14455..cae18c1 100644
--- a/funasr/tasks/asr_transducer.py
+++ b/funasr/tasks/asr_transducer.py
@@ -21,15 +21,13 @@
LightweightConvolutionTransformerDecoder,
TransformerDecoder,
)
-from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
-from funasr.models_transducer.decoder.rnn_decoder import RNNDecoder
-from funasr.models_transducer.decoder.stateless_decoder import StatelessDecoder
-from funasr.models_transducer.encoder.encoder import Encoder
-from funasr.models_transducer.encoder.sanm_encoder import SANMEncoderChunkOpt
-from funasr.models_transducer.espnet_transducer_model import ESPnetASRTransducerModel
-from funasr.models_transducer.espnet_transducer_model_unified import ESPnetASRUnifiedTransducerModel
-from funasr.models_transducer.espnet_transducer_model_uni_asr import UniASRTransducerModel
-from funasr.models_transducer.joint_network import JointNetwork
+from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
+from funasr.models.rnnt_decoder.rnn_decoder import RNNDecoder
+from funasr.models.rnnt_decoder.stateless_decoder import StatelessDecoder
+from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
+from funasr.models.e2e_transducer import TransducerModel
+from funasr.models.e2e_transducer_unified import UnifiedTransducerModel
+from funasr.models.joint_network import JointNetwork
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
@@ -75,7 +73,6 @@
"encoder",
classes=dict(
encoder=Encoder,
- sanm_chunk_opt=SANMEncoderChunkOpt,
),
default="encoder",
)
@@ -158,7 +155,7 @@
group.add_argument(
"--model_conf",
action=NestedDictAction,
- default=get_default_kwargs(ESPnetASRTransducerModel),
+ default=get_default_kwargs(TransducerModel),
help="The keyword arguments for the model class.",
)
# group.add_argument(
@@ -354,7 +351,7 @@
return retval
@classmethod
- def build_model(cls, args: argparse.Namespace) -> ESPnetASRTransducerModel:
+ def build_model(cls, args: argparse.Namespace) -> TransducerModel:
"""Required data depending on task mode.
Args:
cls: ASRTransducerTask object.
@@ -440,22 +437,8 @@
# 7. Build model
- if getattr(args, "encoder", None) is not None and args.encoder == 'sanm_chunk_opt':
- model = UniASRTransducerModel(
- vocab_size=vocab_size,
- token_list=token_list,
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- encoder=encoder,
- decoder=decoder,
- att_decoder=att_decoder,
- joint_network=joint_network,
- **args.model_conf,
- )
-
- elif encoder.unified_model_training:
- model = ESPnetASRUnifiedTransducerModel(
+ if encoder.unified_model_training:
+ model = UnifiedTransducerModel(
vocab_size=vocab_size,
token_list=token_list,
frontend=frontend,
@@ -469,7 +452,7 @@
)
else:
- model = ESPnetASRTransducerModel(
+ model = TransducerModel(
vocab_size=vocab_size,
token_list=token_list,
frontend=frontend,
--
Gitblit v1.9.1