From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 11:55:17 +0800
Subject: [PATCH] fixbug hotwords

---
 funasr/models/campplus/cluster_backend.py |   42 +++++++++++++++++++-----------------------
 1 files changed, 19 insertions(+), 23 deletions(-)

diff --git a/funasr/models/campplus/cluster_backend.py b/funasr/models/campplus/cluster_backend.py
index d667f6b..b98721e 100644
--- a/funasr/models/campplus/cluster_backend.py
+++ b/funasr/models/campplus/cluster_backend.py
@@ -7,10 +7,10 @@
 import scipy
 import torch
 import sklearn
-import hdbscan
 import numpy as np
 
 from sklearn.cluster._kmeans import k_means
+from sklearn.cluster import HDBSCAN
 
 
 class SpectralCluster:
@@ -51,7 +51,7 @@
 
     def p_pruning(self, A):
         if A.shape[0] * self.pval < 6:
-            pval = 6. / A.shape[0]
+            pval = 6.0 / A.shape[0]
         else:
             pval = self.pval
 
@@ -80,7 +80,8 @@
             num_of_spk = k_oracle
         else:
             lambda_gap_list = self.getEigenGaps(
-                lambdas[self.min_num_spks - 1:self.max_num_spks + 1])
+                lambdas[self.min_num_spks - 1 : self.max_num_spks + 1]
+            )
             num_of_spk = np.argmax(lambda_gap_list) + self.min_num_spks
 
         emb = eig_vecs[:, :num_of_spk]
@@ -105,12 +106,9 @@
       Emphasis On Topological Structure. ICASSP2022
     """
 
-    def __init__(self,
-                 n_neighbors=20,
-                 n_components=60,
-                 min_samples=10,
-                 min_cluster_size=10,
-                 metric='cosine'):
+    def __init__(
+        self, n_neighbors=20, n_components=60, min_samples=10, min_cluster_size=10, metric="cosine"
+    ):
         self.n_neighbors = n_neighbors
         self.n_components = n_components
         self.min_samples = min_samples
@@ -118,17 +116,19 @@
         self.metric = metric
 
     def __call__(self, X):
-        from umap.umap_ import UMAP
+        import umap.umap_ as umap
+
         umap_X = umap.UMAP(
             n_neighbors=self.n_neighbors,
             min_dist=0.0,
             n_components=min(self.n_components, X.shape[0] - 2),
             metric=self.metric,
         ).fit_transform(X)
-        labels = hdbscan.HDBSCAN(
+        labels = HDBSCAN(
             min_samples=self.min_samples,
             min_cluster_size=self.min_cluster_size,
-            allow_single_cluster=True).fit_predict(umap_X)
+            allow_single_cluster=True,
+        ).fit_predict(umap_X)
         return labels
 
 
@@ -141,7 +141,7 @@
 
     def __init__(self):
         super().__init__()
-        self.model_config = {'merge_thr':0.78}
+        self.model_config = {"merge_thr": 0.78}
         # self.other_config = kwargs
 
         self.spectral_cluster = SpectralCluster()
@@ -149,21 +149,18 @@
 
     def forward(self, X, **params):
         # clustering and return the labels
-        k = params['oracle_num'] if 'oracle_num' in params else None
-        assert len(
-            X.shape
-        ) == 2, 'modelscope error: the shape of input should be [N, C]'
+        k = params["oracle_num"] if "oracle_num" in params else None
+        assert len(X.shape) == 2, "modelscope error: the shape of input should be [N, C]"
         if X.shape[0] < 20:
-            return np.zeros(X.shape[0], dtype='int')
+            return np.zeros(X.shape[0], dtype="int")
         if X.shape[0] < 2048 or k is not None:
             # unexpected corner case
             labels = self.spectral_cluster(X, k)
         else:
             labels = self.umap_hdbscan_cluster(X)
 
-        if k is None and 'merge_thr' in self.model_config:
-            labels = self.merge_by_cos(labels, X,
-                                       self.model_config['merge_thr'])
+        if k is None and "merge_thr" in self.model_config:
+            labels = self.merge_by_cos(labels, X, self.model_config["merge_thr"])
 
         return labels
 
@@ -180,8 +177,7 @@
                 spk_center.append(spk_emb)
             assert len(spk_center) > 0
             spk_center = np.stack(spk_center, axis=0)
-            norm_spk_center = spk_center / np.linalg.norm(
-                spk_center, axis=1, keepdims=True)
+            norm_spk_center = spk_center / np.linalg.norm(spk_center, axis=1, keepdims=True)
             affinity = np.matmul(norm_spk_center, norm_spk_center.T)
             affinity = np.triu(affinity, 1)
             spks = np.unravel_index(np.argmax(affinity), affinity.shape)

--
Gitblit v1.9.1