Revert "Dev gzf funasr2" (#1164)
17个文件已修改
1 文件已重命名
8个文件已删除
| | |
| | | |
| | | class BatchSampler(torch.utils.data.BatchSampler): |
| | | |
| | | def __init__(self, dataset, batch_type: str="example", batch_size: int=100, sort_size: int=30, drop_last: bool=False, shuffle: bool=True, **kwargs): |
| | | def __init__(self, dataset, batch_size_type: str="example", batch_size: int=100, sort_size: int=30, drop_last: bool=False, shuffle: bool=True, **kwargs): |
| | | |
| | | self.drop_last = drop_last |
| | | self.pre_idx = -1 |
| | | self.dataset = dataset |
| | | self.total_samples = len(dataset) |
| | | # self.batch_type = args.batch_type |
| | | # self.batch_size_type = args.batch_size_type |
| | | # self.batch_size = args.batch_size |
| | | # self.sort_size = args.sort_size |
| | | # self.max_length_token = args.max_length_token |
| | | self.batch_type = batch_type |
| | | self.batch_size_type = batch_size_type |
| | | self.batch_size = batch_size |
| | | self.sort_size = sort_size |
| | | self.max_length_token = kwargs.get("max_length_token", 5000) |
| | |
| | | return self.total_samples |
| | | |
| | | def __iter__(self): |
| | | # print("in sampler") |
| | | print("in sampler") |
| | | |
| | | if self.shuffle: |
| | | np.random.shuffle(self.shuffle_idx) |
| | |
| | | num_sample = 0 |
| | | |
| | | iter_num = (self.total_samples-1) // self.sort_size + 1 |
| | | # print("iter_num: ", iter_num) |
| | | print("iter_num: ", iter_num) |
| | | for iter in range(self.pre_idx + 1, iter_num): |
| | | datalen_with_index = [] |
| | | for i in range(self.sort_size): |
| | |
| | | |
| | | idx_map = self.shuffle_idx[idx] |
| | | # prompt = self.dataset.indexed_dataset[idx_map]["prompt"] |
| | | sample_len_cur = self.dataset.indexed_dataset.get_source_len(self.dataset.indexed_dataset[idx_map]) + \ |
| | | self.dataset.indexed_dataset.get_target_len(self.dataset.indexed_dataset[idx_map]) |
| | | sample_len_cur = self.dataset.indexed_dataset[idx_map]["source_len"] + \ |
| | | self.dataset.indexed_dataset[idx_map]["target_len"] |
| | | |
| | | datalen_with_index.append([idx, sample_len_cur]) |
| | | |
| | |
| | | |
| | | max_token_cur = max(max_token, sample_len_cur_raw) |
| | | max_token_padding = 1 + num_sample |
| | | if self.batch_type == 'token': |
| | | if self.batch_size_type == 'token': |
| | | max_token_padding *= max_token_cur |
| | | if max_token_padding <= self.batch_size: |
| | | batch.append(idx) |
| | |
| | | batch_sampler = BatchSampler(dataset) |
| | | |
| | | |
| | | def collator(samples: list = None): |
| | | return samples |
| | | |
| | | if __name__ == "__main__": |
| | | |
| | | dataloader_tr = torch.utils.data.DataLoader(dataset, |
| | | collate_fn=dataset.collator, |
| | | batch_sampler=batch_sampler, |
| | | shuffle=False, |
| | | num_workers=0, |
| | | num_workers=8, |
| | | pin_memory=True) |
| | | |
| | | print(len(dataset)) |
| | |
| | | |
| | | def __getitem__(self, index): |
| | | return self.contents[index] |
| | | |
| | | def get_source_len(self, data_dict): |
| | | return data_dict["source_len"] |
| | | |
| | | def get_target_len(self, data_dict): |
| | | |
| | | return data_dict["target_len"] if "target_len" in data_dict else 0 |
| | | |
| | | |
| | | class AudioDataset(torch.utils.data.Dataset): |
| | | def __init__(self, path, frontend=None, tokenizer=None, int_pad_value: int = -1, float_pad_value: float = 0.0, **kwargs): |
| | | def __init__(self, path, frontend=None, tokenizer=None, token_id_converter=None): |
| | | |
| | | super().__init__() |
| | | self.indexed_dataset = IndexedDatasetJsonl(path) |
| | | self.frontend = frontend.forward |
| | | self.fs = 16000 if frontend is None else frontend.fs |
| | | self.data_type = "sound" |
| | | self.tokenizer = tokenizer |
| | | self.token_id_converter = token_id_converter |
| | | |
| | | self.int_pad_value = int_pad_value |
| | | self.float_pad_value = float_pad_value |
| | | self.int_pad_value = -1 |
| | | self.float_pad_value = 0.0 |
| | | |
| | | |
| | | |
| | |
| | | data_src = load_audio(source, fs=self.fs) |
| | | speech, speech_lengths = extract_features(data_src, self.data_type, self.frontend) |
| | | target = item["target"] |
| | | ids = self.tokenizer.encode(target) |
| | | text = self.tokenizer.text2tokens(target) |
| | | ids = self.token_id_converter.tokens2ids(text) |
| | | ids_lengths = len(ids) |
| | | text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32) |
| | | |
| | |
| | | tokens = seg_tokenize(tokens, self.seg_dict) |
| | | else: |
| | | tokens = self.tokenizer.text2tokens(text) |
| | | |
| | | text_ints = self.token_id_converter.tokens2ids(tokens) |
| | | data[self.text_name] = np.array(text_ints, dtype=np.int64) |
| | | return data |
| | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | if self.length_normalized_loss: |
| | | batch_size = int((text_lengths + 1).sum()) |
| | | |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | return loss, stats, weight |
| | | |
| | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | if self.length_normalized_loss: |
| | | batch_size = int((text_lengths + self.predictor_bias).sum()) |
| | | |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | return loss, stats, weight |
| | | |
| | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | if self.length_normalized_loss: |
| | | batch_size = int((text_lengths + self.predictor_bias).sum()) |
| | | |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | return loss, stats, weight |
| | | |
| | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | if self.length_normalized_loss: |
| | | batch_size = int((text_lengths + self.predictor_bias).sum()) |
| | | |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | return loss, stats, weight |
| | | |
| | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | if self.length_normalized_loss: |
| | | batch_size = int((text_lengths + self.predictor_bias).sum()) |
| | | |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | return loss, stats, weight |
| | | |
| | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | if self.length_normalized_loss: |
| | | batch_size = int((text_lengths + self.predictor_bias).sum()) |
| | | |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | return loss, stats, weight |
| | | |
| | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | if self.length_normalized_loss: |
| | | batch_size = int((text_lengths + self.predictor_bias).sum()) |
| | | |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | return loss, stats, weight |
| | | |
| | |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | return var_dict_torch_update |
| | | return var_dict_torch_update |
| | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | if self.length_normalized_loss: |
| | | batch_size = int((text_lengths + 1).sum()) |
| | | |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | return loss, stats, weight |
| | | |
| | |
| | | |
| | | Args: |
| | | pad_outputs (Tensor): Prediction tensors (B * Lmax, D). |
| | | pad_targets (LongTensor): Target label tensors (B, Lmax). |
| | | pad_targets (LongTensor): Target label tensors (B, Lmax, D). |
| | | ignore_label (int): Ignore label id. |
| | | |
| | | Returns: |
| | |
| | | import torch |
| | | from funasr.optimizers.fairseq_adam import FairseqAdam |
| | | from funasr.optimizers.sgd import SGD |
| | | |
| | | optim_choices = dict( |
| | | adam=torch.optim.Adam, |
| | | fairseq_adam=FairseqAdam, |
| | | adamw=torch.optim.AdamW, |
| | | sgd=SGD, |
| | | adadelta=torch.optim.Adadelta, |
| | | adagrad=torch.optim.Adagrad, |
| | | adamax=torch.optim.Adamax, |
| | | asgd=torch.optim.ASGD, |
| | | lbfgs=torch.optim.LBFGS, |
| | | rmsprop=torch.optim.RMSprop, |
| | | rprop=torch.optim.Rprop, |
| | | ) |
| | |
| | | import torch |
| | | import torch.multiprocessing |
| | | import torch.nn |
| | | import torch.optim |
| | | |
| | | from funasr.schedulers.noam_lr import NoamLR |
| | | from funasr.schedulers.tri_stage_scheduler import TriStageLR |
| | | from funasr.schedulers.warmup_lr import WarmupLR |
| | | |
| | | scheduler_choices = dict( |
| | | ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau, |
| | | lambdalr=torch.optim.lr_scheduler.LambdaLR, |
| | | steplr=torch.optim.lr_scheduler.StepLR, |
| | | multisteplr=torch.optim.lr_scheduler.MultiStepLR, |
| | | exponentiallr=torch.optim.lr_scheduler.ExponentialLR, |
| | | CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR, |
| | | noamlr=NoamLR, |
| | | warmuplr=WarmupLR, |
| | | tri_stage=TriStageLR, |
| | | cycliclr=torch.optim.lr_scheduler.CyclicLR, |
| | | onecyclelr=torch.optim.lr_scheduler.OneCycleLR, |
| | | CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts, |
| | | ) |
| | |
| | | from abc import abstractmethod |
| | | from typing import Iterable |
| | | from typing import List |
| | | from pathlib import Path |
| | | from typing import Dict |
| | | from typing import Iterable |
| | | from typing import List |
| | | from typing import Union |
| | | |
| | | import numpy as np |
| | | |
| | | class AbsTokenizer(ABC): |
| | | @abstractmethod |
| | | def text2tokens(self, line: str) -> List[str]: |
| | | raise NotImplementedError |
| | | |
| | | @abstractmethod |
| | | def tokens2text(self, tokens: Iterable[str]) -> str: |
| | | raise NotImplementedError |
| | | |
| | | |
| | | class BaseTokenizer(ABC): |
| | | def __init__(self, token_list: Union[Path, str, Iterable[str]]=None, |
| | | unk_symbol: str = "<unk>", |
| | | **kwargs, |
| | | ): |
| | | |
| | | if token_list is not None: |
| | | if isinstance(token_list, (Path, str)): |
| | | token_list = Path(token_list) |
| | | self.token_list_repr = str(token_list) |
| | | self.token_list: List[str] = [] |
| | | |
| | | with token_list.open("r", encoding="utf-8") as f: |
| | | for idx, line in enumerate(f): |
| | | line = line.rstrip() |
| | | self.token_list.append(line) |
| | | |
| | | else: |
| | | self.token_list: List[str] = list(token_list) |
| | | self.token_list_repr = "" |
| | | for i, t in enumerate(self.token_list): |
| | | if i == 3: |
| | | break |
| | | self.token_list_repr += f"{t}, " |
| | | self.token_list_repr += f"... (NVocab={(len(self.token_list))})" |
| | | |
| | | self.token2id: Dict[str, int] = {} |
| | | for i, t in enumerate(self.token_list): |
| | | if t in self.token2id: |
| | | raise RuntimeError(f'Symbol "{t}" is duplicated') |
| | | self.token2id[t] = i |
| | | |
| | | self.unk_symbol = unk_symbol |
| | | if self.unk_symbol not in self.token2id: |
| | | raise RuntimeError( |
| | | f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list" |
| | | ) |
| | | self.unk_id = self.token2id[self.unk_symbol] |
| | | |
| | | def encode(self, text): |
| | | tokens = self.text2tokens(text) |
| | | text_ints = self.tokens2ids(tokens) |
| | | |
| | | return text_ints |
| | | |
| | | def decode(self, text_ints): |
| | | return self.ids2tokens(text_ints) |
| | | |
| | | def get_num_vocabulary_size(self) -> int: |
| | | return len(self.token_list) |
| | | |
| | | def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]: |
| | | if isinstance(integers, np.ndarray) and integers.ndim != 1: |
| | | raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}") |
| | | return [self.token_list[i] for i in integers] |
| | | |
| | | def tokens2ids(self, tokens: Iterable[str]) -> List[int]: |
| | | return [self.token2id.get(i, self.unk_id) for i in tokens] |
| | | |
| | | @abstractmethod |
| | | def text2tokens(self, line: str) -> List[str]: |
| | | raise NotImplementedError |
| | | |
| | | @abstractmethod |
| | | def tokens2text(self, tokens: Iterable[str]) -> str: |
| | | raise NotImplementedError |
| | |
| | | from pathlib import Path |
| | | from typing import Iterable |
| | | from typing import Union |
| | | from abc import ABC |
| | | from abc import abstractmethod |
| | | from typing import Iterable |
| | | from typing import List |
| | | from pathlib import Path |
| | | from typing import Dict |
| | | from typing import Iterable |
| | | from typing import List |
| | | from typing import Union |
| | | |
| | | import numpy as np |
| | | |
| | | from funasr.tokenizer.abs_tokenizer import AbsTokenizer |
| | | from funasr.tokenizer.char_tokenizer import CharTokenizer |
| | |
| | | space_symbol: str = "<space>", |
| | | delimiter: str = None, |
| | | g2p_type: str = None, |
| | | **kwargs, |
| | | ): |
| | | ) -> AbsTokenizer: |
| | | """A helper function to instantiate Tokenizer""" |
| | | if token_type == "bpe": |
| | | if bpemodel is None: |
| | |
| | | raise RuntimeError( |
| | | "remove_non_linguistic_symbols is not implemented for token_type=bpe" |
| | | ) |
| | | return SentencepiecesTokenizer(bpemodel, **kwargs) |
| | | return SentencepiecesTokenizer(bpemodel) |
| | | |
| | | elif token_type == "word": |
| | | if remove_non_linguistic_symbols and non_linguistic_symbols is not None: |
| | |
| | | remove_non_linguistic_symbols=True, |
| | | ) |
| | | else: |
| | | return WordTokenizer(delimiter=delimiter, **kwargs) |
| | | return WordTokenizer(delimiter=delimiter) |
| | | |
| | | elif token_type == "char": |
| | | return CharTokenizer( |
| | | non_linguistic_symbols=non_linguistic_symbols, |
| | | space_symbol=space_symbol, |
| | | remove_non_linguistic_symbols=remove_non_linguistic_symbols, |
| | | **kwargs |
| | | ) |
| | | |
| | | elif token_type == "phn": |
| | |
| | | non_linguistic_symbols=non_linguistic_symbols, |
| | | space_symbol=space_symbol, |
| | | remove_non_linguistic_symbols=remove_non_linguistic_symbols, |
| | | **kwargs |
| | | ) |
| | | |
| | | else: |
| | |
| | | |
| | | |
| | | from funasr.tokenizer.abs_tokenizer import AbsTokenizer |
| | | from funasr.tokenizer.abs_tokenizer import BaseTokenizer |
| | | |
| | | class CharTokenizer(BaseTokenizer): |
| | | |
| | | class CharTokenizer(AbsTokenizer): |
| | | def __init__( |
| | | self, |
| | | non_linguistic_symbols: Union[Path, str, Iterable[str]] = None, |
| | | space_symbol: str = "<space>", |
| | | remove_non_linguistic_symbols: bool = False, |
| | | **kwargs, |
| | | ): |
| | | super().__init__(**kwargs) |
| | | self.space_symbol = space_symbol |
| | | if non_linguistic_symbols is None: |
| | | self.non_linguistic_symbols = set() |
| | |
| | | non_linguistic_symbols: Union[Path, str, Iterable[str]] = None, |
| | | space_symbol: str = "<space>", |
| | | remove_non_linguistic_symbols: bool = False, |
| | | **kwargs, |
| | | ): |
| | | if g2p_type is None: |
| | | self.g2p = split_by_space |
| | |
| | | |
| | | |
| | | class SentencepiecesTokenizer(AbsTokenizer): |
| | | def __init__(self, model: Union[Path, str], **kwargs): |
| | | def __init__(self, model: Union[Path, str]): |
| | | self.model = str(model) |
| | | # NOTE(kamo): |
| | | # Don't build SentencePieceProcessor in __init__() |
| | |
| | | delimiter: str = None, |
| | | non_linguistic_symbols: Union[Path, str, Iterable[str]] = None, |
| | | remove_non_linguistic_symbols: bool = False, |
| | | **kwargs, |
| | | ): |
| | | self.delimiter = delimiter |
| | | |