Merge pull request #165 from alibaba-damo-academy/dev_cmz
punctuation:add training code, support largedataset
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | import os |
| | | from funasr.tasks.punctuation import PunctuationTask |
| | | |
| | | |
| | | def parse_args(): |
| | | parser = PunctuationTask.get_parser() |
| | | parser.add_argument( |
| | | "--gpu_id", |
| | | type=int, |
| | | default=0, |
| | | help="local gpu id.", |
| | | ) |
| | | parser.add_argument( |
| | | "--punc_list", |
| | | type=str, |
| | | default=None, |
| | | help="Punctuation list", |
| | | ) |
| | | args = parser.parse_args() |
| | | return args |
| | | |
| | | |
| | | def main(args=None, cmd=None): |
| | | """ |
| | | punc training. |
| | | """ |
| | | PunctuationTask.main(args=args, cmd=cmd) |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | args = parse_args() |
| | | |
| | | # setup local gpu_id |
| | | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) |
| | | |
| | | # DDP settings |
| | | if args.ngpu > 1: |
| | | args.distributed = True |
| | | else: |
| | | args.distributed = False |
| | | |
| | | main(args=args) |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | import os |
| | | from funasr.tasks.punctuation import PunctuationTask |
| | | |
| | | |
| | | def parse_args(): |
| | | parser = PunctuationTask.get_parser() |
| | | parser.add_argument( |
| | | "--gpu_id", |
| | | type=int, |
| | | default=0, |
| | | help="local gpu id.", |
| | | ) |
| | | parser.add_argument( |
| | | "--punc_list", |
| | | type=str, |
| | | default=None, |
| | | help="Punctuation list", |
| | | ) |
| | | args = parser.parse_args() |
| | | return args |
| | | |
| | | |
| | | def main(args=None, cmd=None): |
| | | """ |
| | | punc training. |
| | | """ |
| | | PunctuationTask.main(args=args, cmd=cmd) |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | args = parse_args() |
| | | |
| | | # setup local gpu_id |
| | | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) |
| | | |
| | | # DDP settings |
| | | if args.ngpu > 1: |
| | | args.distributed = True |
| | | else: |
| | | args.distributed = False |
| | | assert args.num_worker_count == 1 |
| | | |
| | | main(args=args) |
| | |
| | | return seg_dict |
| | | |
| | | class ArkDataLoader(AbsIterFactory): |
| | | def __init__(self, data_list, dict_file, dataset_conf, seg_dict_file=None, mode="train"): |
| | | def __init__(self, data_list, dict_file, dataset_conf, seg_dict_file=None, punc_dict_file=None, mode="train"): |
| | | symbol_table = read_symbol_table(dict_file) if dict_file is not None else None |
| | | if seg_dict_file is not None: |
| | | seg_dict = load_seg_dict(seg_dict_file) |
| | | else: |
| | | seg_dict = None |
| | | if punc_dict_file is not None: |
| | | punc_dict = read_symbol_table(punc_dict_file) |
| | | else: |
| | | punc_dict = None |
| | | self.dataset_conf = dataset_conf |
| | | logging.info("dataloader config: {}".format(self.dataset_conf)) |
| | | batch_mode = self.dataset_conf.get("batch_mode", "padding") |
| | | self.dataset = Dataset(data_list, symbol_table, seg_dict, |
| | | self.dataset = Dataset(data_list, symbol_table, seg_dict, punc_dict, |
| | | self.dataset_conf, mode=mode, batch_mode=batch_mode) |
| | | |
| | | def build_iter(self, epoch, shuffle=True): |
| | |
| | | sample_dict["key"] = key |
| | | else: |
| | | text = item |
| | | sample_dict[data_name] = text.strip().split()[1:] |
| | | segs = text.strip().split() |
| | | sample_dict[data_name] = segs[1:] |
| | | if "key" not in sample_dict: |
| | | sample_dict["key"] = segs[0] |
| | | yield sample_dict |
| | | |
| | | self.close_reader(reader_list) |
| | | |
| | | |
| | | def len_fn_example(data): |
| | | return len(data) |
| | | return 1 |
| | | |
| | | |
| | | def len_fn_token(data): |
| | |
| | | def Dataset(data_list_file, |
| | | dict, |
| | | seg_dict, |
| | | punc_dict, |
| | | conf, |
| | | mode="train", |
| | | batch_mode="padding"): |
| | |
| | | dataset = FilterIterDataPipe(dataset, fn=filter_fn) |
| | | |
| | | if "text" in data_names: |
| | | vocab = {'vocab': dict, 'seg_dict': seg_dict} |
| | | vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict} |
| | | tokenize_fn = partial(tokenize, **vocab) |
| | | dataset = MapperIterDataPipe(dataset, fn=tokenize_fn) |
| | | |
| | |
| | | sort_size=sort_size, |
| | | batch_mode=batch_mode) |
| | | |
| | | dataset = MapperIterDataPipe(dataset, fn=padding if batch_mode == "padding" else clipping) |
| | | int_pad_value = conf.get("int_pad_value", -1) |
| | | float_pad_value = conf.get("float_pad_value", 0.0) |
| | | padding_conf = {"int_pad_value": int_pad_value, "float_pad_value": float_pad_value} |
| | | padding_fn = partial(padding, **padding_conf) |
| | | dataset = MapperIterDataPipe(dataset, fn=padding_fn if batch_mode == "padding" else clipping) |
| | | |
| | | return dataset |
| | |
| | | def padding(data, float_pad_value=0.0, int_pad_value=-1): |
| | | assert isinstance(data, list) |
| | | assert "key" in data[0] |
| | | assert "speech" in data[0] |
| | | assert "text" in data[0] |
| | | |
| | | assert "speech" in data[0] or "text" in data[0] |
| | | |
| | | keys = [x["key"] for x in data] |
| | | |
| | | batch = {} |
| | |
| | | |
| | | def tokenize(data, |
| | | vocab=None, |
| | | seg_dict=None): |
| | | seg_dict=None, |
| | | punc_dict=None): |
| | | assert "text" in data |
| | | assert isinstance(vocab, dict) |
| | | text = data["text"] |
| | | token = [] |
| | | vad = -2 |
| | | |
| | | if seg_dict is not None: |
| | | assert isinstance(seg_dict, dict) |
| | | txt = forward_segment("".join(text).lower(), seg_dict) |
| | | text = seg_tokenize(txt, seg_dict) |
| | | |
| | | for x in text: |
| | | if x in vocab: |
| | | |
| | | length = len(text) |
| | | for i in range(length): |
| | | x = text[i] |
| | | if i == length-1 and "punc" in data and text[i].startswith("vad:"): |
| | | vad = x[-1][4:] |
| | | if len(vad) == 0: |
| | | vad = -1 |
| | | else: |
| | | vad = int(vad) |
| | | elif x in vocab: |
| | | token.append(vocab[x]) |
| | | else: |
| | | token.append(vocab['<unk>']) |
| | | |
| | | if "punc" in data and punc_dict is not None: |
| | | punc_token = [] |
| | | for punc in data["punc"]: |
| | | if punc in punc_dict: |
| | | punc_token.append(punc_dict[punc]) |
| | | else: |
| | | punc_token.append(punc_dict["_"]) |
| | | data["punc"] = np.array(punc_token) |
| | | |
| | | data["text"] = np.array(token) |
| | | if vad is not -2: |
| | | data["vad_indexes"]=np.array([vad], dtype=np.int64) |
| | | return data |
| | |
| | | del data[self.split_text_name] |
| | | return result |
| | | |
| | | class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor): |
| | | def __init__( |
| | | self, |
| | | train: bool, |
| | | token_type: List[str] = [None], |
| | | token_list: List[Union[Path, str, Iterable[str]]] = [None], |
| | | bpemodel: List[Union[Path, str, Iterable[str]]] = [None], |
| | | text_cleaner: Collection[str] = None, |
| | | g2p_type: str = None, |
| | | unk_symbol: str = "<unk>", |
| | | space_symbol: str = "<space>", |
| | | non_linguistic_symbols: Union[Path, str, Iterable[str]] = None, |
| | | delimiter: str = None, |
| | | rir_scp: str = None, |
| | | rir_apply_prob: float = 1.0, |
| | | noise_scp: str = None, |
| | | noise_apply_prob: float = 1.0, |
| | | noise_db_range: str = "3_10", |
| | | speech_volume_normalize: float = None, |
| | | speech_name: str = "speech", |
| | | text_name: List[str] = ["text"], |
| | | vad_name: str = "vad_indexes", |
| | | ): |
| | | # TODO(jiatong): sync with Kamo and Jing on interface for preprocessor |
| | | super().__init__( |
| | | train=train, |
| | | token_type=token_type[0], |
| | | token_list=token_list[0], |
| | | bpemodel=bpemodel[0], |
| | | text_cleaner=text_cleaner, |
| | | g2p_type=g2p_type, |
| | | unk_symbol=unk_symbol, |
| | | space_symbol=space_symbol, |
| | | non_linguistic_symbols=non_linguistic_symbols, |
| | | delimiter=delimiter, |
| | | speech_name=speech_name, |
| | | text_name=text_name[0], |
| | | rir_scp=rir_scp, |
| | | rir_apply_prob=rir_apply_prob, |
| | | noise_scp=noise_scp, |
| | | noise_apply_prob=noise_apply_prob, |
| | | noise_db_range=noise_db_range, |
| | | speech_volume_normalize=speech_volume_normalize, |
| | | ) |
| | | |
| | | assert ( |
| | | len(token_type) == len(token_list) == len(bpemodel) == len(text_name) |
| | | ), "token_type, token_list, bpemodel, or processing text_name mismatched" |
| | | self.num_tokenizer = len(token_type) |
| | | self.tokenizer = [] |
| | | self.token_id_converter = [] |
| | | |
| | | for i in range(self.num_tokenizer): |
| | | if token_type[i] is not None: |
| | | if token_list[i] is None: |
| | | raise ValueError("token_list is required if token_type is not None") |
| | | |
| | | self.tokenizer.append( |
| | | build_tokenizer( |
| | | token_type=token_type[i], |
| | | bpemodel=bpemodel[i], |
| | | delimiter=delimiter, |
| | | space_symbol=space_symbol, |
| | | non_linguistic_symbols=non_linguistic_symbols, |
| | | g2p_type=g2p_type, |
| | | ) |
| | | ) |
| | | self.token_id_converter.append( |
| | | TokenIDConverter( |
| | | token_list=token_list[i], |
| | | unk_symbol=unk_symbol, |
| | | ) |
| | | ) |
| | | else: |
| | | self.tokenizer.append(None) |
| | | self.token_id_converter.append(None) |
| | | |
| | | self.text_cleaner = TextCleaner(text_cleaner) |
| | | self.text_name = text_name # override the text_name from CommonPreprocessor |
| | | self.vad_name = vad_name |
| | | |
| | | def _text_process( |
| | | self, data: Dict[str, Union[str, np.ndarray]] |
| | | ) -> Dict[str, np.ndarray]: |
| | | for i in range(self.num_tokenizer): |
| | | text_name = self.text_name[i] |
| | | if text_name in data and self.tokenizer[i] is not None: |
| | | text = data[text_name] |
| | | text = self.text_cleaner(text) |
| | | tokens = self.tokenizer[i].text2tokens(text) |
| | | if "vad:" in tokens[-1]: |
| | | vad = tokens[-1][4:] |
| | | tokens = tokens[:-1] |
| | | if len(vad) == 0: |
| | | vad = -1 |
| | | else: |
| | | vad = int(vad) |
| | | data[self.vad_name] = np.array([vad], dtype=np.int64) |
| | | text_ints = self.token_id_converter[i].tokens2ids(tokens) |
| | | data[text_name] = np.array(text_ints, dtype=np.int64) |
| | |
| | | att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder) |
| | | return att_outs + fsmn_memory |
| | | |
| | | class MultiHeadedAttentionSANMwithMask(MultiHeadedAttentionSANM): |
| | | def __init__(self, *args, **kwargs): |
| | | super().__init__(*args, **kwargs) |
| | | |
| | | def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None): |
| | | q_h, k_h, v_h, v = self.forward_qkv(x) |
| | | fsmn_memory = self.forward_fsmn(v, mask[0], mask_shfit_chunk) |
| | | q_h = q_h * self.d_k ** (-0.5) |
| | | scores = torch.matmul(q_h, k_h.transpose(-2, -1)) |
| | | att_outs = self.forward_attention(v_h, scores, mask[1], mask_att_chunk_encoder) |
| | | return att_outs + fsmn_memory |
| | | |
| | | class MultiHeadedAttentionSANMDecoder(nn.Module): |
| | | """Multi-Head Attention layer. |
| | | |
| | |
| | | ys_mask = ys_in_pad != ignore_id |
| | | m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0) |
| | | return ys_mask.unsqueeze(-2) & m |
| | | |
| | | def vad_mask(size, vad_pos, device="cpu", dtype=torch.bool): |
| | | """Create mask for decoder self-attention. |
| | | |
| | | :param int size: size of mask |
| | | :param int vad_pos: index of vad index |
| | | :param str device: "cpu" or "cuda" or torch.Tensor.device |
| | | :param torch.dtype dtype: result dtype |
| | | :rtype: torch.Tensor (B, Lmax, Lmax) |
| | | """ |
| | | ret = torch.ones(size, size, device=device, dtype=dtype) |
| | | if vad_pos <= 0 or vad_pos >= size: |
| | | return ret |
| | | sub_corner = torch.zeros( |
| | | vad_pos - 1, size - vad_pos, device=device, dtype=dtype) |
| | | ret[0:vad_pos - 1, vad_pos:] = sub_corner |
| | | return ret |
| | |
| | | @abstractmethod |
| | | def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | raise NotImplementedError |
| | | |
| | | @abstractmethod |
| | | def with_vad(self) -> bool: |
| | | raise NotImplementedError |
| | |
| | | |
| | | class ESPnetPunctuationModel(AbsESPnetModel): |
| | | |
| | | def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0): |
| | | def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | | self.punc_model = punc_model |
| | | self.punc_weight = torch.Tensor(punc_weight) |
| | | self.sos = 1 |
| | | self.eos = 2 |
| | | |
| | | # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR. |
| | | self.ignore_id = ignore_id |
| | | if self.punc_model.with_vad(): |
| | | print("This is a vad puncuation model.") |
| | | |
| | | def nll( |
| | | self, |
| | |
| | | text_lengths: torch.Tensor, |
| | | punc_lengths: torch.Tensor, |
| | | max_length: Optional[int] = None, |
| | | vad_indexes: Optional[torch.Tensor] = None, |
| | | vad_indexes_lengths: Optional[torch.Tensor] = None, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Compute negative log likelihood(nll) |
| | | |
| | |
| | | else: |
| | | text = text[:, :max_length] |
| | | punc = punc[:, :max_length] |
| | | # 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>' |
| | | # text: (Batch, Length) -> x, y: (Batch, Length + 1) |
| | | #x = F.pad(text, [1, 0], "constant", self.eos) |
| | | #t = F.pad(text, [0, 1], "constant", self.ignore_id) |
| | | #for i, l in enumerate(text_lengths): |
| | | # t[i, l] = self.sos |
| | | #x_lengths = text_lengths + 1 |
| | | |
| | | if self.punc_model.with_vad(): |
| | | # Should be VadRealtimeTransformer |
| | | assert vad_indexes is not None |
| | | y, _ = self.punc_model(text, text_lengths, vad_indexes) |
| | | else: |
| | | # Should be TargetDelayTransformer, |
| | | y, _ = self.punc_model(text, text_lengths) |
| | | |
| | | # 2. Forward Language model |
| | | # x: (Batch, Length) -> y: (Batch, Length, NVocab) |
| | | y, _ = self.punc_model(text, text_lengths) |
| | | |
| | | # 3. Calc negative log likelihood |
| | | # Calc negative log likelihood |
| | | # nll: (BxL,) |
| | | if self.training == False: |
| | | _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1) |
| | |
| | | nll = torch.Tensor([f1_score]).repeat(text_lengths.sum()) |
| | | return nll, text_lengths |
| | | else: |
| | | nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), reduction="none", ignore_index=self.ignore_id) |
| | | self.punc_weight = self.punc_weight.to(punc.device) |
| | | nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none", ignore_index=self.ignore_id) |
| | | # nll: (BxL,) -> (BxL,) |
| | | if max_length is None: |
| | | nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0) |
| | |
| | | assert x_lengths.size(0) == total_num |
| | | return nll, x_lengths |
| | | |
| | | def forward(self, text: torch.Tensor, punc: torch.Tensor, text_lengths: torch.Tensor, |
| | | punc_lengths: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: |
| | | nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths) |
| | | def forward( |
| | | self, |
| | | text: torch.Tensor, |
| | | punc: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | punc_lengths: torch.Tensor, |
| | | vad_indexes: Optional[torch.Tensor] = None, |
| | | vad_indexes_lengths: Optional[torch.Tensor] = None, |
| | | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: |
| | | nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes) |
| | | ntokens = y_lengths.sum() |
| | | loss = nll.sum() / ntokens |
| | | stats = dict(loss=loss.detach()) |
| | |
| | | text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]: |
| | | return {} |
| | | |
| | | def inference(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]: |
| | | return self.punc_model(text, text_lengths) |
| | | def inference(self, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | vad_indexes: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, None]: |
| | | if self.punc_model.with_vad(): |
| | | assert vad_indexes is not None |
| | | return self.punc_model(text, text_lengths, vad_indexes) |
| | | else: |
| | | return self.punc_model(text, text_lengths) |
| New file |
| | |
| | | from typing import List |
| | | from typing import Optional |
| | | from typing import Sequence |
| | | from typing import Tuple |
| | | from typing import Union |
| | | import logging |
| | | import torch |
| | | import torch.nn as nn |
| | | from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk |
| | | from typeguard import check_argument_types |
| | | import numpy as np |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask |
| | | from funasr.modules.embedding import SinusoidalPositionEncoder |
| | | from funasr.modules.layer_norm import LayerNorm |
| | | from funasr.modules.multi_layer_conv import Conv1dLinear |
| | | from funasr.modules.multi_layer_conv import MultiLayeredConv1d |
| | | from funasr.modules.positionwise_feed_forward import ( |
| | | PositionwiseFeedForward, # noqa: H301 |
| | | ) |
| | | from funasr.modules.repeat import repeat |
| | | from funasr.modules.subsampling import Conv2dSubsampling |
| | | from funasr.modules.subsampling import Conv2dSubsampling2 |
| | | from funasr.modules.subsampling import Conv2dSubsampling6 |
| | | from funasr.modules.subsampling import Conv2dSubsampling8 |
| | | from funasr.modules.subsampling import TooShortUttError |
| | | from funasr.modules.subsampling import check_short_utt |
| | | from funasr.models.ctc import CTC |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | from funasr.modules.mask import subsequent_mask, vad_mask |
| | | |
| | | class EncoderLayerSANM(nn.Module): |
| | | def __init__( |
| | | self, |
| | | in_size, |
| | | size, |
| | | self_attn, |
| | | feed_forward, |
| | | dropout_rate, |
| | | normalize_before=True, |
| | | concat_after=False, |
| | | stochastic_depth_rate=0.0, |
| | | ): |
| | | """Construct an EncoderLayer object.""" |
| | | super(EncoderLayerSANM, self).__init__() |
| | | self.self_attn = self_attn |
| | | self.feed_forward = feed_forward |
| | | self.norm1 = LayerNorm(in_size) |
| | | self.norm2 = LayerNorm(size) |
| | | self.dropout = nn.Dropout(dropout_rate) |
| | | self.in_size = in_size |
| | | self.size = size |
| | | self.normalize_before = normalize_before |
| | | self.concat_after = concat_after |
| | | if self.concat_after: |
| | | self.concat_linear = nn.Linear(size + size, size) |
| | | self.stochastic_depth_rate = stochastic_depth_rate |
| | | self.dropout_rate = dropout_rate |
| | | |
| | | def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None): |
| | | """Compute encoded features. |
| | | |
| | | Args: |
| | | x_input (torch.Tensor): Input tensor (#batch, time, size). |
| | | mask (torch.Tensor): Mask tensor for the input (#batch, time). |
| | | cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). |
| | | |
| | | Returns: |
| | | torch.Tensor: Output tensor (#batch, time, size). |
| | | torch.Tensor: Mask tensor (#batch, time). |
| | | |
| | | """ |
| | | skip_layer = False |
| | | # with stochastic depth, residual connection `x + f(x)` becomes |
| | | # `x <- x + 1 / (1 - p) * f(x)` at training time. |
| | | stoch_layer_coeff = 1.0 |
| | | if self.training and self.stochastic_depth_rate > 0: |
| | | skip_layer = torch.rand(1).item() < self.stochastic_depth_rate |
| | | stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) |
| | | |
| | | if skip_layer: |
| | | if cache is not None: |
| | | x = torch.cat([cache, x], dim=1) |
| | | return x, mask |
| | | |
| | | residual = x |
| | | if self.normalize_before: |
| | | x = self.norm1(x) |
| | | |
| | | if self.concat_after: |
| | | x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1) |
| | | if self.in_size == self.size: |
| | | x = residual + stoch_layer_coeff * self.concat_linear(x_concat) |
| | | else: |
| | | x = stoch_layer_coeff * self.concat_linear(x_concat) |
| | | else: |
| | | if self.in_size == self.size: |
| | | x = residual + stoch_layer_coeff * self.dropout( |
| | | self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder) |
| | | ) |
| | | else: |
| | | x = stoch_layer_coeff * self.dropout( |
| | | self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder) |
| | | ) |
| | | if not self.normalize_before: |
| | | x = self.norm1(x) |
| | | |
| | | residual = x |
| | | if self.normalize_before: |
| | | x = self.norm2(x) |
| | | x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) |
| | | if not self.normalize_before: |
| | | x = self.norm2(x) |
| | | |
| | | |
| | | return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder |
| | | |
| | | class SANMEncoder(AbsEncoder): |
| | | """ |
| | | author: Speech Lab, Alibaba Group, China |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | input_size: int, |
| | | output_size: int = 256, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | attention_dropout_rate: float = 0.0, |
| | | input_layer: Optional[str] = "conv2d", |
| | | pos_enc_class=SinusoidalPositionEncoder, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | positionwise_layer_type: str = "linear", |
| | | positionwise_conv_kernel_size: int = 1, |
| | | padding_idx: int = -1, |
| | | interctc_layer_idx: List[int] = [], |
| | | interctc_use_conditioning: bool = False, |
| | | kernel_size : int = 11, |
| | | sanm_shfit : int = 0, |
| | | selfattention_layer_type: str = "sanm", |
| | | ): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | | self._output_size = output_size |
| | | |
| | | if input_layer == "linear": |
| | | self.embed = torch.nn.Sequential( |
| | | torch.nn.Linear(input_size, output_size), |
| | | torch.nn.LayerNorm(output_size), |
| | | torch.nn.Dropout(dropout_rate), |
| | | torch.nn.ReLU(), |
| | | pos_enc_class(output_size, positional_dropout_rate), |
| | | ) |
| | | elif input_layer == "conv2d": |
| | | self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate) |
| | | elif input_layer == "conv2d2": |
| | | self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate) |
| | | elif input_layer == "conv2d6": |
| | | self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate) |
| | | elif input_layer == "conv2d8": |
| | | self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate) |
| | | elif input_layer == "embed": |
| | | self.embed = torch.nn.Sequential( |
| | | torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx), |
| | | SinusoidalPositionEncoder(), |
| | | ) |
| | | elif input_layer is None: |
| | | if input_size == output_size: |
| | | self.embed = None |
| | | else: |
| | | self.embed = torch.nn.Linear(input_size, output_size) |
| | | elif input_layer == "pe": |
| | | self.embed = SinusoidalPositionEncoder() |
| | | else: |
| | | raise ValueError("unknown input_layer: " + input_layer) |
| | | self.normalize_before = normalize_before |
| | | if positionwise_layer_type == "linear": |
| | | positionwise_layer = PositionwiseFeedForward |
| | | positionwise_layer_args = ( |
| | | output_size, |
| | | linear_units, |
| | | dropout_rate, |
| | | ) |
| | | elif positionwise_layer_type == "conv1d": |
| | | positionwise_layer = MultiLayeredConv1d |
| | | positionwise_layer_args = ( |
| | | output_size, |
| | | linear_units, |
| | | positionwise_conv_kernel_size, |
| | | dropout_rate, |
| | | ) |
| | | elif positionwise_layer_type == "conv1d-linear": |
| | | positionwise_layer = Conv1dLinear |
| | | positionwise_layer_args = ( |
| | | output_size, |
| | | linear_units, |
| | | positionwise_conv_kernel_size, |
| | | dropout_rate, |
| | | ) |
| | | else: |
| | | raise NotImplementedError("Support only linear or conv1d.") |
| | | |
| | | if selfattention_layer_type == "selfattn": |
| | | encoder_selfattn_layer = MultiHeadedAttention |
| | | encoder_selfattn_layer_args = ( |
| | | attention_heads, |
| | | output_size, |
| | | attention_dropout_rate, |
| | | ) |
| | | |
| | | elif selfattention_layer_type == "sanm": |
| | | self.encoder_selfattn_layer = MultiHeadedAttentionSANM |
| | | encoder_selfattn_layer_args0 = ( |
| | | attention_heads, |
| | | input_size, |
| | | output_size, |
| | | attention_dropout_rate, |
| | | kernel_size, |
| | | sanm_shfit, |
| | | ) |
| | | |
| | | encoder_selfattn_layer_args = ( |
| | | attention_heads, |
| | | output_size, |
| | | output_size, |
| | | attention_dropout_rate, |
| | | kernel_size, |
| | | sanm_shfit, |
| | | ) |
| | | |
| | | self.encoders0 = repeat( |
| | | 1, |
| | | lambda lnum: EncoderLayerSANM( |
| | | input_size, |
| | | output_size, |
| | | self.encoder_selfattn_layer(*encoder_selfattn_layer_args0), |
| | | positionwise_layer(*positionwise_layer_args), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ), |
| | | ) |
| | | |
| | | self.encoders = repeat( |
| | | num_blocks-1, |
| | | lambda lnum: EncoderLayerSANM( |
| | | output_size, |
| | | output_size, |
| | | self.encoder_selfattn_layer(*encoder_selfattn_layer_args), |
| | | positionwise_layer(*positionwise_layer_args), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ), |
| | | ) |
| | | if self.normalize_before: |
| | | self.after_norm = LayerNorm(output_size) |
| | | |
| | | self.interctc_layer_idx = interctc_layer_idx |
| | | if len(interctc_layer_idx) > 0: |
| | | assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks |
| | | self.interctc_use_conditioning = interctc_use_conditioning |
| | | self.conditioning_layer = None |
| | | self.dropout = nn.Dropout(dropout_rate) |
| | | |
| | | def output_size(self) -> int: |
| | | return self._output_size |
| | | |
| | | def forward( |
| | | self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | prev_states: torch.Tensor = None, |
| | | ctc: CTC = None, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
| | | """Embed positions in tensor. |
| | | |
| | | Args: |
| | | xs_pad: input tensor (B, L, D) |
| | | ilens: input length (B) |
| | | prev_states: Not to be used now. |
| | | Returns: |
| | | position embedded tensor and mask |
| | | """ |
| | | masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) |
| | | xs_pad *= self.output_size()**0.5 |
| | | if self.embed is None: |
| | | xs_pad = xs_pad |
| | | elif ( |
| | | isinstance(self.embed, Conv2dSubsampling) |
| | | or isinstance(self.embed, Conv2dSubsampling2) |
| | | or isinstance(self.embed, Conv2dSubsampling6) |
| | | or isinstance(self.embed, Conv2dSubsampling8) |
| | | ): |
| | | short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) |
| | | if short_status: |
| | | raise TooShortUttError( |
| | | f"has {xs_pad.size(1)} frames and is too short for subsampling " |
| | | + f"(it needs more than {limit_size} frames), return empty results", |
| | | xs_pad.size(1), |
| | | limit_size, |
| | | ) |
| | | xs_pad, masks = self.embed(xs_pad, masks) |
| | | else: |
| | | xs_pad = self.embed(xs_pad) |
| | | |
| | | # xs_pad = self.dropout(xs_pad) |
| | | encoder_outs = self.encoders0(xs_pad, masks) |
| | | xs_pad, masks = encoder_outs[0], encoder_outs[1] |
| | | intermediate_outs = [] |
| | | if len(self.interctc_layer_idx) == 0: |
| | | encoder_outs = self.encoders(xs_pad, masks) |
| | | xs_pad, masks = encoder_outs[0], encoder_outs[1] |
| | | else: |
| | | for layer_idx, encoder_layer in enumerate(self.encoders): |
| | | encoder_outs = encoder_layer(xs_pad, masks) |
| | | xs_pad, masks = encoder_outs[0], encoder_outs[1] |
| | | |
| | | if layer_idx + 1 in self.interctc_layer_idx: |
| | | encoder_out = xs_pad |
| | | |
| | | # intermediate outputs are also normalized |
| | | if self.normalize_before: |
| | | encoder_out = self.after_norm(encoder_out) |
| | | |
| | | intermediate_outs.append((layer_idx + 1, encoder_out)) |
| | | |
| | | if self.interctc_use_conditioning: |
| | | ctc_out = ctc.softmax(encoder_out) |
| | | xs_pad = xs_pad + self.conditioning_layer(ctc_out) |
| | | |
| | | if self.normalize_before: |
| | | xs_pad = self.after_norm(xs_pad) |
| | | |
| | | olens = masks.squeeze(1).sum(1) |
| | | if len(intermediate_outs) > 0: |
| | | return (xs_pad, intermediate_outs), olens, None |
| | | return xs_pad, olens, None |
| | | |
| | | class SANMVadEncoder(AbsEncoder): |
| | | """ |
| | | author: Speech Lab, Alibaba Group, China |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | input_size: int, |
| | | output_size: int = 256, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | attention_dropout_rate: float = 0.0, |
| | | input_layer: Optional[str] = "conv2d", |
| | | pos_enc_class=SinusoidalPositionEncoder, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | positionwise_layer_type: str = "linear", |
| | | positionwise_conv_kernel_size: int = 1, |
| | | padding_idx: int = -1, |
| | | interctc_layer_idx: List[int] = [], |
| | | interctc_use_conditioning: bool = False, |
| | | kernel_size : int = 11, |
| | | sanm_shfit : int = 0, |
| | | selfattention_layer_type: str = "sanm", |
| | | ): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | | self._output_size = output_size |
| | | |
| | | if input_layer == "linear": |
| | | self.embed = torch.nn.Sequential( |
| | | torch.nn.Linear(input_size, output_size), |
| | | torch.nn.LayerNorm(output_size), |
| | | torch.nn.Dropout(dropout_rate), |
| | | torch.nn.ReLU(), |
| | | pos_enc_class(output_size, positional_dropout_rate), |
| | | ) |
| | | elif input_layer == "conv2d": |
| | | self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate) |
| | | elif input_layer == "conv2d2": |
| | | self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate) |
| | | elif input_layer == "conv2d6": |
| | | self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate) |
| | | elif input_layer == "conv2d8": |
| | | self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate) |
| | | elif input_layer == "embed": |
| | | self.embed = torch.nn.Sequential( |
| | | torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx), |
| | | SinusoidalPositionEncoder(), |
| | | ) |
| | | elif input_layer is None: |
| | | if input_size == output_size: |
| | | self.embed = None |
| | | else: |
| | | self.embed = torch.nn.Linear(input_size, output_size) |
| | | elif input_layer == "pe": |
| | | self.embed = SinusoidalPositionEncoder() |
| | | else: |
| | | raise ValueError("unknown input_layer: " + input_layer) |
| | | self.normalize_before = normalize_before |
| | | if positionwise_layer_type == "linear": |
| | | positionwise_layer = PositionwiseFeedForward |
| | | positionwise_layer_args = ( |
| | | output_size, |
| | | linear_units, |
| | | dropout_rate, |
| | | ) |
| | | elif positionwise_layer_type == "conv1d": |
| | | positionwise_layer = MultiLayeredConv1d |
| | | positionwise_layer_args = ( |
| | | output_size, |
| | | linear_units, |
| | | positionwise_conv_kernel_size, |
| | | dropout_rate, |
| | | ) |
| | | elif positionwise_layer_type == "conv1d-linear": |
| | | positionwise_layer = Conv1dLinear |
| | | positionwise_layer_args = ( |
| | | output_size, |
| | | linear_units, |
| | | positionwise_conv_kernel_size, |
| | | dropout_rate, |
| | | ) |
| | | else: |
| | | raise NotImplementedError("Support only linear or conv1d.") |
| | | |
| | | if selfattention_layer_type == "selfattn": |
| | | encoder_selfattn_layer = MultiHeadedAttention |
| | | encoder_selfattn_layer_args = ( |
| | | attention_heads, |
| | | output_size, |
| | | attention_dropout_rate, |
| | | ) |
| | | |
| | | elif selfattention_layer_type == "sanm": |
| | | self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask |
| | | encoder_selfattn_layer_args0 = ( |
| | | attention_heads, |
| | | input_size, |
| | | output_size, |
| | | attention_dropout_rate, |
| | | kernel_size, |
| | | sanm_shfit, |
| | | ) |
| | | |
| | | encoder_selfattn_layer_args = ( |
| | | attention_heads, |
| | | output_size, |
| | | output_size, |
| | | attention_dropout_rate, |
| | | kernel_size, |
| | | sanm_shfit, |
| | | ) |
| | | |
| | | self.encoders0 = repeat( |
| | | 1, |
| | | lambda lnum: EncoderLayerSANM( |
| | | input_size, |
| | | output_size, |
| | | self.encoder_selfattn_layer(*encoder_selfattn_layer_args0), |
| | | positionwise_layer(*positionwise_layer_args), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ), |
| | | ) |
| | | |
| | | self.encoders = repeat( |
| | | num_blocks-1, |
| | | lambda lnum: EncoderLayerSANM( |
| | | output_size, |
| | | output_size, |
| | | self.encoder_selfattn_layer(*encoder_selfattn_layer_args), |
| | | positionwise_layer(*positionwise_layer_args), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ), |
| | | ) |
| | | if self.normalize_before: |
| | | self.after_norm = LayerNorm(output_size) |
| | | |
| | | self.interctc_layer_idx = interctc_layer_idx |
| | | if len(interctc_layer_idx) > 0: |
| | | assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks |
| | | self.interctc_use_conditioning = interctc_use_conditioning |
| | | self.conditioning_layer = None |
| | | self.dropout = nn.Dropout(dropout_rate) |
| | | |
| | | def output_size(self) -> int: |
| | | return self._output_size |
| | | |
| | | def forward( |
| | | self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | vad_indexes: torch.Tensor, |
| | | prev_states: torch.Tensor = None, |
| | | ctc: CTC = None, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
| | | """Embed positions in tensor. |
| | | |
| | | Args: |
| | | xs_pad: input tensor (B, L, D) |
| | | ilens: input length (B) |
| | | prev_states: Not to be used now. |
| | | Returns: |
| | | position embedded tensor and mask |
| | | """ |
| | | masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) |
| | | sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0) |
| | | no_future_masks = masks & sub_masks |
| | | xs_pad *= self.output_size()**0.5 |
| | | if self.embed is None: |
| | | xs_pad = xs_pad |
| | | elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2) |
| | | or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)): |
| | | short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) |
| | | if short_status: |
| | | raise TooShortUttError( |
| | | f"has {xs_pad.size(1)} frames and is too short for subsampling " + |
| | | f"(it needs more than {limit_size} frames), return empty results", |
| | | xs_pad.size(1), |
| | | limit_size, |
| | | ) |
| | | xs_pad, masks = self.embed(xs_pad, masks) |
| | | else: |
| | | xs_pad = self.embed(xs_pad) |
| | | |
| | | # xs_pad = self.dropout(xs_pad) |
| | | mask_tup0 = [masks, no_future_masks] |
| | | encoder_outs = self.encoders0(xs_pad, mask_tup0) |
| | | xs_pad, _ = encoder_outs[0], encoder_outs[1] |
| | | intermediate_outs = [] |
| | | #if len(self.interctc_layer_idx) == 0: |
| | | if False: |
| | | # Here, we should not use the repeat operation to do it for all layers. |
| | | encoder_outs = self.encoders(xs_pad, masks) |
| | | xs_pad, masks = encoder_outs[0], encoder_outs[1] |
| | | else: |
| | | for layer_idx, encoder_layer in enumerate(self.encoders): |
| | | if layer_idx + 1 == len(self.encoders): |
| | | # This is last layer. |
| | | coner_mask = torch.ones(masks.size(0), |
| | | masks.size(-1), |
| | | masks.size(-1), |
| | | device=xs_pad.device, |
| | | dtype=torch.bool) |
| | | for word_index, length in enumerate(ilens): |
| | | coner_mask[word_index, :, :] = vad_mask(masks.size(-1), |
| | | vad_indexes[word_index], |
| | | device=xs_pad.device) |
| | | layer_mask = masks & coner_mask |
| | | else: |
| | | layer_mask = no_future_masks |
| | | mask_tup1 = [masks, layer_mask] |
| | | encoder_outs = encoder_layer(xs_pad, mask_tup1) |
| | | xs_pad, layer_mask = encoder_outs[0], encoder_outs[1] |
| | | |
| | | if layer_idx + 1 in self.interctc_layer_idx: |
| | | encoder_out = xs_pad |
| | | |
| | | # intermediate outputs are also normalized |
| | | if self.normalize_before: |
| | | encoder_out = self.after_norm(encoder_out) |
| | | |
| | | intermediate_outs.append((layer_idx + 1, encoder_out)) |
| | | |
| | | if self.interctc_use_conditioning: |
| | | ctc_out = ctc.softmax(encoder_out) |
| | | xs_pad = xs_pad + self.conditioning_layer(ctc_out) |
| | | |
| | | if self.normalize_before: |
| | | xs_pad = self.after_norm(xs_pad) |
| | | |
| | | olens = masks.squeeze(1).sum(1) |
| | | if len(intermediate_outs) > 0: |
| | | return (xs_pad, intermediate_outs), olens, None |
| | | return xs_pad, olens, None |
| | | |
| | |
| | | from funasr.modules.embedding import PositionalEncoding |
| | | from funasr.modules.embedding import SinusoidalPositionEncoder |
| | | #from funasr.models.encoder.transformer_encoder import TransformerEncoder as Encoder |
| | | from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder |
| | | from funasr.punctuation.sanm_encoder import SANMEncoder as Encoder |
| | | #from funasr.modules.mask import subsequent_n_mask |
| | | from funasr.punctuation.abs_model import AbsPunctuation |
| | | |
| | |
| | | y = self.decoder(h) |
| | | return y, None |
| | | |
| | | def with_vad(self): |
| | | return False |
| | | |
| | | def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]: |
| | | """Score new token. |
| | | |
| New file |
| | |
| | | from typing import Any |
| | | from typing import List |
| | | from typing import Tuple |
| | | |
| | | import torch |
| | | import torch.nn as nn |
| | | |
| | | from funasr.modules.embedding import SinusoidalPositionEncoder |
| | | from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder |
| | | from funasr.punctuation.abs_model import AbsPunctuation |
| | | |
| | | |
| | | class VadRealtimeTransformer(AbsPunctuation): |
| | | |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | | punc_size: int, |
| | | pos_enc: str = None, |
| | | embed_unit: int = 128, |
| | | att_unit: int = 256, |
| | | head: int = 2, |
| | | unit: int = 1024, |
| | | layer: int = 4, |
| | | dropout_rate: float = 0.5, |
| | | kernel_size: int = 11, |
| | | sanm_shfit: int = 0, |
| | | ): |
| | | super().__init__() |
| | | if pos_enc == "sinusoidal": |
| | | # pos_enc_class = PositionalEncoding |
| | | pos_enc_class = SinusoidalPositionEncoder |
| | | elif pos_enc is None: |
| | | |
| | | def pos_enc_class(*args, **kwargs): |
| | | return nn.Sequential() # indentity |
| | | |
| | | else: |
| | | raise ValueError(f"unknown pos-enc option: {pos_enc}") |
| | | |
| | | self.embed = nn.Embedding(vocab_size, embed_unit) |
| | | self.encoder = Encoder( |
| | | input_size=embed_unit, |
| | | output_size=att_unit, |
| | | attention_heads=head, |
| | | linear_units=unit, |
| | | num_blocks=layer, |
| | | dropout_rate=dropout_rate, |
| | | input_layer="pe", |
| | | # pos_enc_class=pos_enc_class, |
| | | padding_idx=0, |
| | | kernel_size=kernel_size, |
| | | sanm_shfit=sanm_shfit, |
| | | ) |
| | | self.decoder = nn.Linear(att_unit, punc_size) |
| | | |
| | | |
| | | # def _target_mask(self, ys_in_pad): |
| | | # ys_mask = ys_in_pad != 0 |
| | | # m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0) |
| | | # return ys_mask.unsqueeze(-2) & m |
| | | |
| | | def forward(self, input: torch.Tensor, text_lengths: torch.Tensor, |
| | | vad_indexes: torch.Tensor) -> Tuple[torch.Tensor, None]: |
| | | """Compute loss value from buffer sequences. |
| | | |
| | | Args: |
| | | input (torch.Tensor): Input ids. (batch, len) |
| | | hidden (torch.Tensor): Target ids. (batch, len) |
| | | |
| | | """ |
| | | x = self.embed(input) |
| | | # mask = self._target_mask(input) |
| | | h, _, _ = self.encoder(x, text_lengths, vad_indexes) |
| | | y = self.decoder(h) |
| | | return y, None |
| | | |
| | | def with_vad(self): |
| | | return True |
| | | |
| | | def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]: |
| | | """Score new token. |
| | | |
| | | Args: |
| | | y (torch.Tensor): 1D torch.int64 prefix tokens. |
| | | state: Scorer state for prefix tokens |
| | | x (torch.Tensor): encoder feature that generates ys. |
| | | |
| | | Returns: |
| | | tuple[torch.Tensor, Any]: Tuple of |
| | | torch.float32 scores for next token (vocab_size) |
| | | and next state for ys |
| | | |
| | | """ |
| | | y = y.unsqueeze(0) |
| | | h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state) |
| | | h = self.decoder(h[:, -1]) |
| | | logp = h.log_softmax(dim=-1).squeeze(0) |
| | | return logp, cache |
| | | |
| | | def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]: |
| | | """Score new token batch. |
| | | |
| | | Args: |
| | | ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). |
| | | states (List[Any]): Scorer states for prefix tokens. |
| | | xs (torch.Tensor): |
| | | The encoder feature that generates ys (n_batch, xlen, n_feat). |
| | | |
| | | Returns: |
| | | tuple[torch.Tensor, List[Any]]: Tuple of |
| | | batchfied scores for next token with shape of `(n_batch, vocab_size)` |
| | | and next state list for ys. |
| | | |
| | | """ |
| | | # merge states |
| | | n_batch = len(ys) |
| | | n_layers = len(self.encoder.encoders) |
| | | if states[0] is None: |
| | | batch_state = None |
| | | else: |
| | | # transpose state of [batch, layer] into [layer, batch] |
| | | batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)] |
| | | |
| | | # batch decoding |
| | | h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state) |
| | | h = self.decoder(h[:, -1]) |
| | | logp = h.log_softmax(dim=-1) |
| | | |
| | | # transpose state of [layer, batch] into [batch, layer] |
| | | state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] |
| | | return logp, state_list |
| | |
| | | train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf, |
| | | seg_dict_file=args.seg_dict_file if hasattr(args, |
| | | "seg_dict_file") else None, |
| | | punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None, |
| | | mode="train") |
| | | valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf, |
| | | seg_dict_file=args.seg_dict_file if hasattr(args, |
| | | "seg_dict_file") else None, |
| | | punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None, |
| | | mode="eval") |
| | | elif args.dataset_type == "small": |
| | | train_iter_factory = cls.build_iter_factory( |
| | |
| | | from typeguard import check_return_type |
| | | |
| | | from funasr.datasets.collate_fn import CommonCollateFn |
| | | from funasr.datasets.preprocessor import MutliTokenizerCommonPreprocessor |
| | | from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor |
| | | from funasr.punctuation.abs_model import AbsPunctuation |
| | | from funasr.punctuation.espnet_model import ESPnetPunctuationModel |
| | | from funasr.punctuation.target_delay_transformer import TargetDelayTransformer |
| | | from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer |
| | | from funasr.tasks.abs_task import AbsTask |
| | | from funasr.text.phoneme_tokenizer import g2p_choices |
| | | from funasr.torch_utils.initialize import initialize |
| | |
| | | |
| | | punc_choices = ClassChoices( |
| | | "punctuation", |
| | | classes=dict( |
| | | target_delay=TargetDelayTransformer, |
| | | ), |
| | | classes=dict(target_delay=TargetDelayTransformer, vad_realtime=VadRealtimeTransformer), |
| | | type_check=AbsPunctuation, |
| | | default="TargetDelayTransformer", |
| | | default="target_delay", |
| | | ) |
| | | |
| | | |
| | |
| | | # NOTE(kamo): add_arguments(..., required=True) can't be used |
| | | # to provide --print_config mode. Instead of it, do as |
| | | required = parser.get_default("required") |
| | | #import pdb;pdb.set_trace() |
| | | #required += ["token_list"] |
| | | |
| | | group.add_argument( |
| | | "--token_list", |
| | |
| | | bpemodels = [args.bpemodel, args.bpemodel] |
| | | text_names = ["text", "punc"] |
| | | if args.use_preprocessor: |
| | | retval = MutliTokenizerCommonPreprocessor( |
| | | retval = PuncTrainTokenizerCommonPreprocessor( |
| | | train=train, |
| | | token_type=token_types, |
| | | token_list=token_lists, |
| | |
| | | def optional_data_names( |
| | | cls, train: bool = True, inference: bool = False |
| | | ) -> Tuple[str, ...]: |
| | | retval = () |
| | | retval = ("vad",) |
| | | return retval |
| | | |
| | | @classmethod |
| | |
| | | args.token_list = token_list.copy() |
| | | if isinstance(args.punc_list, str): |
| | | with open(args.punc_list, encoding="utf-8") as f2: |
| | | punc_list = [line.rstrip() for line in f2] |
| | | pairs = [line.rstrip().split(":") for line in f2] |
| | | punc_list = [pair[0] for pair in pairs] |
| | | punc_weight_list = [float(pair[1]) for pair in pairs] |
| | | args.punc_list = punc_list.copy() |
| | | elif isinstance(args.punc_list, list): |
| | | # This is in the inference code path. |
| | | punc_list = args.punc_list.copy() |
| | | punc_weight_list = [1] * len(punc_list) |
| | | if isinstance(args.token_list, (tuple, list)): |
| | | token_list = args.token_list.copy() |
| | | else: |
| | |
| | | |
| | | # 2. Build ESPnetModel |
| | | # Assume the last-id is sos_and_eos |
| | | model = ESPnetPunctuationModel(punc_model=punc, vocab_size=vocab_size, **args.model_conf) |
| | | if "punc_weight" in args.model_conf: |
| | | args.model_conf.pop("punc_weight") |
| | | model = ESPnetPunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf) |
| | | |
| | | # FIXME(kamo): Should be done in model? |
| | | # 3. Initialize |