zhifu gao
2024-05-08 b1c186fd00fef54bcad3aa1d073a1a313642d641
Dev gzf exp (#1700)

* resume from step

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* train_loss_avg train_acc_avg

* train_loss_avg train_acc_avg

* train_loss_avg train_acc_avg

* log step

* wav is not exist

* wav is not exist

* decoding

* decoding

* decoding

* wechat

* decoding key

* decoding key

* decoding key

* decoding key

* decoding key
11个文件已修改
128 ■■■■ 已修改文件
examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/paraformer/demo.py 10 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/auto/auto_model.py 29 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/export.py 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/inference.py 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/sense_voice_datasets/datasets.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/decoder.py 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/model.py 52 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/search.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/whisper_lib/decoding.py 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/misc.py 14 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
@@ -9,11 +9,9 @@
model = AutoModel(model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch")
mm = model.model
for p in mm.parameters():
    print(f"{p.numel()}")
res = model.generate(input=wav_file)
print(res)
# [[beg1, end1], [beg2, end2], .., [begN, endN]]
# beg/end: ms
examples/industrial_data_pretraining/paraformer/demo.py
@@ -14,14 +14,8 @@
)
res = model.generate(
    input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
)
res = model.generate(
    input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
)
res = model.generate(
    input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
    input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
    cache={},
)
print(res)
funasr/auto/auto_model.py
@@ -26,6 +26,7 @@
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.utils import export_utils
from funasr.utils import misc
try:
    from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
@@ -35,14 +36,7 @@
def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
    """
    :param input:
    :param input_len:
    :param data_type:
    :param frontend:
    :return:
    """
    """ """
    data_list = []
    key_list = []
    filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
@@ -73,7 +67,8 @@
                    key_list.append(key)
        else:
            if key is None:
                key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
                # key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
                key = misc.extract_filename_without_extension(data_in)
            data_list = [data_in]
            key_list = [key]
    elif isinstance(data_in, (list, tuple)):
@@ -90,10 +85,14 @@
        else:
            # [audio sample point, fbank, text]
            data_list = data_in
            key_list = [
                "rand_key_" + "".join(random.choice(chars) for _ in range(13))
                for _ in range(len(data_in))
            ]
            key_list = []
            for data_i in data_in:
                if isinstance(data_i, str) and os.path.exists(data_i):
                    key = misc.extract_filename_without_extension(data_i)
                else:
                    key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
                key_list.append(key)
    else:  # raw text; audio sample point, fbank; bytes
        if isinstance(data_in, bytes):  # audio bytes
            data_in = load_bytes(data_in)
@@ -108,6 +107,10 @@
class AutoModel:
    def __init__(self, **kwargs):
        log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
        logging.basicConfig(level=log_level)
        if not kwargs.get("disable_log", True):
            tables.print()
funasr/bin/export.py
@@ -17,9 +17,6 @@
            return cfg_item
    kwargs = to_plain_list(cfg)
    log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
    logging.basicConfig(level=log_level)
    if kwargs.get("debug", False):
        import pdb
funasr/bin/inference.py
@@ -16,9 +16,6 @@
            return cfg_item
    kwargs = to_plain_list(cfg)
    log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
    logging.basicConfig(level=log_level)
    if kwargs.get("debug", False):
        import pdb
funasr/datasets/sense_voice_datasets/datasets.py
@@ -112,7 +112,7 @@
            eos = self.tokenizer.encode(self.eos, allowed_special="all")  # [eos]
            ids = prompt_ids + target_ids + eos
            ids = prompt_ids + target_ids + eos  # [sos, task, lid, text, eos]
            ids_lengths = len(ids)
            text = torch.tensor(ids, dtype=torch.int64)
funasr/models/sense_voice/decoder.py
@@ -472,7 +472,7 @@
        is_pad_mask = kwargs.get("is_pad_mask", False)
        is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
        fsmn_cache = cache[layer]["fsmn_cache"] if len(cache) > 0 or cache is None else None
        fsmn_cache = cache[layer]["fsmn_cache"] if cache is not None and len(cache) > 0 else None
        # if fsmn_cache is not None:
        #     x = x[:, -1:]
        att_res, fsmn_cache = self.attn(self.attn_ln(x), mask=None, cache=fsmn_cache)
@@ -599,5 +599,6 @@
    def score(self, ys, state, x):
        """Score."""
        ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
        logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state)
        logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=None)
        logp = torch.log_softmax(logp, dim=-1)
        return logp.squeeze(0)[-1, :], state
