Shi Xian
2024-02-27 ba589e05c1448d0487198f8603cba247f22d67e1
Merge pull request #1393 from alibaba-damo-academy/dev_gzf

Dev gzf
11个文件已修改
8个文件已添加
1182 ■■■■ 已修改文件
.gitignore 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
README.md 25 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
README_zh.md 24 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py 8 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/auto/auto_model.py 27 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train.py 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/llm_datasets/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/llm_datasets/datasets.py 131 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/llm_datasets/preprocessor.py 37 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/llm_datasets/samplers.py 277 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/metrics/compute_acc.py 17 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/llm_asr/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/llm_asr/adaptor.py 29 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/llm_asr/model.py 333 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/paraformer/cif_predictor.py 200 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tokenizer/hf_tokenizer.py 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/load_pretrained_model.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/trainer.py 46 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
setup.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
.gitignore
@@ -25,3 +25,4 @@
emotion2vec*
GPT-SoVITS*
modelscope_models
examples/aishell/llm_asr_nar/*
README.md
@@ -105,10 +105,8 @@
from funasr import AutoModel
# paraformer-zh is a multi-functional asr model
# use vad, punc, spk or not as you need
model = AutoModel(model="paraformer-zh", model_revision="v2.0.4",
                  vad_model="fsmn-vad", vad_model_revision="v2.0.4",
                  punc_model="ct-punc-c", punc_model_revision="v2.0.4",
                  # spk_model="cam++", spk_model_revision="v2.0.2",
model = AutoModel(model="paraformer-zh",  vad_model="fsmn-vad",  punc_model="ct-punc-c",
                  # spk_model="cam++",
                  )
res = model.generate(input=f"{model.model_path}/example/asr_example.wav", 
                     batch_size_s=300, 
@@ -125,7 +123,7 @@
encoder_chunk_look_back = 4 #number of chunks to lookback for encoder self-attention
decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention
model = AutoModel(model="paraformer-zh-streaming", model_revision="v2.0.4")
model = AutoModel(model="paraformer-zh-streaming")
import soundfile
import os
@@ -148,17 +146,19 @@
```python
from funasr import AutoModel
model = AutoModel(model="fsmn-vad", model_revision="v2.0.4")
model = AutoModel(model="fsmn-vad")
wav_file = f"{model.model_path}/example/asr_example.wav"
res = model.generate(input=wav_file)
print(res)
```
Note: The output format of the VAD model is: `[[beg1, end1], [beg2, end2], ..., [begN, endN]]`, where `begN/endN` indicates the starting/ending point of the `N-th` valid audio segment, measured in milliseconds.
### Voice Activity Detection (Streaming)
```python
from funasr import AutoModel
chunk_size = 200 # ms
model = AutoModel(model="fsmn-vad", model_revision="v2.0.4")
model = AutoModel(model="fsmn-vad")
import soundfile
@@ -175,11 +175,18 @@
    if len(res[0]["value"]):
        print(res)
```
Note: The output format for the streaming VAD model can be one of four scenarios:
- `[[beg1, end1], [beg2, end2], .., [begN, endN]]`:The same as the offline VAD output result mentioned above.
- `[[beg, -1]]`:Indicates that only a starting point has been detected.
- `[[-1, end]]`:Indicates that only an ending point has been detected.
- `[]`:Indicates that neither a starting point nor an ending point has been detected.
The output is measured in milliseconds and represents the absolute time from the starting point.
### Punctuation Restoration
```python
from funasr import AutoModel
model = AutoModel(model="ct-punc", model_revision="v2.0.4")
model = AutoModel(model="ct-punc")
res = model.generate(input="那今天的会就到这里吧 happy new year 明年见")
print(res)
```
@@ -187,7 +194,7 @@
```python
from funasr import AutoModel
model = AutoModel(model="fa-zh", model_revision="v2.0.4")
model = AutoModel(model="fa-zh")
wav_file = f"{model.model_path}/example/asr_example.wav"
text_file = f"{model.model_path}/example/text.txt"
res = model.generate(input=(wav_file, text_file), data_type=("sound", "text"))
README_zh.md
@@ -101,10 +101,8 @@
from funasr import AutoModel
# paraformer-zh is a multi-functional asr model
# use vad, punc, spk or not as you need
model = AutoModel(model="paraformer-zh", model_revision="v2.0.4",
                  vad_model="fsmn-vad", vad_model_revision="v2.0.4",
                  punc_model="ct-punc-c", punc_model_revision="v2.0.4",
                  # spk_model="cam++", spk_model_revision="v2.0.2",
model = AutoModel(model="paraformer-zh",  vad_model="fsmn-vad", punc_model="ct-punc-c",
                  # spk_model="cam++"
                  )
res = model.generate(input=f"{model.model_path}/example/asr_example.wav", 
            batch_size_s=300, 
@@ -122,7 +120,7 @@
encoder_chunk_look_back = 4 #number of chunks to lookback for encoder self-attention
decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention
model = AutoModel(model="paraformer-zh-streaming", model_revision="v2.0.4")
model = AutoModel(model="paraformer-zh-streaming")
import soundfile
import os
@@ -146,19 +144,21 @@
```python
from funasr import AutoModel
model = AutoModel(model="fsmn-vad", model_revision="v2.0.4")
model = AutoModel(model="fsmn-vad")
wav_file = f"{model.model_path}/example/asr_example.wav"
res = model.generate(input=wav_file)
print(res)
```
注:VAD模型输出格式为:`[[beg1, end1], [beg2, end2], .., [begN, endN]]`,其中`begN/endN`表示第`N`个有效音频片段的起始点/结束点,
单位为毫秒。
### 语音端点检测(实时)
```python
from funasr import AutoModel
chunk_size = 200 # ms
model = AutoModel(model="fsmn-vad", model_revision="v2.0.4")
model = AutoModel(model="fsmn-vad")
import soundfile
@@ -175,12 +175,18 @@
    if len(res[0]["value"]):
        print(res)
```
注:流式VAD模型输出格式为4种情况:
- `[[beg1, end1], [beg2, end2], .., [begN, endN]]`:同上离线VAD输出结果。
- `[[beg, -1]]`:表示只检测到起始点。
- `[[-1, end]]`:表示只检测到结束点。
- `[]`:表示既没有检测到起始点,也没有检测到结束点
输出结果单位为毫秒,从起始点开始的绝对时间。
### 标点恢复
```python
from funasr import AutoModel
model = AutoModel(model="ct-punc", model_revision="v2.0.4")
model = AutoModel(model="ct-punc")
res = model.generate(input="那今天的会就到这里吧 happy new year 明年见")
print(res)
@@ -190,7 +196,7 @@
```python
from funasr import AutoModel
model = AutoModel(model="fa-zh", model_revision="v2.0.0")
model = AutoModel(model="fa-zh")
wav_file = f"{model.model_path}/example/asr_example.wav"
text_file = f"{model.model_path}/example/text.txt"
examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
@@ -10,6 +10,8 @@
res = model.generate(input=wav_file)
print(res)
# [[beg1, end1], [beg2, end2], .., [begN, endN]]
# beg/end: ms
@@ -37,3 +39,9 @@
    # print(res)
    if len(res[0]["value"]):
        print(res)
# 1. [[beg1, end1], [beg2, end2], .., [begN, endN]]; [[beg, end]]; [[beg1, end1], [beg2, end2]]
# 2. [[beg, -1]]
# 3. [[-1, end]]
# beg/end: ms
funasr/auto/auto_model.py
@@ -162,8 +162,10 @@
            tokenizer_class = tables.tokenizer_classes.get(tokenizer)
            tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
            kwargs["tokenizer"] = tokenizer
            kwargs["token_list"] = tokenizer.token_list
            vocab_size = len(tokenizer.token_list)
            kwargs["token_list"] = tokenizer.token_list if hasattr(tokenizer, "token_list") else None
            kwargs["token_list"] = tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else kwargs["token_list"]
            vocab_size = len(kwargs["token_list"])
        else:
            vocab_size = -1
        
@@ -184,15 +186,18 @@
        # init_param
        init_param = kwargs.get("init_param", None)
        if init_param is not None:
            logging.info(f"Loading pretrained params from {init_param}")
            load_pretrained_model(
                model=model,
                path=init_param,
                ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
                oss_bucket=kwargs.get("oss_bucket", None),
                scope_map=kwargs.get("scope_map", None),
                excludes=kwargs.get("excludes", None),
            )
            if os.path.exists(init_param):
                logging.info(f"Loading pretrained params from {init_param}")
                load_pretrained_model(
                    model=model,
                    path=init_param,
                    ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
                    oss_bucket=kwargs.get("oss_bucket", None),
                    scope_map=kwargs.get("scope_map", None),
                    excludes=kwargs.get("excludes", None),
                )
            else:
                print(f"error, init_param does not exist!: {init_param}")
        
        return model, kwargs
    
funasr/bin/train.py
@@ -85,7 +85,9 @@
    # build model
    model_class = tables.model_classes.get(kwargs["model"])
    model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
    vocab_size = len(tokenizer.token_list) if hasattr(tokenizer, "token_list") else None
    vocab_size = len(tokenizer.get_vocab()) if hasattr(tokenizer, "get_vocab") else vocab_size
    model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
@@ -108,8 +110,10 @@
                )
            else:
                logging.info(f"Checkpoint does not exist, init randomly: {p}")
    else:
    elif kwargs.get("init", None):
        initialize(model, kwargs.get("init", "kaiming_normal"))
    else:
        print("No initialize method")
    # freeze_param
funasr/datasets/llm_datasets/__init__.py
funasr/datasets/llm_datasets/datasets.py
New file
@@ -0,0 +1,131 @@
import torch
import copy
from funasr.register import tables
from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
@tables.register("dataset_classes", "AudioLLMDataset")
class AudioLLMDataset(torch.utils.data.Dataset):
    """
    AudioLLMDataset
    """
    def __init__(self,
                 path,
                 index_ds: str = None,
                 frontend=None,
                 tokenizer=None,
                 int_pad_value: int = -1,
                 float_pad_value: float = 0.0,
                  **kwargs):
        super().__init__()
        index_ds_class = tables.index_ds_classes.get(index_ds)
        self.index_ds = index_ds_class(path, **kwargs)
        preprocessor_speech = kwargs.get("preprocessor_speech", None)
        if preprocessor_speech:
            preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
            preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf", {}))
        self.preprocessor_speech = preprocessor_speech
        preprocessor_text = kwargs.get("preprocessor_text", None)
        if preprocessor_text:
            preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
            preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf", {}))
        self.preprocessor_text = preprocessor_text
        self.frontend = frontend
        self.fs = 16000 if frontend is None else frontend.fs
        self.data_type = "sound"
        self.tokenizer = tokenizer
        self.float_pad_value = float_pad_value
        self.prompt = kwargs.get("prompt", "Transcribe speech to text.")
        self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(
            self.prompt)  # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
        self.prompt_af = ""
        self.IGNORE_INDEX = kwargs.get("IGNORE_INDEX", -100)
        self.int_pad_value = self.IGNORE_INDEX
    def get_source_len(self, index):
        item = self.index_ds[index]
        return self.index_ds.get_source_len(item)
    def get_target_len(self, index):
        item = self.index_ds[index]
        return self.index_ds.get_target_len(item)
    def __len__(self):
        return len(self.index_ds)
    def __getitem__(self, index):
        item = self.index_ds[index]
        # import pdb;
        # pdb.set_trace()
        source = item["source"]
        data_src = load_audio_text_image_video(source, fs=self.fs)
        if self.preprocessor_speech:
            data_src = self.preprocessor_speech(data_src, fs=self.fs)
        speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d]
        speech = speech.squeeze(0)
        target = item["target"]
        if self.preprocessor_text:
            target = self.preprocessor_text(target)
        prompt_ids_pre = self.tokenizer.encode(self.prompt_pre) # [bos,prompt]
        prompt_pre_length = len(prompt_ids_pre)
        prompt_input = "{}{}".format(self.prompt_pre, target)
        prompt_input_ids = self.tokenizer.encode(prompt_input)
        audio_length = len(prompt_input_ids) - prompt_pre_length
        input_ids = prompt_input_ids + [self.tokenizer.pad_token_id]
        input_ids = torch.tensor(input_ids, dtype=torch.int64) #[bos, prompt, input, pad]
        input_ids[prompt_pre_length:] = -1  # [bos, prompt,-1,-1]
        attention_mask = input_ids.ge(-1) # [true, true, true, true], length mask
        prompt_answer = "{}{}".format(self.prompt_pre, target)
        prompt_answer_ids = self.tokenizer.encode(prompt_answer)
        answer_length = len(prompt_answer_ids) - prompt_pre_length
        labels_ids = copy.deepcopy(prompt_input_ids) + [self.tokenizer.eos_token_id]
        labels_ids = torch.tensor(labels_ids, dtype=torch.int64)  # [bos, prompt, input, eos]
        labels_ids[:prompt_pre_length] = -1  # [-1, -1, input, eos]
        label_mask = labels_ids.ge(0)  # [False,False,True,True]
        labels_ids[~label_mask] = self.IGNORE_INDEX  # [-100,-100,input,eos]
        audio_mask = [0] * prompt_pre_length + [1] * audio_length + [0]
        audio_mask = torch.tensor(audio_mask, dtype=torch.float32)
        ids = self.tokenizer.encode(target) # token ids is different from labels_ids
        text = torch.tensor(ids, dtype=torch.int64)
        text_lengths = torch.tensor([len(ids)], dtype=torch.int32)
        return {"speech": speech,
                "speech_lengths": speech_lengths,
                "text": text,
                "text_lengths": text_lengths,
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels_ids": labels_ids,
                "label_mask": label_mask,
                "audio_mask": audio_mask,
                }
    def collator(self, samples: list=None):
        outputs = {}
        for sample in samples:
            for key in sample.keys():
                if key not in outputs:
                    outputs[key] = []
                outputs[key].append(sample[key])
        for key, data_list in outputs.items():
            if isinstance(data_list[0], torch.Tensor):
                if data_list[0].dtype == torch.int64:
                    pad_value = self.int_pad_value
                else:
                    pad_value = self.float_pad_value
                outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
        return outputs
funasr/datasets/llm_datasets/preprocessor.py
New file
@@ -0,0 +1,37 @@
import os
import json
import torch
import logging
import concurrent.futures
import librosa
import torch.distributed as dist
from typing import Collection
import torch
import torchaudio
from torch import nn
import random
import re
import string
from funasr.tokenizer.cleaner import TextCleaner
from funasr.register import tables
@tables.register("preprocessor_classes", "TextPreprocessRemovePunctuation")
class TextPreprocessSegDict(nn.Module):
    def __init__(self,
                 **kwargs):
        super().__init__()
    def forward(self, text, **kwargs):
        # 定义英文标点符号
        en_punct = string.punctuation
        # 定义中文标点符号(部分常用的)
        cn_punct = '。?!,、;:“”‘’()《》【】…—~·'
        # 合并英文和中文标点符号
        all_punct = en_punct + cn_punct
        # 创建正则表达式模式,匹配任何在all_punct中的字符
        punct_pattern = re.compile('[{}]'.format(re.escape(all_punct)))
        # 使用正则表达式的sub方法替换掉这些字符
        return punct_pattern.sub('', text)
funasr/datasets/llm_datasets/samplers.py
New file
@@ -0,0 +1,277 @@
import torch
import numpy as np
import logging
import torch.distributed as dist
from funasr.register import tables
@tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
class BatchSampler(torch.utils.data.BatchSampler):
    def __init__(self, dataset,
                 batch_type: str = "example",
                 batch_size: int = 100,
                 buffer_size: int = 30,
                 drop_last: bool = False,
                 shuffle: bool = True,
                 is_training: bool = True,
                 **kwargs):
        self.drop_last = drop_last
        self.pre_idx = -1
        self.dataset = dataset
        self.total_samples = len(dataset)
        self.batch_type = batch_type
        self.batch_size = int(batch_size)
        self.buffer_size = buffer_size
        self.max_token_length = kwargs.get("max_token_length", 5000)
        self.shuffle_idx = np.arange(self.total_samples)
        self.shuffle = shuffle and is_training
        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
    def __len__(self):
        return (self.total_samples-1) // self.batch_size + 1
    def set_epoch(self, epoch):
        np.random.seed(epoch)
    def __iter__(self):
        if self.shuffle:
            np.random.shuffle(self.shuffle_idx)
        batch = []
        max_token = 0
        num_sample = 0
        iter_num = (self.total_samples - 1) // self.buffer_size + 1
        # print("iter_num: ", iter_num)
        for iter in range(self.pre_idx + 1, iter_num):
            datalen_with_index = []
            for i in range(self.buffer_size):
                idx = iter * self.buffer_size + i
                if idx >= self.total_samples:
                    continue
                idx_map = self.shuffle_idx[idx]
                # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
                target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
                source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
                sample_len_cur = source_len + target_len
                datalen_with_index.append([idx, sample_len_cur])
            datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
            for item in datalen_with_index_sort:
                idx, sample_len_cur_raw = item
                if sample_len_cur_raw > self.max_token_length:
                    continue
                max_token_cur = max(max_token, sample_len_cur_raw)
                max_token_padding = 1 + num_sample
                if self.batch_type != 'example':
                    max_token_padding *= max_token_cur
                if max_token_padding <= self.batch_size:
                    batch.append(idx)
                    max_token = max_token_cur
                    num_sample += 1
                else:
                    yield batch
                    batch = [idx]
                    max_token = sample_len_cur_raw
                    num_sample = 1
@tables.register("batch_sampler_classes", "BatchSampler")
@tables.register("batch_sampler_classes", "RankFullLocalShuffleBatchSampler")
class RankFullLocalShuffleBatchSampler(torch.utils.data.BatchSampler):
    def __init__(self, dataset,
                 batch_type: str = "example",
                 batch_size: int = 100,
                 buffer_size: int = 30,
                 drop_last: bool = True,
                 shuffle: bool = True,
                 is_training: bool = True,
                 **kwargs):
        self.drop_last = drop_last
        self.pre_idx = -1
        self.dataset = dataset
        self.total_samples = len(dataset)
        self.batch_type = batch_type
        self.batch_size = int(batch_size)
        self.buffer_size = buffer_size
        self.max_token_length = kwargs.get("max_token_length", 1500)
        self.shuffle_idx = np.arange(self.total_samples)
        self.shuffle = shuffle and is_training
        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
        try:
            rank = dist.get_rank()
            world_size = dist.get_world_size()
        except:
            rank = 0
            world_size = 1
        self.rank = rank
        self.world_size = world_size
    def __len__(self):
        return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
    def set_epoch(self, epoch):
        np.random.seed(epoch)
    def __iter__(self):
        batch_size_total = self.batch_size * self.world_size
        if self.shuffle:
            np.random.shuffle(self.shuffle_idx)
        batch = []
        max_token = 0
        num_sample = 0
        iter_num = (self.total_samples - 1) // self.buffer_size + 1
        # print("iter_num: ", iter_num)
        for iter in range(self.pre_idx + 1, iter_num):
            # if iter == iter_num -1 and self.drop_last:
            #     continue
            datalen_with_index = []
            for i in range(self.buffer_size):
                idx = iter * self.buffer_size + i
                if idx >= self.total_samples:
                    continue
                idx_map = self.shuffle_idx[idx]
                # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
                source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
                target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
                sample_len_cur = source_len + target_len
                datalen_with_index.append([idx, sample_len_cur])
            datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
            for item in datalen_with_index_sort:
                idx, sample_len_cur_raw = item
                if sample_len_cur_raw > self.max_token_length:
                    continue
                max_token_cur = max(max_token, sample_len_cur_raw)
                max_token_padding = 1 + num_sample
                # if self.batch_type != 'example':
                #     max_token_padding *= max_token_cur
                if max_token_padding <= batch_size_total:
                    batch.append(idx)
                    max_token = max_token_cur
                    num_sample += 1
                else:
                    batch_rank = batch[self.rank*self.batch_size: (self.rank+1)*self.batch_size]
                    yield batch_rank
                    batch = [idx]
                    max_token = sample_len_cur_raw
                    num_sample = 1
@tables.register("batch_sampler_classes", "RankFullLocalShuffleDynamicBatchSampler")
class RankFullLocalShuffleDynamicBatchSampler(torch.utils.data.BatchSampler):
    def __init__(self, dataset,
                 batch_type: str = "example",
                 batch_size: int = 100,
                 buffer_size: int = 30,
                 drop_last: bool = True,
                 shuffle: bool = True,
                 is_training: bool = True,
                 **kwargs):
        self.drop_last = drop_last
        self.pre_idx = -1
        self.dataset = dataset
        self.total_samples = len(dataset)
        self.batch_type = batch_type
        self.batch_size = int(batch_size)
        self.buffer_size = buffer_size
        self.max_token_length = kwargs.get("max_token_length", 1500)
        self.shuffle_idx = np.arange(self.total_samples)
        self.shuffle = shuffle and is_training
        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
        try:
            rank = dist.get_rank()
            world_size = dist.get_world_size()
        except:
            rank = 0
            world_size = 1
        self.rank = rank
        self.world_size = world_size
    def __len__(self):
        return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
    def set_epoch(self, epoch):
        np.random.seed(epoch)
    def __iter__(self):
        batch_size_total = self.batch_size * self.world_size
        if self.shuffle:
            np.random.shuffle(self.shuffle_idx)
        batch_list_all_rank = []
        batch_list_cur = []
        max_token = 0
        num_sample = 0
        iter_num = (self.total_samples - 1) // self.buffer_size + 1
        # print("iter_num: ", iter_num)
        for iter in range(self.pre_idx + 1, iter_num):
            # if iter == iter_num - 1 and self.drop_last:
            #     continue
            datalen_with_index = []
            for i in range(self.buffer_size):
                idx = iter * self.buffer_size + i
                if idx >= self.total_samples:
                    continue
                idx_map = self.shuffle_idx[idx]
                # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
                source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
                target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
                sample_len_cur = source_len + target_len
                datalen_with_index.append([idx, sample_len_cur])
            datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
            for ii, item in enumerate(datalen_with_index_sort):
                is_last_batch = iter == iter_num - 1 and ii == len(datalen_with_index_sort)
                idx, sample_len_cur_raw = item
                if sample_len_cur_raw > self.max_token_length:
                    continue
                max_token_cur = max(max_token, sample_len_cur_raw)
                max_token_padding = 1 + num_sample
                if self.batch_type != 'example':
                    max_token_padding *= max_token_cur
                if len(batch_list_all_rank) < self.world_size:
                    if max_token_padding <= self.batch_size:
                        batch_list_cur.append(idx)
                        max_token = max_token_cur
                        num_sample += 1
                    else:
                        batch_list_all_rank.append(batch_list_cur)
                        batch_list_cur = []
                else:
                    batch_rank = batch_list_all_rank[self.rank]
                    yield batch_rank
                    batch_list_all_rank = [idx]
                    max_token = sample_len_cur_raw
                    num_sample = 1
funasr/metrics/compute_acc.py
@@ -21,3 +21,20 @@
    )
    denominator = torch.sum(mask)
    return float(numerator) / float(denominator)
def compute_accuracy(pad_outputs, pad_targets, ignore_label):
    """Calculate accuracy.
    Args:
        pad_outputs (LongTensor): Prediction tensors (B, Lmax).
        pad_targets (LongTensor): Target label tensors (B, Lmax).
        ignore_label (int): Ignore label id.
    Returns:
        float: Accuracy value (0.0 - 1.0).
    """
    mask = pad_targets != ignore_label
    numerator = torch.sum(pad_outputs.masked_select(mask) == pad_targets.masked_select(mask))
    denominator = torch.sum(mask)
    return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type
funasr/models/llm_asr/__init__.py
funasr/models/llm_asr/adaptor.py
New file
@@ -0,0 +1,29 @@
import torch
import torch.nn as nn
from funasr.register import tables
@tables.register("adaptor_classes", "Linear")
class Linear(nn.Module):
    def __init__(self, downsample_rate, encoder_dim, llm_dim, ffn_dim: int = 2048, **kwargs):
        super().__init__()
        self.k = downsample_rate
        self.encoder_dim = encoder_dim
        self.llm_dim = llm_dim
        self.linear1 = nn.Linear(self.encoder_dim * self.k, ffn_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(ffn_dim, self.llm_dim)
    def forward(self, x):
        batch_size, seq_len, dim = x.size()
        num_frames_to_discard = seq_len % self.k
        if num_frames_to_discard > 0:
            x = x[:, :-num_frames_to_discard, :]
        seq_len = x.size(1)
        x = x.contiguous()
        x = x.view(batch_size, seq_len // self.k, dim * self.k)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x
funasr/models/llm_asr/model.py
New file
@@ -0,0 +1,333 @@
import logging
from typing import Union, Dict, List, Tuple, Optional
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast
from funasr.models.scama.utils import sequence_mask
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.ctc.ctc import CTC
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.metrics.compute_acc import th_accuracy, compute_accuracy
# from funasr.models.e2e_asr_common import ErrorCalculator
from funasr.train_utils.device_funcs import force_gatherable
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.register import tables
@tables.register("model_classes", "LLMASRNAR")
class LLMASRNAR(nn.Module):
    """ """
    def __init__(
        self,
        specaug: str = None,
        specaug_conf: dict = None,
        normalize: str = None,
        normalize_conf: dict = None,
        encoder: str = None,
        encoder_conf: dict = None,
        decoder: str = None,
        decoder_conf: dict = None,
        ctc: str = None,
        ctc_conf: dict = None,
        ctc_weight: float = 0.5,
        llm: str = None,
        llm_conf: dict = None,
        adaptor: str = None,
        adaptor_conf: dict = None,
        input_size: int = 80,
        vocab_size: int = -1,
        ignore_id: int = -1,
        blank_id: int = 0,
        sos: int = 1,
        eos: int = 2,
        lsm_weight: float = 0.0,
        length_normalized_loss: bool = False,
        report_cer: bool = True,
        report_wer: bool = True,
        sym_space: str = "<space>",
        sym_blank: str = "<blank>",
        # extract_feats_in_collect_stats: bool = True,
        share_embedding: bool = False,
        # preencoder: Optional[AbsPreEncoder] = None,
        # postencoder: Optional[AbsPostEncoder] = None,
        **kwargs,
    ):
        super().__init__()
        if specaug is not None:
            specaug_class = tables.specaug_classes.get(specaug)
            specaug = specaug_class(**specaug_conf)
        if normalize is not None:
            normalize_class = tables.normalize_classes.get(normalize)
            normalize = normalize_class(**normalize_conf)
        # audio encoder
        hub = encoder_conf.get("hub", None)
        if hub == "funasr":
            from funasr import AutoModel
            init_param_path = encoder_conf.get("init_param_path", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
            model = AutoModel(model=init_param_path, model_revision="v2.0.4")
            # frontend = model.kwargs.get("frontend")
            model.model.decoder = None
            self.audio_encoder = model.model
            # self.frontend = frontend
        elif hub == "hf":
            pass
        else:
            encoder_class = tables.encoder_classes.get(encoder)
            encoder = encoder_class(input_size=input_size, **encoder_conf)
            encoder_output_size = encoder.output_size()
        # llm
        hub = llm_conf.get("hub", "hf")
        self.llm = None
        if hub == "hf":
            from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
            init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
            model = AutoModelForCausalLM.from_pretrained(
                init_param_path,
                load_in_8bit=None,
                device_map=None,
                use_cache=None,
            )
            freeze = llm_conf.get("freeze", True)
            if freeze:
                for name, param in model.named_parameters():
                    param.requires_grad = False
                model.eval()
            self.llm = model
        # adaptor
        adaptor_class = tables.adaptor_classes.get(adaptor)
        adaptor = adaptor_class(**adaptor_conf)
        self.adaptor = adaptor
        self.blank_id = blank_id
        self.sos = sos if sos is not None else vocab_size - 1
        self.eos = eos if eos is not None else vocab_size - 1
        self.vocab_size = vocab_size
        self.ignore_id = ignore_id
        self.specaug = specaug
        self.normalize = normalize
        self.encoder = encoder
        self.criterion_att = LabelSmoothingLoss(
            size=vocab_size,
            padding_idx=ignore_id,
            smoothing=lsm_weight,
            normalize_length=length_normalized_loss,
        )
        #
        # if report_cer or report_wer:
        #     self.error_calculator = ErrorCalculator(
        #         token_list, sym_space, sym_blank, report_cer, report_wer
        #     )
        #
        self.error_calculator = None
        self.length_normalized_loss = length_normalized_loss
        self.beam_search = None
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask:torch.Tensor,
        labels_ids: torch.Tensor,
        label_mask: torch.Tensor,
        audio_mask: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        # import pdb;
        # pdb.set_trace()
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]
        batch_size = speech.shape[0]
        # audio encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, audio_mask=audio_mask)
        # adaptor
        encoder_out = self.adaptor(encoder_out)
        if input_ids is not None:
            input_ids[input_ids == -1] = 0
            input_ids[input_ids == -100] = 0
            if hasattr(self.llm.model, "embed_tokens"):
                inputs_embeds = self.llm.model.embed_tokens(input_ids)
            elif hasattr(self.llm.model.model, "embed_tokens"):
                inputs_embeds = self.llm.model.model.embed_tokens(input_ids)
            else:
                inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)
            if audio_mask is not None:
                batch_size, token_num, dims = inputs_embeds.shape
                _, l, _ = encoder_out.shape
                encoder_outs_pad = F.pad(encoder_out, (0, 0, token_num-l-1, 1, 0, 0), value=0.0)
                inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (1.0-audio_mask[:, :, None])
                inputs_embeds = F.pad(inputs_embeds[:, 1:, :], (0, 0, 0, 1, 0, 0), value=0.0)
        model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids)
        loss = model_outputs.loss
        stats = {}
        with torch.no_grad():
            preds = torch.argmax(model_outputs.logits, -1)
            acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
            stats["acc"] = acc_att
        stats["loss"] = torch.clone(loss.detach())
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
            batch_size = int((text_lengths + 1).sum())
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
    def encode(
        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        audio_mask = kwargs.get("audio_mask", None)
        audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
        batch = {"speech": speech, "speech_lengths": speech_lengths}
        enc, enc_lens = self.audio_encoder.encode(**batch)
        with autocast(False):
            enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
            pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(enc,
                                                                               mask=enc_mask,
                                                                               target_label_length=audio_token_lengths,
                                                                               )
        return pre_acoustic_embeds, pre_token_length
    def inference(self,
                  data_in,
                  data_lengths=None,
                  key: list = None,
                  tokenizer=None,
                  frontend=None,
                  **kwargs,
                  ):
        prompt = kwargs.get("prompt", "Transcribe speech to text.")
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")
        meta_data = {}
        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
            speech, speech_lengths = data_in, data_lengths
            if len(speech.shape) < 3:
                speech = speech[None, :, :]
            if speech_lengths is None:
                speech_lengths = speech.shape[1]
        else:
            # extract fbank feats
            time1 = time.perf_counter()
            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
                                                            data_type=kwargs.get("data_type", "sound"),
                                                            tokenizer=tokenizer)
            time2 = time.perf_counter()
            meta_data["load_data"] = f"{time2 - time1:0.3f}"
            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
                                                   frontend=frontend)
            time3 = time.perf_counter()
            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
            meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
        speech = speech.to(device=kwargs["device"])
        speech_lengths = speech_lengths.to(device=kwargs["device"])
        # Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        # adaptor
        encoder_out = self.adaptor(encoder_out)
        prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(prompt)
        prompt_ids = tokenizer.encode(prompt_pre)
        prompt_length = len(prompt_ids)
        prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(kwargs["device"])
        if hasattr(self.llm.model, "embed_tokens"):
            inputs_embeds = self.llm.model.embed_tokens(prompt_ids)
        elif hasattr(self.llm.model.model, "embed_tokens"):
            inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids)
        else:
            inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)
        inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1)  # [prompt, audio]
        attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"])
        model_outputs = self.llm.generate(
            inputs_embeds=inputs_embeds,
            max_length=kwargs.get("max_length", 200),
            max_new_tokens=kwargs.get("max_new_tokens", 200),
            num_beams=kwargs.get("num_beams", 4),
            do_sample=kwargs.get("do_sample", False),
            min_length=kwargs.get("min_length", 1),
            top_p=kwargs.get("top_p", 1.0),
            repetition_penalty=kwargs.get("repetition_penalty", 1.0),
            length_penalty=kwargs.get("length_penalty", 1.0),
            temperature=kwargs.get("temperature", 1.0),
            attention_mask=attention_mask,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id
        )
        text = tokenizer.batch_decode(model_outputs, add_special_tokens=False, skip_special_tokens=True)
        ibest_writer = None
        if kwargs.get("output_dir") is not None:
            if not hasattr(self, "writer"):
                self.writer = DatadirWriter(kwargs.get("output_dir"))
            ibest_writer = self.writer[f"{0 + 1}best_recog"]
        results = []
        result_i = {"key": key[0], "text": text}
        results.append(result_i)
        if ibest_writer is not None:
            ibest_writer["text"][key[0]] = text
        return results, meta_data
funasr/models/paraformer/cif_predictor.py
@@ -10,7 +10,7 @@
from funasr.register import tables
from funasr.train_utils.device_funcs import to_device
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from torch.cuda.amp import autocast
@tables.register("predictor_classes", "CifPredictor")
class CifPredictor(torch.nn.Module):
@@ -28,42 +28,44 @@
    def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
                target_label_length=None):
        h = hidden
        context = h.transpose(1, 2)
        queries = self.pad(context)
        memory = self.cif_conv1d(queries)
        output = memory + context
        output = self.dropout(output)
        output = output.transpose(1, 2)
        output = torch.relu(output)
        output = self.cif_output(output)
        alphas = torch.sigmoid(output)
        alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
        if mask is not None:
            mask = mask.transpose(-1, -2).float()
            alphas = alphas * mask
        if mask_chunk_predictor is not None:
            alphas = alphas * mask_chunk_predictor
        alphas = alphas.squeeze(-1)
        mask = mask.squeeze(-1)
        if target_label_length is not None:
            target_length = target_label_length
        elif target_label is not None:
            target_length = (target_label != ignore_id).float().sum(-1)
        else:
            target_length = None
        token_num = alphas.sum(-1)
        if target_length is not None:
            alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
        elif self.tail_threshold > 0.0:
            hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
        with autocast(False):
            h = hidden
            context = h.transpose(1, 2)
            queries = self.pad(context)
            memory = self.cif_conv1d(queries)
            output = memory + context
            output = self.dropout(output)
            output = output.transpose(1, 2)
            output = torch.relu(output)
            output = self.cif_output(output)
            alphas = torch.sigmoid(output)
            alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
            if mask is not None:
                mask = mask.transpose(-1, -2).float()
                alphas = alphas * mask
            if mask_chunk_predictor is not None:
                alphas = alphas * mask_chunk_predictor
            alphas = alphas.squeeze(-1)
            mask = mask.squeeze(-1)
            if target_label_length is not None:
                target_length = target_label_length
            elif target_label is not None:
                target_length = (target_label != ignore_id).float().sum(-1)
            else:
                target_length = None
            token_num = alphas.sum(-1)
            if target_length is not None:
                alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
            elif self.tail_threshold > 0.0:
                hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
            acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
            
        acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
        if target_length is None and self.tail_threshold > 0.0:
            token_num_int = torch.max(token_num).type(torch.int32).item()
            acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
            if target_length is None and self.tail_threshold > 0.0:
                token_num_int = torch.max(token_num).type(torch.int32).item()
                acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
        return acoustic_embeds, token_num, alphas, cif_peak
    def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
@@ -169,41 +171,43 @@
    def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
                target_label_length=None):
        h = hidden
        context = h.transpose(1, 2)
        queries = self.pad(context)
        output = torch.relu(self.cif_conv1d(queries))
        output = output.transpose(1, 2)
        output = self.cif_output(output)
        alphas = torch.sigmoid(output)
        alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
        if mask is not None:
            mask = mask.transpose(-1, -2).float()
            alphas = alphas * mask
        if mask_chunk_predictor is not None:
            alphas = alphas * mask_chunk_predictor
        alphas = alphas.squeeze(-1)
        mask = mask.squeeze(-1)
        if target_label_length is not None:
            target_length = target_label_length.squeeze(-1)
        elif target_label is not None:
            target_length = (target_label != ignore_id).float().sum(-1)
        else:
            target_length = None
        token_num = alphas.sum(-1)
        if target_length is not None:
            alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
        elif self.tail_threshold > 0.0:
            if self.tail_mask:
                hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
        with autocast(False):
            h = hidden
            context = h.transpose(1, 2)
            queries = self.pad(context)
            output = torch.relu(self.cif_conv1d(queries))
            output = output.transpose(1, 2)
            output = self.cif_output(output)
            alphas = torch.sigmoid(output)
            alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
            if mask is not None:
                mask = mask.transpose(-1, -2).float()
                alphas = alphas * mask
            if mask_chunk_predictor is not None:
                alphas = alphas * mask_chunk_predictor
            alphas = alphas.squeeze(-1)
            mask = mask.squeeze(-1)
            if target_label_length is not None:
                target_length = target_label_length.squeeze(-1)
            elif target_label is not None:
                target_length = (target_label != ignore_id).float().sum(-1)
            else:
                hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=None)
        acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
        if target_length is None and self.tail_threshold > 0.0:
            token_num_int = torch.max(token_num).type(torch.int32).item()
            acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
                target_length = None
            token_num = alphas.sum(-1)
            if target_length is not None:
                alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
            elif self.tail_threshold > 0.0:
                if self.tail_mask:
                    hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
                else:
                    hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=None)
            acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
            if target_length is None and self.tail_threshold > 0.0:
                token_num_int = torch.max(token_num).type(torch.int32).item()
                acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
        return acoustic_embeds, token_num, alphas, cif_peak
@@ -370,62 +374,6 @@
        predictor_alignments = index_div_bool_zeros_count_tile_out
        predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
        return predictor_alignments.detach(), predictor_alignments_length.detach()
    def gen_tf2torch_map_dict(self):
        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
        map_dict_local = {
            ## predictor
            "{}.cif_conv1d.weight".format(tensor_name_prefix_torch):
                {"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": (2, 1, 0),
                 },  # (256,256,3),(3,256,256)
            "{}.cif_conv1d.bias".format(tensor_name_prefix_torch):
                {"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.cif_output.weight".format(tensor_name_prefix_torch):
                {"name": "{}/conv1d_1/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (1,256),(1,256,1)
            "{}.cif_output.bias".format(tensor_name_prefix_torch):
                {"name": "{}/conv1d_1/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (1,),(1,)
        }
        return map_dict_local
    def convert_tf2torch(self,
                         var_dict_tf,
                         var_dict_torch,
                         ):
        map_dict = self.gen_tf2torch_map_dict()
        var_dict_torch_update = dict()
        for name in sorted(var_dict_torch.keys(), reverse=False):
            names = name.split('.')
            if names[0] == self.tf2torch_tensor_name_prefix_torch:
                name_tf = map_dict[name]["name"]
                data_tf = var_dict_tf[name_tf]
                if map_dict[name]["squeeze"] is not None:
                    data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
                if map_dict[name]["transpose"] is not None:
                    data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
                data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
                                                                                                var_dict_torch[
                                                                                                    name].size(),
                                                                                                data_tf.size())
                var_dict_torch_update[name] = data_tf
                logging.info(
                    "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
                                                                                  var_dict_tf[name_tf].shape))
        return var_dict_torch_update
class mae_loss(torch.nn.Module):
funasr/tokenizer/hf_tokenizer.py
New file
@@ -0,0 +1,15 @@
try:
    from transformers import AutoTokenizer
except:
    print("If you want to use hugging, please `pip install -U transformers`")
from funasr.register import tables
@tables.register("tokenizer_classes", "HuggingfaceTokenizer")
def HuggingfaceTokenizer(init_param_path, **kwargs):
    tokenizer = AutoTokenizer.from_pretrained(init_param_path)
    return tokenizer
funasr/train_utils/load_pretrained_model.py
@@ -118,7 +118,7 @@
        else:
            print(f"Warning, miss key in ckpt: {k}, mapped: {k_ddp}")
            
    flag = obj.load_state_dict(dst_state, strict=True)
    flag = obj.load_state_dict(dst_state, strict=False)
    # print(flag)
# def load_pretrained_model(
funasr/train_utils/trainer.py
@@ -5,7 +5,8 @@
from tqdm import tqdm
from datetime import datetime
import torch.distributed as dist
from contextlib import nullcontext
from torch.cuda.amp import autocast, GradScaler
from contextlib import nullcontext, contextmanager
# from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter
from pathlib import Path
@@ -13,6 +14,15 @@
from funasr.train_utils.device_funcs import to_device
from funasr.train_utils.recursive_op import recursive_average
from funasr.train_utils.average_nbest_models import average_checkpoints
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
@contextmanager
def maybe_autocast(enabled):
    if enabled:
        with autocast():
            yield
    else:
        yield
class Trainer:
    """
@@ -36,8 +46,9 @@
                 dataloader_train,
                 dataloader_val,
                 local_rank,
                 use_ddp=False,
                 use_fsdp=False,
                 use_ddp: bool = False,
                 use_fsdp: bool = False,
                 use_fp16: bool = False,
                 output_dir: str="./",
                 **kwargs):
        """
@@ -72,6 +83,11 @@
        self.kwargs = kwargs
        self.log_interval = kwargs.get("log_interval", 50)
        self.batch_total = 0
        self.use_fp16 = use_fp16
        self.disable_gpu_cache = kwargs.get("disable_gpu_cache", True)
        scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
        scaler = ShardedGradScaler(enabled=use_fp16) if use_ddp else scaler
        self.scaler = scaler
        
    
        try:
@@ -103,6 +119,8 @@
            'optimizer': self.optim.state_dict(),
            'scheduler': self.scheduler.state_dict(),
        }
        if self.scaler:
            state["scaler_state"] = self.scaler.state_dict()
        # Create output directory if it does not exist
        os.makedirs(self.output_dir, exist_ok=True)
        filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
@@ -141,6 +159,8 @@
            self.model.load_state_dict(dst_state)
            self.optim.load_state_dict(checkpoint['optimizer'])
            self.scheduler.load_state_dict(checkpoint['scheduler'])
            if self.scaler and 'scaler_state' in checkpoint:
                self.scaler.load_state_dict(checkpoint['scaler_state'])
            print(f"Checkpoint loaded successfully from '{ckpt}'")
        else:
            print(f"No checkpoint found at '{ckpt}', starting from scratch")
@@ -221,9 +241,10 @@
            my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
            with my_context():
                time2 = time.perf_counter()
                retval = self.model(**batch)
                torch.cuda.empty_cache()
                with maybe_autocast(self.use_fp16):
                    retval = self.model(**batch)
                if self.disable_gpu_cache: torch.cuda.empty_cache()
                time3 = time.perf_counter()
                speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
@@ -241,7 +262,10 @@
                    loss *= self.world_size
                # Scale the loss since we're not updating for every mini-batch
                loss = loss / accum_grad
                loss.backward()
                if self.use_fp16:
                    self.scaler.scale(loss).backward()
                else:
                    loss.backward()
                time4 = time.perf_counter()
                speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
            
@@ -264,10 +288,14 @@
                # Execute an optimization step (update model parameters)
                if self.use_ddp or self.use_fsdp:
                    dist.barrier()
                self.optim.step()
                if self.use_fp16:
                    self.scaler.step(self.optim)
                    self.scaler.update()
                else:
                    self.optim.step()
                self.scheduler.step()
                # Clear gradients for the next accumulation stage
                self.optim.zero_grad()
                self.optim.zero_grad(set_to_none=True)
                total_time = f"{time.perf_counter() - time5:0.3f}"
                time5 = time.perf_counter()
                speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
setup.py
@@ -40,11 +40,11 @@
        "umap_learn",
        "jaconv",
        "hydra-core>=1.3.2",
        "tensorboardX",
    ],
    # train: The modules invoked when training only.
    "train": [
        "editdistance",
        "tensorboardX",
    ],
    # all: The modules should be optionally installled due to some reason.
    #      Please consider moving them to "install" occasionally