Dev gzf exp (#1626)
* sensevoice finetune
* sensevoice finetune
* sensevoice finetune
* sensevoice finetune
* sensevoice finetune
* sensevoice finetune
* sensevoice finetune
* sensevoice finetune
* sensevoice finetune
* sensevoice finetune
| | |
| | | torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled) |
| | | torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark) |
| | | torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True) |
| | | # open tf32 |
| | | torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True) |
| | | |
| | | local_rank = int(os.environ.get('LOCAL_RANK', 0)) |
| | | if local_rank == 0: |
| | |
| | | self.epoch = 0 |
| | | self.sort_size = sort_size * num_replicas |
| | | self.max_token_length = kwargs.get("max_token_length", 2048) |
| | | self.min_token_length = kwargs.get("min_token_length", 0) |
| | | self.length_scale_source = kwargs.get("length_scale_source", 1.0) |
| | | |
| | | |
| | |
| | | |
| | | for idx in sorted_indices: |
| | | original_sample_length = self.dataset.get_source_len(idx) |
| | | if original_sample_length > self.max_token_length: # Skip samples that exceed the max length |
| | | if original_sample_length < self.min_token_length or original_sample_length > self.max_token_length: # Skip samples that exceed the max length |
| | | continue |
| | | # Set sample_length based on the batch type |
| | | sample_length = 1 if self.batch_type == "example" else original_sample_length |
| | |
| | | |
| | | def __init__(self, path: str, **kwargs): |
| | | super().__init__() |
| | | |
| | | self.max_source_length = kwargs.get("max_source_length", 2048) |
| | | self.min_source_length = kwargs.get("min_source_length", 0) |
| | | self.max_target_length = kwargs.get("max_target_length", 2048) |
| | | self.min_target_length = kwargs.get("min_target_length", 0) |
| | | if isinstance(path, (list, tuple)): # wav.scp, text.txt/text.trans |
| | | from funasr.datasets.audio_datasets.scp2jsonl import gen_jsonl_from_wav_text_list |
| | | jsonl_outdir = os.path.dirname(path[0]) |
| | |
| | | target_len = data.get("target_len", 0) |
| | | if "aishell" in source: |
| | | target = target.replace(" ", "") |
| | | |
| | | if source_len < self.min_source_length or source_len > self.max_source_length: |
| | | continue |
| | | if target_len < self.min_target_length or target_len > self.max_target_length: |
| | | continue |
| | | contents_i = {"source": source, |
| | | "prompt": prompt, |
| | | "target": target, |
| | |
| | | # to print the register_table: |
| | | # from funasr.register import tables |
| | | # tables.print() |
| | | |
| | | # network architecture |
| | | model: SenseVoice |
| | | model_conf: |
| | | lsm_weight: 0.1 |
| | | length_normalized_loss: true |
| | | hub: funasr |
| | | activation_checkpoint: true |
| | | sos: "<|startoftranscript|>" |
| | | eos: "<|endoftext|>" |
| | | downsample_rate: 4 |
| | | use_padmask: true |
| | | |
| | | |
| | | |
| | | # only use for hub == funasr, |
| | | # if hub == openai, dims is automaticall download |
| | | dims: |
| | | n_mels: 128 |
| | | n_vocab: 51866 |
| | | n_vocab: 60515 |
| | | n_audio_ctx: 1500 |
| | | n_audio_state: 1280 |
| | | n_audio_head: 20 |
| | |
| | | frontend: WhisperFrontend |
| | | frontend_conf: |
| | | fs: 16000 |
| | | n_mels: ${dims.n_mels} |
| | | do_pad_trim: true |
| | | n_mels: ${model_conf.dims.n_mels} |
| | | do_pad_trim: false |
| | | |
| | | tokenizer: WhisperTokenizer |
| | | tokenizer: SenseVoiceTokenizer |
| | | tokenizer_conf: |
| | | language: null |
| | | task: transcribe |
| | | vocab_path: null |
| | | is_multilingual: true |
| | | num_languages: 100 |
| | | num_languages: 8749 |
| | | |
| | | scope_map: [none, "model."] |
| | | dataset: SenseVoiceDataset |
| | | dataset_conf: |
| | | index_ds: IndexDSJsonl |
| | | batch_sampler: EspnetStyleBatchSampler |
| | | batch_type: length # example or length |
| | | batch_size: 7000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len; |
| | | max_token_length: 2000 # filter samples if source_token_len+target_token_len > max_token_length, |
| | | min_token_length: 60 |
| | | shuffle: True |
| | | num_workers: 4 |
| | | sos: ${model_conf.sos} |
| | | eos: ${model_conf.eos} |
| | | |
| | | train_conf: |
| | | accum_grad: 2 |
| | | grad_clip: 5 |
| | | max_epoch: 20 |
| | | keep_nbest_models: 20 |
| | | avg_nbest_model: ${train_conf.keep_nbest_models} |
| | | log_interval: 50 |
| | | |
| | | optim: adamw |
| | | optim_conf: |
| | | lr: 0.00002 |
| | | |
| | | scheduler: warmuplr |
| | | scheduler_conf: |
| | | warmup_steps: 10000 |
| | | |
| | | specaug: SpecAug |
| | | specaug_conf: |
| | | apply_time_warp: true |
| | | time_warp_window: 5 |
| | | time_warp_mode: bicubic |
| | | apply_freq_mask: true |
| | | freq_mask_width_range: |
| | | - 0 |
| | | - 40 |
| | | num_freq_mask: 2 |
| | | apply_time_mask: true |
| | | time_mask_width_ratio_range: |
| | | - 0.0 |
| | | - 0.12 |
| | | num_time_mask: 2 |
| | | |
| | | scope_map: ['encoder.encoders', 'model.encoder', 'decoder.decoders', 'model.decoder'] |
| | |
| | | # Initialize the gradient accumulation |
| | | optim.zero_grad() |
| | | speed_stats = {} |
| | | time5 = time.perf_counter() |
| | | |
| | | iterator_stop = torch.tensor(0).to(self.device) |
| | | |
| | | dataloader_train.batch_sampler.set_epoch(epoch) |
| | | time_beg = time.perf_counter() |
| | | time5 = time_beg |
| | | for batch_idx, batch in enumerate(dataloader_train): |
| | | if self.use_ddp or self.use_fsdp: |
| | | dist.all_reduce(iterator_stop, dist.ReduceOp.SUM) |
| | |
| | | break |
| | | self.batch_total += 1 |
| | | time1 = time.perf_counter() |
| | | speed_stats["data_load"] = f"{time1-time5:0.3f}" |
| | | speed_stats["data_load"] = f"{time1-time_beg:0.3f}" |
| | | |
| | | batch = to_device(batch, self.device) |
| | | |
| | | my_context = model.no_sync if batch_idx % accum_grad != 0 else nullcontext |
| | | my_context = nullcontext |
| | | if self.use_ddp or self.use_fsdp: |
| | | my_context = model.no_sync if batch_idx % accum_grad != 0 else my_context |
| | | with my_context(): |
| | | time2 = time.perf_counter() |
| | | with maybe_autocast(self.use_fp16): |
| | |
| | | if (batch_idx+1) % self.save_checkpoint_interval == 0: |
| | | self.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler, step=batch_idx+1) |
| | | |
| | | time_beg = time.perf_counter() |
| | | else: |
| | | if self.use_ddp or self.use_fsdp: |
| | | iterator_stop.fill_(1) |