From fdafd3f6bc2f04d16e7cab5afcdb1257e87a8a78 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 17 十二月 2024 11:15:53 +0800
Subject: [PATCH] emotion2vec
---
funasr/models/paraformer/search.py | 48 +++++++++++++++++++++++-------------------------
1 files changed, 23 insertions(+), 25 deletions(-)
diff --git a/funasr/models/paraformer/search.py b/funasr/models/paraformer/search.py
index 8789025..16e13dd 100644
--- a/funasr/models/paraformer/search.py
+++ b/funasr/models/paraformer/search.py
@@ -1,17 +1,19 @@
-from itertools import chain
-import logging
-from typing import Any
-from typing import Dict
-from typing import List
-from typing import NamedTuple
-from typing import Tuple
-from typing import Union
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
import torch
+import logging
+from itertools import chain
+from typing import Any, Dict, List, NamedTuple, Tuple, Union
-from funasr.metrics import end_detect
-from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface
-from funasr.models.transformer.scorers.scorer_interface import ScorerInterface
+from funasr.metrics.common import end_detect
+from funasr.models.transformer.scorers.scorer_interface import (
+ PartialScorerInterface,
+ ScorerInterface,
+)
+
class Hypothesis(NamedTuple):
"""Hypothesis data type."""
@@ -318,9 +320,7 @@
Hypothesis(
score=weighted_scores[j],
yseq=self.append_token(hyp.yseq, j),
- scores=self.merge_scores(
- hyp.scores, scores, j, part_scores, part_j
- ),
+ scores=self.merge_scores(hyp.scores, scores, j, part_scores, part_j),
states=self.merge_states(states, part_states, part_j),
)
)
@@ -332,7 +332,11 @@
return best_hyps
def forward(
- self, x: torch.Tensor, am_scores: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
+ self,
+ x: torch.Tensor,
+ am_scores: torch.Tensor,
+ maxlenratio: float = 0.0,
+ minlenratio: float = 0.0,
) -> List[Hypothesis]:
"""Perform beam search.
@@ -376,8 +380,7 @@
# check the number of hypotheses reaching to eos
if len(nbest_hyps) == 0:
logging.warning(
- "there is no N-best results, perform recognition "
- "again with smaller minlenratio."
+ "there is no N-best results, perform recognition " "again with smaller minlenratio."
)
return (
[]
@@ -388,17 +391,13 @@
# report the best result
best = nbest_hyps[0]
for k, v in best.scores.items():
- logging.info(
- f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
- )
+ logging.info(f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}")
logging.info(f"total log probability: {best.score:.2f}")
logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
if self.token_list is not None:
logging.info(
- "best hypo: "
- + "".join([self.token_list[x.item()] for x in best.yseq[1:-1]])
- + "\n"
+ "best hypo: " + "".join([self.token_list[x.item()] for x in best.yseq[1:-1]]) + "\n"
)
return nbest_hyps
@@ -433,8 +432,7 @@
if i == maxlen - 1:
logging.info("adding <eos> in the last position in the loop")
running_hyps = [
- h._replace(yseq=self.append_token(h.yseq, self.eos))
- for h in running_hyps
+ h._replace(yseq=self.append_token(h.yseq, self.eos)) for h in running_hyps
]
# add ended hypotheses to a final list, and removed them from current hypotheses
--
Gitblit v1.9.1