From 65d1005fd2cd5566ac819aa6f41e43dff9c8a691 Mon Sep 17 00:00:00 2001
From: 志浩 <neo.dzh@alibaba-inc.com>
Date: 星期一, 27 二月 2023 14:46:01 +0800
Subject: [PATCH] fixbug for sd and sv
---
funasr/models/e2e_diar_sond.py | 10 +++++++---
1 files changed, 7 insertions(+), 3 deletions(-)
diff --git a/funasr/models/e2e_diar_sond.py b/funasr/models/e2e_diar_sond.py
index f55bbf6..ad54723 100644
--- a/funasr/models/e2e_diar_sond.py
+++ b/funasr/models/e2e_diar_sond.py
@@ -90,6 +90,7 @@
self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :])
self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
self.inter_score_loss_weight = inter_score_loss_weight
+ self.forward_steps = 0
def generate_pse_embedding(self):
embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float)
@@ -123,7 +124,7 @@
"""
assert speech.shape[0] == binary_labels.shape[0], (speech.shape, binary_labels.shape)
batch_size = speech.shape[0]
-
+ self.forward_steps = self.forward_steps + 1
# 1. Network forward
pred, inter_outputs = self.prediction_forward(
speech, speech_lengths,
@@ -198,6 +199,7 @@
cf=cf,
acc=acc,
der=der,
+ forward_steps=self.forward_steps,
)
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
@@ -262,8 +264,10 @@
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
- spk_labels: torch.Tensor = None,
- spk_labels_lengths: torch.Tensor = None,
+ profile: torch.Tensor = None,
+ profile_lengths: torch.Tensor = None,
+ binary_labels: torch.Tensor = None,
+ binary_labels_lengths: torch.Tensor = None,
) -> Dict[str, torch.Tensor]:
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
return {"feats": feats, "feats_lengths": feats_lengths}
--
Gitblit v1.9.1