| | |
| | | |
| | | |
| | | self.embed = nn.Embedding(vocab_size, embed_unit) |
| | | encoder_class = tables.encoder_classes.get(encoder.lower()) |
| | | encoder_class = tables.encoder_classes.get(encoder) |
| | | encoder = encoder_class(**encoder_conf) |
| | | |
| | | self.decoder = nn.Linear(att_unit, punc_size) |
| | |
| | | |
| | | |
| | | |
| | | def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]: |
| | | def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs): |
| | | """Compute loss value from buffer sequences. |
| | | |
| | | Args: |
| | |
| | | # text = data_in[0] |
| | | # text_lengths = data_lengths[0] if data_lengths is not None else None |
| | | split_size = kwargs.get("split_size", 20) |
| | | |
| | | tokens = split_words(text) |
| | | |
| | | jieba_usr_dict = kwargs.get("jieba_usr_dict", None) |
| | | if jieba_usr_dict and isinstance(jieba_usr_dict, str): |
| | | import jieba |
| | | jieba.load_userdict(jieba_usr_dict) |
| | | jieba_usr_dict = jieba |
| | | kwargs["jieba_usr_dict"] = "jieba_usr_dict" |
| | | tokens = split_words(text, jieba_usr_dict=jieba_usr_dict) |
| | | tokens_int = tokenizer.encode(tokens) |
| | | |
| | | mini_sentences = split_to_mini_sentence(tokens, split_size) |