游雁
2023-03-31 d0cd484fdc21c06b8bc892bb2ab1c2a25fb1da8a
export
11个文件已修改
3 文件已重命名
2个文件已删除
1071 ■■■■ 已修改文件
funasr/bin/punctuation_infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/punctuation_infer_vadrealtime.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/preprocessor.py 14 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/__init__.py 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/target_delay_transformer.py 87 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/vad_realtime_transformer.py 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/lm/espnet_model.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/sanm_encoder.py 232 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/target_delay_transformer.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/vad_realtime_transformer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/punctuation/abs_model.py 31 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/punctuation/sanm_encoder.py 590 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/punctuation/text_preprocessor.py 13 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/lm.py 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/punctuation.py 14 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train/abs_model.py 56 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/punctuation_infer.py
@@ -23,7 +23,7 @@
from funasr.utils import config_argparse
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.punctuation.text_preprocessor import split_to_mini_sentence
from funasr.datasets.preprocessor import split_to_mini_sentence
class Text2Punc:
funasr/bin/punctuation_infer_vadrealtime.py
@@ -23,7 +23,7 @@
from funasr.utils import config_argparse
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.punctuation.text_preprocessor import split_to_mini_sentence
from funasr.datasets.preprocessor import split_to_mini_sentence
class Text2Punc:
funasr/datasets/preprocessor.py
@@ -800,3 +800,17 @@
                    data[self.vad_name] = np.array([vad], dtype=np.int64)
                text_ints = self.token_id_converter[i].tokens2ids(tokens)
                data[text_name] = np.array(text_ints, dtype=np.int64)
def split_to_mini_sentence(words: list, word_limit: int = 20):
    assert word_limit > 1
    if len(words) <= word_limit:
        return [words]
    sentences = []
    length = len(words)
    sentence_len = length // word_limit
    for i in range(sentence_len):
        sentences.append(words[i * word_limit:(i + 1) * word_limit])
    if length % word_limit > 0:
        sentences.append(words[sentence_len * word_limit:])
    return sentences
funasr/export/models/__init__.py
@@ -3,10 +3,10 @@
from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
from funasr.models.e2e_vad import E2EVadModel
from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
from funasr.punctuation.target_delay_transformer import TargetDelayTransformer
from funasr.models.target_delay_transformer import TargetDelayTransformer
from funasr.export.models.target_delay_transformer import TargetDelayTransformer as TargetDelayTransformer_export
from funasr.punctuation.espnet_model import ESPnetPunctuationModel
from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer
from funasr.train.abs_model import PunctuationModel
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
from funasr.export.models.vad_realtime_transformer import VadRealtimeTransformer as VadRealtimeTransformer_export
def get_model(model, export_config=None):
@@ -16,7 +16,7 @@
        return Paraformer_export(model, **export_config)
    elif isinstance(model, E2EVadModel):
        return E2EVadModel_export(model, **export_config)
    elif isinstance(model, ESPnetPunctuationModel):
    elif isinstance(model, PunctuationModel):
        if isinstance(model.punc_model, TargetDelayTransformer):
            return TargetDelayTransformer_export(model.punc_model, **export_config)
        elif isinstance(model.punc_model, VadRealtimeTransformer):
