zhifu gao
2023-12-11 172e7ac986f299ad545cbd91a8cecc3ef967af36
Revert "Dev gzf funasr2" (#1164)

17个文件已修改
1 文件已重命名
8个文件已删除
1608 ■■■■■ 已修改文件
funasr/bin/asr_trainer.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/cli/model_class_factory.py 298 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/cli/models/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/cli/models/paraformer.py 652 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/cli/train_cli.py 163 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/cli/trainer.py 199 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/data_sampler.py 16 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/dataloader_fn.py 5 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/dataset_jsonl.py 18 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/small_datasets/preprocessor.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_asr.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_asr_contextual_paraformer.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_asr_paraformer.py 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_uni_asr.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/nets_utils.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/optimizers/__init__.py 17 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/schedulers/__init__.py 23 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tokenizer/abs_tokenizer.py 74 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tokenizer/build_tokenizer.py 19 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tokenizer/char_tokenizer.py 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tokenizer/funtoken.py 75 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tokenizer/phoneme_tokenizer.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tokenizer/sentencepiece_tokenizer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tokenizer/word_tokenizer.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/dynamic_import.py 13 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/load_fr_py.py 13 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_trainer.py
funasr/cli/model_class_factory.py
File was deleted
funasr/cli/models/__init__.py
funasr/cli/models/paraformer.py
File was deleted
funasr/cli/train_cli.py
File was deleted
funasr/cli/trainer.py
File was deleted
funasr/datasets/data_sampler.py
@@ -4,17 +4,17 @@
class BatchSampler(torch.utils.data.BatchSampler):
    
    def __init__(self, dataset, batch_type: str="example", batch_size: int=100, sort_size: int=30, drop_last: bool=False, shuffle: bool=True, **kwargs):
    def __init__(self, dataset, batch_size_type: str="example", batch_size: int=100, sort_size: int=30, drop_last: bool=False, shuffle: bool=True, **kwargs):
        
        self.drop_last = drop_last
        self.pre_idx = -1
        self.dataset = dataset
        self.total_samples = len(dataset)
        # self.batch_type = args.batch_type
        # self.batch_size_type = args.batch_size_type
        # self.batch_size = args.batch_size
        # self.sort_size = args.sort_size
        # self.max_length_token = args.max_length_token
        self.batch_type = batch_type
        self.batch_size_type = batch_size_type
        self.batch_size = batch_size
        self.sort_size = sort_size
        self.max_length_token = kwargs.get("max_length_token", 5000)
@@ -26,7 +26,7 @@
        return self.total_samples
    def __iter__(self):
        # print("in sampler")
        print("in sampler")
        
        if self.shuffle:
            np.random.shuffle(self.shuffle_idx)
@@ -36,7 +36,7 @@
        num_sample = 0
        iter_num = (self.total_samples-1) // self.sort_size + 1
        # print("iter_num: ", iter_num)
        print("iter_num: ", iter_num)
        for iter in range(self.pre_idx + 1, iter_num):
            datalen_with_index = []
            for i in range(self.sort_size):
@@ -46,8 +46,8 @@
                idx_map = self.shuffle_idx[idx]
                # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
                sample_len_cur = self.dataset.indexed_dataset.get_source_len(self.dataset.indexed_dataset[idx_map]) + \
                                 self.dataset.indexed_dataset.get_target_len(self.dataset.indexed_dataset[idx_map])
                sample_len_cur = self.dataset.indexed_dataset[idx_map]["source_len"] + \
                                 self.dataset.indexed_dataset[idx_map]["target_len"]
                datalen_with_index.append([idx, sample_len_cur])
            
@@ -59,7 +59,7 @@
                max_token_cur = max(max_token, sample_len_cur_raw)
                max_token_padding = 1 + num_sample
                if self.batch_type == 'token':
                if self.batch_size_type == 'token':
                    max_token_padding *= max_token_cur
                if max_token_padding <= self.batch_size:
                    batch.append(idx)
funasr/datasets/dataloader_fn.py
@@ -38,13 +38,16 @@
batch_sampler = BatchSampler(dataset)
def collator(samples: list = None):
    return samples
if __name__ == "__main__":
    
    dataloader_tr = torch.utils.data.DataLoader(dataset,
                                                collate_fn=dataset.collator,
                                                batch_sampler=batch_sampler,
                                                shuffle=False,
                                                num_workers=0,
                                                num_workers=8,
                                                pin_memory=True)
    
    print(len(dataset))
funasr/datasets/dataset_jsonl.py
@@ -78,26 +78,21 @@
    
    def __getitem__(self, index):
        return self.contents[index]
    def get_source_len(self, data_dict):
        return data_dict["source_len"]
    def get_target_len(self, data_dict):
        return data_dict["target_len"] if "target_len" in data_dict else 0
class AudioDataset(torch.utils.data.Dataset):
    def __init__(self, path, frontend=None, tokenizer=None, int_pad_value: int = -1, float_pad_value: float = 0.0, **kwargs):
    def __init__(self, path, frontend=None, tokenizer=None, token_id_converter=None):
        super().__init__()
        self.indexed_dataset = IndexedDatasetJsonl(path)
        self.frontend = frontend.forward
        self.fs = 16000 if frontend is None else frontend.fs
        self.data_type = "sound"
        self.tokenizer = tokenizer
        self.token_id_converter = token_id_converter
        self.int_pad_value = int_pad_value
        self.float_pad_value = float_pad_value
        self.int_pad_value = -1
        self.float_pad_value = 0.0
    
@@ -113,7 +108,8 @@
        data_src = load_audio(source, fs=self.fs)
        speech, speech_lengths = extract_features(data_src, self.data_type, self.frontend)
        target = item["target"]
        ids = self.tokenizer.encode(target)
        text = self.tokenizer.text2tokens(target)
        ids = self.token_id_converter.tokens2ids(text)
        ids_lengths = len(ids)
        text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)
