zhifu gao
2024-04-29 11cf10e433c173efd892766b669e0bba57253fed
Dev gzf exp (#1678)

* resume from step

* batch

* batch

* batch
4个文件已修改
59 ■■■■■ 已修改文件
funasr/datasets/audio_datasets/scp2jsonl.py 13 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/model.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/schedulers/lambdalr_cus.py 42 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tokenizer/abs_tokenizer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/scp2jsonl.py
@@ -7,6 +7,7 @@
import concurrent.futures
import librosa
import torch.distributed as dist
from tqdm import tqdm
def gen_jsonl_from_wav_text_list(
@@ -28,6 +29,7 @@
            with open(data_file, "r") as f:
                data_file_lists = f.readlines()
                print("")
                lines_for_each_th = (len(data_file_lists) - 1) // cpu_cores + 1
                task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
                # import pdb;pdb.set_trace()
@@ -41,6 +43,7 @@
                                    i * lines_for_each_th : (i + 1) * lines_for_each_th
                                ],
                                data_type,
                                i,
                            )
                            for i in range(task_num)
                        ]
@@ -69,11 +72,15 @@
        dist.barrier()
def parse_context_length(data_list: list, data_type: str):
def parse_context_length(data_list: list, data_type: str, id=0):
    pbar = tqdm(total=len(data_list), dynamic_ncols=True)
    res = {}
    for i, line in enumerate(data_list):
        key, line = line.strip().split(maxsplit=1)
        pbar.update(1)
        pbar.set_description(f"cpu: {id}")
        lines = line.strip().split(maxsplit=1)
        key = lines[0]
        line = lines[1] if len(lines) > 1 else ""
        line = line.strip()
        if os.path.exists(line):
            waveform, _ = librosa.load(line, sr=16000)
funasr/models/sense_voice/model.py
@@ -329,6 +329,8 @@
        stats["loss"] = torch.clone(loss.detach())
        stats["batch_size"] = batch_size
        stats["batch_size_x_frames"] = frames * batch_size
        stats["batch_size_real_frames"] = speech_lengths.sum().item()
        stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
funasr/schedulers/lambdalr_cus.py
@@ -2,28 +2,36 @@
from torch.optim.lr_scheduler import _LRScheduler
# class CustomLambdaLR(_LRScheduler):
#     def __init__(self, optimizer, warmup_steps, last_epoch=-1):
#         self.warmup_steps = warmup_steps
#         super().__init__(optimizer, last_epoch)
#
#     def get_lr(self):
#         if self.last_epoch < self.warmup_steps:
#             return [
#                 base_lr * min(self.last_epoch / self.warmup_steps, 1) for base_lr in self.base_lrs
#             ]
#         else:
#             return [base_lr for base_lr in self.base_lrs]
class CustomLambdaLR(_LRScheduler):
    def __init__(self, optimizer, warmup_steps, last_epoch=-1):
    def __init__(
        self,
        optimizer,
        warmup_steps: int = 25000,
        total_steps: int = 500000,
        last_epoch=-1,
        verbose=False,
    ):
        self.warmup_steps = warmup_steps
        super().__init__(optimizer, last_epoch)
        self.total_steps = total_steps
        super().__init__(optimizer, last_epoch, verbose)
    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            return [
                base_lr * min(self.last_epoch / self.warmup_steps, 1) for base_lr in self.base_lrs
            ]
        else:
            return [base_lr for base_lr in self.base_lrs]
class CustomLambdaLR(_LRScheduler):
    def __init__(self, optimizer, train_config, last_epoch=-1, verbose=False):
        self.warmup_steps = train_config.warmup_steps
        self.total_steps = train_config.total_steps
        super(CustomLambdaLR, self).__init__(optimizer, last_epoch, verbose)
    def get_lr(self):
        step = self._step_count
        step = self.last_epoch + 1
        if step < self.warmup_steps:
            lr_scale = step / self.warmup_steps
        else:
funasr/tokenizer/abs_tokenizer.py
@@ -62,7 +62,7 @@
                raise RuntimeError(f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list")
            self.unk_id = self.token2id[self.unk_symbol]
    def encode(self, text):
    def encode(self, text, **kwargs):
        tokens = self.text2tokens(text)
        text_ints = self.tokens2ids(tokens)