From 668b830cb2a8f69c1cfb131ec9542d27f91b7283 Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期三, 10 一月 2024 19:10:26 +0800
Subject: [PATCH] update cam++ for embed extract

---
 funasr/bin/inference.py                                       |    6 
 funasr/models/paraformer/model.py                             |    2 
 funasr/models/campplus/template.yaml                          |   23 +
 funasr/models/campplus/__init__.py                            |    1 
 funasr/models/campplus/model.py                               |   88 +++---
 examples/industrial_data_pretraining/spk_verification/demo.py |   11 
 funasr/models/campplus/components.py                          |  112 +++++--
 funasr/models/campplus/utils.py                               |  533 +++++++++++++++++++++++++++++++++++++++++
 8 files changed, 689 insertions(+), 87 deletions(-)

diff --git a/examples/industrial_data_pretraining/spk_verification/demo.py b/examples/industrial_data_pretraining/spk_verification/demo.py
new file mode 100644
index 0000000..0b5588f
--- /dev/null
+++ b/examples/industrial_data_pretraining/spk_verification/demo.py
@@ -0,0 +1,11 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+from funasr import AutoModel
+
+model = AutoModel(model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common")
+
+res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
+print(res)
\ No newline at end of file
diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py
index c4ff69b..2d94e70 100644
--- a/funasr/bin/inference.py
+++ b/funasr/bin/inference.py
@@ -159,6 +159,9 @@
 			tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
 			kwargs["tokenizer"] = tokenizer
 			kwargs["token_list"] = tokenizer.token_list
+			vocab_size = len(tokenizer.token_list)
+		else:
+			vocab_size = -1
 		
 		# build frontend
 		frontend = kwargs.get("frontend", None)
@@ -170,8 +173,7 @@
 		
 		# build model
 		model_class = tables.model_classes.get(kwargs["model"].lower())
-		model = model_class(**kwargs, **kwargs["model_conf"],
-		                    vocab_size=len(tokenizer.token_list) if tokenizer is not None else -1)
+		model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
 		model.eval()
 		model.to(device)
 		
diff --git a/funasr/models/campplus/__init__.py b/funasr/models/campplus/__init__.py
index ff44fed..e69de29 100644
--- a/funasr/models/campplus/__init__.py
+++ b/funasr/models/campplus/__init__.py
@@ -1 +0,0 @@
-from .campplus import CAMPPlus
diff --git a/funasr/models/campplus/layers.py b/funasr/models/campplus/components.py
similarity index 86%
rename from funasr/models/campplus/layers.py
rename to funasr/models/campplus/components.py
index 0475612..43d366e 100644
--- a/funasr/models/campplus/layers.py
+++ b/funasr/models/campplus/components.py
@@ -7,6 +7,82 @@
 from torch import nn
 
 
+class BasicResBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, in_planes, planes, stride=1):
+        super(BasicResBlock, self).__init__()
+        self.conv1 = nn.Conv2d(in_planes,
+                               planes,
+                               kernel_size=3,
+                               stride=(stride, 1),
+                               padding=1,
+                               bias=False)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.conv2 = nn.Conv2d(planes,
+                               planes,
+                               kernel_size=3,
+                               stride=1,
+                               padding=1,
+                               bias=False)
+        self.bn2 = nn.BatchNorm2d(planes)
+
+        self.shortcut = nn.Sequential()
+        if stride != 1 or in_planes != self.expansion * planes:
+            self.shortcut = nn.Sequential(
+                nn.Conv2d(in_planes,
+                          self.expansion * planes,
+                          kernel_size=1,
+                          stride=(stride, 1),
+                          bias=False),
+                nn.BatchNorm2d(self.expansion * planes))
+
+    def forward(self, x):
+        out = F.relu(self.bn1(self.conv1(x)))
+        out = self.bn2(self.conv2(out))
+        out += self.shortcut(x)
+        out = F.relu(out)
+        return out
+
+
+class FCM(nn.Module):
+    def __init__(self,
+                 block=BasicResBlock,
+                 num_blocks=[2, 2],
+                 m_channels=32,
+                 feat_dim=80):
+        super(FCM, self).__init__()
+        self.in_planes = m_channels
+        self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(m_channels)
+
+        self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
+        self.layer2 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
+
+        self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(m_channels)
+        self.out_channels = m_channels * (feat_dim // 8)
+
+    def _make_layer(self, block, planes, num_blocks, stride):
+        strides = [stride] + [1] * (num_blocks - 1)
+        layers = []
+        for stride in strides:
+            layers.append(block(self.in_planes, planes, stride))
+            self.in_planes = planes * block.expansion
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = x.unsqueeze(1)
+        out = F.relu(self.bn1(self.conv1(x)))
+        out = self.layer1(out)
+        out = self.layer2(out)
+        out = F.relu(self.bn2(self.conv2(out)))
+
+        shape = out.shape
+        out = out.reshape(shape[0], shape[1] * shape[2], shape[3])
+        return out
+
+
 def get_nonlinear(config_str, channels):
     nonlinear = nn.Sequential()
     for name in config_str.split('-'):
@@ -216,39 +292,3 @@
         return x
 
 
-class BasicResBlock(nn.Module):
-    expansion = 1
-
-    def __init__(self, in_planes, planes, stride=1):
-        super(BasicResBlock, self).__init__()
-        self.conv1 = nn.Conv2d(in_planes,
-                               planes,
-                               kernel_size=3,
-                               stride=(stride, 1),
-                               padding=1,
-                               bias=False)
-        self.bn1 = nn.BatchNorm2d(planes)
-        self.conv2 = nn.Conv2d(planes,
-                               planes,
-                               kernel_size=3,
-                               stride=1,
-                               padding=1,
-                               bias=False)
-        self.bn2 = nn.BatchNorm2d(planes)
-
-        self.shortcut = nn.Sequential()
-        if stride != 1 or in_planes != self.expansion * planes:
-            self.shortcut = nn.Sequential(
-                nn.Conv2d(in_planes,
-                          self.expansion * planes,
-                          kernel_size=1,
-                          stride=(stride, 1),
-                          bias=False),
-                nn.BatchNorm2d(self.expansion * planes))
-
-    def forward(self, x):
-        out = F.relu(self.bn1(self.conv1(x)))
-        out = self.bn2(self.conv2(out))
-        out += self.shortcut(x)
-        out = F.relu(out)
-        return out
diff --git a/funasr/models/campplus/campplus.py b/funasr/models/campplus/model.py
similarity index 64%
rename from funasr/models/campplus/campplus.py
rename to funasr/models/campplus/model.py
index 88113ec..84938cc 100644
--- a/funasr/models/campplus/campplus.py
+++ b/funasr/models/campplus/model.py
@@ -1,54 +1,24 @@
 # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
 # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
 
+import os
+import time
+import torch
+import logging
+import numpy as np
+import torch.nn as nn
 from collections import OrderedDict
+from typing import Union, Dict, List, Tuple, Optional
 
-import torch.nn.functional as F
-from torch import nn
+from funasr.utils.load_utils import load_audio_text_image_video
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.register import tables
+from funasr.models.campplus.components import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, \
+    BasicResBlock, get_nonlinear, FCM
+from funasr.models.campplus.utils import extract_feature
 
 
-from funasr.models.campplus.layers import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, \
-    BasicResBlock, get_nonlinear
-
-
-class FCM(nn.Module):
-    def __init__(self,
-                 block=BasicResBlock,
-                 num_blocks=[2, 2],
-                 m_channels=32,
-                 feat_dim=80):
-        super(FCM, self).__init__()
-        self.in_planes = m_channels
-        self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
-        self.bn1 = nn.BatchNorm2d(m_channels)
-
-        self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
-        self.layer2 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
-
-        self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False)
-        self.bn2 = nn.BatchNorm2d(m_channels)
-        self.out_channels = m_channels * (feat_dim // 8)
-
-    def _make_layer(self, block, planes, num_blocks, stride):
-        strides = [stride] + [1] * (num_blocks - 1)
-        layers = []
-        for stride in strides:
-            layers.append(block(self.in_planes, planes, stride))
-            self.in_planes = planes * block.expansion
-        return nn.Sequential(*layers)
-
-    def forward(self, x):
-        x = x.unsqueeze(1)
-        out = F.relu(self.bn1(self.conv1(x)))
-        out = self.layer1(out)
-        out = self.layer2(out)
-        out = F.relu(self.bn2(self.conv2(out)))
-
-        shape = out.shape
-        out = out.reshape(shape[0], shape[1] * shape[2], shape[3])
-        return out
-
-
+@tables.register("model_classes", "CAMPPlus")
 class CAMPPlus(nn.Module):
     def __init__(self,
                  feat_dim=80,
@@ -58,8 +28,9 @@
                  init_channels=128,
                  config_str='batchnorm-relu',
                  memory_efficient=True,
-                 output_level='segment'):
-        super(CAMPPlus, self).__init__()
+                 output_level='segment',
+                 **kwargs,):
+        super().__init__()
 
         self.head = FCM(feat_dim=feat_dim)
         channels = self.head.out_channels