funasr/datasets/small_datasets/preprocessor.py
@@ -361,7 +361,6 @@
                    tokens = seg_tokenize(tokens, self.seg_dict)
            else:
                tokens = self.tokenizer.text2tokens(text)
            text_ints = self.token_id_converter.tokens2ids(tokens)
            data[self.text_name] = np.array(text_ints, dtype=np.int64)
        return data
funasr/models/e2e_asr.py
@@ -223,7 +223,6 @@
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
            batch_size = int((text_lengths + 1).sum())
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
funasr/models/e2e_asr_contextual_paraformer.py
@@ -234,7 +234,6 @@
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
            batch_size = int((text_lengths + self.predictor_bias).sum())
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
    
funasr/models/e2e_asr_paraformer.py
@@ -256,7 +256,6 @@
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
            batch_size = int((text_lengths + self.predictor_bias).sum())
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
@@ -869,7 +868,6 @@
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
            batch_size = int((text_lengths + self.predictor_bias).sum())
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
@@ -1497,7 +1495,6 @@
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
            batch_size = int((text_lengths + self.predictor_bias).sum())
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
@@ -1769,7 +1766,6 @@
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
            batch_size = int((text_lengths + self.predictor_bias).sum())
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
@@ -1972,7 +1968,6 @@
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
            batch_size = int((text_lengths + self.predictor_bias).sum())
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
@@ -2267,4 +2262,4 @@
                    "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
                                                                                  var_dict_tf[name_tf].shape))
        return var_dict_torch_update
        return var_dict_torch_update
funasr/models/e2e_uni_asr.py
@@ -443,7 +443,6 @@
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
            batch_size = int((text_lengths + 1).sum())
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
funasr/modules/nets_utils.py
@@ -347,7 +347,7 @@
    Args:
        pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
        pad_targets (LongTensor): Target label tensors (B, Lmax).
        pad_targets (LongTensor): Target label tensors (B, Lmax, D).
        ignore_label (int): Ignore label id.
    Returns:
funasr/optimizers/__init__.py
@@ -1,17 +0,0 @@
import torch
from funasr.optimizers.fairseq_adam import FairseqAdam
from funasr.optimizers.sgd import SGD
optim_choices = dict(
    adam=torch.optim.Adam,
    fairseq_adam=FairseqAdam,
    adamw=torch.optim.AdamW,
    sgd=SGD,
    adadelta=torch.optim.Adadelta,
    adagrad=torch.optim.Adagrad,
    adamax=torch.optim.Adamax,
    asgd=torch.optim.ASGD,
    lbfgs=torch.optim.LBFGS,
    rmsprop=torch.optim.RMSprop,
    rprop=torch.optim.Rprop,
)
funasr/schedulers/__init__.py
@@ -1,23 +0,0 @@
import torch
import torch.multiprocessing
import torch.nn
import torch.optim
from funasr.schedulers.noam_lr import NoamLR
from funasr.schedulers.tri_stage_scheduler import TriStageLR
from funasr.schedulers.warmup_lr import WarmupLR
scheduler_choices = dict(
    ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
    lambdalr=torch.optim.lr_scheduler.LambdaLR,
    steplr=torch.optim.lr_scheduler.StepLR,
    multisteplr=torch.optim.lr_scheduler.MultiStepLR,
    exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
    CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
    noamlr=NoamLR,
    warmuplr=WarmupLR,
    tri_stage=TriStageLR,
    cycliclr=torch.optim.lr_scheduler.CyclicLR,
    onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
    CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
)
funasr/tokenizer/abs_tokenizer.py
@@ -2,87 +2,13 @@
from abc import abstractmethod
from typing import Iterable
from typing import List
from pathlib import Path
from typing import Dict
from typing import Iterable
from typing import List
from typing import Union
import numpy as np
class AbsTokenizer(ABC):
    @abstractmethod
    def text2tokens(self, line: str) -> List[str]:
        raise NotImplementedError
    @abstractmethod
    def tokens2text(self, tokens: Iterable[str]) -> str:
        raise NotImplementedError
