From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 runtime/python/libtorch/funasr_torch/utils/utils.py |   54 ++++++++++++++++++++++++++++--------------------------
 1 files changed, 28 insertions(+), 26 deletions(-)

diff --git a/runtime/python/libtorch/funasr_torch/utils/utils.py b/runtime/python/libtorch/funasr_torch/utils/utils.py
index 913ddc1..ee43852 100644
--- a/runtime/python/libtorch/funasr_torch/utils/utils.py
+++ b/runtime/python/libtorch/funasr_torch/utils/utils.py
@@ -1,40 +1,43 @@
 # -*- encoding: utf-8 -*-
-
-import functools
+import yaml
 import logging
-import pickle
+import functools
+import numpy as np
 from pathlib import Path
 from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
 
-import numpy as np
-import yaml
-
-
-import warnings
-
 root_dir = Path(__file__).resolve().parent
-
 logger_initialized = {}
 
+def pad_list(xs, pad_value, max_len=None):
+    n_batch = len(xs)
+    if max_len is None:
+        max_len = max(x.size(0) for x in xs)
+    # pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
+    # numpy format
+    pad = (np.zeros((n_batch, max_len)) + pad_value).astype(np.int32)
+    for i in range(n_batch):
+        pad[i, : xs[i].shape[0]] = xs[i]
 
-class TokenIDConverter():
-    def __init__(self, token_list: Union[List, str],
-                 ):
+    return pad
+
+class TokenIDConverter:
+    def __init__(
+        self,
+        token_list: Union[List, str],
+    ):
 
         self.token_list = token_list
         self.unk_symbol = token_list[-1]
         self.token2id = {v: i for i, v in enumerate(self.token_list)}
         self.unk_id = self.token2id[self.unk_symbol]
 
-
     def get_num_vocabulary_size(self) -> int:
         return len(self.token_list)
 
-    def ids2tokens(self,
-                   integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
+    def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
         if isinstance(integers, np.ndarray) and integers.ndim != 1:
-            raise TokenIDConverterError(
-                f"Must be 1 dim ndarray, but got {integers.ndim}")
+            raise TokenIDConverterError(f"Must be 1 dim ndarray, but got {integers.ndim}")
         return [self.token_list[i] for i in integers]
 
     def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
@@ -42,7 +45,7 @@
         return [self.token2id.get(i, self.unk_id) for i in tokens]
 
 
-class CharTokenizer():
+class CharTokenizer:
     def __init__(
         self,
         symbol_value: Union[Path, str, Iterable[str]] = None,
@@ -77,7 +80,7 @@
                 if line.startswith(w):
                     if not self.remove_non_linguistic_symbols:
                         tokens.append(line[: len(w)])
-                    line = line[len(w):]
+                    line = line[len(w) :]
                     break
             else:
                 t = line[0]
@@ -100,7 +103,6 @@
         )
 
 
-
 class Hypothesis(NamedTuple):
     """Hypothesis data type."""
 
@@ -120,15 +122,15 @@
 
 def read_yaml(yaml_path: Union[str, Path]) -> Dict:
     if not Path(yaml_path).exists():
-        raise FileExistsError(f'The {yaml_path} does not exist.')
+        raise FileExistsError(f"The {yaml_path} does not exist.")
 
-    with open(str(yaml_path), 'rb') as f:
+    with open(str(yaml_path), "rb") as f:
         data = yaml.load(f, Loader=yaml.Loader)
     return data
 
 
 @functools.lru_cache()
-def get_logger(name='funasr_torch'):
+def get_logger(name="funasr_torch"):
     """Initialize and get a logger by name.
     If the logger has not been initialized, this method will initialize the
     logger by adding one or two handlers, otherwise the initialized logger will
@@ -148,8 +150,8 @@
             return logger
 
     formatter = logging.Formatter(
-        '[%(asctime)s] %(name)s %(levelname)s: %(message)s',
-        datefmt="%Y/%m/%d %H:%M:%S")
+        "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S"
+    )
 
     sh = logging.StreamHandler()
     sh.setFormatter(formatter)

--
Gitblit v1.9.1