@@ -123,3 +94,28 @@
         if self.output_level == 'frame':
             x = x.transpose(1, 2)
         return x
+
+    def generate(self,
+                 data_in,
+                 data_lengths=None,
+                 key: list=None,
+                 tokenizer=None,
+                 frontend=None,
+                 **kwargs,
+                 ):
+        # extract fbank feats
+        meta_data = {}
+        time1 = time.perf_counter()
+        audio_sample_list = load_audio_text_image_video(data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound")
+        time2 = time.perf_counter()
+        meta_data["load_data"] = f"{time2 - time1:0.3f}"
+        speech, speech_lengths = extract_feature(audio_sample_list)
+        time3 = time.perf_counter()
+        meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+        meta_data["batch_data_time"] = np.array(speech_lengths).sum().item() / 16000.0
+        # import pdb; pdb.set_trace()
+        results = []
+        embeddings = self.forward(speech)
+        for embedding in embeddings:
+            results.append({"spk_embedding":embedding})
+        return results, meta_data
\ No newline at end of file
diff --git a/funasr/models/campplus/template.yaml b/funasr/models/campplus/template.yaml
new file mode 100644
index 0000000..38dcfde
--- /dev/null
+++ b/funasr/models/campplus/template.yaml
@@ -0,0 +1,23 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.register import tables
+# tables.print()
+
+# network architecture
+model: CAMPPlus
+model_conf:
+    feat_dim: 80
+    embedding_size: 192
+    growth_rate: 32
+    bn_size: 4
+    init_channels: 128
+    config_str: 'batchnorm-relu'
+    memory_efficient: True
+    output_level: 'segment'
+
+# frontend related
+frontend: WavFrontend
+frontend_conf:
+    fs: 16000
diff --git a/funasr/models/campplus/utils.py b/funasr/models/campplus/utils.py
new file mode 100644
index 0000000..c86a9f0
--- /dev/null
+++ b/funasr/models/campplus/utils.py
@@ -0,0 +1,533 @@
+# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+import io
+from typing import Union
+
+import librosa as sf
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torchaudio.compliance.kaldi as Kaldi
+from torch import nn
+
+import contextlib
+import os
+import tempfile
+from abc import ABCMeta, abstractmethod
+from pathlib import Path
+from typing import Generator, Union
+
+import requests
+
+
+def check_audio_list(audio: list):
+    audio_dur = 0
+    for i in range(len(audio)):
+        seg = audio[i]
+        assert seg[1] >= seg[0], 'modelscope error: Wrong time stamps.'
+        assert isinstance(seg[2], np.ndarray), 'modelscope error: Wrong data type.'
+        assert int(seg[1] * 16000) - int(
+            seg[0] * 16000
+        ) == seg[2].shape[
+            0], 'modelscope error: audio data in list is inconsistent with time length.'
+        if i > 0:
+            assert seg[0] >= audio[
+                i - 1][1], 'modelscope error: Wrong time stamps.'
+        audio_dur += seg[1] - seg[0]
+    return audio_dur
+    # assert audio_dur > 5, 'modelscope error: The effective audio duration is too short.'
+
+
+def sv_preprocess(inputs: Union[np.ndarray, list]):
+	output = []
+	for i in range(len(inputs)):
+		if isinstance(inputs[i], str):
+			file_bytes = File.read(inputs[i])
+			data, fs = sf.load(io.BytesIO(file_bytes), dtype='float32')
+			if len(data.shape) == 2:
+				data = data[:, 0]
+			data = torch.from_numpy(data).unsqueeze(0)
+			data = data.squeeze(0)
+		elif isinstance(inputs[i], np.ndarray):
+			assert len(
+				inputs[i].shape
+			) == 1, 'modelscope error: Input array should be [N, T]'
+			data = inputs[i]
+			if data.dtype in ['int16', 'int32', 'int64']:
+				data = (data / (1 << 15)).astype('float32')
+			else:
+				data = data.astype('float32')
+			data = torch.from_numpy(data)
+		else:
+			raise ValueError(
+				'modelscope error: The input type is restricted to audio address and nump array.'
+			)
+		output.append(data)
+	return output
+
+
+def sv_chunk(vad_segments: list, fs = 16000) -> list:
+    config = {
+            'seg_dur': 1.5,
+            'seg_shift': 0.75,
+        }
+    def seg_chunk(seg_data):
+        seg_st = seg_data[0]
+        data = seg_data[2]
+        chunk_len = int(config['seg_dur'] * fs)
+        chunk_shift = int(config['seg_shift'] * fs)
+        last_chunk_ed = 0
+        seg_res = []
+        for chunk_st in range(0, data.shape[0], chunk_shift):
+            chunk_ed = min(chunk_st + chunk_len, data.shape[0])
+            if chunk_ed <= last_chunk_ed:
+                break
+            last_chunk_ed = chunk_ed
+            chunk_st = max(0, chunk_ed - chunk_len)
+            chunk_data = data[chunk_st:chunk_ed]
+            if chunk_data.shape[0] < chunk_len:
+                chunk_data = np.pad(chunk_data,
+                                    (0, chunk_len - chunk_data.shape[0]),
+                                    'constant')
+            seg_res.append([
+                chunk_st / fs + seg_st, chunk_ed / fs + seg_st,
+                chunk_data
+            ])
+        return seg_res
+
+    segs = []
+    for i, s in enumerate(vad_segments):
+        segs.extend(seg_chunk(s))
+
+    return segs
+
+
+def extract_feature(audio):
+    features = []
+    feature_lengths = []
+    for au in audio:
+        feature = Kaldi.fbank(
+            au.unsqueeze(0), num_mel_bins=80)
+        feature = feature - feature.mean(dim=0, keepdim=True)
+        features.append(feature.unsqueeze(0))
+        feature_lengths.append(au.shape[0])
+    features = torch.cat(features)
+    return features, feature_lengths
+
+
+def postprocess(segments: list, vad_segments: list,
+                labels: np.ndarray, embeddings: np.ndarray) -> list:
+    assert len(segments) == len(labels)
+    labels = correct_labels(labels)
+    distribute_res = []
+    for i in range(len(segments)):
+        distribute_res.append([segments[i][0], segments[i][1], labels[i]])
+    # merge the same speakers chronologically
+    distribute_res = merge_seque(distribute_res)
+
+    # accquire speaker center
+    spk_embs = []
+    for i in range(labels.max() + 1):
+        spk_emb = embeddings[labels == i].mean(0)
+        spk_embs.append(spk_emb)
+    spk_embs = np.stack(spk_embs)
+
+    def is_overlapped(t1, t2):
+        if t1 > t2 + 1e-4:
+            return True
+        return False
+
+    # distribute the overlap region
+    for i in range(1, len(distribute_res)):
+        if is_overlapped(distribute_res[i - 1][1], distribute_res[i][0]):
+            p = (distribute_res[i][0] + distribute_res[i - 1][1]) / 2
+            distribute_res[i][0] = p
+            distribute_res[i - 1][1] = p
+
+    # smooth the result
+    distribute_res = smooth(distribute_res)
+
+    return distribute_res
+
+
+def correct_labels(labels):
+    labels_id = 0
+    id2id = {}
+    new_labels = []
+    for i in labels:
+        if i not in id2id:
+            id2id[i] = labels_id
+            labels_id += 1
+        new_labels.append(id2id[i])
+    return np.array(new_labels)
+
+def merge_seque(distribute_res):
+    res = [distribute_res[0]]
+    for i in range(1, len(distribute_res)):
+        if distribute_res[i][2] != res[-1][2] or distribute_res[i][
+                0] > res[-1][1]:
+            res.append(distribute_res[i])
+        else:
+            res[-1][1] = distribute_res[i][1]
+    return res
+
+def smooth(res, mindur=1):
+    # short segments are assigned to nearest speakers.
+    for i in range(len(res)):
+        res[i][0] = round(res[i][0], 2)
+        res[i][1] = round(res[i][1], 2)
+        if res[i][1] - res[i][0] < mindur:
+            if i == 0:
+                res[i][2] = res[i + 1][2]
+            elif i == len(res) - 1:
+                res[i][2] = res[i - 1][2]
+            elif res[i][0] - res[i - 1][1] <= res[i + 1][0] - res[i][1]:
+                res[i][2] = res[i - 1][2]
+            else:
+                res[i][2] = res[i + 1][2]
+    # merge the speakers
+    res = merge_seque(res)
+
+    return res
+
+
+def distribute_spk(sentence_list, sd_time_list):
+    sd_sentence_list = []
+    for d in sentence_list:
+        sentence_start = d['ts_list'][0][0]
+        sentence_end = d['ts_list'][-1][1]
+        sentence_spk = 0
+        max_overlap = 0
+        for sd_time in sd_time_list:
+            spk_st, spk_ed, spk = sd_time
+            spk_st = spk_st*1000
+            spk_ed = spk_ed*1000
+            overlap = max(
+                min(sentence_end, spk_ed) - max(sentence_start, spk_st), 0)
+            if overlap > max_overlap:
+                max_overlap = overlap
+                sentence_spk = spk
+        d['spk'] = sentence_spk
+        sd_sentence_list.append(d)
+    return sd_sentence_list
+
+
+
+
+class Storage(metaclass=ABCMeta):
+    """Abstract class of storage.
+
+    All backends need to implement two apis: ``read()`` and ``read_text()``.
+    ``read()`` reads the file as a byte stream and ``read_text()`` reads
+    the file as texts.
+    """
+
+    @abstractmethod
+    def read(self, filepath: str):
+        pass
+
+    @abstractmethod
+    def read_text(self, filepath: str):
+        pass
+
+    @abstractmethod
+    def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
+        pass
+
+    @abstractmethod
+    def write_text(self,
+                   obj: str,
+                   filepath: Union[str, Path],
+                   encoding: str = 'utf-8') -> None:
+        pass
+
+
+class LocalStorage(Storage):
+    """Local hard disk storage"""
+
+    def read(self, filepath: Union[str, Path]) -> bytes:
+        """Read data from a given ``filepath`` with 'rb' mode.
+
+        Args:
+            filepath (str or Path): Path to read data.
+
+        Returns:
+            bytes: Expected bytes object.
+        """
+        with open(filepath, 'rb') as f:
+            content = f.read()
+        return content
+
+    def read_text(self,
+                  filepath: Union[str, Path],
+                  encoding: str = 'utf-8') -> str:
+        """Read data from a given ``filepath`` with 'r' mode.
+
+        Args:
+            filepath (str or Path): Path to read data.
+            encoding (str): The encoding format used to open the ``filepath``.
+                Default: 'utf-8'.
+
+        Returns:
+            str: Expected text reading from ``filepath``.
+        """
+        with open(filepath, 'r', encoding=encoding) as f:
+            value_buf = f.read()
+        return value_buf
+
+    def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
+        """Write data to a given ``filepath`` with 'wb' mode.
+
+        Note:
+            ``write`` will create a directory if the directory of ``filepath``
+            does not exist.
+
+        Args:
+            obj (bytes): Data to be written.
+            filepath (str or Path): Path to write data.
+        """
+        dirname = os.path.dirname(filepath)
+        if dirname and not os.path.exists(dirname):
+            os.makedirs(dirname, exist_ok=True)
+
+        with open(filepath, 'wb') as f:
+            f.write(obj)
+
+    def write_text(self,
+                   obj: str,
+                   filepath: Union[str, Path],
+                   encoding: str = 'utf-8') -> None:
+        """Write data to a given ``filepath`` with 'w' mode.
+
+        Note:
+            ``write_text`` will create a directory if the directory of
+            ``filepath`` does not exist.
+
+        Args:
+            obj (str): Data to be written.
+            filepath (str or Path): Path to write data.
+            encoding (str): The encoding format used to open the ``filepath``.
+                Default: 'utf-8'.
+        """
+        dirname = os.path.dirname(filepath)
+        if dirname and not os.path.exists(dirname):
+            os.makedirs(dirname, exist_ok=True)
+
+        with open(filepath, 'w', encoding=encoding) as f:
+            f.write(obj)
+
+    @contextlib.contextmanager
+    def as_local_path(
+            self,
+            filepath: Union[str,
+                            Path]) -> Generator[Union[str, Path], None, None]:
+        """Only for unified API and do nothing."""
+        yield filepath
+
+
+class HTTPStorage(Storage):
+    """HTTP and HTTPS storage."""
+
+    def read(self, url):
+        # TODO @wenmeng.zwm add progress bar if file is too large
+        r = requests.get(url)
+        r.raise_for_status()
+        return r.content
+
+    def read_text(self, url):
+        r = requests.get(url)
+        r.raise_for_status()
+        return r.text
+
+    @contextlib.contextmanager
+    def as_local_path(
+            self, filepath: str) -> Generator[Union[str, Path], None, None]:
+        """Download a file from ``filepath``.
+
+        ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
+        can be called with ``with`` statement, and when exists from the
+        ``with`` statement, the temporary path will be released.
+
+        Args:
+            filepath (str): Download a file from ``filepath``.
+
+        Examples:
+            >>> storage = HTTPStorage()
+            >>> # After existing from the ``with`` clause,
+            >>> # the path will be removed
+            >>> with storage.get_local_path('http://path/to/file') as path:
+            ...     # do something here
+        """
+        try:
+            f = tempfile.NamedTemporaryFile(delete=False)
+            f.write(self.read(filepath))
+            f.close()
+            yield f.name
+        finally:
+            os.remove(f.name)
+
+    def write(self, obj: bytes, url: Union[str, Path]) -> None:
+        raise NotImplementedError('write is not supported by HTTP Storage')
+
+    def write_text(self,
+                   obj: str,
+                   url: Union[str, Path],
+                   encoding: str = 'utf-8') -> None:
+        raise NotImplementedError(
+            'write_text is not supported by HTTP Storage')
+
+
+class OSSStorage(Storage):
+    """OSS storage."""
+
+    def __init__(self, oss_config_file=None):
+        # read from config file or env var
+        raise NotImplementedError(
+            'OSSStorage.__init__ to be implemented in the future')
+
+    def read(self, filepath):
+        raise NotImplementedError(
+            'OSSStorage.read to be implemented in the future')
+
+    def read_text(self, filepath, encoding='utf-8'):
+        raise NotImplementedError(
+            'OSSStorage.read_text to be implemented in the future')
+
+    @contextlib.contextmanager
+    def as_local_path(
+            self, filepath: str) -> Generator[Union[str, Path], None, None]:
+        """Download a file from ``filepath``.
+
+        ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
+        can be called with ``with`` statement, and when exists from the
+        ``with`` statement, the temporary path will be released.
+
+        Args:
+            filepath (str): Download a file from ``filepath``.
+
+        Examples:
+            >>> storage = OSSStorage()
+            >>> # After existing from the ``with`` clause,
+            >>> # the path will be removed
+            >>> with storage.get_local_path('http://path/to/file') as path:
+            ...     # do something here
+        """
+        try:
+            f = tempfile.NamedTemporaryFile(delete=False)
+            f.write(self.read(filepath))
+            f.close()
+            yield f.name
+        finally:
+            os.remove(f.name)
+
+    def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
+        raise NotImplementedError(
+            'OSSStorage.write to be implemented in the future')
+
+    def write_text(self,
+                   obj: str,
+                   filepath: Union[str, Path],
+                   encoding: str = 'utf-8') -> None:
+        raise NotImplementedError(
+            'OSSStorage.write_text to be implemented in the future')
+
+
+G_STORAGES = {}
+
+
+class File(object):
+    _prefix_to_storage: dict = {
+        'oss': OSSStorage,
+        'http': HTTPStorage,
+        'https': HTTPStorage,
+        'local': LocalStorage,
+    }
+
+    @staticmethod
+    def _get_storage(uri):
+        assert isinstance(uri,
+                          str), f'uri should be str type, but got {type(uri)}'
+
+        if '://' not in uri:
+            # local path
+            storage_type = 'local'
+        else:
+            prefix, _ = uri.split('://')
+            storage_type = prefix
+
+        assert storage_type in File._prefix_to_storage, \
+            f'Unsupported uri {uri}, valid prefixs: '\
+            f'{list(File._prefix_to_storage.keys())}'
+
+        if storage_type not in G_STORAGES:
+            G_STORAGES[storage_type] = File._prefix_to_storage[storage_type]()
+
+        return G_STORAGES[storage_type]
+
+    @staticmethod
+    def read(uri: str) -> bytes:
+        """Read data from a given ``filepath`` with 'rb' mode.
+
+        Args:
+            filepath (str or Path): Path to read data.
+
+        Returns:
+            bytes: Expected bytes object.
+        """
+        storage = File._get_storage(uri)
+        return storage.read(uri)
+
+    @staticmethod
+    def read_text(uri: Union[str, Path], encoding: str = 'utf-8') -> str:
+        """Read data from a given ``filepath`` with 'r' mode.
+
+        Args:
+            filepath (str or Path): Path to read data.
+            encoding (str): The encoding format used to open the ``filepath``.
+                Default: 'utf-8'.
+
+        Returns:
+            str: Expected text reading from ``filepath``.
+        """
+        storage = File._get_storage(uri)
+        return storage.read_text(uri)
+
+    @staticmethod
+    def write(obj: bytes, uri: Union[str, Path]) -> None:
+        """Write data to a given ``filepath`` with 'wb' mode.
+
+        Note:
+            ``write`` will create a directory if the directory of ``filepath``
+            does not exist.
+
+        Args:
+            obj (bytes): Data to be written.
+            filepath (str or Path): Path to write data.
+        """
+        storage = File._get_storage(uri)
+        return storage.write(obj, uri)
+
+    @staticmethod
+    def write_text(obj: str, uri: str, encoding: str = 'utf-8') -> None:
+        """Write data to a given ``filepath`` with 'w' mode.
+
+        Note:
+            ``write_text`` will create a directory if the directory of
+            ``filepath`` does not exist.
+
+        Args:
+            obj (str): Data to be written.
+            filepath (str or Path): Path to write data.
+            encoding (str): The encoding format used to open the ``filepath``.
+                Default: 'utf-8'.
+        """
+        storage = File._get_storage(uri)
+        return storage.write_text(obj, uri)
+
+    @contextlib.contextmanager
+    def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]:
+        """Only for unified API and do nothing."""
+        storage = File._get_storage(uri)
+        with storage.as_local_path(uri) as local_path:
+            yield local_path
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index 9ee4dfc..78a72ec 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -447,7 +447,6 @@
              frontend=None,
              **kwargs,
              ):
-		
 		# init beamsearch
 		is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
 		is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
@@ -475,7 +474,6 @@
 			meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
 			
 		speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
-
 		# Encoder
 		encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
 		if isinstance(encoder_out, tuple):

--
Gitblit v1.9.1