From d0cd484fdc21c06b8bc892bb2ab1c2a25fb1da8a Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 31 三月 2023 15:05:37 +0800
Subject: [PATCH] export
---
funasr/export/models/vad_realtime_transformer.py | 7
funasr/bin/punctuation_infer.py | 2
funasr/models/encoder/sanm_encoder.py | 232 ++++++++++++
funasr/train/abs_model.py | 56 ++
funasr/tasks/lm.py | 8
funasr/export/models/target_delay_transformer.py | 87 ----
funasr/export/models/__init__.py | 8
funasr/models/vad_realtime_transformer.py | 2
funasr/tasks/punctuation.py | 14
/dev/null | 590 --------------------------------
funasr/punctuation/text_preprocessor.py | 13
funasr/lm/espnet_model.py | 2
funasr/models/target_delay_transformer.py | 3
funasr/datasets/preprocessor.py | 14
funasr/bin/punctuation_infer_vadrealtime.py | 2
15 files changed, 309 insertions(+), 731 deletions(-)
diff --git a/funasr/bin/punctuation_infer.py b/funasr/bin/punctuation_infer.py
index a801ee8..dd28ef8 100644
--- a/funasr/bin/punctuation_infer.py
+++ b/funasr/bin/punctuation_infer.py
@@ -23,7 +23,7 @@
from funasr.utils import config_argparse
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
-from funasr.punctuation.text_preprocessor import split_to_mini_sentence
+from funasr.datasets.preprocessor import split_to_mini_sentence
class Text2Punc:
diff --git a/funasr/bin/punctuation_infer_vadrealtime.py b/funasr/bin/punctuation_infer_vadrealtime.py
index ce1cee8..81f9d7a 100644
--- a/funasr/bin/punctuation_infer_vadrealtime.py
+++ b/funasr/bin/punctuation_infer_vadrealtime.py
@@ -23,7 +23,7 @@
from funasr.utils import config_argparse
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
-from funasr.punctuation.text_preprocessor import split_to_mini_sentence
+from funasr.datasets.preprocessor import split_to_mini_sentence
class Text2Punc:
diff --git a/funasr/datasets/preprocessor.py b/funasr/datasets/preprocessor.py
index 98cca1d..afeff4e 100644
--- a/funasr/datasets/preprocessor.py
+++ b/funasr/datasets/preprocessor.py
@@ -800,3 +800,17 @@
data[self.vad_name] = np.array([vad], dtype=np.int64)
text_ints = self.token_id_converter[i].tokens2ids(tokens)
data[text_name] = np.array(text_ints, dtype=np.int64)
+
+
+def split_to_mini_sentence(words: list, word_limit: int = 20):
+ assert word_limit > 1
+ if len(words) <= word_limit:
+ return [words]
+ sentences = []
+ length = len(words)
+ sentence_len = length // word_limit
+ for i in range(sentence_len):
+ sentences.append(words[i * word_limit:(i + 1) * word_limit])
+ if length % word_limit > 0:
+ sentences.append(words[sentence_len * word_limit:])
+ return sentences
\ No newline at end of file
diff --git a/funasr/export/models/__init__.py b/funasr/export/models/__init__.py
index 62ee723..4ac0456 100644
--- a/funasr/export/models/__init__.py
+++ b/funasr/export/models/__init__.py
@@ -3,10 +3,10 @@
from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
from funasr.models.e2e_vad import E2EVadModel
from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
-from funasr.punctuation.target_delay_transformer import TargetDelayTransformer
+from funasr.models.target_delay_transformer import TargetDelayTransformer
from funasr.export.models.target_delay_transformer import TargetDelayTransformer as TargetDelayTransformer_export
-from funasr.punctuation.espnet_model import ESPnetPunctuationModel
-from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer
+from funasr.train.abs_model import PunctuationModel
+from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
from funasr.export.models.vad_realtime_transformer import VadRealtimeTransformer as VadRealtimeTransformer_export
def get_model(model, export_config=None):
@@ -16,7 +16,7 @@
return Paraformer_export(model, **export_config)
elif isinstance(model, E2EVadModel):
return E2EVadModel_export(model, **export_config)
- elif isinstance(model, ESPnetPunctuationModel):
+ elif isinstance(model, PunctuationModel):
if isinstance(model.punc_model, TargetDelayTransformer):
return TargetDelayTransformer_export(model.punc_model, **export_config)
elif isinstance(model.punc_model, VadRealtimeTransformer):
diff --git a/funasr/export/models/target_delay_transformer.py b/funasr/export/models/target_delay_transformer.py
index fd90835..bfe3ec4 100644
--- a/funasr/export/models/target_delay_transformer.py
+++ b/funasr/export/models/target_delay_transformer.py
@@ -1,17 +1,7 @@
-from typing import Any
-from typing import List
from typing import Tuple
import torch
import torch.nn as nn
-
-from funasr.export.utils.torch_function import MakePadMask
-from funasr.export.utils.torch_function import sequence_mask
-#from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
-from funasr.punctuation.sanm_encoder import SANMEncoder
-from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
-from funasr.punctuation.abs_model import AbsPunctuation
-
class TargetDelayTransformer(nn.Module):
@@ -32,85 +22,10 @@
self.feats_dim = self.embed.embedding_dim
self.num_embeddings = self.embed.num_embeddings
self.model_name = model_name
- from typing import Any
- from typing import List
- from typing import Tuple
- import torch
- import torch.nn as nn
-
- from funasr.export.utils.torch_function import MakePadMask
- from funasr.export.utils.torch_function import sequence_mask
# from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
- from funasr.punctuation.sanm_encoder import SANMEncoder
+ from funasr.models.encoder.sanm_encoder import SANMEncoder
from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
- from funasr.punctuation.abs_model import AbsPunctuation
-
- # class TargetDelayTransformer(nn.Module):
- #
- # def __init__(
- # self,
- # model,
- # max_seq_len=512,
- # model_name='punc_model',
- # **kwargs,
- # ):
- # super().__init__()
- # onnx = False
- # if "onnx" in kwargs:
- # onnx = kwargs["onnx"]
- # self.embed = model.embed
- # self.decoder = model.decoder
- # self.model = model
- # self.feats_dim = self.embed.embedding_dim
- # self.num_embeddings = self.embed.num_embeddings
- # self.model_name = model_name
- #
- # if isinstance(model.encoder, SANMEncoder):
- # self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
- # else:
- # assert False, "Only support samn encode."
- #
- # def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
- # """Compute loss value from buffer sequences.
- #
- # Args:
- # input (torch.Tensor): Input ids. (batch, len)
- # hidden (torch.Tensor): Target ids. (batch, len)
- #
- # """
- # x = self.embed(input)
- # # mask = self._target_mask(input)
- # h, _ = self.encoder(x, text_lengths)
- # y = self.decoder(h)
- # return y
- #
- # def get_dummy_inputs(self):
- # length = 120
- # text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
- # text_lengths = torch.tensor([length - 20, length], dtype=torch.int32)
- # return (text_indexes, text_lengths)
- #
- # def get_input_names(self):
- # return ['input', 'text_lengths']
- #
- # def get_output_names(self):
- # return ['logits']
- #
- # def get_dynamic_axes(self):
- # return {
- # 'input': {
- # 0: 'batch_size',
- # 1: 'feats_length'
- # },
- # 'text_lengths': {
- # 0: 'batch_size',
- # },
- # 'logits': {
- # 0: 'batch_size',
- # 1: 'logits_length'
- # },
- # }
if isinstance(model.encoder, SANMEncoder):
self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
diff --git a/funasr/export/models/vad_realtime_transformer.py b/funasr/export/models/vad_realtime_transformer.py
index 093e71d..693b9c8 100644
--- a/funasr/export/models/vad_realtime_transformer.py
+++ b/funasr/export/models/vad_realtime_transformer.py
@@ -1,14 +1,9 @@
-from typing import Any
-from typing import List
from typing import Tuple
import torch
import torch.nn as nn
-from funasr.modules.embedding import SinusoidalPositionEncoder
-from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder
-from funasr.punctuation.abs_model import AbsPunctuation
-from funasr.punctuation.sanm_encoder import SANMVadEncoder
+from funasr.models.encoder.sanm_encoder import SANMVadEncoder
from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export
class VadRealtimeTransformer(nn.Module):
diff --git a/funasr/lm/espnet_model.py b/funasr/lm/espnet_model.py
index db11b67..a9b8130 100644
--- a/funasr/lm/espnet_model.py
+++ b/funasr/lm/espnet_model.py
@@ -12,7 +12,7 @@
from funasr.train.abs_espnet_model import AbsESPnetModel
-class ESPnetLanguageModel(AbsESPnetModel):
+class LanguageModel(AbsESPnetModel):
def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
assert check_argument_types()
super().__init__()
diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py
index 57890ef..2a3a353 100644
--- a/funasr/models/encoder/sanm_encoder.py
+++ b/funasr/models/encoder/sanm_encoder.py
@@ -10,7 +10,7 @@
from typeguard import check_argument_types
import numpy as np
from funasr.modules.nets_utils import make_pad_mask
-from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
+from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
from funasr.modules.embedding import SinusoidalPositionEncoder
from funasr.modules.layer_norm import LayerNorm
from funasr.modules.multi_layer_conv import Conv1dLinear
@@ -27,7 +27,7 @@
from funasr.modules.subsampling import check_short_utt
from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder
-
+from funasr.modules.mask import subsequent_mask, vad_mask
class EncoderLayerSANM(nn.Module):
def __init__(
@@ -958,3 +958,231 @@
var_dict_tf[name_tf].shape))
return var_dict_torch_update
+
+
+class SANMVadEncoder(AbsEncoder):
+ """
+ author: Speech Lab, Alibaba Group, China
+
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: Optional[str] = "conv2d",
+ pos_enc_class=SinusoidalPositionEncoder,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ positionwise_layer_type: str = "linear",
+ positionwise_conv_kernel_size: int = 1,
+ padding_idx: int = -1,
+ interctc_layer_idx: List[int] = [],
+ interctc_use_conditioning: bool = False,
+ kernel_size : int = 11,
+ sanm_shfit : int = 0,
+ selfattention_layer_type: str = "sanm",
+ ):
+ assert check_argument_types()
+ super().__init__()
+ self._output_size = output_size
+
+ if input_layer == "linear":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Linear(input_size, output_size),
+ torch.nn.LayerNorm(output_size),
+ torch.nn.Dropout(dropout_rate),
+ torch.nn.ReLU(),
+ pos_enc_class(output_size, positional_dropout_rate),
+ )
+ elif input_layer == "conv2d":
+ self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
+ elif input_layer == "conv2d2":
+ self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
+ elif input_layer == "conv2d6":
+ self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
+ elif input_layer == "conv2d8":
+ self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
+ elif input_layer == "embed":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
+ SinusoidalPositionEncoder(),
+ )
+ elif input_layer is None:
+ if input_size == output_size:
+ self.embed = None
+ else:
+ self.embed = torch.nn.Linear(input_size, output_size)
+ elif input_layer == "pe":
+ self.embed = SinusoidalPositionEncoder()
+ else:
+ raise ValueError("unknown input_layer: " + input_layer)
+ self.normalize_before = normalize_before
+ if positionwise_layer_type == "linear":
+ positionwise_layer = PositionwiseFeedForward
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ dropout_rate,
+ )
+ elif positionwise_layer_type == "conv1d":
+ positionwise_layer = MultiLayeredConv1d
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ positionwise_conv_kernel_size,
+ dropout_rate,
+ )
+ elif positionwise_layer_type == "conv1d-linear":
+ positionwise_layer = Conv1dLinear
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ positionwise_conv_kernel_size,
+ dropout_rate,
+ )
+ else:
+ raise NotImplementedError("Support only linear or conv1d.")
+
+ if selfattention_layer_type == "selfattn":
+ encoder_selfattn_layer = MultiHeadedAttention
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ )
+
+ elif selfattention_layer_type == "sanm":
+ self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
+ encoder_selfattn_layer_args0 = (
+ attention_heads,
+ input_size,
+ output_size,
+ attention_dropout_rate,
+ kernel_size,
+ sanm_shfit,
+ )
+
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ output_size,
+ output_size,
+ attention_dropout_rate,
+ kernel_size,
+ sanm_shfit,
+ )
+
+ self.encoders0 = repeat(
+ 1,
+ lambda lnum: EncoderLayerSANM(
+ input_size,
+ output_size,
+ self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
+ positionwise_layer(*positionwise_layer_args),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+
+ self.encoders = repeat(
+ num_blocks-1,
+ lambda lnum: EncoderLayerSANM(
+ output_size,
+ output_size,
+ self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
+ positionwise_layer(*positionwise_layer_args),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+ if self.normalize_before:
+ self.after_norm = LayerNorm(output_size)
+
+ self.interctc_layer_idx = interctc_layer_idx
+ if len(interctc_layer_idx) > 0:
+ assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
+ self.interctc_use_conditioning = interctc_use_conditioning
+ self.conditioning_layer = None
+ self.dropout = nn.Dropout(dropout_rate)
+
+ def output_size(self) -> int:
+ return self._output_size
+
+ def forward(
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ vad_indexes: torch.Tensor,
+ prev_states: torch.Tensor = None,
+ ctc: CTC = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ """Embed positions in tensor.
+
+ Args:
+ xs_pad: input tensor (B, L, D)
+ ilens: input length (B)
+ prev_states: Not to be used now.
+ Returns:
+ position embedded tensor and mask
+ """
+ masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+ sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
+ no_future_masks = masks & sub_masks
+ xs_pad *= self.output_size()**0.5
+ if self.embed is None:
+ xs_pad = xs_pad
+ elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
+ or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
+ short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+ if short_status:
+ raise TooShortUttError(
+ f"has {xs_pad.size(1)} frames and is too short for subsampling " +
+ f"(it needs more than {limit_size} frames), return empty results",
+ xs_pad.size(1),
+ limit_size,
+ )
+ xs_pad, masks = self.embed(xs_pad, masks)
+ else:
+ xs_pad = self.embed(xs_pad)
+
+ # xs_pad = self.dropout(xs_pad)
+ mask_tup0 = [masks, no_future_masks]
+ encoder_outs = self.encoders0(xs_pad, mask_tup0)
+ xs_pad, _ = encoder_outs[0], encoder_outs[1]
+ intermediate_outs = []
+
+
+ for layer_idx, encoder_layer in enumerate(self.encoders):
+ if layer_idx + 1 == len(self.encoders):
+ # This is last layer.
+ coner_mask = torch.ones(masks.size(0),
+ masks.size(-1),
+ masks.size(-1),
+ device=xs_pad.device,
+ dtype=torch.bool)
+ for word_index, length in enumerate(ilens):
+ coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
+ vad_indexes[word_index],
+ device=xs_pad.device)
+ layer_mask = masks & coner_mask
+ else:
+ layer_mask = no_future_masks
+ mask_tup1 = [masks, layer_mask]
+ encoder_outs = encoder_layer(xs_pad, mask_tup1)
+ xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
+
+ if self.normalize_before:
+ xs_pad = self.after_norm(xs_pad)
+
+ olens = masks.squeeze(1).sum(1)
+ if len(intermediate_outs) > 0:
+ return (xs_pad, intermediate_outs), olens, None
+ return xs_pad, olens, None
diff --git a/funasr/punctuation/target_delay_transformer.py b/funasr/models/target_delay_transformer.py
similarity index 97%
rename from funasr/punctuation/target_delay_transformer.py
rename to funasr/models/target_delay_transformer.py
index 219af26..a71952b 100644
--- a/funasr/punctuation/target_delay_transformer.py
+++ b/funasr/models/target_delay_transformer.py
@@ -5,12 +5,11 @@
import torch
import torch.nn as nn
-from funasr.modules.embedding import PositionalEncoding
from funasr.modules.embedding import SinusoidalPositionEncoder
#from funasr.models.encoder.transformer_encoder import TransformerEncoder as Encoder
from funasr.punctuation.sanm_encoder import SANMEncoder as Encoder
#from funasr.modules.mask import subsequent_n_mask
-from funasr.punctuation.abs_model import AbsPunctuation
+from funasr.train.abs_model import AbsPunctuation
class TargetDelayTransformer(AbsPunctuation):
diff --git a/funasr/punctuation/vad_realtime_transformer.py b/funasr/models/vad_realtime_transformer.py
similarity index 98%
rename from funasr/punctuation/vad_realtime_transformer.py
rename to funasr/models/vad_realtime_transformer.py
index 35224f9..2945572 100644
--- a/funasr/punctuation/vad_realtime_transformer.py
+++ b/funasr/models/vad_realtime_transformer.py
@@ -7,7 +7,7 @@
from funasr.modules.embedding import SinusoidalPositionEncoder
from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder
-from funasr.punctuation.abs_model import AbsPunctuation
+from funasr.train.abs_model import AbsPunctuation
class VadRealtimeTransformer(AbsPunctuation):
diff --git a/funasr/punctuation/abs_model.py b/funasr/punctuation/abs_model.py
deleted file mode 100644
index 404d5e8..0000000
--- a/funasr/punctuation/abs_model.py
+++ /dev/null
@@ -1,31 +0,0 @@
-from abc import ABC
-from abc import abstractmethod
-from typing import Tuple
-
-import torch
-
-from funasr.modules.scorers.scorer_interface import BatchScorerInterface
-
-
-class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
- """The abstract class
-
- To share the loss calculation way among different models,
- We uses delegate pattern here:
- The instance of this class should be passed to "LanguageModel"
-
- >>> from funasr.punctuation.abs_model import AbsPunctuation
- >>> punc = AbsPunctuation()
- >>> model = ESPnetPunctuationModel(punc=punc)
-
- This "model" is one of mediator objects for "Task" class.
-
- """
-
- @abstractmethod
- def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- raise NotImplementedError
-
- @abstractmethod
- def with_vad(self) -> bool:
- raise NotImplementedError
diff --git a/funasr/punctuation/sanm_encoder.py b/funasr/punctuation/sanm_encoder.py
deleted file mode 100644
index 8962093..0000000
--- a/funasr/punctuation/sanm_encoder.py
+++ /dev/null
@@ -1,590 +0,0 @@
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-import logging
-import torch
-import torch.nn as nn
-from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
-from typeguard import check_argument_types
-import numpy as np
-from funasr.modules.nets_utils import make_pad_mask
-from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
-from funasr.modules.embedding import SinusoidalPositionEncoder
-from funasr.modules.layer_norm import LayerNorm
-from funasr.modules.multi_layer_conv import Conv1dLinear
-from funasr.modules.multi_layer_conv import MultiLayeredConv1d
-from funasr.modules.positionwise_feed_forward import (
- PositionwiseFeedForward, # noqa: H301
-)
-from funasr.modules.repeat import repeat
-from funasr.modules.subsampling import Conv2dSubsampling
-from funasr.modules.subsampling import Conv2dSubsampling2
-from funasr.modules.subsampling import Conv2dSubsampling6
-from funasr.modules.subsampling import Conv2dSubsampling8
-from funasr.modules.subsampling import TooShortUttError
-from funasr.modules.subsampling import check_short_utt
-from funasr.models.ctc import CTC
-from funasr.models.encoder.abs_encoder import AbsEncoder
-
-from funasr.modules.nets_utils import make_pad_mask
-from funasr.modules.mask import subsequent_mask, vad_mask
-
-class EncoderLayerSANM(nn.Module):
- def __init__(
- self,
- in_size,
- size,
- self_attn,
- feed_forward,
- dropout_rate,
- normalize_before=True,
- concat_after=False,
- stochastic_depth_rate=0.0,
- ):
- """Construct an EncoderLayer object."""
- super(EncoderLayerSANM, self).__init__()
- self.self_attn = self_attn
- self.feed_forward = feed_forward
- self.norm1 = LayerNorm(in_size)
- self.norm2 = LayerNorm(size)
- self.dropout = nn.Dropout(dropout_rate)
- self.in_size = in_size
- self.size = size
- self.normalize_before = normalize_before
- self.concat_after = concat_after
- if self.concat_after:
- self.concat_linear = nn.Linear(size + size, size)
- self.stochastic_depth_rate = stochastic_depth_rate
- self.dropout_rate = dropout_rate
-
- def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
- """Compute encoded features.
-
- Args:
- x_input (torch.Tensor): Input tensor (#batch, time, size).
- mask (torch.Tensor): Mask tensor for the input (#batch, time).
- cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
-
- Returns:
- torch.Tensor: Output tensor (#batch, time, size).
- torch.Tensor: Mask tensor (#batch, time).
-
- """
- skip_layer = False
- # with stochastic depth, residual connection `x + f(x)` becomes
- # `x <- x + 1 / (1 - p) * f(x)` at training time.
- stoch_layer_coeff = 1.0
- if self.training and self.stochastic_depth_rate > 0:
- skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
- stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
-
- if skip_layer:
- if cache is not None:
- x = torch.cat([cache, x], dim=1)
- return x, mask
-
- residual = x
- if self.normalize_before:
- x = self.norm1(x)
-
- if self.concat_after:
- x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
- if self.in_size == self.size:
- x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
- else:
- x = stoch_layer_coeff * self.concat_linear(x_concat)
- else:
- if self.in_size == self.size:
- x = residual + stoch_layer_coeff * self.dropout(
- self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
- )
- else:
- x = stoch_layer_coeff * self.dropout(
- self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
- )
- if not self.normalize_before:
- x = self.norm1(x)
-
- residual = x
- if self.normalize_before:
- x = self.norm2(x)
- x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
- if not self.normalize_before:
- x = self.norm2(x)
-
-
- return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
-
-class SANMEncoder(AbsEncoder):
- """
- author: Speech Lab, Alibaba Group, China
-
- """
-
- def __init__(
- self,
- input_size: int,
- output_size: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.0,
- input_layer: Optional[str] = "conv2d",
- pos_enc_class=SinusoidalPositionEncoder,
- normalize_before: bool = True,
- concat_after: bool = False,
- positionwise_layer_type: str = "linear",
- positionwise_conv_kernel_size: int = 1,
- padding_idx: int = -1,
- interctc_layer_idx: List[int] = [],
- interctc_use_conditioning: bool = False,
- kernel_size : int = 11,
- sanm_shfit : int = 0,
- selfattention_layer_type: str = "sanm",
- ):
- assert check_argument_types()
- super().__init__()
- self._output_size = output_size
-
- if input_layer == "linear":
- self.embed = torch.nn.Sequential(
- torch.nn.Linear(input_size, output_size),
- torch.nn.LayerNorm(output_size),
- torch.nn.Dropout(dropout_rate),
- torch.nn.ReLU(),
- pos_enc_class(output_size, positional_dropout_rate),
- )
- elif input_layer == "conv2d":
- self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
- elif input_layer == "conv2d2":
- self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
- elif input_layer == "conv2d6":
- self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
- elif input_layer == "conv2d8":
- self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
- elif input_layer == "embed":
- self.embed = torch.nn.Sequential(
- torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
- SinusoidalPositionEncoder(),
- )
- elif input_layer is None:
- if input_size == output_size:
- self.embed = None
- else:
- self.embed = torch.nn.Linear(input_size, output_size)
- elif input_layer == "pe":
- self.embed = SinusoidalPositionEncoder()
- else:
- raise ValueError("unknown input_layer: " + input_layer)
- self.normalize_before = normalize_before
- if positionwise_layer_type == "linear":
- positionwise_layer = PositionwiseFeedForward
- positionwise_layer_args = (
- output_size,
- linear_units,
- dropout_rate,
- )
- elif positionwise_layer_type == "conv1d":
- positionwise_layer = MultiLayeredConv1d
- positionwise_layer_args = (
- output_size,
- linear_units,
- positionwise_conv_kernel_size,
- dropout_rate,
- )
- elif positionwise_layer_type == "conv1d-linear":
- positionwise_layer = Conv1dLinear
- positionwise_layer_args = (
- output_size,
- linear_units,
- positionwise_conv_kernel_size,
- dropout_rate,
- )
- else:
- raise NotImplementedError("Support only linear or conv1d.")
-
- if selfattention_layer_type == "selfattn":
- encoder_selfattn_layer = MultiHeadedAttention
- encoder_selfattn_layer_args = (
- attention_heads,
- output_size,
- attention_dropout_rate,
- )
-
- elif selfattention_layer_type == "sanm":
- self.encoder_selfattn_layer = MultiHeadedAttentionSANM
- encoder_selfattn_layer_args0 = (
- attention_heads,
- input_size,
- output_size,
- attention_dropout_rate,
- kernel_size,
- sanm_shfit,
- )
-
- encoder_selfattn_layer_args = (
- attention_heads,
- output_size,
- output_size,
- attention_dropout_rate,
- kernel_size,
- sanm_shfit,
- )
-
- self.encoders0 = repeat(
- 1,
- lambda lnum: EncoderLayerSANM(
- input_size,
- output_size,
- self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
- positionwise_layer(*positionwise_layer_args),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
-
- self.encoders = repeat(
- num_blocks-1,
- lambda lnum: EncoderLayerSANM(
- output_size,
- output_size,
- self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
- positionwise_layer(*positionwise_layer_args),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
- if self.normalize_before:
- self.after_norm = LayerNorm(output_size)
-
- self.interctc_layer_idx = interctc_layer_idx
- if len(interctc_layer_idx) > 0:
- assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
- self.interctc_use_conditioning = interctc_use_conditioning
- self.conditioning_layer = None
- self.dropout = nn.Dropout(dropout_rate)
-
- def output_size(self) -> int:
- return self._output_size
-
- def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- prev_states: torch.Tensor = None,
- ctc: CTC = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
- """Embed positions in tensor.
-
- Args:
- xs_pad: input tensor (B, L, D)
- ilens: input length (B)
- prev_states: Not to be used now.
- Returns:
- position embedded tensor and mask
- """
- masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
- xs_pad *= self.output_size()**0.5
- if self.embed is None:
- xs_pad = xs_pad
- elif (
- isinstance(self.embed, Conv2dSubsampling)
- or isinstance(self.embed, Conv2dSubsampling2)
- or isinstance(self.embed, Conv2dSubsampling6)
- or isinstance(self.embed, Conv2dSubsampling8)
- ):
- short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
- if short_status:
- raise TooShortUttError(
- f"has {xs_pad.size(1)} frames and is too short for subsampling "
- + f"(it needs more than {limit_size} frames), return empty results",
- xs_pad.size(1),
- limit_size,
- )
- xs_pad, masks = self.embed(xs_pad, masks)
- else:
- xs_pad = self.embed(xs_pad)
-
- # xs_pad = self.dropout(xs_pad)
- encoder_outs = self.encoders0(xs_pad, masks)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
- intermediate_outs = []
- if len(self.interctc_layer_idx) == 0:
- encoder_outs = self.encoders(xs_pad, masks)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
- else:
- for layer_idx, encoder_layer in enumerate(self.encoders):
- encoder_outs = encoder_layer(xs_pad, masks)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
-
- if layer_idx + 1 in self.interctc_layer_idx:
- encoder_out = xs_pad
-
- # intermediate outputs are also normalized
- if self.normalize_before:
- encoder_out = self.after_norm(encoder_out)
-
- intermediate_outs.append((layer_idx + 1, encoder_out))
-
- if self.interctc_use_conditioning:
- ctc_out = ctc.softmax(encoder_out)
- xs_pad = xs_pad + self.conditioning_layer(ctc_out)
-
- if self.normalize_before:
- xs_pad = self.after_norm(xs_pad)
-
- olens = masks.squeeze(1).sum(1)
- if len(intermediate_outs) > 0:
- return (xs_pad, intermediate_outs), olens, None
- return xs_pad, olens, None
-
-class SANMVadEncoder(AbsEncoder):
- """
- author: Speech Lab, Alibaba Group, China
-
- """
-
- def __init__(
- self,
- input_size: int,
- output_size: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.0,
- input_layer: Optional[str] = "conv2d",
- pos_enc_class=SinusoidalPositionEncoder,
- normalize_before: bool = True,
- concat_after: bool = False,
- positionwise_layer_type: str = "linear",
- positionwise_conv_kernel_size: int = 1,
- padding_idx: int = -1,
- interctc_layer_idx: List[int] = [],
- interctc_use_conditioning: bool = False,
- kernel_size : int = 11,
- sanm_shfit : int = 0,
- selfattention_layer_type: str = "sanm",
- ):
- assert check_argument_types()
- super().__init__()
- self._output_size = output_size
-
- if input_layer == "linear":
- self.embed = torch.nn.Sequential(
- torch.nn.Linear(input_size, output_size),
- torch.nn.LayerNorm(output_size),
- torch.nn.Dropout(dropout_rate),
- torch.nn.ReLU(),
- pos_enc_class(output_size, positional_dropout_rate),
- )
- elif input_layer == "conv2d":
- self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
- elif input_layer == "conv2d2":
- self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
- elif input_layer == "conv2d6":
- self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
- elif input_layer == "conv2d8":
- self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
- elif input_layer == "embed":
- self.embed = torch.nn.Sequential(
- torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
- SinusoidalPositionEncoder(),
- )
- elif input_layer is None:
- if input_size == output_size:
- self.embed = None
- else:
- self.embed = torch.nn.Linear(input_size, output_size)
- elif input_layer == "pe":
- self.embed = SinusoidalPositionEncoder()
- else:
- raise ValueError("unknown input_layer: " + input_layer)
- self.normalize_before = normalize_before
- if positionwise_layer_type == "linear":
- positionwise_layer = PositionwiseFeedForward
- positionwise_layer_args = (
- output_size,
- linear_units,
- dropout_rate,
- )
- elif positionwise_layer_type == "conv1d":
- positionwise_layer = MultiLayeredConv1d
- positionwise_layer_args = (
- output_size,
- linear_units,
- positionwise_conv_kernel_size,
- dropout_rate,
- )
- elif positionwise_layer_type == "conv1d-linear":
- positionwise_layer = Conv1dLinear
- positionwise_layer_args = (
- output_size,
- linear_units,
- positionwise_conv_kernel_size,
- dropout_rate,
- )
- else:
- raise NotImplementedError("Support only linear or conv1d.")
-
- if selfattention_layer_type == "selfattn":
- encoder_selfattn_layer = MultiHeadedAttention
- encoder_selfattn_layer_args = (
- attention_heads,
- output_size,
- attention_dropout_rate,
- )
-
- elif selfattention_layer_type == "sanm":
- self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
- encoder_selfattn_layer_args0 = (
- attention_heads,
- input_size,
- output_size,
- attention_dropout_rate,
- kernel_size,
- sanm_shfit,
- )
-
- encoder_selfattn_layer_args = (
- attention_heads,
- output_size,
- output_size,
- attention_dropout_rate,
- kernel_size,
- sanm_shfit,
- )
-
- self.encoders0 = repeat(
- 1,
- lambda lnum: EncoderLayerSANM(
- input_size,
- output_size,
- self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
- positionwise_layer(*positionwise_layer_args),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
-
- self.encoders = repeat(
- num_blocks-1,
- lambda lnum: EncoderLayerSANM(
- output_size,
- output_size,
- self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
- positionwise_layer(*positionwise_layer_args),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
- if self.normalize_before:
- self.after_norm = LayerNorm(output_size)
-
- self.interctc_layer_idx = interctc_layer_idx
- if len(interctc_layer_idx) > 0:
- assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
- self.interctc_use_conditioning = interctc_use_conditioning
- self.conditioning_layer = None
- self.dropout = nn.Dropout(dropout_rate)
-
- def output_size(self) -> int:
- return self._output_size
-
- def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- vad_indexes: torch.Tensor,
- prev_states: torch.Tensor = None,
- ctc: CTC = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
- """Embed positions in tensor.
-
- Args:
- xs_pad: input tensor (B, L, D)
- ilens: input length (B)
- prev_states: Not to be used now.
- Returns:
- position embedded tensor and mask
- """
- masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
- sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
- no_future_masks = masks & sub_masks
- xs_pad *= self.output_size()**0.5
- if self.embed is None:
- xs_pad = xs_pad
- elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
- or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
- short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
- if short_status:
- raise TooShortUttError(
- f"has {xs_pad.size(1)} frames and is too short for subsampling " +
- f"(it needs more than {limit_size} frames), return empty results",
- xs_pad.size(1),
- limit_size,
- )
- xs_pad, masks = self.embed(xs_pad, masks)
- else:
- xs_pad = self.embed(xs_pad)
-
- # xs_pad = self.dropout(xs_pad)
- mask_tup0 = [masks, no_future_masks]
- encoder_outs = self.encoders0(xs_pad, mask_tup0)
- xs_pad, _ = encoder_outs[0], encoder_outs[1]
- intermediate_outs = []
- #if len(self.interctc_layer_idx) == 0:
- if False:
- # Here, we should not use the repeat operation to do it for all layers.
- encoder_outs = self.encoders(xs_pad, masks)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
- else:
- for layer_idx, encoder_layer in enumerate(self.encoders):
- if layer_idx + 1 == len(self.encoders):
- # This is last layer.
- coner_mask = torch.ones(masks.size(0),
- masks.size(-1),
- masks.size(-1),
- device=xs_pad.device,
- dtype=torch.bool)
- for word_index, length in enumerate(ilens):
- coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
- vad_indexes[word_index],
- device=xs_pad.device)
- layer_mask = masks & coner_mask
- else:
- layer_mask = no_future_masks
- mask_tup1 = [masks, layer_mask]
- encoder_outs = encoder_layer(xs_pad, mask_tup1)
- xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
-
- if layer_idx + 1 in self.interctc_layer_idx:
- encoder_out = xs_pad
-
- # intermediate outputs are also normalized
- if self.normalize_before:
- encoder_out = self.after_norm(encoder_out)
-
- intermediate_outs.append((layer_idx + 1, encoder_out))
-
- if self.interctc_use_conditioning:
- ctc_out = ctc.softmax(encoder_out)
- xs_pad = xs_pad + self.conditioning_layer(ctc_out)
-
- if self.normalize_before:
- xs_pad = self.after_norm(xs_pad)
-
- olens = masks.squeeze(1).sum(1)
- if len(intermediate_outs) > 0:
- return (xs_pad, intermediate_outs), olens, None
- return xs_pad, olens, None
-
diff --git a/funasr/punctuation/text_preprocessor.py b/funasr/punctuation/text_preprocessor.py
index c9c4bac..8b13789 100644
--- a/funasr/punctuation/text_preprocessor.py
+++ b/funasr/punctuation/text_preprocessor.py
@@ -1,12 +1 @@
-def split_to_mini_sentence(words: list, word_limit: int = 20):
- assert word_limit > 1
- if len(words) <= word_limit:
- return [words]
- sentences = []
- length = len(words)
- sentence_len = length // word_limit
- for i in range(sentence_len):
- sentences.append(words[i * word_limit:(i + 1) * word_limit])
- if length % word_limit > 0:
- sentences.append(words[sentence_len * word_limit:])
- return sentences
+
diff --git a/funasr/tasks/lm.py b/funasr/tasks/lm.py
index 608c1d3..dc8fd3e 100644
--- a/funasr/tasks/lm.py
+++ b/funasr/tasks/lm.py
@@ -15,7 +15,7 @@
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import CommonPreprocessor
from funasr.lm.abs_model import AbsLM
-from funasr.lm.espnet_model import ESPnetLanguageModel
+from funasr.lm.espnet_model import LanguageModel
from funasr.lm.seq_rnn_lm import SequentialRNNLM
from funasr.lm.transformer_lm import TransformerLM
from funasr.tasks.abs_task import AbsTask
@@ -83,7 +83,7 @@
group.add_argument(
"--model_conf",
action=NestedDictAction,
- default=get_default_kwargs(ESPnetLanguageModel),
+ default=get_default_kwargs(LanguageModel),
help="The keyword arguments for model class.",
)
@@ -178,7 +178,7 @@
return retval
@classmethod
- def build_model(cls, args: argparse.Namespace) -> ESPnetLanguageModel:
+ def build_model(cls, args: argparse.Namespace) -> LanguageModel:
assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
@@ -201,7 +201,7 @@
# 2. Build ESPnetModel
# Assume the last-id is sos_and_eos
- model = ESPnetLanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf)
+ model = LanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf)
# 3. Initialize
if args.init is not None:
diff --git a/funasr/tasks/punctuation.py b/funasr/tasks/punctuation.py
index ea1e102..0170f28 100644
--- a/funasr/tasks/punctuation.py
+++ b/funasr/tasks/punctuation.py
@@ -14,10 +14,10 @@
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor
-from funasr.punctuation.abs_model import AbsPunctuation
-from funasr.punctuation.espnet_model import ESPnetPunctuationModel
-from funasr.punctuation.target_delay_transformer import TargetDelayTransformer
-from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer
+from funasr.train.abs_model import AbsPunctuation
+from funasr.train.abs_model import PunctuationModel
+from funasr.models.target_delay_transformer import TargetDelayTransformer
+from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
@@ -79,7 +79,7 @@
group.add_argument(
"--model_conf",
action=NestedDictAction,
- default=get_default_kwargs(ESPnetPunctuationModel),
+ default=get_default_kwargs(PunctuationModel),
help="The keyword arguments for model class.",
)
@@ -183,7 +183,7 @@
return retval
@classmethod
- def build_model(cls, args: argparse.Namespace) -> ESPnetPunctuationModel:
+ def build_model(cls, args: argparse.Namespace) -> PunctuationModel:
assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
@@ -218,7 +218,7 @@
# Assume the last-id is sos_and_eos
if "punc_weight" in args.model_conf:
args.model_conf.pop("punc_weight")
- model = ESPnetPunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
+ model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
# FIXME(kamo): Should be done in model?
# 3. Initialize
diff --git a/funasr/punctuation/espnet_model.py b/funasr/train/abs_model.py
similarity index 85%
rename from funasr/punctuation/espnet_model.py
rename to funasr/train/abs_model.py
index 7266b38..8bfba45 100644
--- a/funasr/punctuation/espnet_model.py
+++ b/funasr/train/abs_model.py
@@ -1,3 +1,9 @@
+from abc import ABC
+from abc import abstractmethod
+from typing import Tuple
+
+import torch
+
from typing import Dict
from typing import Optional
from typing import Tuple
@@ -7,13 +13,34 @@
from typeguard import check_argument_types
from funasr.modules.nets_utils import make_pad_mask
-from funasr.punctuation.abs_model import AbsPunctuation
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.modules.scorers.scorer_interface import BatchScorerInterface
-class ESPnetPunctuationModel(AbsESPnetModel):
+class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
+ """The abstract class
+
+ To share the loss calculation way among different models,
+ We uses delegate pattern here:
+ The instance of this class should be passed to "LanguageModel"
+
+ This "model" is one of mediator objects for "Task" class.
+
+ """
+
+ @abstractmethod
+ def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ raise NotImplementedError
+
+ @abstractmethod
+ def with_vad(self) -> bool:
+ raise NotImplementedError
+
+
+class PunctuationModel(AbsESPnetModel):
+
def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None):
assert check_argument_types()
super().__init__()
@@ -21,12 +48,12 @@
self.punc_weight = torch.Tensor(punc_weight)
self.sos = 1
self.eos = 2
-
+
# ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
self.ignore_id = ignore_id
- #if self.punc_model.with_vad():
+ # if self.punc_model.with_vad():
# print("This is a vad puncuation model.")
-
+
def nll(
self,
text: torch.Tensor,
@@ -54,7 +81,7 @@
else:
text = text[:, :max_length]
punc = punc[:, :max_length]
-
+
if self.punc_model.with_vad():
# Should be VadRealtimeTransformer
assert vad_indexes is not None
@@ -62,7 +89,7 @@
else:
# Should be TargetDelayTransformer,
y, _ = self.punc_model(text, text_lengths)
-
+
# Calc negative log likelihood
# nll: (BxL,)
if self.training == False:
@@ -75,7 +102,8 @@
return nll, text_lengths
else:
self.punc_weight = self.punc_weight.to(punc.device)
- nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none", ignore_index=self.ignore_id)
+ nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
+ ignore_index=self.ignore_id)
# nll: (BxL,) -> (BxL,)
if max_length is None:
nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
@@ -87,7 +115,7 @@
# nll: (BxL,) -> (B, L)
nll = nll.view(batch_size, -1)
return nll, text_lengths
-
+
def batchify_nll(self,
text: torch.Tensor,
punc: torch.Tensor,
@@ -113,7 +141,7 @@
nlls = []
x_lengths = []
max_length = text_lengths.max()
-
+
start_idx = 0
while True:
end_idx = min(start_idx + batch_size, total_num)
@@ -132,7 +160,7 @@
assert nll.size(0) == total_num
assert x_lengths.size(0) == total_num
return nll, x_lengths
-
+
def forward(
self,
text: torch.Tensor,
@@ -146,15 +174,15 @@
ntokens = y_lengths.sum()
loss = nll.sum() / ntokens
stats = dict(loss=loss.detach())
-
+
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
return loss, stats, weight
-
+
def collect_feats(self, text: torch.Tensor, punc: torch.Tensor,
text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]:
return {}
-
+
def inference(self,
text: torch.Tensor,
text_lengths: torch.Tensor,
--
Gitblit v1.9.1