11个文件已修改
3 文件已重命名
2个文件已删除
| | |
| | | from funasr.utils import config_argparse |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.punctuation.text_preprocessor import split_to_mini_sentence |
| | | from funasr.datasets.preprocessor import split_to_mini_sentence |
| | | |
| | | |
| | | class Text2Punc: |
| | |
| | | from funasr.utils import config_argparse |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.punctuation.text_preprocessor import split_to_mini_sentence |
| | | from funasr.datasets.preprocessor import split_to_mini_sentence |
| | | |
| | | |
| | | class Text2Punc: |
| | |
| | | data[self.vad_name] = np.array([vad], dtype=np.int64) |
| | | text_ints = self.token_id_converter[i].tokens2ids(tokens) |
| | | data[text_name] = np.array(text_ints, dtype=np.int64) |
| | | |
| | | |
| | | def split_to_mini_sentence(words: list, word_limit: int = 20): |
| | | assert word_limit > 1 |
| | | if len(words) <= word_limit: |
| | | return [words] |
| | | sentences = [] |
| | | length = len(words) |
| | | sentence_len = length // word_limit |
| | | for i in range(sentence_len): |
| | | sentences.append(words[i * word_limit:(i + 1) * word_limit]) |
| | | if length % word_limit > 0: |
| | | sentences.append(words[sentence_len * word_limit:]) |
| | | return sentences |
| | |
| | | from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export |
| | | from funasr.models.e2e_vad import E2EVadModel |
| | | from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export |
| | | from funasr.punctuation.target_delay_transformer import TargetDelayTransformer |
| | | from funasr.models.target_delay_transformer import TargetDelayTransformer |
| | | from funasr.export.models.target_delay_transformer import TargetDelayTransformer as TargetDelayTransformer_export |
| | | from funasr.punctuation.espnet_model import ESPnetPunctuationModel |
| | | from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer |
| | | from funasr.train.abs_model import PunctuationModel |
| | | from funasr.models.vad_realtime_transformer import VadRealtimeTransformer |
| | | from funasr.export.models.vad_realtime_transformer import VadRealtimeTransformer as VadRealtimeTransformer_export |
| | | |
| | | def get_model(model, export_config=None): |
| | |
| | | return Paraformer_export(model, **export_config) |
| | | elif isinstance(model, E2EVadModel): |
| | | return E2EVadModel_export(model, **export_config) |
| | | elif isinstance(model, ESPnetPunctuationModel): |
| | | elif isinstance(model, PunctuationModel): |
| | | if isinstance(model.punc_model, TargetDelayTransformer): |
| | | return TargetDelayTransformer_export(model.punc_model, **export_config) |
| | | elif isinstance(model.punc_model, VadRealtimeTransformer): |
| | |
| | | from typing import Any |
| | | from typing import List |
| | | from typing import Tuple |
| | | |
| | | import torch |
| | | import torch.nn as nn |
| | | |
| | | from funasr.export.utils.torch_function import MakePadMask |
| | | from funasr.export.utils.torch_function import sequence_mask |
| | | #from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder |
| | | from funasr.punctuation.sanm_encoder import SANMEncoder |
| | | from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export |
| | | from funasr.punctuation.abs_model import AbsPunctuation |
| | | |
| | | |
| | | class TargetDelayTransformer(nn.Module): |
| | | |
| | |
| | | self.feats_dim = self.embed.embedding_dim |
| | | self.num_embeddings = self.embed.num_embeddings |
| | | self.model_name = model_name |
| | | from typing import Any |
| | | from typing import List |
| | | from typing import Tuple |
| | | |
| | | import torch |
| | | import torch.nn as nn |
| | | |
| | | from funasr.export.utils.torch_function import MakePadMask |
| | | from funasr.export.utils.torch_function import sequence_mask |
| | | # from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder |
| | | from funasr.punctuation.sanm_encoder import SANMEncoder |
| | | from funasr.models.encoder.sanm_encoder import SANMEncoder |
| | | from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export |
| | | from funasr.punctuation.abs_model import AbsPunctuation |
| | | |
| | | # class TargetDelayTransformer(nn.Module): |
| | | # |
| | | # def __init__( |
| | | # self, |
| | | # model, |
| | | # max_seq_len=512, |
| | | # model_name='punc_model', |
| | | # **kwargs, |
| | | # ): |
| | | # super().__init__() |
| | | # onnx = False |
| | | # if "onnx" in kwargs: |
| | | # onnx = kwargs["onnx"] |
| | | # self.embed = model.embed |
| | | # self.decoder = model.decoder |
| | | # self.model = model |
| | | # self.feats_dim = self.embed.embedding_dim |
| | | # self.num_embeddings = self.embed.num_embeddings |
| | | # self.model_name = model_name |
| | | # |
| | | # if isinstance(model.encoder, SANMEncoder): |
| | | # self.encoder = SANMEncoder_export(model.encoder, onnx=onnx) |
| | | # else: |
| | | # assert False, "Only support samn encode." |
| | | # |
| | | # def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]: |
| | | # """Compute loss value from buffer sequences. |
| | | # |
| | | # Args: |
| | | # input (torch.Tensor): Input ids. (batch, len) |
| | | # hidden (torch.Tensor): Target ids. (batch, len) |
| | | # |
| | | # """ |
| | | # x = self.embed(input) |
| | | # # mask = self._target_mask(input) |
| | | # h, _ = self.encoder(x, text_lengths) |
| | | # y = self.decoder(h) |
| | | # return y |
| | | # |
| | | # def get_dummy_inputs(self): |
| | | # length = 120 |
| | | # text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)) |
| | | # text_lengths = torch.tensor([length - 20, length], dtype=torch.int32) |
| | | # return (text_indexes, text_lengths) |
| | | # |
| | | # def get_input_names(self): |
| | | # return ['input', 'text_lengths'] |
| | | # |
| | | # def get_output_names(self): |
| | | # return ['logits'] |
| | | # |
| | | # def get_dynamic_axes(self): |
| | | # return { |
| | | # 'input': { |
| | | # 0: 'batch_size', |
| | | # 1: 'feats_length' |
| | | # }, |
| | | # 'text_lengths': { |
| | | # 0: 'batch_size', |
| | | # }, |
| | | # 'logits': { |
| | | # 0: 'batch_size', |
| | | # 1: 'logits_length' |
| | | # }, |
| | | # } |
| | | |
| | | if isinstance(model.encoder, SANMEncoder): |
| | | self.encoder = SANMEncoder_export(model.encoder, onnx=onnx) |
| | |
| | | from typing import Any |
| | | from typing import List |
| | | from typing import Tuple |
| | | |
| | | import torch |
| | | import torch.nn as nn |
| | | |
| | | from funasr.modules.embedding import SinusoidalPositionEncoder |
| | | from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder |
| | | from funasr.punctuation.abs_model import AbsPunctuation |
| | | from funasr.punctuation.sanm_encoder import SANMVadEncoder |
| | | from funasr.models.encoder.sanm_encoder import SANMVadEncoder |
| | | from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export |
| | | |
| | | class VadRealtimeTransformer(nn.Module): |
| | |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | |
| | | |
| | | class ESPnetLanguageModel(AbsESPnetModel): |
| | | class LanguageModel(AbsESPnetModel): |
| | | def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | |
| | | from typeguard import check_argument_types |
| | | import numpy as np |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM |
| | | from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask |
| | | from funasr.modules.embedding import SinusoidalPositionEncoder |
| | | from funasr.modules.layer_norm import LayerNorm |
| | | from funasr.modules.multi_layer_conv import Conv1dLinear |
| | |
| | | from funasr.modules.subsampling import check_short_utt |
| | | from funasr.models.ctc import CTC |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | |
| | | from funasr.modules.mask import subsequent_mask, vad_mask |
| | | |
| | | class EncoderLayerSANM(nn.Module): |
| | | def __init__( |
| | |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | return var_dict_torch_update |
| | | |
| | | |
| | | class SANMVadEncoder(AbsEncoder): |
| | | """ |
| | | author: Speech Lab, Alibaba Group, China |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | input_size: int, |
| | | output_size: int = 256, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | attention_dropout_rate: float = 0.0, |
| | | input_layer: Optional[str] = "conv2d", |
| | | pos_enc_class=SinusoidalPositionEncoder, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | positionwise_layer_type: str = "linear", |
| | | positionwise_conv_kernel_size: int = 1, |
| | | padding_idx: int = -1, |
| | | interctc_layer_idx: List[int] = [], |
| | | interctc_use_conditioning: bool = False, |
| | | kernel_size : int = 11, |
| | | sanm_shfit : int = 0, |
| | | selfattention_layer_type: str = "sanm", |
| | | ): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | | self._output_size = output_size |
| | | |
| | | if input_layer == "linear": |
| | | self.embed = torch.nn.Sequential( |
| | | torch.nn.Linear(input_size, output_size), |
| | | torch.nn.LayerNorm(output_size), |
| | | torch.nn.Dropout(dropout_rate), |
| | | torch.nn.ReLU(), |
| | | pos_enc_class(output_size, positional_dropout_rate), |
| | | ) |
| | | elif input_layer == "conv2d": |
| | | self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate) |
| | | elif input_layer == "conv2d2": |
| | | self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate) |
| | | elif input_layer == "conv2d6": |
| | | self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate) |
| | | elif input_layer == "conv2d8": |
| | | self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate) |
| | | elif input_layer == "embed": |
| | | self.embed = torch.nn.Sequential( |
| | | torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx), |
| | | SinusoidalPositionEncoder(), |
| | | ) |
| | | elif input_layer is None: |
| | | if input_size == output_size: |
| | | self.embed = None |
| | | else: |
| | | self.embed = torch.nn.Linear(input_size, output_size) |
| | | elif input_layer == "pe": |
| | | self.embed = SinusoidalPositionEncoder() |
| | | else: |
| | | raise ValueError("unknown input_layer: " + input_layer) |
| | | self.normalize_before = normalize_before |
| | | if positionwise_layer_type == "linear": |
| | | positionwise_layer = PositionwiseFeedForward |
| | | positionwise_layer_args = ( |
| | | output_size, |
| | | linear_units, |
| | | dropout_rate, |
| | | ) |
| | | elif positionwise_layer_type == "conv1d": |
| | | positionwise_layer = MultiLayeredConv1d |
| | | positionwise_layer_args = ( |
| | | output_size, |
| | | linear_units, |
| | | positionwise_conv_kernel_size, |
| | | dropout_rate, |
| | | ) |
| | | elif positionwise_layer_type == "conv1d-linear": |
| | | positionwise_layer = Conv1dLinear |
| | | positionwise_layer_args = ( |
| | | output_size, |
| | | linear_units, |
| | | positionwise_conv_kernel_size, |
| | | dropout_rate, |
| | | ) |
| | | else: |
| | | raise NotImplementedError("Support only linear or conv1d.") |
| | | |
| | | if selfattention_layer_type == "selfattn": |
| | | encoder_selfattn_layer = MultiHeadedAttention |
| | | encoder_selfattn_layer_args = ( |
| | | attention_heads, |
| | | output_size, |
| | | attention_dropout_rate, |
| | | ) |
| | | |
| | | elif selfattention_layer_type == "sanm": |
| | | self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask |
| | | encoder_selfattn_layer_args0 = ( |
| | | attention_heads, |
| | | input_size, |
| | | output_size, |
| | | attention_dropout_rate, |
| | | kernel_size, |
| | | sanm_shfit, |
| | | ) |
| | | |
| | | encoder_selfattn_layer_args = ( |
| | | attention_heads, |
| | | output_size, |
| | | output_size, |
| | | attention_dropout_rate, |
| | | kernel_size, |
| | | sanm_shfit, |
| | | ) |
| | | |
| | | self.encoders0 = repeat( |
| | | 1, |
| | | lambda lnum: EncoderLayerSANM( |
| | | input_size, |
| | | output_size, |
| | | self.encoder_selfattn_layer(*encoder_selfattn_layer_args0), |
| | | positionwise_layer(*positionwise_layer_args), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ), |
| | | ) |
| | | |
| | | self.encoders = repeat( |
| | | num_blocks-1, |
| | | lambda lnum: EncoderLayerSANM( |
| | | output_size, |
| | | output_size, |
| | | self.encoder_selfattn_layer(*encoder_selfattn_layer_args), |
| | | positionwise_layer(*positionwise_layer_args), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ), |
| | | ) |
| | | if self.normalize_before: |
| | | self.after_norm = LayerNorm(output_size) |
| | | |
| | | self.interctc_layer_idx = interctc_layer_idx |
| | | if len(interctc_layer_idx) > 0: |
| | | assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks |
| | | self.interctc_use_conditioning = interctc_use_conditioning |
| | | self.conditioning_layer = None |
| | | self.dropout = nn.Dropout(dropout_rate) |
| | | |
| | | def output_size(self) -> int: |
| | | return self._output_size |
| | | |
| | | def forward( |
| | | self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | vad_indexes: torch.Tensor, |
| | | prev_states: torch.Tensor = None, |
| | | ctc: CTC = None, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
| | | """Embed positions in tensor. |
| | | |
| | | Args: |
| | | xs_pad: input tensor (B, L, D) |
| | | ilens: input length (B) |
| | | prev_states: Not to be used now. |
| | | Returns: |
| | | position embedded tensor and mask |
| | | """ |
| | | masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) |
| | | sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0) |
| | | no_future_masks = masks & sub_masks |
| | | xs_pad *= self.output_size()**0.5 |
| | | if self.embed is None: |
| | | xs_pad = xs_pad |
| | | elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2) |
| | | or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)): |
| | | short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) |
| | | if short_status: |
| | | raise TooShortUttError( |
| | | f"has {xs_pad.size(1)} frames and is too short for subsampling " + |
| | | f"(it needs more than {limit_size} frames), return empty results", |
| | | xs_pad.size(1), |
| | | limit_size, |
| | | ) |
| | | xs_pad, masks = self.embed(xs_pad, masks) |
| | | else: |
| | | xs_pad = self.embed(xs_pad) |
| | | |
| | | # xs_pad = self.dropout(xs_pad) |
| | | mask_tup0 = [masks, no_future_masks] |
| | | encoder_outs = self.encoders0(xs_pad, mask_tup0) |
| | | xs_pad, _ = encoder_outs[0], encoder_outs[1] |
| | | intermediate_outs = [] |
| | | |
| | | |
| | | for layer_idx, encoder_layer in enumerate(self.encoders): |
| | | if layer_idx + 1 == len(self.encoders): |
| | | # This is last layer. |
| | | coner_mask = torch.ones(masks.size(0), |
| | | masks.size(-1), |
| | | masks.size(-1), |
| | | device=xs_pad.device, |
| | | dtype=torch.bool) |
| | | for word_index, length in enumerate(ilens): |
| | | coner_mask[word_index, :, :] = vad_mask(masks.size(-1), |
| | | vad_indexes[word_index], |
| | | device=xs_pad.device) |
| | | layer_mask = masks & coner_mask |
| | | else: |
| | | layer_mask = no_future_masks |
| | | mask_tup1 = [masks, layer_mask] |
| | | encoder_outs = encoder_layer(xs_pad, mask_tup1) |
| | | xs_pad, layer_mask = encoder_outs[0], encoder_outs[1] |
| | | |
| | | if self.normalize_before: |
| | | xs_pad = self.after_norm(xs_pad) |
| | | |
| | | olens = masks.squeeze(1).sum(1) |
| | | if len(intermediate_outs) > 0: |
| | | return (xs_pad, intermediate_outs), olens, None |
| | | return xs_pad, olens, None |
| File was renamed from funasr/punctuation/target_delay_transformer.py |
| | |
| | | import torch |
| | | import torch.nn as nn |
| | | |
| | | from funasr.modules.embedding import PositionalEncoding |
| | | from funasr.modules.embedding import SinusoidalPositionEncoder |
| | | #from funasr.models.encoder.transformer_encoder import TransformerEncoder as Encoder |
| | | from funasr.punctuation.sanm_encoder import SANMEncoder as Encoder |
| | | #from funasr.modules.mask import subsequent_n_mask |
| | | from funasr.punctuation.abs_model import AbsPunctuation |
| | | from funasr.train.abs_model import AbsPunctuation |
| | | |
| | | |
| | | class TargetDelayTransformer(AbsPunctuation): |
| File was renamed from funasr/punctuation/vad_realtime_transformer.py |
| | |
| | | |
| | | from funasr.modules.embedding import SinusoidalPositionEncoder |
| | | from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder |
| | | from funasr.punctuation.abs_model import AbsPunctuation |
| | | from funasr.train.abs_model import AbsPunctuation |
| | | |
| | | |
| | | class VadRealtimeTransformer(AbsPunctuation): |
| | |
| | | def split_to_mini_sentence(words: list, word_limit: int = 20): |
| | | assert word_limit > 1 |
| | | if len(words) <= word_limit: |
| | | return [words] |
| | | sentences = [] |
| | | length = len(words) |
| | | sentence_len = length // word_limit |
| | | for i in range(sentence_len): |
| | | sentences.append(words[i * word_limit:(i + 1) * word_limit]) |
| | | if length % word_limit > 0: |
| | | sentences.append(words[sentence_len * word_limit:]) |
| | | return sentences |
| | | |
| | |
| | | from funasr.datasets.collate_fn import CommonCollateFn |
| | | from funasr.datasets.preprocessor import CommonPreprocessor |
| | | from funasr.lm.abs_model import AbsLM |
| | | from funasr.lm.espnet_model import ESPnetLanguageModel |
| | | from funasr.lm.espnet_model import LanguageModel |
| | | from funasr.lm.seq_rnn_lm import SequentialRNNLM |
| | | from funasr.lm.transformer_lm import TransformerLM |
| | | from funasr.tasks.abs_task import AbsTask |
| | |
| | | group.add_argument( |
| | | "--model_conf", |
| | | action=NestedDictAction, |
| | | default=get_default_kwargs(ESPnetLanguageModel), |
| | | default=get_default_kwargs(LanguageModel), |
| | | help="The keyword arguments for model class.", |
| | | ) |
| | | |
| | |
| | | return retval |
| | | |
| | | @classmethod |
| | | def build_model(cls, args: argparse.Namespace) -> ESPnetLanguageModel: |
| | | def build_model(cls, args: argparse.Namespace) -> LanguageModel: |
| | | assert check_argument_types() |
| | | if isinstance(args.token_list, str): |
| | | with open(args.token_list, encoding="utf-8") as f: |
| | |
| | | |
| | | # 2. Build ESPnetModel |
| | | # Assume the last-id is sos_and_eos |
| | | model = ESPnetLanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf) |
| | | model = LanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf) |
| | | |
| | | # 3. Initialize |
| | | if args.init is not None: |
| | |
| | | |
| | | from funasr.datasets.collate_fn import CommonCollateFn |
| | | from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor |
| | | from funasr.punctuation.abs_model import AbsPunctuation |
| | | from funasr.punctuation.espnet_model import ESPnetPunctuationModel |
| | | from funasr.punctuation.target_delay_transformer import TargetDelayTransformer |
| | | from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer |
| | | from funasr.train.abs_model import AbsPunctuation |
| | | from funasr.train.abs_model import PunctuationModel |
| | | from funasr.models.target_delay_transformer import TargetDelayTransformer |
| | | from funasr.models.vad_realtime_transformer import VadRealtimeTransformer |
| | | from funasr.tasks.abs_task import AbsTask |
| | | from funasr.text.phoneme_tokenizer import g2p_choices |
| | | from funasr.torch_utils.initialize import initialize |
| | |
| | | group.add_argument( |
| | | "--model_conf", |
| | | action=NestedDictAction, |
| | | default=get_default_kwargs(ESPnetPunctuationModel), |
| | | default=get_default_kwargs(PunctuationModel), |
| | | help="The keyword arguments for model class.", |
| | | ) |
| | | |
| | |
| | | return retval |
| | | |
| | | @classmethod |
| | | def build_model(cls, args: argparse.Namespace) -> ESPnetPunctuationModel: |
| | | def build_model(cls, args: argparse.Namespace) -> PunctuationModel: |
| | | assert check_argument_types() |
| | | if isinstance(args.token_list, str): |
| | | with open(args.token_list, encoding="utf-8") as f: |
| | |
| | | # Assume the last-id is sos_and_eos |
| | | if "punc_weight" in args.model_conf: |
| | | args.model_conf.pop("punc_weight") |
| | | model = ESPnetPunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf) |
| | | model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf) |
| | | |
| | | # FIXME(kamo): Should be done in model? |
| | | # 3. Initialize |
| File was renamed from funasr/punctuation/espnet_model.py |
| | |
| | | from abc import ABC |
| | | from abc import abstractmethod |
| | | from typing import Tuple |
| | | |
| | | import torch |
| | | |
| | | from typing import Dict |
| | | from typing import Optional |
| | | from typing import Tuple |
| | |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | from funasr.punctuation.abs_model import AbsPunctuation |
| | | from funasr.torch_utils.device_funcs import force_gatherable |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | |
| | | from funasr.modules.scorers.scorer_interface import BatchScorerInterface |
| | | |
| | | class ESPnetPunctuationModel(AbsESPnetModel): |
| | | |
| | | class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC): |
| | | """The abstract class |
| | | |
| | | To share the loss calculation way among different models, |
| | | We uses delegate pattern here: |
| | | The instance of this class should be passed to "LanguageModel" |
| | | |
| | | This "model" is one of mediator objects for "Task" class. |
| | | |
| | | """ |
| | | |
| | | @abstractmethod |
| | | def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | raise NotImplementedError |
| | | |
| | | @abstractmethod |
| | | def with_vad(self) -> bool: |
| | | raise NotImplementedError |
| | | |
| | | |
| | | class PunctuationModel(AbsESPnetModel): |
| | | |
| | | def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None): |
| | | assert check_argument_types() |
| | |
| | | return nll, text_lengths |
| | | else: |
| | | self.punc_weight = self.punc_weight.to(punc.device) |
| | | nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none", ignore_index=self.ignore_id) |
| | | nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none", |
| | | ignore_index=self.ignore_id) |
| | | # nll: (BxL,) -> (BxL,) |
| | | if max_length is None: |
| | | nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0) |