zhifu gao
2024-04-17 824377d2aae11dc9ebbde871e3b23a0e0cadc7af
Dev gzf exp (#1626)

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune
5个文件已修改
97 ■■■■ 已修改文件
funasr/bin/train.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/espnet_samplers.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/index_ds.py 10 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/template.yaml 71 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/trainer.py 11 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train.py
@@ -55,6 +55,8 @@
    torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
    torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
    torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
    # open tf32
    torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True)
    
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    if local_rank == 0:
funasr/datasets/audio_datasets/espnet_samplers.py
@@ -61,6 +61,7 @@
        self.epoch = 0
        self.sort_size = sort_size * num_replicas
        self.max_token_length = kwargs.get("max_token_length", 2048)
        self.min_token_length = kwargs.get("min_token_length", 0)
        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
@@ -85,7 +86,7 @@
        
        for idx in sorted_indices:
            original_sample_length = self.dataset.get_source_len(idx)
            if original_sample_length > self.max_token_length:  # Skip samples that exceed the max length
            if original_sample_length < self.min_token_length or original_sample_length > self.max_token_length:  # Skip samples that exceed the max length
                continue
            # Set sample_length based on the batch type
            sample_length = 1 if self.batch_type == "example" else original_sample_length
funasr/datasets/audio_datasets/index_ds.py
@@ -76,7 +76,10 @@
    
    def __init__(self, path: str, **kwargs):
        super().__init__()
        self.max_source_length = kwargs.get("max_source_length", 2048)
        self.min_source_length = kwargs.get("min_source_length", 0)
        self.max_target_length = kwargs.get("max_target_length", 2048)
        self.min_target_length = kwargs.get("min_target_length", 0)
        if isinstance(path, (list, tuple)): # wav.scp, text.txt/text.trans
            from funasr.datasets.audio_datasets.scp2jsonl import gen_jsonl_from_wav_text_list
            jsonl_outdir = os.path.dirname(path[0])
@@ -101,7 +104,10 @@
                    target_len = data.get("target_len", 0)
                    if "aishell" in source:
                        target = target.replace(" ", "")
                    if source_len < self.min_source_length or source_len > self.max_source_length:
                        continue
                    if target_len < self.min_target_length or target_len > self.max_target_length:
                        continue
                    contents_i = {"source": source,
                                 "prompt": prompt,
                                 "target": target,
funasr/models/sense_voice/template.yaml
@@ -4,21 +4,20 @@
# to print the register_table:
# from funasr.register import tables
# tables.print()
# network architecture
model: SenseVoice
model_conf:
    lsm_weight: 0.1
    length_normalized_loss: true
    hub: funasr
    activation_checkpoint: true
    sos: "<|startoftranscript|>"
    eos: "<|endoftext|>"
    downsample_rate: 4
    use_padmask: true
# only use for hub == funasr,
#  if hub == openai, dims is automaticall download
dims:
    n_mels: 128
    n_vocab: 51866
        n_vocab: 60515
    n_audio_ctx: 1500
    n_audio_state: 1280
    n_audio_head: 20
@@ -32,14 +31,58 @@
frontend: WhisperFrontend
frontend_conf:
    fs: 16000
    n_mels: ${dims.n_mels}
    do_pad_trim: true
    n_mels: ${model_conf.dims.n_mels}
    do_pad_trim: false
tokenizer: WhisperTokenizer
tokenizer: SenseVoiceTokenizer
tokenizer_conf:
  language: null
  task: transcribe
  vocab_path: null
  is_multilingual: true
  num_languages: 100
  num_languages: 8749
scope_map: [none, "model."]
dataset: SenseVoiceDataset
dataset_conf:
    index_ds: IndexDSJsonl
    batch_sampler: EspnetStyleBatchSampler
    batch_type: length # example or length
    batch_size: 7000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
    max_token_length: 2000 # filter samples if source_token_len+target_token_len > max_token_length,
    min_token_length: 60
    shuffle: True
    num_workers: 4
    sos: ${model_conf.sos}
    eos: ${model_conf.eos}
train_conf:
  accum_grad: 2
  grad_clip: 5
  max_epoch: 20
  keep_nbest_models: 20
  avg_nbest_model: ${train_conf.keep_nbest_models}
  log_interval: 50
optim: adamw
optim_conf:
    lr: 0.00002
scheduler: warmuplr
scheduler_conf:
    warmup_steps: 10000
specaug: SpecAug
specaug_conf:
    apply_time_warp: true
    time_warp_window: 5
    time_warp_mode: bicubic
    apply_freq_mask: true
    freq_mask_width_range:
    - 0
    - 40
    num_freq_mask: 2
    apply_time_mask: true
    time_mask_width_ratio_range:
    - 0.0
    - 0.12
    num_time_mask: 2
scope_map: ['encoder.encoders', 'model.encoder', 'decoder.decoders', 'model.decoder']
funasr/train_utils/trainer.py
@@ -268,10 +268,12 @@
        # Initialize the gradient accumulation
        optim.zero_grad()
        speed_stats = {}
        time5 = time.perf_counter()
        iterator_stop = torch.tensor(0).to(self.device)
        dataloader_train.batch_sampler.set_epoch(epoch)
        time_beg = time.perf_counter()
        time5 = time_beg
        for batch_idx, batch in enumerate(dataloader_train):
            if self.use_ddp or self.use_fsdp:
                dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
@@ -279,11 +281,13 @@
                    break
            self.batch_total += 1
            time1 = time.perf_counter()
            speed_stats["data_load"] = f"{time1-time5:0.3f}"
            speed_stats["data_load"] = f"{time1-time_beg:0.3f}"
            batch = to_device(batch, self.device)
            
            my_context = model.no_sync if batch_idx % accum_grad != 0 else nullcontext
            my_context = nullcontext
            if self.use_ddp or self.use_fsdp:
                my_context = model.no_sync if batch_idx % accum_grad != 0 else my_context
            with my_context():
                time2 = time.perf_counter()
                with maybe_autocast(self.use_fp16):
@@ -384,6 +388,7 @@
            if (batch_idx+1) % self.save_checkpoint_interval == 0:
                self.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler, step=batch_idx+1)
            time_beg = time.perf_counter()
        else:
            if self.use_ddp or self.use_fsdp:
                iterator_stop.fill_(1)