From c0e72dd1ba86c19205ee633673b2497d18a68077 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 11 一月 2024 17:36:59 +0800
Subject: [PATCH] Merge branch 'funasr1.0' of github.com:alibaba-damo-academy/FunASR into funasr1.0 add

---
 funasr/models/campplus/utils.py |   90 ++++++++++++++++++++++-----------------------
 1 files changed, 44 insertions(+), 46 deletions(-)

diff --git a/funasr/models/campplus/utils.py b/funasr/models/campplus/utils.py
index c86a9f0..9964356 100644
--- a/funasr/models/campplus/utils.py
+++ b/funasr/models/campplus/utils.py
@@ -2,23 +2,19 @@
 # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
 
 import io
-from typing import Union
-
-import librosa as sf
-import numpy as np
-import torch
-import torch.nn.functional as F
-import torchaudio.compliance.kaldi as Kaldi
-from torch import nn
-
-import contextlib
 import os
+import torch
+import requests
 import tempfile
-from abc import ABCMeta, abstractmethod
+import contextlib
+import numpy as np
+import librosa as sf
+from typing import Union
 from pathlib import Path
 from typing import Generator, Union
-
-import requests
+from abc import ABCMeta, abstractmethod
+import torchaudio.compliance.kaldi as Kaldi
+from funasr.models.transformer.utils.nets_utils import pad_list
 
 
 def check_audio_list(audio: list):
@@ -40,31 +36,31 @@
 
 
 def sv_preprocess(inputs: Union[np.ndarray, list]):
-	output = []
-	for i in range(len(inputs)):
-		if isinstance(inputs[i], str):
-			file_bytes = File.read(inputs[i])
-			data, fs = sf.load(io.BytesIO(file_bytes), dtype='float32')
-			if len(data.shape) == 2:
-				data = data[:, 0]
-			data = torch.from_numpy(data).unsqueeze(0)
-			data = data.squeeze(0)
-		elif isinstance(inputs[i], np.ndarray):
-			assert len(
-				inputs[i].shape
-			) == 1, 'modelscope error: Input array should be [N, T]'
-			data = inputs[i]
-			if data.dtype in ['int16', 'int32', 'int64']:
-				data = (data / (1 << 15)).astype('float32')
-			else:
-				data = data.astype('float32')
-			data = torch.from_numpy(data)
-		else:
-			raise ValueError(
-				'modelscope error: The input type is restricted to audio address and nump array.'
-			)
-		output.append(data)
-	return output
+    output = []
+    for i in range(len(inputs)):
+        if isinstance(inputs[i], str):
+            file_bytes = File.read(inputs[i])
+            data, fs = sf.load(io.BytesIO(file_bytes), dtype='float32')
+            if len(data.shape) == 2:
+                data = data[:, 0]
+            data = torch.from_numpy(data).unsqueeze(0)
+            data = data.squeeze(0)
+        elif isinstance(inputs[i], np.ndarray):
+            assert len(
+                inputs[i].shape
+            ) == 1, 'modelscope error: Input array should be [N, T]'
+            data = inputs[i]
+            if data.dtype in ['int16', 'int32', 'int64']:
+                data = (data / (1 << 15)).astype('float32')
+            else:
+                data = data.astype('float32')
+            data = torch.from_numpy(data)
+        else:
+            raise ValueError(
+                'modelscope error: The input type is restricted to audio address and nump array.'
+            )
+        output.append(data)
+    return output
 
 
 def sv_chunk(vad_segments: list, fs = 16000) -> list:
@@ -105,15 +101,19 @@
 
 def extract_feature(audio):
     features = []
+    feature_times = []
     feature_lengths = []
     for au in audio:
         feature = Kaldi.fbank(
             au.unsqueeze(0), num_mel_bins=80)
         feature = feature - feature.mean(dim=0, keepdim=True)
-        features.append(feature.unsqueeze(0))
-        feature_lengths.append(au.shape[0])
-    features = torch.cat(features)
-    return features, feature_lengths
+        features.append(feature)
+        feature_times.append(au.shape[0])
+        feature_lengths.append(feature.shape[0])
+    # padding for batch inference
+    features_padded = pad_list(features, pad_value=0)
+    # features = torch.cat(features)
+    return features_padded, feature_lengths, feature_times
 
 
 def postprocess(segments: list, vad_segments: list,
@@ -195,8 +195,8 @@
 def distribute_spk(sentence_list, sd_time_list):
     sd_sentence_list = []
     for d in sentence_list:
-        sentence_start = d['ts_list'][0][0]
-        sentence_end = d['ts_list'][-1][1]
+        sentence_start = d['start']
+        sentence_end = d['end']
         sentence_spk = 0
         max_overlap = 0
         for sd_time in sd_time_list:
@@ -211,8 +211,6 @@
         d['spk'] = sentence_spk
         sd_sentence_list.append(d)
     return sd_sentence_list
-
-
 
 
 class Storage(metaclass=ABCMeta):

--
Gitblit v1.9.1