zhifu gao
2024-02-06 d92cd5ae037ae85ab9730499d99e5c1bd475eed2
Funasr1.0 (#1362)

* funasr1.0.5

* funasr1.0.5 audio samples input

* batch_type token

* batch_type token

* huggingface model zoo

* dataloader

* dataloader

* fbank input

* vad is_final=True bugfix
7个文件已修改
275 ■■■■■ 已修改文件
funasr/auto/auto_model.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/index_ds.py 54 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/samplers.py 193 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/fsmn_vad_streaming/model.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/paraformer/cif_predictor.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/paraformer/model.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/trainer.py 18 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/auto/auto_model.py
@@ -171,7 +171,7 @@
        # build model
        model_class = tables.model_classes.get(kwargs["model"])
        model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
        model.eval()
        model.to(device)
        
        # init_param
@@ -206,6 +206,7 @@
        kwargs = self.kwargs if kwargs is None else kwargs
        kwargs.update(cfg)
        model = self.model if model is None else model
        model.eval()
        batch_size = kwargs.get("batch_size", 1)
        # if kwargs.get("device", "cpu") == "cpu":
funasr/datasets/audio_datasets/index_ds.py
@@ -6,8 +6,8 @@
from funasr.register import tables
@tables.register("index_ds_classes", "IndexDSJsonl")
class IndexDSJsonl(torch.utils.data.Dataset):
@tables.register("index_ds_classes", "IndexDSJsonlRankSplit")
class IndexDSJsonlRankSplit(torch.utils.data.Dataset):
    
    def __init__(self, path):
        super().__init__()
@@ -66,3 +66,53 @@
    def get_target_len(self, data_dict):
        
        return data_dict["target_len"] if "target_len" in data_dict else 0
@tables.register("index_ds_classes", "IndexDSJsonl")
@tables.register("index_ds_classes", "IndexDSJsonlRankFull")
class IndexDSJsonlRankFull(torch.utils.data.Dataset):
    def __init__(self, path):
        super().__init__()
        contents = []
        with open(path, encoding='utf-8') as fin:
            for line in fin:
                data = json.loads(line.strip())
                if "text" in data:  # for sft
                    self.contents.append(data['text'])
                if "source" in data:  # for speech lab pretrain
                    prompt = data.get("prompt", "<ASR>")
                    source = data["source"]
                    target = data["target"]
                    source_len = data.get("source_len", 1)
                    target_len = data.get("target_len", 0)
                    contents.append({"source": source,
                                     "prompt": prompt,
                                     "target": target,
                                     "source_len": source_len,
                                     "target_len": target_len,
                                     }
                                    )
        self.contents = contents
        logging.info(
            "total_num of samplers across ranks: {}".format(len(self.contents)))
    def __len__(self):
        return len(self.contents)
    def __getitem__(self, index):
        try:
            data = self.contents[index]
        except:
            print(index)
        return data
    def get_source_len(self, data_dict):
        return data_dict.get("source_len", 1)
    def get_target_len(self, data_dict):
        return data_dict.get("target_len", 0)
funasr/datasets/audio_datasets/samplers.py
@@ -1,5 +1,7 @@
import torch
import numpy as np
import logging
import torch.distributed as dist
from funasr.register import tables
@@ -82,3 +84,194 @@
                    max_token = sample_len_cur_raw
                    num_sample = 1
@tables.register("batch_sampler_classes", "BatchSampler")
@tables.register("batch_sampler_classes", "RankFullLocalShuffleBatchSampler")
class RankFullLocalShuffleBatchSampler(torch.utils.data.BatchSampler):
    def __init__(self, dataset,
                 batch_type: str = "example",
                 batch_size: int = 100,
                 buffer_size: int = 30,
                 drop_last: bool = True,
                 shuffle: bool = True,
                 is_training: bool = True,
                 **kwargs):
        self.drop_last = drop_last
        self.pre_idx = -1
        self.dataset = dataset
        self.total_samples = len(dataset)
        self.batch_type = batch_type
        self.batch_size = int(batch_size)
        self.buffer_size = buffer_size
        self.max_token_length = kwargs.get("max_token_length", 1500)
        self.shuffle_idx = np.arange(self.total_samples)
        self.shuffle = shuffle and is_training
        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
        try:
            rank = dist.get_rank()
            world_size = dist.get_world_size()
        except:
            rank = 0
            world_size = 1
        self.rank = rank
        self.world_size = world_size
    def __len__(self):
        return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
    def set_epoch(self, epoch):
        np.random.seed(epoch)
    def __iter__(self):
        batch_size_total = self.batch_size * self.world_size
        if self.shuffle:
            np.random.shuffle(self.shuffle_idx)
        batch = []
        max_token = 0
        num_sample = 0
        iter_num = (self.total_samples - 1) // self.buffer_size + 1
        # print("iter_num: ", iter_num)
        for iter in range(self.pre_idx + 1, iter_num):
            # if iter == iter_num -1 and self.drop_last:
            #     continue
            datalen_with_index = []
            for i in range(self.buffer_size):
                idx = iter * self.buffer_size + i
                if idx >= self.total_samples:
                    continue
                idx_map = self.shuffle_idx[idx]
                # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
                source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
                target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
                sample_len_cur = source_len + target_len
                datalen_with_index.append([idx, sample_len_cur])
            datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
            for item in datalen_with_index_sort:
                idx, sample_len_cur_raw = item
                if sample_len_cur_raw > self.max_token_length:
                    continue
                max_token_cur = max(max_token, sample_len_cur_raw)
                max_token_padding = 1 + num_sample
                # if self.batch_type != 'example':
                #     max_token_padding *= max_token_cur
                if max_token_padding <= batch_size_total:
                    batch.append(idx)
                    max_token = max_token_cur
                    num_sample += 1
                else:
                    batch_rank = batch[self.rank*self.batch_size: (self.rank+1)*self.batch_size]
                    yield batch_rank
                    batch = [idx]
                    max_token = sample_len_cur_raw
                    num_sample = 1
@tables.register("batch_sampler_classes", "RankFullLocalShuffleDynamicBatchSampler")
class RankFullLocalShuffleDynamicBatchSampler(torch.utils.data.BatchSampler):
    def __init__(self, dataset,
                 batch_type: str = "example",
                 batch_size: int = 100,
                 buffer_size: int = 30,
                 drop_last: bool = True,
                 shuffle: bool = True,
                 is_training: bool = True,
                 **kwargs):
        self.drop_last = drop_last
        self.pre_idx = -1
        self.dataset = dataset
        self.total_samples = len(dataset)
        self.batch_type = batch_type
        self.batch_size = int(batch_size)
        self.buffer_size = buffer_size
        self.max_token_length = kwargs.get("max_token_length", 1500)
        self.shuffle_idx = np.arange(self.total_samples)
        self.shuffle = shuffle and is_training
        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
        try:
            rank = dist.get_rank()
            world_size = dist.get_world_size()
        except:
            rank = 0
            world_size = 1
        self.rank = rank
        self.world_size = world_size
    def __len__(self):
        return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
    def set_epoch(self, epoch):
        np.random.seed(epoch)
    def __iter__(self):
        batch_size_total = self.batch_size * self.world_size
        if self.shuffle:
            np.random.shuffle(self.shuffle_idx)
        batch_list_all_rank = []
        batch_list_cur = []
        max_token = 0
        num_sample = 0
        iter_num = (self.total_samples - 1) // self.buffer_size + 1
        # print("iter_num: ", iter_num)
        for iter in range(self.pre_idx + 1, iter_num):
            # if iter == iter_num - 1 and self.drop_last:
            #     continue
            datalen_with_index = []
            for i in range(self.buffer_size):
                idx = iter * self.buffer_size + i
                if idx >= self.total_samples:
                    continue
                idx_map = self.shuffle_idx[idx]
                # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
                source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
                target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
                sample_len_cur = source_len + target_len
                datalen_with_index.append([idx, sample_len_cur])
            datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
            for ii, item in enumerate(datalen_with_index_sort):
                is_last_batch = iter == iter_num - 1 and ii == len(datalen_with_index_sort)
                idx, sample_len_cur_raw = item
                if sample_len_cur_raw > self.max_token_length:
                    continue
                max_token_cur = max(max_token, sample_len_cur_raw)
                max_token_padding = 1 + num_sample
                if self.batch_type != 'example':
                    max_token_padding *= max_token_cur
                if len(batch_list_all_rank) < self.world_size:
                    if max_token_padding <= self.batch_size:
                        batch_list_cur.append(idx)
                        max_token = max_token_cur
                        num_sample += 1
                    else:
                        batch_list_all_rank.append(batch_list_cur)
                        batch_list_cur = []
                else:
                    batch_rank = batch_list_all_rank[self.rank]
                    yield batch_rank
                    batch_list_all_rank = [idx]
                    max_token = sample_len_cur_raw
                    num_sample = 1
funasr/models/fsmn_vad_streaming/model.py
@@ -575,7 +575,8 @@
        
        time1 = time.perf_counter()
        is_streaming_input = kwargs.get("is_streaming_input", False) if chunk_size >= 15000 else kwargs.get("is_streaming_input", True)
        cfg = {"is_final": kwargs.get("is_final", False), "is_streaming_input": is_streaming_input}
        is_final = kwargs.get("is_final", False) if is_streaming_input else kwargs.get("is_final", True)
        cfg = {"is_final": is_final, "is_streaming_input": is_streaming_input}
        audio_sample_list = load_audio_text_image_video(data_in,
                                                        fs=frontend.fs,
                                                        audio_fs=kwargs.get("fs", 16000),
funasr/models/paraformer/cif_predictor.py
@@ -186,7 +186,7 @@
        alphas = alphas.squeeze(-1)
        mask = mask.squeeze(-1)
        if target_label_length is not None:
            target_length = target_label_length
            target_length = target_label_length.squeeze(-1)
        elif target_label is not None:
            target_length = (target_label != ignore_id).float().sum(-1)
        else:
funasr/models/paraformer/model.py
@@ -491,6 +491,8 @@
        b, n, d = decoder_out.size()
        if isinstance(key[0], (list, tuple)):
            key = key[0]
        if len(key) < b:
            key = key*b
        for i in range(b):
            x = encoder_out[i, :encoder_out_lens[i], :]
            am_scores = decoder_out[i, :pre_token_length[i], :]
funasr/train_utils/trainer.py
@@ -204,7 +204,25 @@
            my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
            with my_context():
                time2 = time.perf_counter()
                print("before, GPU, memory: {:.1} MB, "
                      "{:.1} MB, "
                      "{:.1} MB, "
                      "{:.1} MB".format(torch.cuda.memory_allocated()/1024/1024/1024,
                                     torch.cuda.max_memory_allocated()/1024/1024/1024,
                                     torch.cuda.memory_reserved()/1024/1024/1024,
                                     torch.cuda.max_memory_reserved()/1024/1024/1024,
                                     ))
                retval = self.model(**batch)
                torch.cuda.empty_cache()
                print("after, GPU, memory: {:.1} MB, "
                      "{:.1} MB, "
                      "{:.1} MB, "
                      "{:.1} MB".format(torch.cuda.memory_allocated()/1024/1024/1024,
                                     torch.cuda.max_memory_allocated()/1024/1024/1024,
                                     torch.cuda.memory_reserved()/1024/1024/1024,
                                     torch.cuda.max_memory_reserved()/1024/1024/1024,
                                     ))
                time3 = time.perf_counter()
                speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
                loss, stats, weight = retval