funasr/export/models/target_delay_transformer.py
@@ -1,17 +1,7 @@
from typing import Any
from typing import List
from typing import Tuple
import torch
import torch.nn as nn
from funasr.export.utils.torch_function import MakePadMask
from funasr.export.utils.torch_function import sequence_mask
#from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
from funasr.punctuation.sanm_encoder import SANMEncoder
from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
from funasr.punctuation.abs_model import AbsPunctuation
class TargetDelayTransformer(nn.Module):
@@ -32,85 +22,10 @@
        self.feats_dim = self.embed.embedding_dim
        self.num_embeddings = self.embed.num_embeddings
        self.model_name = model_name
        from typing import Any
        from typing import List
        from typing import Tuple
        import torch
        import torch.nn as nn
        from funasr.export.utils.torch_function import MakePadMask
        from funasr.export.utils.torch_function import sequence_mask
        # from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
        from funasr.punctuation.sanm_encoder import SANMEncoder
        from funasr.models.encoder.sanm_encoder import SANMEncoder
        from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
        from funasr.punctuation.abs_model import AbsPunctuation
        # class TargetDelayTransformer(nn.Module):
        #
        #     def __init__(
        #             self,
        #             model,
        #             max_seq_len=512,
        #             model_name='punc_model',
        #             **kwargs,
        #     ):
        #         super().__init__()
        #         onnx = False
        #         if "onnx" in kwargs:
        #             onnx = kwargs["onnx"]
        #         self.embed = model.embed
        #         self.decoder = model.decoder
        #         self.model = model
        #         self.feats_dim = self.embed.embedding_dim
        #         self.num_embeddings = self.embed.num_embeddings
        #         self.model_name = model_name
        #
        #         if isinstance(model.encoder, SANMEncoder):
        #             self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
        #         else:
        #             assert False, "Only support samn encode."
        #
        #     def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
        #         """Compute loss value from buffer sequences.
        #
        #         Args:
        #             input (torch.Tensor): Input ids. (batch, len)
        #             hidden (torch.Tensor): Target ids. (batch, len)
        #
        #         """
        #         x = self.embed(input)
        #         # mask = self._target_mask(input)
        #         h, _ = self.encoder(x, text_lengths)
        #         y = self.decoder(h)
        #         return y
        #
        #     def get_dummy_inputs(self):
        #         length = 120
        #         text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
        #         text_lengths = torch.tensor([length - 20, length], dtype=torch.int32)
        #         return (text_indexes, text_lengths)
        #
        #     def get_input_names(self):
        #         return ['input', 'text_lengths']
        #
        #     def get_output_names(self):
        #         return ['logits']
        #
        #     def get_dynamic_axes(self):
        #         return {
        #             'input': {
        #                 0: 'batch_size',
        #                 1: 'feats_length'
        #             },
        #             'text_lengths': {
        #                 0: 'batch_size',
        #             },
        #             'logits': {
        #                 0: 'batch_size',
        #                 1: 'logits_length'
        #             },
        #         }
        if isinstance(model.encoder, SANMEncoder):
            self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
funasr/export/models/vad_realtime_transformer.py
@@ -1,14 +1,9 @@
from typing import Any
from typing import List
from typing import Tuple
import torch
import torch.nn as nn
from funasr.modules.embedding import SinusoidalPositionEncoder
from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder
from funasr.punctuation.abs_model import AbsPunctuation
from funasr.punctuation.sanm_encoder import SANMVadEncoder
from funasr.models.encoder.sanm_encoder import SANMVadEncoder
from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export
class VadRealtimeTransformer(nn.Module):
funasr/lm/espnet_model.py
@@ -12,7 +12,7 @@
from funasr.train.abs_espnet_model import AbsESPnetModel
class ESPnetLanguageModel(AbsESPnetModel):
class LanguageModel(AbsESPnetModel):
    def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
        assert check_argument_types()
        super().__init__()
funasr/models/encoder/sanm_encoder.py
@@ -10,7 +10,7 @@
from typeguard import check_argument_types
import numpy as np
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
from funasr.modules.embedding import SinusoidalPositionEncoder
from funasr.modules.layer_norm import LayerNorm
from funasr.modules.multi_layer_conv import Conv1dLinear
@@ -27,7 +27,7 @@
from funasr.modules.subsampling import check_short_utt
from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.mask import subsequent_mask, vad_mask
class EncoderLayerSANM(nn.Module):
    def __init__(
@@ -958,3 +958,231 @@
                                                                                      var_dict_tf[name_tf].shape))
    
        return var_dict_torch_update
