zhifu gao
2024-06-05 db9ec58cb430fa10b273a7a365457bf33dc46adc
Dev gzf exp (#1785)

* resume from step

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* train_loss_avg train_acc_avg

* train_loss_avg train_acc_avg

* train_loss_avg train_acc_avg

* log step

* wav is not exist

* wav is not exist

* decoding

* decoding

* decoding

* wechat

* decoding key

* decoding key

* decoding key

* decoding key

* decoding key

* decoding key

* dynamic batch

* start_data_split_i=0

* total_time/accum_grad

* total_time/accum_grad

* total_time/accum_grad

* update avg slice

* update avg slice

* sensevoice sanm

* sensevoice sanm

* sensevoice sanm

---------

Co-authored-by: 北念 <lzr265946@alibaba-inc.com>
3个文件已修改
31 ■■■■ 已修改文件
funasr/datasets/audio_datasets/espnet_samplers.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/decoder.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/model.py 28 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/espnet_samplers.py
@@ -147,7 +147,9 @@
        start_idx = self.rank * batches_per_rank
        end_idx = start_idx + batches_per_rank
        rank_batches = buffer_batches[start_idx + self.start_step : end_idx]
        self.batch_num = len(rank_batches)
        logging.info(
            f"rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num: {end_idx-start_idx}, batch_num_after_step: {len(rank_batches)}"
        )
funasr/models/sense_voice/decoder.py
@@ -360,6 +360,7 @@
        """Score."""
        ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
        logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state)
        logp = torch.log_softmax(logp, dim=-1)
        return logp.squeeze(0)[-1, :], state
funasr/models/sense_voice/model.py
@@ -1264,15 +1264,29 @@
        if isinstance(task, str):
            task = [task]
        task = "".join([f"<|{x}|>" for x in task])
        initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
        sos = kwargs.get("model_conf").get("sos")
        if isinstance(sos, str):
            initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
        language = DecodingOptions.get("language", None)
        language = None if language == "auto" else language
            language = DecodingOptions.get("language", None)
            language = None if language == "auto" else language
        sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
        sos_int = tokenizer.encode(sos, allowed_special="all")
            sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
            sos_int = tokenizer.encode(sos, allowed_special="all")
        else:
            language = DecodingOptions.get("language", None)
            language = None if language == "auto" else language
            initial_prompt = kwargs.get("initial_prompt", f"{task}")
            initial_prompt_lid = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
            initial_prompt_lid_int = tokenizer.encode(initial_prompt_lid, allowed_special="all")
            sos_int = [sos] + initial_prompt_lid_int
        eos = kwargs.get("model_conf").get("eos")
        eos_int = tokenizer.encode(eos, allowed_special="all")
        if isinstance(eos, str):
            eos_int = tokenizer.encode(eos, allowed_special="all")
        else:
            eos_int = [eos]
        self.beam_search.sos = sos_int
        self.beam_search.eos = eos_int[0]
@@ -1298,7 +1312,7 @@
        self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
        encoder_out, encoder_out_lens = self.encode(
            speech[None, :, :].permute(0, 2, 1), speech_lengths
            speech[None, :, :], speech_lengths
        )
        if text_token_int is not None: