From 26b81480a88cc2868639c5160989394199acdcdd Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期三, 15 三月 2023 11:35:18 +0800
Subject: [PATCH] update
---
tests/test_asr_inference_pipeline.py | 2 +-
funasr/models/e2e_diar_eend_ola.py | 16 ++++++++--------
2 files changed, 9 insertions(+), 9 deletions(-)
diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py
index 6835a64..f3e34bc 100644
--- a/funasr/models/e2e_diar_eend_ola.py
+++ b/funasr/models/e2e_diar_eend_ola.py
@@ -52,15 +52,15 @@
super().__init__()
self.frontend = frontend
- self.encoder = encoder
- self.encoder_decoder_attractor = encoder_decoder_attractor
+ self.enc = encoder
+ self.eda = encoder_decoder_attractor
self.attractor_loss_weight = attractor_loss_weight
self.max_n_speaker = max_n_speaker
if mapping_dict is None:
mapping_dict = generate_mapping_dict(max_speaker_num=self.max_n_speaker)
self.mapping_dict = mapping_dict
# PostNet
- self.PostNet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
+ self.postnet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1)
def forward_encoder(self, xs, ilens):
@@ -68,7 +68,7 @@
pad_shape = xs.shape
xs_mask = [torch.ones(ilen).to(xs.device) for ilen in ilens]
xs_mask = torch.nn.utils.rnn.pad_sequence(xs_mask, batch_first=True, padding_value=0).unsqueeze(-2)
- emb = self.encoder(xs, xs_mask)
+ emb = self.enc(xs, xs_mask)
emb = torch.split(emb.view(pad_shape[0], pad_shape[1], -1), 1, dim=0)
emb = [e[0][:ilen] for e, ilen in zip(emb, ilens)]
return emb
@@ -77,7 +77,7 @@
maxlen = torch.max(ilens).to(torch.int).item()
logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
logits = nn.utils.rnn.pack_padded_sequence(logits, ilens, batch_first=True, enforce_sorted=False)
- outputs, (_, _) = self.PostNet(logits)
+ outputs, (_, _) = self.postnet(logits)
outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
outputs = [self.output_layer(output) for output in outputs]
@@ -112,7 +112,7 @@
text = text[:, : text_lengths.max()]
# 1. Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ encoder_out, encoder_out_lens = self.enc(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
@@ -198,10 +198,10 @@
orders = [np.arange(e.shape[0]) for e in emb]
for order in orders:
np.random.shuffle(order)
- attractors, probs = self.encoder_decoder_attractor.estimate(
+ attractors, probs = self.eda.estimate(
[e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
else:
- attractors, probs = self.encoder_decoder_attractor.estimate(emb)
+ attractors, probs = self.eda.estimate(emb)
attractors_active = []
for p, att, e in zip(probs, attractors, emb):
if n_speakers and n_speakers >= 0:
diff --git a/tests/test_asr_inference_pipeline.py b/tests/test_asr_inference_pipeline.py
index 70dbe89..32b8af5 100644
--- a/tests/test_asr_inference_pipeline.py
+++ b/tests/test_asr_inference_pipeline.py
@@ -451,7 +451,7 @@
def test_uniasr_2pass_zhcn_16k_common_vocab8358_offline(self):
inference_pipeline = pipeline(
- task=Tasks.,
+ task=Tasks.auto_speech_recognition,
model='damo/speech_UniASauto_speech_recognitionR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline')
rec_result = inference_pipeline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav',
--
Gitblit v1.9.1