From 4ee715e70e36cdba7b05fe044fecab9cf4fa16ff Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 03 七月 2023 17:23:02 +0800
Subject: [PATCH] websocket bug
---
funasr/models/encoder/conformer_encoder.py | 84 ++++++++++++++++++++++++++++++------------
1 files changed, 60 insertions(+), 24 deletions(-)
diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py
index c837cf5..e5fac62 100644
--- a/funasr/models/encoder/conformer_encoder.py
+++ b/funasr/models/encoder/conformer_encoder.py
@@ -12,16 +12,15 @@
import torch
from torch import nn
-from typeguard import check_argument_types
from funasr.models.ctc import CTC
-from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.attention import (
MultiHeadedAttention, # noqa: H301
RelPositionMultiHeadedAttention, # noqa: H301
RelPositionMultiHeadedAttentionChunk,
LegacyRelPositionMultiHeadedAttention, # noqa: H301
)
+from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.embedding import (
PositionalEncoding, # noqa: H301
ScaledPositionalEncoding, # noqa: H301
@@ -30,7 +29,6 @@
StreamingRelPositionalEncoding,
)
from funasr.modules.layer_norm import LayerNorm
-from funasr.modules.normalization import get_normalization
from funasr.modules.multi_layer_conv import Conv1dLinear
from funasr.modules.multi_layer_conv import MultiLayeredConv1d
from funasr.modules.nets_utils import get_activation
@@ -308,7 +306,7 @@
feed_forward: torch.nn.Module,
feed_forward_macaron: torch.nn.Module,
conv_mod: torch.nn.Module,
- norm_class: torch.nn.Module = torch.nn.LayerNorm,
+ norm_class: torch.nn.Module = LayerNorm,
norm_args: Dict = {},
dropout_rate: float = 0.0,
) -> None:
@@ -534,7 +532,6 @@
interctc_use_conditioning: bool = False,
stochastic_depth_rate: Union[float, List[float]] = 0.0,
):
- assert check_argument_types()
super().__init__()
self._output_size = output_size
@@ -895,7 +892,7 @@
return x, cache
-class ConformerChunkEncoder(torch.nn.Module):
+class ConformerChunkEncoder(AbsEncoder):
"""Encoder module definition.
Args:
input_size: Input size.
@@ -940,12 +937,10 @@
default_chunk_size: int = 16,
jitter_range: int = 4,
subsampling_factor: int = 1,
- **activation_parameters,
) -> None:
"""Construct an Encoder object."""
super().__init__()
- assert check_argument_types()
self.embed = StreamingConvInput(
input_size,
@@ -961,7 +956,7 @@
)
activation = get_activation(
- activation_type, **activation_parameters
+ activation_type
)
pos_wise_args = (
@@ -991,9 +986,6 @@
simplified_att_score,
)
- norm_class, norm_args = get_normalization(
- norm_type,
- )
fn_modules = []
for _ in range(num_blocks):
@@ -1003,8 +995,6 @@
PositionwiseFeedForward(*pos_wise_args),
PositionwiseFeedForward(*pos_wise_args),
CausalConvolution(*conv_mod_args),
- norm_class=norm_class,
- norm_args=norm_args,
dropout_rate=dropout_rate,
)
fn_modules.append(module)
@@ -1012,11 +1002,9 @@
self.encoders = MultiBlocks(
[fn() for fn in fn_modules],
output_size,
- norm_class=norm_class,
- norm_args=norm_args,
)
- self.output_size = output_size
+ self._output_size = output_size
self.dynamic_chunk_training = dynamic_chunk_training
self.short_chunk_threshold = short_chunk_threshold
@@ -1028,6 +1016,9 @@
self.jitter_range = jitter_range
self.time_reduction_factor = time_reduction_factor
+
+ def output_size(self) -> int:
+ return self._output_size
def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
"""Return the corresponding number of sample for a given chunk size, in frames.
@@ -1084,10 +1075,13 @@
limit_size,
)
- mask = make_source_mask(x_len)
+ mask = make_source_mask(x_len).to(x.device)
if self.unified_model_training:
- chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+ if self.training:
+ chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+ else:
+ chunk_size = self.default_chunk_size
x, mask = self.embed(x, mask, chunk_size)
pos_enc = self.pos_enc(x)
chunk_mask = make_chunk_mask(
@@ -1119,12 +1113,15 @@
elif self.dynamic_chunk_training:
max_len = x.size(1)
- chunk_size = torch.randint(1, max_len, (1,)).item()
+ if self.training:
+ chunk_size = torch.randint(1, max_len, (1,)).item()
- if chunk_size > (max_len * self.short_chunk_threshold):
- chunk_size = max_len
+ if chunk_size > (max_len * self.short_chunk_threshold):
+ chunk_size = max_len
+ else:
+ chunk_size = (chunk_size % self.short_chunk_size) + 1
else:
- chunk_size = (chunk_size % self.short_chunk_size) + 1
+ chunk_size = self.default_chunk_size
x, mask = self.embed(x, mask, chunk_size)
pos_enc = self.pos_enc(x)
@@ -1151,7 +1148,46 @@
x = x[:,::self.time_reduction_factor,:]
olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
- return x, olens
+ return x, olens, None
+
+ def full_utt_forward(
+ self,
+ x: torch.Tensor,
+ x_len: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encode input sequences.
+ Args:
+ x: Encoder input features. (B, T_in, F)
+ x_len: Encoder input features lengths. (B,)
+ Returns:
+ x: Encoder outputs. (B, T_out, D_enc)
+ x_len: Encoder outputs lenghts. (B,)
+ """
+ short_status, limit_size = check_short_utt(
+ self.embed.subsampling_factor, x.size(1)
+ )
+
+ if short_status:
+ raise TooShortUttError(
+ f"has {x.size(1)} frames and is too short for subsampling "
+ + f"(it needs more than {limit_size} frames), return empty results",
+ x.size(1),
+ limit_size,
+ )
+
+ mask = make_source_mask(x_len).to(x.device)
+ x, mask = self.embed(x, mask, None)
+ pos_enc = self.pos_enc(x)
+ x_utt = self.encoders(
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=None,
+ )
+
+ if self.time_reduction_factor > 1:
+ x_utt = x_utt[:,::self.time_reduction_factor,:]
+ return x_utt
def simu_chunk_forward(
self,
--
Gitblit v1.9.1