| | |
| | | |
| | | model = AutoModel(model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch") |
| | | |
| | | mm = model.model |
| | | for p in mm.parameters(): |
| | | print(f"{p.numel()}") |
| | | res = model.generate(input=wav_file) |
| | | print(res) |
| | | |
| | | # [[beg1, end1], [beg2, end2], .., [begN, endN]] |
| | | # beg/end: ms |
| | | |
| | |
| | | ) |
| | | |
| | | res = model.generate( |
| | | input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" |
| | | ) |
| | | res = model.generate( |
| | | input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" |
| | | ) |
| | | |
| | | res = model.generate( |
| | | input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" |
| | | input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", |
| | | cache={}, |
| | | ) |
| | | |
| | | print(res) |
| | |
| | | from funasr.train_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.train_utils.load_pretrained_model import load_pretrained_model |
| | | from funasr.utils import export_utils |
| | | from funasr.utils import misc |
| | | |
| | | try: |
| | | from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk |
| | |
| | | |
| | | |
| | | def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None): |
| | | """ |
| | | |
| | | :param input: |
| | | :param input_len: |
| | | :param data_type: |
| | | :param frontend: |
| | | :return: |
| | | """ |
| | | """ """ |
| | | data_list = [] |
| | | key_list = [] |
| | | filelist = [".scp", ".txt", ".json", ".jsonl", ".text"] |
| | |
| | | key_list.append(key) |
| | | else: |
| | | if key is None: |
| | | key = "rand_key_" + "".join(random.choice(chars) for _ in range(13)) |
| | | # key = "rand_key_" + "".join(random.choice(chars) for _ in range(13)) |
| | | key = misc.extract_filename_without_extension(data_in) |
| | | data_list = [data_in] |
| | | key_list = [key] |
| | | elif isinstance(data_in, (list, tuple)): |
| | |
| | | else: |
| | | # [audio sample point, fbank, text] |
| | | data_list = data_in |
| | | key_list = [ |
| | | "rand_key_" + "".join(random.choice(chars) for _ in range(13)) |
| | | for _ in range(len(data_in)) |
| | | ] |
| | | key_list = [] |
| | | for data_i in data_in: |
| | | if isinstance(data_i, str) and os.path.exists(data_i): |
| | | key = misc.extract_filename_without_extension(data_i) |
| | | else: |
| | | key = "rand_key_" + "".join(random.choice(chars) for _ in range(13)) |
| | | key_list.append(key) |
| | | |
| | | else: # raw text; audio sample point, fbank; bytes |
| | | if isinstance(data_in, bytes): # audio bytes |
| | | data_in = load_bytes(data_in) |
| | |
| | | |
| | | eos = self.tokenizer.encode(self.eos, allowed_special="all") # [eos] |
| | | |
| | | ids = prompt_ids + target_ids + eos |
| | | ids = prompt_ids + target_ids + eos # [sos, task, lid, text, eos] |
| | | ids_lengths = len(ids) |
| | | |
| | | text = torch.tensor(ids, dtype=torch.int64) |
| | |
| | | stats = {} |
| | | |
| | | # 1. Forward decoder |
| | | # ys_pad: [sos, task, lid, text, eos] |
| | | decoder_out = self.model.decoder( |
| | | x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens |
| | | ) |
| | | |
| | | # 2. Compute attention loss |
| | | mask = torch.ones_like(ys_pad) * (-1) |
| | | ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64) |
| | | ys_pad_mask[ys_pad_mask == 0] = -1 |
| | | mask = torch.ones_like(ys_pad) * (-1) # [sos, task, lid, text, eos]: [-1, -1, -1, -1] |
| | | ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to( |
| | | torch.int64 |
| | | ) # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1] + [-1, -1, 0, 0, 0] |
| | | ys_pad_mask[ys_pad_mask == 0] = -1 # [-1, -1, lid, text, eos] |
| | | # decoder_out: [sos, task, lid, text] |
| | | # ys_pad_mask: [-1, lid, text, eos] |
| | | loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:]) |
| | | |
| | | with torch.no_grad(): |
| | |
| | | # config_json = os.path.join(model_path, "configuration.json") |
| | | # if os.path.exists(config_json): |
| | | # shutil.copy(config_json, os.path.join(kwargs.get("output_dir", "./"), "configuration.json")) |
| | | |
| | | |
| | | def extract_filename_without_extension(file_path): |
| | | """ |
| | | 从给定的文件路径中提取文件名(不包含路径和扩展名) |
| | | :param file_path: 完整的文件路径 |
| | | :return: 文件名(不含路径和扩展名) |
| | | """ |
| | | # 首先,使用os.path.basename获取路径中的文件名部分(含扩展名) |
| | | filename_with_extension = os.path.basename(file_path) |
| | | # 然后,使用os.path.splitext分离文件名和扩展名 |
| | | filename, extension = os.path.splitext(filename_with_extension) |
| | | # 返回不包含扩展名的文件名 |
| | | return filename |