zhifu gao
2024-03-04 d2c1204d91d7c98be7998e3966bd82e22750293b
Revert "Dev yf" (#1418)

10个文件已修改
11个文件已删除
1905 ■■■■■ 已修改文件
examples/industrial_data_pretraining/contextual_paraformer/demo.py 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/contextual_paraformer/demo.sh 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/contextual_paraformer/demo2.sh 9 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/contextual_paraformer/path.sh 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/lcbnet/compute_wer_details.py 702 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/lcbnet/demo.py 13 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/lcbnet/demo.sh 72 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/lcbnet/run_bwer_recall.sh 11 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/lcbnet/utils 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/seaco_paraformer/demo.py 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/auto/auto_model.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/frontends/default.py 20 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/conformer/encoder.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/contextual_paraformer/model.py 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/lcbnet/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/lcbnet/attention.py 112 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/lcbnet/encoder.py 392 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/lcbnet/model.py 495 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/seaco_paraformer/model.py 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/load_pretrained_model.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/load_utils.py 24 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/contextual_paraformer/demo.py
old mode 100755 new mode 100644
examples/industrial_data_pretraining/contextual_paraformer/demo.sh
old mode 100755 new mode 100644
@@ -2,7 +2,7 @@
model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"
model_revision="v2.0.4"
python ../../../funasr/bin/inference.py \
python funasr/bin/inference.py \
+model=${model} \
+model_revision=${model_revision} \
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" \
examples/industrial_data_pretraining/contextual_paraformer/demo2.sh
File was deleted
examples/industrial_data_pretraining/contextual_paraformer/path.sh
File was deleted
examples/industrial_data_pretraining/lcbnet/compute_wer_details.py
File was deleted
examples/industrial_data_pretraining/lcbnet/demo.py
File was deleted
examples/industrial_data_pretraining/lcbnet/demo.sh
File was deleted
examples/industrial_data_pretraining/lcbnet/run_bwer_recall.sh
File was deleted
examples/industrial_data_pretraining/lcbnet/utils
File was deleted
examples/industrial_data_pretraining/seaco_paraformer/demo.py
@@ -7,10 +7,10 @@
model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                  model_revision="v2.0.4",
                  # vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
                  # vad_model_revision="v2.0.4",
                  # punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
                  # punc_model_revision="v2.0.4",
                  vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
                  vad_model_revision="v2.0.4",
                  punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
                  punc_model_revision="v2.0.4",
                  # spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
                  # spk_model_revision="v2.0.2",
                  )
funasr/auto/auto_model.py
@@ -28,7 +28,7 @@
    from funasr.models.campplus.cluster_backend import ClusterBackend
except:
    print("If you want to use the speaker diarization, please `pip install hdbscan`")
import pdb
def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
    """
@@ -46,7 +46,6 @@
    chars = string.ascii_letters + string.digits
    if isinstance(data_in, str) and data_in.startswith('http'): # url
        data_in = download_from_url(data_in)
    if isinstance(data_in, str) and os.path.exists(data_in): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
        _, file_extension = os.path.splitext(data_in)
        file_extension = file_extension.lower()
@@ -169,6 +168,7 @@
            vocab_size = len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
        else:
            vocab_size = -1
        # build frontend
        frontend = kwargs.get("frontend", None)
        kwargs["input_size"] = None
@@ -181,6 +181,7 @@
        # build model
        model_class = tables.model_classes.get(kwargs["model"])
        model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
        model.to(device)
        
        # init_param
@@ -238,7 +239,6 @@
            data_batch = data_list[beg_idx:end_idx]
            key_batch = key_list[beg_idx:end_idx]
            batch = {"data_in": data_batch, "key": key_batch}
            if (end_idx - beg_idx) == 1 and kwargs.get("data_type", None) == "fbank": # fbank
                batch["data_in"] = data_batch[0]
                batch["data_lengths"] = input_len
funasr/frontends/default.py
@@ -3,6 +3,7 @@
from typing import Tuple
from typing import Union
import logging
import humanfriendly
import numpy as np
import torch
import torch.nn as nn
@@ -15,10 +16,8 @@
from funasr.frontends.utils.stft import Stft
from funasr.frontends.utils.frontend import Frontend
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.register import tables
@tables.register("frontend_classes", "DefaultFrontend")
class DefaultFrontend(nn.Module):
    """Conventional frontend structure for ASR.
    Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
