From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 11:55:17 +0800
Subject: [PATCH] fixbug hotwords
---
funasr/models/campplus/cluster_backend.py | 42 +++++++++++++++++++-----------------------
1 files changed, 19 insertions(+), 23 deletions(-)
diff --git a/funasr/models/campplus/cluster_backend.py b/funasr/models/campplus/cluster_backend.py
index d667f6b..b98721e 100644
--- a/funasr/models/campplus/cluster_backend.py
+++ b/funasr/models/campplus/cluster_backend.py
@@ -7,10 +7,10 @@
import scipy
import torch
import sklearn
-import hdbscan
import numpy as np
from sklearn.cluster._kmeans import k_means
+from sklearn.cluster import HDBSCAN
class SpectralCluster:
@@ -51,7 +51,7 @@
def p_pruning(self, A):
if A.shape[0] * self.pval < 6:
- pval = 6. / A.shape[0]
+ pval = 6.0 / A.shape[0]
else:
pval = self.pval
@@ -80,7 +80,8 @@
num_of_spk = k_oracle
else:
lambda_gap_list = self.getEigenGaps(
- lambdas[self.min_num_spks - 1:self.max_num_spks + 1])
+ lambdas[self.min_num_spks - 1 : self.max_num_spks + 1]
+ )
num_of_spk = np.argmax(lambda_gap_list) + self.min_num_spks
emb = eig_vecs[:, :num_of_spk]
@@ -105,12 +106,9 @@
Emphasis On Topological Structure. ICASSP2022
"""
- def __init__(self,
- n_neighbors=20,
- n_components=60,
- min_samples=10,
- min_cluster_size=10,
- metric='cosine'):
+ def __init__(
+ self, n_neighbors=20, n_components=60, min_samples=10, min_cluster_size=10, metric="cosine"
+ ):
self.n_neighbors = n_neighbors
self.n_components = n_components
self.min_samples = min_samples
@@ -118,17 +116,19 @@
self.metric = metric
def __call__(self, X):
- from umap.umap_ import UMAP
+ import umap.umap_ as umap
+
umap_X = umap.UMAP(
n_neighbors=self.n_neighbors,
min_dist=0.0,
n_components=min(self.n_components, X.shape[0] - 2),
metric=self.metric,
).fit_transform(X)
- labels = hdbscan.HDBSCAN(
+ labels = HDBSCAN(
min_samples=self.min_samples,
min_cluster_size=self.min_cluster_size,
- allow_single_cluster=True).fit_predict(umap_X)
+ allow_single_cluster=True,
+ ).fit_predict(umap_X)
return labels
@@ -141,7 +141,7 @@
def __init__(self):
super().__init__()
- self.model_config = {'merge_thr':0.78}
+ self.model_config = {"merge_thr": 0.78}
# self.other_config = kwargs
self.spectral_cluster = SpectralCluster()
@@ -149,21 +149,18 @@
def forward(self, X, **params):
# clustering and return the labels
- k = params['oracle_num'] if 'oracle_num' in params else None
- assert len(
- X.shape
- ) == 2, 'modelscope error: the shape of input should be [N, C]'
+ k = params["oracle_num"] if "oracle_num" in params else None
+ assert len(X.shape) == 2, "modelscope error: the shape of input should be [N, C]"
if X.shape[0] < 20:
- return np.zeros(X.shape[0], dtype='int')
+ return np.zeros(X.shape[0], dtype="int")
if X.shape[0] < 2048 or k is not None:
# unexpected corner case
labels = self.spectral_cluster(X, k)
else:
labels = self.umap_hdbscan_cluster(X)
- if k is None and 'merge_thr' in self.model_config:
- labels = self.merge_by_cos(labels, X,
- self.model_config['merge_thr'])
+ if k is None and "merge_thr" in self.model_config:
+ labels = self.merge_by_cos(labels, X, self.model_config["merge_thr"])
return labels
@@ -180,8 +177,7 @@
spk_center.append(spk_emb)
assert len(spk_center) > 0
spk_center = np.stack(spk_center, axis=0)
- norm_spk_center = spk_center / np.linalg.norm(
- spk_center, axis=1, keepdims=True)
+ norm_spk_center = spk_center / np.linalg.norm(spk_center, axis=1, keepdims=True)
affinity = np.matmul(norm_spk_center, norm_spk_center.T)
affinity = np.triu(affinity, 1)
spks = np.unravel_index(np.argmax(affinity), affinity.shape)
--
Gitblit v1.9.1