游雁
2024-01-11 c0e72dd1ba86c19205ee633673b2497d18a68077
funasr/models/campplus/utils.py
@@ -2,23 +2,19 @@
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import io
from typing import Union
import librosa as sf
import numpy as np
import torch
import torch.nn.functional as F
import torchaudio.compliance.kaldi as Kaldi
from torch import nn
import contextlib
import os
import torch
import requests
import tempfile
from abc import ABCMeta, abstractmethod
import contextlib
import numpy as np
import librosa as sf
from typing import Union
from pathlib import Path
from typing import Generator, Union
import requests
from abc import ABCMeta, abstractmethod
import torchaudio.compliance.kaldi as Kaldi
from funasr.models.transformer.utils.nets_utils import pad_list
def check_audio_list(audio: list):
@@ -40,31 +36,31 @@
def sv_preprocess(inputs: Union[np.ndarray, list]):
   output = []
   for i in range(len(inputs)):
      if isinstance(inputs[i], str):
         file_bytes = File.read(inputs[i])
         data, fs = sf.load(io.BytesIO(file_bytes), dtype='float32')
         if len(data.shape) == 2:
            data = data[:, 0]
         data = torch.from_numpy(data).unsqueeze(0)
         data = data.squeeze(0)
      elif isinstance(inputs[i], np.ndarray):
         assert len(
            inputs[i].shape
         ) == 1, 'modelscope error: Input array should be [N, T]'
         data = inputs[i]
         if data.dtype in ['int16', 'int32', 'int64']:
            data = (data / (1 << 15)).astype('float32')
         else:
            data = data.astype('float32')
         data = torch.from_numpy(data)
      else:
         raise ValueError(
            'modelscope error: The input type is restricted to audio address and nump array.'
         )
      output.append(data)
   return output
    output = []
    for i in range(len(inputs)):
        if isinstance(inputs[i], str):
            file_bytes = File.read(inputs[i])
            data, fs = sf.load(io.BytesIO(file_bytes), dtype='float32')
            if len(data.shape) == 2:
                data = data[:, 0]
            data = torch.from_numpy(data).unsqueeze(0)
            data = data.squeeze(0)
        elif isinstance(inputs[i], np.ndarray):
            assert len(
                inputs[i].shape
            ) == 1, 'modelscope error: Input array should be [N, T]'
            data = inputs[i]
            if data.dtype in ['int16', 'int32', 'int64']:
                data = (data / (1 << 15)).astype('float32')
            else:
                data = data.astype('float32')
            data = torch.from_numpy(data)
        else:
            raise ValueError(
                'modelscope error: The input type is restricted to audio address and nump array.'
            )
        output.append(data)
    return output
def sv_chunk(vad_segments: list, fs = 16000) -> list:
@@ -105,15 +101,19 @@
def extract_feature(audio):
    features = []
    feature_times = []
    feature_lengths = []
    for au in audio:
        feature = Kaldi.fbank(
            au.unsqueeze(0), num_mel_bins=80)
        feature = feature - feature.mean(dim=0, keepdim=True)
        features.append(feature.unsqueeze(0))
        feature_lengths.append(au.shape[0])
    features = torch.cat(features)
    return features, feature_lengths
        features.append(feature)
        feature_times.append(au.shape[0])
        feature_lengths.append(feature.shape[0])
    # padding for batch inference
    features_padded = pad_list(features, pad_value=0)
    # features = torch.cat(features)
    return features_padded, feature_lengths, feature_times
def postprocess(segments: list, vad_segments: list,
@@ -195,8 +195,8 @@
def distribute_spk(sentence_list, sd_time_list):
    sd_sentence_list = []
    for d in sentence_list:
        sentence_start = d['ts_list'][0][0]
        sentence_end = d['ts_list'][-1][1]
        sentence_start = d['start']
        sentence_end = d['end']
        sentence_spk = 0
        max_overlap = 0
        for sd_time in sd_time_list:
@@ -211,8 +211,6 @@
        d['spk'] = sentence_spk
        sd_sentence_list.append(d)
    return sd_sentence_list
class Storage(metaclass=ABCMeta):