funasr/models/sense_voice/model.py
@@ -378,14 +378,19 @@
        stats = {}
        # 1. Forward decoder
        # ys_pad: [sos, task, lid, text, eos]
        decoder_out = self.model.decoder(
            x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
        )
        # 2. Compute attention loss
        mask = torch.ones_like(ys_pad) * (-1)
        ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64)
        ys_pad_mask[ys_pad_mask == 0] = -1
        mask = torch.ones_like(ys_pad) * (-1)  # [sos, task, lid, text, eos]: [-1, -1, -1, -1]
        ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(
            torch.int64
        )  # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1] + [-1, -1, 0, 0, 0]
        ys_pad_mask[ys_pad_mask == 0] = -1  # [-1, -1, lid, text, eos]
        # decoder_out: [sos, task, lid, text]
        # ys_pad_mask: [-1, lid, text, eos]
        loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:])
        with torch.no_grad():
@@ -797,6 +802,16 @@
                data_type=kwargs.get("data_type", "sound"),
                tokenizer=tokenizer,
            )
            if (
                isinstance(kwargs.get("data_type", None), (list, tuple))
                and len(kwargs.get("data_type", [])) > 1
            ):
                audio_sample_list, text_token_int_list = audio_sample_list
                text_token_int = text_token_int_list[0]
            else:
                text_token_int = None
            time2 = time.perf_counter()
            meta_data["load_data"] = f"{time2 - time1:0.3f}"
            speech, speech_lengths = extract_fbank(
@@ -832,6 +847,37 @@
            speech[None, :, :].permute(0, 2, 1), speech_lengths
        )
        if text_token_int is not None:
            i = 0
            results = []
            ibest_writer = None
            if kwargs.get("output_dir") is not None:
                if not hasattr(self, "writer"):
                    self.writer = DatadirWriter(kwargs.get("output_dir"))
                ibest_writer = self.writer[f"1best_recog"]
            # 1. Forward decoder
            ys_pad = torch.tensor(sos_int + text_token_int, dtype=torch.int64).to(kwargs["device"])[
                None, :
            ]
            ys_pad_lens = torch.tensor([len(sos_int + text_token_int)], dtype=torch.int64).to(
                kwargs["device"]
            )[None, :]
            decoder_out = self.model.decoder(
                x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
            )
            token_int = decoder_out.argmax(-1)[0, :].tolist()
            text = tokenizer.decode(token_int)
            result_i = {"key": key[i], "text": text}
            results.append(result_i)
            if ibest_writer is not None:
                # ibest_writer["token"][key[i]] = " ".join(token)
                ibest_writer["text"][key[i]] = text
            return results, meta_data
        # c. Passed the encoder result and the beam search
        nbest_hyps = self.beam_search(
            x=encoder_out[0],
funasr/models/sense_voice/search.py
@@ -370,6 +370,8 @@
            # post process of one iteration
            running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
            # end detection
            # if len(ended_hyps) > 0:
            #     print(f"ended_hyps: {ended_hyps}")
            if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
                logging.info(f"end detected at {i}")
                break
funasr/models/sense_voice/whisper_lib/decoding.py
@@ -62,8 +62,10 @@
    else:
        x = x.to(mel.device)
    # FIX(funasr): sense vocie
    # logits = model.logits(x[:, :-1], mel)[:, -1]
    logits = model.logits(x[:, :], mel)[:, -1]
    logits = model.logits(x[:, :-1], mel)[:, -1]
    # collect detected languages; suppress all non-language tokens
    mask = torch.ones(logits.shape[-1], dtype=torch.bool)
    mask[list(tokenizer.all_language_tokens)] = False
funasr/utils/misc.py
@@ -78,3 +78,17 @@
    #     config_json = os.path.join(model_path, "configuration.json")
    #     if os.path.exists(config_json):
    #         shutil.copy(config_json, os.path.join(kwargs.get("output_dir", "./"), "configuration.json"))
def extract_filename_without_extension(file_path):
    """
    从给定的文件路径中提取文件名(不包含路径和扩展名)
    :param file_path: 完整的文件路径
    :return: 文件名(不含路径和扩展名)
    """
    # 首先,使用os.path.basename获取路径中的文件名部分(含扩展名)
    filename_with_extension = os.path.basename(file_path)
    # 然后,使用os.path.splitext分离文件名和扩展名
    filename, extension = os.path.splitext(filename_with_extension)
    # 返回不包含扩展名的文件名
    return filename