class SANMVadEncoder(AbsEncoder):
    """
    author: Speech Lab, Alibaba Group, China
    """
    def __init__(
        self,
        input_size: int,
        output_size: int = 256,
        attention_heads: int = 4,
        linear_units: int = 2048,
        num_blocks: int = 6,
        dropout_rate: float = 0.1,
        positional_dropout_rate: float = 0.1,
        attention_dropout_rate: float = 0.0,
        input_layer: Optional[str] = "conv2d",
        pos_enc_class=SinusoidalPositionEncoder,
        normalize_before: bool = True,
        concat_after: bool = False,
        positionwise_layer_type: str = "linear",
        positionwise_conv_kernel_size: int = 1,
        padding_idx: int = -1,
        interctc_layer_idx: List[int] = [],
        interctc_use_conditioning: bool = False,
        kernel_size : int = 11,
        sanm_shfit : int = 0,
        selfattention_layer_type: str = "sanm",
    ):
        assert check_argument_types()
        super().__init__()
        self._output_size = output_size
        if input_layer == "linear":
            self.embed = torch.nn.Sequential(
                torch.nn.Linear(input_size, output_size),
                torch.nn.LayerNorm(output_size),
                torch.nn.Dropout(dropout_rate),
                torch.nn.ReLU(),
                pos_enc_class(output_size, positional_dropout_rate),
            )
        elif input_layer == "conv2d":
            self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
        elif input_layer == "conv2d2":
            self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
        elif input_layer == "conv2d6":
            self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
        elif input_layer == "conv2d8":
            self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
        elif input_layer == "embed":
            self.embed = torch.nn.Sequential(
                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
                SinusoidalPositionEncoder(),
            )
        elif input_layer is None:
            if input_size == output_size:
                self.embed = None
            else:
                self.embed = torch.nn.Linear(input_size, output_size)
        elif input_layer == "pe":
            self.embed = SinusoidalPositionEncoder()
        else:
            raise ValueError("unknown input_layer: " + input_layer)
        self.normalize_before = normalize_before
        if positionwise_layer_type == "linear":
            positionwise_layer = PositionwiseFeedForward
            positionwise_layer_args = (
                output_size,
                linear_units,
                dropout_rate,
            )
        elif positionwise_layer_type == "conv1d":
            positionwise_layer = MultiLayeredConv1d
            positionwise_layer_args = (
                output_size,
                linear_units,
                positionwise_conv_kernel_size,
                dropout_rate,
            )
        elif positionwise_layer_type == "conv1d-linear":
            positionwise_layer = Conv1dLinear
            positionwise_layer_args = (
                output_size,
                linear_units,
                positionwise_conv_kernel_size,
                dropout_rate,
            )
        else:
            raise NotImplementedError("Support only linear or conv1d.")
        if selfattention_layer_type == "selfattn":
            encoder_selfattn_layer = MultiHeadedAttention
            encoder_selfattn_layer_args = (
                attention_heads,
                output_size,
                attention_dropout_rate,
            )
        elif selfattention_layer_type == "sanm":
            self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
            encoder_selfattn_layer_args0 = (
                attention_heads,
                input_size,
                output_size,
                attention_dropout_rate,
                kernel_size,
                sanm_shfit,
            )
            encoder_selfattn_layer_args = (
                attention_heads,
                output_size,
                output_size,
                attention_dropout_rate,
                kernel_size,
                sanm_shfit,
            )
        self.encoders0 = repeat(
            1,
            lambda lnum: EncoderLayerSANM(
                input_size,
                output_size,
                self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
                positionwise_layer(*positionwise_layer_args),
                dropout_rate,
                normalize_before,
                concat_after,
            ),
        )
        self.encoders = repeat(
            num_blocks-1,
            lambda lnum: EncoderLayerSANM(
                output_size,
                output_size,
                self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
                positionwise_layer(*positionwise_layer_args),
                dropout_rate,
                normalize_before,
                concat_after,
            ),
        )
        if self.normalize_before:
            self.after_norm = LayerNorm(output_size)
        self.interctc_layer_idx = interctc_layer_idx
        if len(interctc_layer_idx) > 0:
            assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
        self.interctc_use_conditioning = interctc_use_conditioning
        self.conditioning_layer = None
        self.dropout = nn.Dropout(dropout_rate)
    def output_size(self) -> int:
        return self._output_size
    def forward(
        self,
        xs_pad: torch.Tensor,
        ilens: torch.Tensor,
        vad_indexes: torch.Tensor,
        prev_states: torch.Tensor = None,
        ctc: CTC = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        """Embed positions in tensor.
        Args:
            xs_pad: input tensor (B, L, D)
            ilens: input length (B)
            prev_states: Not to be used now.
        Returns:
            position embedded tensor and mask
        """
        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
        sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
        no_future_masks = masks & sub_masks
        xs_pad *= self.output_size()**0.5
        if self.embed is None:
            xs_pad = xs_pad
        elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
              or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
            if short_status:
                raise TooShortUttError(
                    f"has {xs_pad.size(1)} frames and is too short for subsampling " +
                    f"(it needs more than {limit_size} frames), return empty results",
                    xs_pad.size(1),
                    limit_size,
                )
            xs_pad, masks = self.embed(xs_pad, masks)
        else:
            xs_pad = self.embed(xs_pad)
        # xs_pad = self.dropout(xs_pad)
        mask_tup0 = [masks, no_future_masks]
        encoder_outs = self.encoders0(xs_pad, mask_tup0)
        xs_pad, _ = encoder_outs[0], encoder_outs[1]
        intermediate_outs = []
        for layer_idx, encoder_layer in enumerate(self.encoders):
                if layer_idx + 1 == len(self.encoders):
                    # This is last layer.
                    coner_mask = torch.ones(masks.size(0),
                                            masks.size(-1),
                                            masks.size(-1),
                                            device=xs_pad.device,
                                            dtype=torch.bool)
                    for word_index, length in enumerate(ilens):
                        coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
                                                                vad_indexes[word_index],
                                                                device=xs_pad.device)
                    layer_mask = masks & coner_mask
                else:
                    layer_mask = no_future_masks
                mask_tup1 = [masks, layer_mask]
                encoder_outs = encoder_layer(xs_pad, mask_tup1)
                xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
        if self.normalize_before:
            xs_pad = self.after_norm(xs_pad)
        olens = masks.squeeze(1).sum(1)
        if len(intermediate_outs) > 0:
            return (xs_pad, intermediate_outs), olens, None
        return xs_pad, olens, None