@@ -26,7 +25,7 @@
    def __init__(
            self,
            fs: int = 16000,
            fs: Union[int, str] = 16000,
            n_fft: int = 512,
            win_length: int = None,
            hop_length: int = 128,
@@ -41,14 +40,14 @@
            frontend_conf: Optional[dict] = None,
            apply_stft: bool = True,
            use_channel: int = None,
            **kwargs,
    ):
        super().__init__()
        if isinstance(fs, str):
            fs = humanfriendly.parse_size(fs)
        # Deepcopy (In general, dict shouldn't be used as default arg)
        frontend_conf = copy.deepcopy(frontend_conf)
        self.hop_length = hop_length
        self.fs = fs
        if apply_stft:
            self.stft = Stft(
@@ -85,12 +84,8 @@
        return self.n_mels
    def forward(
            self, input: torch.Tensor, input_lengths:  Union[torch.Tensor, list]
            self, input: torch.Tensor, input_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if isinstance(input_lengths, list):
            input_lengths = torch.tensor(input_lengths)
        if  input.dtype == torch.float64:
            input = input.float()
        # 1. Domain-conversion: e.g. Stft: time -> time-freq
        if self.stft is not None:
            input_stft, feats_lens = self._compute_stft(input, input_lengths)
@@ -150,7 +145,7 @@
    def __init__(
            self,
            fs: int = 16000,
            fs: Union[int, str] = 16000,
            n_fft: int = 512,
            win_length: int = None,
            hop_length: int = None,
@@ -173,6 +168,9 @@
            mc: bool = True
    ):
        super().__init__()
        if isinstance(fs, str):
            fs = humanfriendly.parse_size(fs)
        # Deepcopy (In general, dict shouldn't be used as default arg)
        frontend_conf = copy.deepcopy(frontend_conf)
        if win_length is None and hop_length is None:
funasr/models/conformer/encoder.py
@@ -47,7 +47,7 @@
from funasr.models.transformer.utils.subsampling import Conv2dSubsamplingPad
from funasr.models.transformer.utils.subsampling import StreamingConvInput
from funasr.register import tables
import pdb
class ConvolutionModule(nn.Module):
    """ConvolutionModule in Conformer model.
funasr/models/contextual_paraformer/model.py
@@ -29,7 +29,7 @@
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
import pdb
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
@@ -62,6 +62,7 @@
        crit_attn_weight = kwargs.get("crit_attn_weight", 0.0)
        crit_attn_smooth = kwargs.get("crit_attn_smooth", 0.0)
        bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0)
        if bias_encoder_type == 'lstm':
            self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=bias_encoder_dropout_rate)
@@ -112,6 +113,7 @@
        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        loss_ctc, cer_ctc = None, None
        
        stats = dict()
@@ -125,6 +127,7 @@
            # Collect CTC branch stats
            stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
            stats["cer_ctc"] = cer_ctc
        
        # 2b. Attention decoder branch
        loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss(
@@ -168,19 +171,17 @@
    ):
        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
            encoder_out.device)
        if self.predictor_bias == 1:
            _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
            ys_pad_lens = ys_pad_lens + self.predictor_bias
        pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
                                                                     ignore_id=self.ignore_id)
        # -1. bias encoder
        if self.use_decoder_embedding:
            hw_embed = self.decoder.embed(hotword_pad)
        else:
            hw_embed = self.bias_embed(hotword_pad)
        hw_embed, (_, _) = self.bias_encoder(hw_embed)
        _ind = np.arange(0, hotword_pad.shape[0]).tolist()
        selected = hw_embed[_ind, [i - 1 for i in hotword_lengths.detach().cpu().tolist()]]
@@ -291,7 +292,6 @@
        decoder_outs = self.decoder(
            encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale
        )
        decoder_out = decoder_outs[0]
        decoder_out = torch.log_softmax(decoder_out, dim=-1)
        return decoder_out, ys_pad_lens
@@ -305,7 +305,6 @@
                 **kwargs,
                 ):
        # init beamsearch
        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
        is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        if self.beam_search is None and (is_use_lm or is_use_ctc):
@@ -317,12 +316,9 @@
        
        # extract fbank feats
        time1 = time.perf_counter()
        audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
        time2 = time.perf_counter()
        meta_data["load_data"] = f"{time2 - time1:0.3f}"
        speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
                                               frontend=frontend)
        time3 = time.perf_counter()
@@ -349,6 +345,7 @@
        if torch.max(pre_token_length) < 1:
            return []
        
        decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens,
                                                                 pre_acoustic_embeds,
                                                                 pre_token_length,
funasr/models/lcbnet/__init__.py
funasr/models/lcbnet/attention.py
File was deleted
funasr/models/lcbnet/encoder.py
File was deleted
funasr/models/lcbnet/model.py
File was deleted
funasr/models/seaco_paraformer/model.py
@@ -30,7 +30,7 @@
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
import pdb
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
else:
@@ -209,16 +209,13 @@
                               nfilter=50,
                               seaco_weight=1.0):
        # decoder forward
        decoder_out, decoder_hidden, _ = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, return_hidden=True, return_both=True)
        decoder_pred = torch.log_softmax(decoder_out, dim=-1)
        if hw_list is not None:
            hw_lengths = [len(i) for i in hw_list]
            hw_list_ = [torch.Tensor(i).long() for i in hw_list]
            hw_list_pad = pad_list(hw_list_, 0).to(encoder_out.device)
            selected = self._hotword_representation(hw_list_pad, torch.Tensor(hw_lengths).int().to(encoder_out.device))
            contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device)
            num_hot_word = contextual_info.shape[1]
            _contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device)
@@ -256,8 +253,8 @@
                # logits = dec_output * dha_mask + dha_output[:,:,:-1] * (1-dha_mask)
                logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask)
                return logits
            merged_pred = _merge_res(decoder_pred, dha_pred)
            # import pdb; pdb.set_trace()
            return merged_pred
        else:
            return decoder_pred
@@ -307,6 +304,7 @@
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)
        meta_data = {}
        
        # extract fbank feats
@@ -332,7 +330,6 @@
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]
        
        # predictor
        predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
        pre_acoustic_embeds, pre_token_length, _, _ = predictor_outs[0], predictor_outs[1], \
@@ -341,14 +338,15 @@
        if torch.max(pre_token_length) < 1:
            return []
        decoder_out = self._seaco_decode_with_ASF(encoder_out, encoder_out_lens,
                                                   pre_acoustic_embeds,
                                                   pre_token_length,
                                                   hw_list=self.hotword_list)
        # decoder_out, _ = decoder_outs[0], decoder_outs[1]
        _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
                                                                  pre_token_length)
        results = []
        b, n, d = decoder_out.size()
        for i in range(b):
funasr/train_utils/load_pretrained_model.py
@@ -7,7 +7,7 @@
import torch
import torch.nn
import torch.optim
import pdb
def filter_state_dict(
    dst_state: Dict[str, Union[float, torch.Tensor]],
@@ -63,7 +63,6 @@
    dst_state = obj.state_dict()
    
    print(f"ckpt: {path}")
    if oss_bucket is None:
        src_state = torch.load(path, map_location=map_location)
    else:
funasr/utils/load_utils.py
@@ -13,22 +13,26 @@
    from funasr.download.file import download_from_url
except:
    print("urllib is not installed, if you infer from url, please install it first.")
import pdb
def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: int = 16000, data_type="sound", tokenizer=None, **kwargs):
    if isinstance(data_or_path_or_list, (list, tuple)):
        if data_type is not None and isinstance(data_type, (list, tuple)):
            data_types = [data_type] * len(data_or_path_or_list)
            data_or_path_or_list_ret = [[] for d in data_type]
            for i, (data_type_i, data_or_path_or_list_i) in enumerate(zip(data_types, data_or_path_or_list)):
                for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)):
                    data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs)
                    data_or_path_or_list_ret[j].append(data_or_path_or_list_j)
            return data_or_path_or_list_ret
        else:
            return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list]
    if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file
        data_or_path_or_list = download_from_url(data_or_path_or_list)
@@ -52,18 +56,6 @@
        data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
    elif isinstance(data_or_path_or_list, np.ndarray):  # audio sample point
        data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze()  # [n_samples,]
    elif isinstance(data_or_path_or_list, str) and data_type == "kaldi_ark":
        data_mat = kaldiio.load_mat(data_or_path_or_list)
        if isinstance(data_mat, tuple):
            audio_fs, mat = data_mat
        else:
            mat = data_mat
        if mat.dtype == 'int16' or mat.dtype == 'int32':
            mat = mat.astype(np.float64)
            mat = mat / 32768
        if mat.ndim ==2:
            mat = mat[:,0]
        data_or_path_or_list = mat
    else:
        pass
        # print(f"unsupport data type: {data_or_path_or_list}, return raw data")
@@ -89,6 +81,8 @@
    return array
def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None, **kwargs):
    # import pdb;
    # pdb.set_trace()
    if isinstance(data, np.ndarray):
        data = torch.from_numpy(data)
        if len(data.shape) < 2:
@@ -106,7 +100,9 @@
            data_list.append(data_i)
            data_len.append(data_i.shape[0])
        data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
    # import pdb;
    # pdb.set_trace()
    # if data_type == "sound":
    data, data_len = frontend(data, data_len, **kwargs)
    
    if isinstance(data_len, (list, tuple)):