From c0e72dd1ba86c19205ee633673b2497d18a68077 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 11 一月 2024 17:36:59 +0800
Subject: [PATCH] Merge branch 'funasr1.0' of github.com:alibaba-damo-academy/FunASR into funasr1.0 add
---
funasr/models/campplus/utils.py | 90 ++++++++++++++++++++++-----------------------
1 files changed, 44 insertions(+), 46 deletions(-)
diff --git a/funasr/models/campplus/utils.py b/funasr/models/campplus/utils.py
index c86a9f0..9964356 100644
--- a/funasr/models/campplus/utils.py
+++ b/funasr/models/campplus/utils.py
@@ -2,23 +2,19 @@
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import io
-from typing import Union
-
-import librosa as sf
-import numpy as np
-import torch
-import torch.nn.functional as F
-import torchaudio.compliance.kaldi as Kaldi
-from torch import nn
-
-import contextlib
import os
+import torch
+import requests
import tempfile
-from abc import ABCMeta, abstractmethod
+import contextlib
+import numpy as np
+import librosa as sf
+from typing import Union
from pathlib import Path
from typing import Generator, Union
-
-import requests
+from abc import ABCMeta, abstractmethod
+import torchaudio.compliance.kaldi as Kaldi
+from funasr.models.transformer.utils.nets_utils import pad_list
def check_audio_list(audio: list):
@@ -40,31 +36,31 @@
def sv_preprocess(inputs: Union[np.ndarray, list]):
- output = []
- for i in range(len(inputs)):
- if isinstance(inputs[i], str):
- file_bytes = File.read(inputs[i])
- data, fs = sf.load(io.BytesIO(file_bytes), dtype='float32')
- if len(data.shape) == 2:
- data = data[:, 0]
- data = torch.from_numpy(data).unsqueeze(0)
- data = data.squeeze(0)
- elif isinstance(inputs[i], np.ndarray):
- assert len(
- inputs[i].shape
- ) == 1, 'modelscope error: Input array should be [N, T]'
- data = inputs[i]
- if data.dtype in ['int16', 'int32', 'int64']:
- data = (data / (1 << 15)).astype('float32')
- else:
- data = data.astype('float32')
- data = torch.from_numpy(data)
- else:
- raise ValueError(
- 'modelscope error: The input type is restricted to audio address and nump array.'
- )
- output.append(data)
- return output
+ output = []
+ for i in range(len(inputs)):
+ if isinstance(inputs[i], str):
+ file_bytes = File.read(inputs[i])
+ data, fs = sf.load(io.BytesIO(file_bytes), dtype='float32')
+ if len(data.shape) == 2:
+ data = data[:, 0]
+ data = torch.from_numpy(data).unsqueeze(0)
+ data = data.squeeze(0)
+ elif isinstance(inputs[i], np.ndarray):
+ assert len(
+ inputs[i].shape
+ ) == 1, 'modelscope error: Input array should be [N, T]'
+ data = inputs[i]
+ if data.dtype in ['int16', 'int32', 'int64']:
+ data = (data / (1 << 15)).astype('float32')
+ else:
+ data = data.astype('float32')
+ data = torch.from_numpy(data)
+ else:
+ raise ValueError(
+ 'modelscope error: The input type is restricted to audio address and nump array.'
+ )
+ output.append(data)
+ return output
def sv_chunk(vad_segments: list, fs = 16000) -> list:
@@ -105,15 +101,19 @@
def extract_feature(audio):
features = []
+ feature_times = []
feature_lengths = []
for au in audio:
feature = Kaldi.fbank(
au.unsqueeze(0), num_mel_bins=80)
feature = feature - feature.mean(dim=0, keepdim=True)
- features.append(feature.unsqueeze(0))
- feature_lengths.append(au.shape[0])
- features = torch.cat(features)
- return features, feature_lengths
+ features.append(feature)
+ feature_times.append(au.shape[0])
+ feature_lengths.append(feature.shape[0])
+ # padding for batch inference
+ features_padded = pad_list(features, pad_value=0)
+ # features = torch.cat(features)
+ return features_padded, feature_lengths, feature_times
def postprocess(segments: list, vad_segments: list,
@@ -195,8 +195,8 @@
def distribute_spk(sentence_list, sd_time_list):
sd_sentence_list = []
for d in sentence_list:
- sentence_start = d['ts_list'][0][0]
- sentence_end = d['ts_list'][-1][1]
+ sentence_start = d['start']
+ sentence_end = d['end']
sentence_spk = 0
max_overlap = 0
for sd_time in sd_time_list:
@@ -211,8 +211,6 @@
d['spk'] = sentence_spk
sd_sentence_list.append(d)
return sd_sentence_list
-
-
class Storage(metaclass=ABCMeta):
--
Gitblit v1.9.1