funasr/models/target_delay_transformer.py
File was renamed from funasr/punctuation/target_delay_transformer.py
@@ -5,12 +5,11 @@
import torch
import torch.nn as nn
from funasr.modules.embedding import PositionalEncoding
from funasr.modules.embedding import SinusoidalPositionEncoder
#from funasr.models.encoder.transformer_encoder import TransformerEncoder as Encoder
from funasr.punctuation.sanm_encoder import SANMEncoder as Encoder
#from funasr.modules.mask import subsequent_n_mask
from funasr.punctuation.abs_model import AbsPunctuation
from funasr.train.abs_model import AbsPunctuation
class TargetDelayTransformer(AbsPunctuation):
funasr/models/vad_realtime_transformer.py
File was renamed from funasr/punctuation/vad_realtime_transformer.py
@@ -7,7 +7,7 @@
from funasr.modules.embedding import SinusoidalPositionEncoder
from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder
from funasr.punctuation.abs_model import AbsPunctuation
from funasr.train.abs_model import AbsPunctuation
class VadRealtimeTransformer(AbsPunctuation):
funasr/punctuation/abs_model.py
File was deleted
funasr/punctuation/sanm_encoder.py
File was deleted
funasr/punctuation/text_preprocessor.py
@@ -1,12 +1 @@
def split_to_mini_sentence(words: list, word_limit: int = 20):
    assert word_limit > 1
    if len(words) <= word_limit:
        return [words]
    sentences = []
    length = len(words)
    sentence_len = length // word_limit
    for i in range(sentence_len):
        sentences.append(words[i * word_limit:(i + 1) * word_limit])
    if length % word_limit > 0:
        sentences.append(words[sentence_len * word_limit:])
    return sentences
funasr/tasks/lm.py
@@ -15,7 +15,7 @@
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import CommonPreprocessor
from funasr.lm.abs_model import AbsLM
from funasr.lm.espnet_model import ESPnetLanguageModel
from funasr.lm.espnet_model import LanguageModel
from funasr.lm.seq_rnn_lm import SequentialRNNLM
from funasr.lm.transformer_lm import TransformerLM
from funasr.tasks.abs_task import AbsTask
@@ -83,7 +83,7 @@
        group.add_argument(
            "--model_conf",
            action=NestedDictAction,
            default=get_default_kwargs(ESPnetLanguageModel),
            default=get_default_kwargs(LanguageModel),
            help="The keyword arguments for model class.",
        )
@@ -178,7 +178,7 @@
        return retval
    @classmethod
    def build_model(cls, args: argparse.Namespace) -> ESPnetLanguageModel:
    def build_model(cls, args: argparse.Namespace) -> LanguageModel:
        assert check_argument_types()
        if isinstance(args.token_list, str):
            with open(args.token_list, encoding="utf-8") as f:
@@ -201,7 +201,7 @@
        # 2. Build ESPnetModel
        # Assume the last-id is sos_and_eos
        model = ESPnetLanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf)
        model = LanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf)
        # 3. Initialize
        if args.init is not None:
