From f33ebfd1c70859f38eaac22673ab0ee9682ea7c3 Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期三, 15 三月 2023 16:11:44 +0800
Subject: [PATCH] update
---
funasr/modules/eend_ola/encoder_decoder_attractor.py | 11 ++++-------
funasr/models/e2e_diar_eend_ola.py | 14 ++++++++++++--
2 files changed, 16 insertions(+), 9 deletions(-)
diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py
index 79cb614..097b23a 100644
--- a/funasr/models/e2e_diar_eend_ola.py
+++ b/funasr/models/e2e_diar_eend_ola.py
@@ -76,7 +76,7 @@
def forward_post_net(self, logits, ilens):
maxlen = torch.max(ilens).to(torch.int).item()
logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
- logits = nn.utils.rnn.pack_padded_sequence(logits, ilens, batch_first=True, enforce_sorted=False)
+ logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True, enforce_sorted=False)
outputs, (_, _) = self.postnet(logits)
outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
@@ -231,7 +231,7 @@
pred[i] = pred[i - 1]
else:
pred[i] = 0
- pred = [self.reporter.inv_mapping_func(i, self.mapping_dict) for i in pred]
+ pred = [self.inv_mapping_func(i) for i in pred]
decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred]
decisions = torch.from_numpy(
np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to(
@@ -239,5 +239,15 @@
decisions = decisions[:, :n_speaker]
return decisions
+ def inv_mapping_func(self, label):
+
+ if not isinstance(label, int):
+ label = int(label)
+ if label in self.mapping_dict['label2dec'].keys():
+ num = self.mapping_dict['label2dec'][label]
+ else:
+ num = -1
+ return num
+
def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
pass
\ No newline at end of file
diff --git a/funasr/modules/eend_ola/encoder_decoder_attractor.py b/funasr/modules/eend_ola/encoder_decoder_attractor.py
index 4e599ab..45ac982 100644
--- a/funasr/modules/eend_ola/encoder_decoder_attractor.py
+++ b/funasr/modules/eend_ola/encoder_decoder_attractor.py
@@ -2,8 +2,7 @@
import torch
import torch.nn.functional as F
from torch import nn
-from modelscope.utils.logger import get_logger
-logger = get_logger()
+
class EncoderDecoderAttractor(nn.Module):
@@ -17,14 +16,12 @@
self.n_units = n_units
def forward_core(self, xs, zeros):
- logger.info("xs: ".format(xs))
- ilens = torch.from_numpy(np.array([x.shape[0] for x in xs])).to(torch.float32).to(xs[0].device)
- logger.info("ilens: ".format(ilens))
+ ilens = torch.from_numpy(np.array([x.shape[0] for x in xs])).to(torch.int64)
xs = [self.enc0_dropout(x) for x in xs]
xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=-1)
xs = nn.utils.rnn.pack_padded_sequence(xs, ilens, batch_first=True, enforce_sorted=False)
_, (hx, cx) = self.encoder(xs)
- zlens = torch.from_numpy(np.array([z.shape[0] for z in zeros])).to(torch.float32).to(zeros[0].device)
+ zlens = torch.from_numpy(np.array([z.shape[0] for z in zeros])).to(torch.int64)
max_zlen = torch.max(zlens).to(torch.int).item()
zeros = [self.enc0_dropout(z) for z in zeros]
zeros = nn.utils.rnn.pad_sequence(zeros, batch_first=True, padding_value=-1)
@@ -50,4 +47,4 @@
zeros = [torch.zeros(max_n_speakers, self.n_units).to(torch.float32).to(xs[0].device) for _ in xs]
attractors = self.forward_core(xs, zeros)
probs = [torch.sigmoid(torch.flatten(self.counter(att))) for att in attractors]
- return attractors, probs
\ No newline at end of file
+ return attractors, probs
--
Gitblit v1.9.1