From 28a19dbc4e85d3b8a4ec2ef7483bba64d422b43f Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 12 四月 2023 18:03:06 +0800
Subject: [PATCH] Merge remote-tracking branch 'origin/main' into dev_aky
---
funasr/bin/sond_inference.py | 40 ++++++++++++++++++++++++++++++++--------
1 files changed, 32 insertions(+), 8 deletions(-)
diff --git a/funasr/bin/sond_inference.py b/funasr/bin/sond_inference.py
index 4767577..5a0a8e2 100755
--- a/funasr/bin/sond_inference.py
+++ b/funasr/bin/sond_inference.py
@@ -42,7 +42,7 @@
Examples:
>>> import soundfile
>>> import numpy as np
- >>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pth")
+ >>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pb")
>>> profile = np.load("profiles.npy")
>>> audio, rate = soundfile.read("speech.wav")
>>> speech2diar(audio, profile)
@@ -54,7 +54,7 @@
self,
diar_train_config: Union[Path, str] = None,
diar_model_file: Union[Path, str] = None,
- device: str = "cpu",
+ device: Union[str, torch.device] = "cpu",
batch_size: int = 1,
dtype: str = "float32",
streaming: bool = False,
@@ -114,9 +114,19 @@
# little-endian order: lower bit first
return (np.array(list(b)[::-1]) == '1').astype(dtype)
- return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
+ # process oov
+ seq = np.array([int(x) for x in seq])
+ new_seq = []
+ for i, x in enumerate(seq):
+ if x < 2 ** vec_dim:
+ new_seq.append(x)
+ else:
+ idx_list = np.where(seq < 2 ** vec_dim)[0]
+ idx = np.abs(idx_list - i).argmin()
+ new_seq.append(seq[idx_list[idx]])
+ return np.row_stack([int2vec(x, vec_dim) for x in new_seq])
- def post_processing(self, raw_logits: torch.Tensor, spk_num: int):
+ def post_processing(self, raw_logits: torch.Tensor, spk_num: int, output_format: str = "speaker_turn"):
logits_idx = raw_logits.argmax(-1) # B, T, vocab_size -> B, T
# upsampling outputs to match inputs
ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio
@@ -127,8 +137,14 @@
).squeeze(1).long()
logits_idx = logits_idx[0].tolist()
pse_labels = [self.token_list[x] for x in logits_idx]
+ if output_format == "pse_labels":
+ return pse_labels, None
+
multi_labels = self.seq2arr(pse_labels, spk_num)[:, :spk_num] # remove padding speakers
multi_labels = self.smooth_multi_labels(multi_labels)
+ if output_format == "binary_labels":
+ return multi_labels, None
+
spk_list = ["spk{}".format(i + 1) for i in range(spk_num)]
spk_turns = self.calc_spk_turns(multi_labels, spk_list)
results = OrderedDict()
@@ -149,6 +165,7 @@
self,
speech: Union[torch.Tensor, np.ndarray],
profile: Union[torch.Tensor, np.ndarray],
+ output_format: str = "speaker_turn"
):
"""Inference
@@ -178,7 +195,7 @@
batch = to_device(batch, device=self.device)
logits = self.diar_model.prediction_forward(**batch)
- results, pse_labels = self.post_processing(logits, profile.shape[1])
+ results, pse_labels = self.post_processing(logits, profile.shape[1], output_format)
return results, pse_labels
@@ -231,6 +248,7 @@
dur_threshold: int = 10,
out_format: str = "vad",
param_dict: Optional[dict] = None,
+ mode: str = "sond",
**kwargs,
):
assert check_argument_types()
@@ -254,7 +272,7 @@
set_all_random_seed(seed)
# 2a. Build speech2xvec [Optional]
- if param_dict is not None and "extract_profile" in param_dict and param_dict["extract_profile"]:
+ if mode == "sond_demo" and param_dict is not None and "extract_profile" in param_dict and param_dict["extract_profile"]:
assert "sv_train_config" in param_dict, "sv_train_config must be provided param_dict."
assert "sv_model_file" in param_dict, "sv_model_file must be provided in param_dict."
sv_train_config = param_dict["sv_train_config"]
@@ -312,13 +330,16 @@
def _forward(
data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
- raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str]]] = None,
+ raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str, bytes]]] = None,
output_dir_v2: Optional[str] = None,
param_dict: Optional[dict] = None,
):
logging.info("param_dict: {}".format(param_dict))
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, (list, tuple)):
+ if not isinstance(raw_inputs[0], List):
+ raw_inputs = [raw_inputs]
+
assert all([len(example) >= 2 for example in raw_inputs]), \
"The length of test case in raw_inputs must larger than 1 (>=2)."
@@ -363,7 +384,7 @@
pse_label_writer = open("{}/labels.txt".format(output_path), "w")
logging.info("Start to diarize...")
result_list = []
- for keys, batch in loader:
+ for idx, (keys, batch) in enumerate(loader):
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
@@ -381,6 +402,9 @@
pse_label_writer.write("{} {}\n".format(key, " ".join(pse_labels)))
pse_label_writer.flush()
+ if idx % 100 == 0:
+ logging.info("Processing {:5d}: {}".format(idx, key))
+
if output_path is not None:
output_writer.close()
pse_label_writer.close()
--
Gitblit v1.9.1