funasr/tasks/punctuation.py
@@ -14,10 +14,10 @@
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor
from funasr.punctuation.abs_model import AbsPunctuation
from funasr.punctuation.espnet_model import ESPnetPunctuationModel
from funasr.punctuation.target_delay_transformer import TargetDelayTransformer
from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer
from funasr.train.abs_model import AbsPunctuation
from funasr.train.abs_model import PunctuationModel
from funasr.models.target_delay_transformer import TargetDelayTransformer
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
@@ -79,7 +79,7 @@
        group.add_argument(
            "--model_conf",
            action=NestedDictAction,
            default=get_default_kwargs(ESPnetPunctuationModel),
            default=get_default_kwargs(PunctuationModel),
            help="The keyword arguments for model class.",
        )
@@ -183,7 +183,7 @@
        return retval
    @classmethod
    def build_model(cls, args: argparse.Namespace) -> ESPnetPunctuationModel:
    def build_model(cls, args: argparse.Namespace) -> PunctuationModel:
        assert check_argument_types()
        if isinstance(args.token_list, str):
            with open(args.token_list, encoding="utf-8") as f:
@@ -218,7 +218,7 @@
        # Assume the last-id is sos_and_eos
        if "punc_weight" in args.model_conf:
            args.model_conf.pop("punc_weight")
        model = ESPnetPunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
        model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
        # FIXME(kamo): Should be done in model?
        # 3. Initialize
funasr/train/abs_model.py
File was renamed from funasr/punctuation/espnet_model.py
@@ -1,3 +1,9 @@
from abc import ABC
from abc import abstractmethod
from typing import Tuple
import torch
from typing import Dict
from typing import Optional
from typing import Tuple
@@ -7,13 +13,34 @@
from typeguard import check_argument_types
from funasr.modules.nets_utils import make_pad_mask
from funasr.punctuation.abs_model import AbsPunctuation
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
class ESPnetPunctuationModel(AbsESPnetModel):
class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
    """The abstract class
    To share the loss calculation way among different models,
    We uses delegate pattern here:
    The instance of this class should be passed to "LanguageModel"
    This "model" is one of mediator objects for "Task" class.
    """
    @abstractmethod
    def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        raise NotImplementedError
    @abstractmethod
    def with_vad(self) -> bool:
        raise NotImplementedError
class PunctuationModel(AbsESPnetModel):
    def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None):
        assert check_argument_types()
        super().__init__()
@@ -21,12 +48,12 @@
        self.punc_weight = torch.Tensor(punc_weight)
        self.sos = 1
        self.eos = 2
        # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
        self.ignore_id = ignore_id
        #if self.punc_model.with_vad():
        # if self.punc_model.with_vad():
        #    print("This is a vad puncuation model.")
    def nll(
        self,
        text: torch.Tensor,
@@ -54,7 +81,7 @@
        else:
            text = text[:, :max_length]
            punc = punc[:, :max_length]
        if self.punc_model.with_vad():
            # Should be VadRealtimeTransformer
            assert vad_indexes is not None
@@ -62,7 +89,7 @@
        else:
            # Should be TargetDelayTransformer,
            y, _ = self.punc_model(text, text_lengths)
        # Calc negative log likelihood
        # nll: (BxL,)
        if self.training == False:
@@ -75,7 +102,8 @@
            return nll, text_lengths
        else:
            self.punc_weight = self.punc_weight.to(punc.device)
            nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none", ignore_index=self.ignore_id)
            nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
                                  ignore_index=self.ignore_id)
        # nll: (BxL,) -> (BxL,)
        if max_length is None:
            nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
@@ -87,7 +115,7 @@
        # nll: (BxL,) -> (B, L)
        nll = nll.view(batch_size, -1)
        return nll, text_lengths
    def batchify_nll(self,
                     text: torch.Tensor,
                     punc: torch.Tensor,
@@ -113,7 +141,7 @@
            nlls = []
            x_lengths = []
            max_length = text_lengths.max()
            start_idx = 0
            while True:
                end_idx = min(start_idx + batch_size, total_num)
@@ -132,7 +160,7 @@
        assert nll.size(0) == total_num
        assert x_lengths.size(0) == total_num
        return nll, x_lengths
    def forward(
        self,
        text: torch.Tensor,
@@ -146,15 +174,15 @@
        ntokens = y_lengths.sum()
        loss = nll.sum() / ntokens
        stats = dict(loss=loss.detach())
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
        return loss, stats, weight
    def collect_feats(self, text: torch.Tensor, punc: torch.Tensor,
                      text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]:
        return {}
    def inference(self,
                  text: torch.Tensor,
                  text_lengths: torch.Tensor,