class BaseTokenizer(ABC):
    def __init__(self, token_list: Union[Path, str, Iterable[str]]=None,
                 unk_symbol: str = "<unk>",
                 **kwargs,
                 ):
        if token_list is not None:
            if isinstance(token_list, (Path, str)):
                token_list = Path(token_list)
                self.token_list_repr = str(token_list)
                self.token_list: List[str] = []
                with token_list.open("r", encoding="utf-8") as f:
                    for idx, line in enumerate(f):
                        line = line.rstrip()
                        self.token_list.append(line)
            else:
                self.token_list: List[str] = list(token_list)
                self.token_list_repr = ""
                for i, t in enumerate(self.token_list):
                    if i == 3:
                        break
                    self.token_list_repr += f"{t}, "
                self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
            self.token2id: Dict[str, int] = {}
            for i, t in enumerate(self.token_list):
                if t in self.token2id:
                    raise RuntimeError(f'Symbol "{t}" is duplicated')
                self.token2id[t] = i
            self.unk_symbol = unk_symbol
            if self.unk_symbol not in self.token2id:
                raise RuntimeError(
                    f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
                )
            self.unk_id = self.token2id[self.unk_symbol]
    def encode(self, text):
        tokens = self.text2tokens(text)
        text_ints = self.tokens2ids(tokens)
        return text_ints
    def decode(self, text_ints):
        return self.ids2tokens(text_ints)
    def get_num_vocabulary_size(self) -> int:
        return len(self.token_list)
    def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
        if isinstance(integers, np.ndarray) and integers.ndim != 1:
            raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
        return [self.token_list[i] for i in integers]
    def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
        return [self.token2id.get(i, self.unk_id) for i in tokens]
    @abstractmethod
    def text2tokens(self, line: str) -> List[str]:
        raise NotImplementedError
    @abstractmethod
    def tokens2text(self, tokens: Iterable[str]) -> str:
        raise NotImplementedError
funasr/tokenizer/build_tokenizer.py
@@ -1,17 +1,7 @@
from pathlib import Path
from typing import Iterable
from typing import Union
from abc import ABC
from abc import abstractmethod
from typing import Iterable
from typing import List
from pathlib import Path
from typing import Dict
from typing import Iterable
from typing import List
from typing import Union
import numpy as np
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
from funasr.tokenizer.char_tokenizer import CharTokenizer
@@ -28,8 +18,7 @@
    space_symbol: str = "<space>",
    delimiter: str = None,
    g2p_type: str = None,
    **kwargs,
):
) -> AbsTokenizer:
    """A helper function to instantiate Tokenizer"""
    if token_type == "bpe":
        if bpemodel is None:
@@ -39,7 +28,7 @@
            raise RuntimeError(
                "remove_non_linguistic_symbols is not implemented for token_type=bpe"
            )
        return SentencepiecesTokenizer(bpemodel, **kwargs)
        return SentencepiecesTokenizer(bpemodel)
    elif token_type == "word":
        if remove_non_linguistic_symbols and non_linguistic_symbols is not None:
@@ -49,14 +38,13 @@
                remove_non_linguistic_symbols=True,
            )
        else:
            return WordTokenizer(delimiter=delimiter, **kwargs)
            return WordTokenizer(delimiter=delimiter)
    elif token_type == "char":
        return CharTokenizer(
            non_linguistic_symbols=non_linguistic_symbols,
            space_symbol=space_symbol,
            remove_non_linguistic_symbols=remove_non_linguistic_symbols,
            **kwargs
        )
    elif token_type == "phn":
@@ -65,7 +53,6 @@
            non_linguistic_symbols=non_linguistic_symbols,
            space_symbol=space_symbol,
            remove_non_linguistic_symbols=remove_non_linguistic_symbols,
            **kwargs
        )
    else:
funasr/tokenizer/char_tokenizer.py
@@ -6,17 +6,15 @@
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
from funasr.tokenizer.abs_tokenizer import BaseTokenizer
class CharTokenizer(BaseTokenizer):
class CharTokenizer(AbsTokenizer):
    def __init__(
        self,
        non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
        space_symbol: str = "<space>",
        remove_non_linguistic_symbols: bool = False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.space_symbol = space_symbol
        if non_linguistic_symbols is None:
            self.non_linguistic_symbols = set()
funasr/tokenizer/funtoken.py
File was deleted
funasr/tokenizer/phoneme_tokenizer.py
@@ -363,7 +363,6 @@
        non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
        space_symbol: str = "<space>",
        remove_non_linguistic_symbols: bool = False,
        **kwargs,
    ):
        if g2p_type is None:
            self.g2p = split_by_space
funasr/tokenizer/sentencepiece_tokenizer.py
@@ -9,7 +9,7 @@
class SentencepiecesTokenizer(AbsTokenizer):
    def __init__(self, model: Union[Path, str], **kwargs):
    def __init__(self, model: Union[Path, str]):
        self.model = str(model)
        # NOTE(kamo):
        # Don't build SentencePieceProcessor in __init__()
funasr/tokenizer/word_tokenizer.py
@@ -14,7 +14,6 @@
        delimiter: str = None,
        non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
        remove_non_linguistic_symbols: bool = False,
        **kwargs,
    ):
        self.delimiter = delimiter
funasr/utils/dynamic_import.py
File was deleted
funasr/utils/load_fr_py.py
File was deleted