lzr265946
2023-02-03 1d97d628f2f19674fa50495e984db8185604ca8e
Merge branch 'main' into dev
14个文件已修改
603 ■■■■■ 已修改文件
docs/images/dingding.jpg 补丁 | 查看 | 原始文档 | blame | 历史
docs/images/wechat.png 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer.py 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer_timestamp.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer_vad_punc.py 107 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_uniasr.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/punc_inference_launch.py 18 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/punctuation_infer.py 328 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/punctuation/abs_model.py 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/punctuation/espnet_model.py 41 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/punctuation/target_delay_transformer.py 30 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/punctuation/text_preprocessor.py 21 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/grpc/Readme.md 38 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
docs/images/dingding.jpg

docs/images/wechat.png

funasr/bin/asr_inference.py
@@ -368,7 +368,7 @@
#         except TooShortUttError as e:
#             logging.warning(f"Utterance {keys} {e}")
#             hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
#             results = [[" ", ["<space>"], [2], hyp]] * nbest
#             results = [[" ", ["sil"], [2], hyp]] * nbest
#
#         # Only supporting batch_size==1
#         key = keys[0]
@@ -577,7 +577,7 @@
            except TooShortUttError as e:
                logging.warning(f"Utterance {keys} {e}")
                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
                results = [[" ", ["<space>"], [2], hyp]] * nbest
                results = [[" ", ["sil"], [2], hyp]] * nbest
            
            # Only supporting batch_size==1
            key = keys[0]
funasr/bin/asr_inference_paraformer.py
@@ -227,6 +227,8 @@
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
                                                                        predictor_outs[2], predictor_outs[3]
        pre_token_length = pre_token_length.round().long()
        if torch.max(pre_token_length) < 1:
            return []
        decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
        decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
@@ -394,7 +396,7 @@
#         results = speech2text(**batch)
#         if len(results) < 1:
#             hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
#             results = [[" ", ["<space>"], [2], hyp, 10, 6]] * nbest
#             results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
#         time_end = time.time()
#         forward_time = time_end - time_beg
#         lfr_factor = results[0][-1]
@@ -623,7 +625,7 @@
            results = speech2text(**batch)
            if len(results) < 1:
                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
                results = [[" ", ["<space>"], [2], hyp, 10, 6]] * nbest
                results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
            time_end = time.time()
            forward_time = time_end - time_beg
            lfr_factor = results[0][-1]
funasr/bin/asr_inference_paraformer_timestamp.py
@@ -410,7 +410,7 @@
        results = speech2text(**batch)
        if len(results) < 1:
            hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
            results = [[" ", ["<space>"], [2], hyp, 10, 6]] * nbest
            results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
        time_end = time.time()
        forward_time = time_end - time_beg
        lfr_factor = results[0][-1]
funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -1,9 +1,10 @@
#!/usr/bin/env python3
import json
import argparse
import logging
import sys
import time
import json
from pathlib import Path
from typing import Optional
from typing import Sequence
@@ -38,10 +39,10 @@
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.vad import VADTask
from funasr.utils.timestamp_tools import time_stamp_lfr6
from funasr.tasks.punctuation import PunctuationTask
from funasr.bin.punctuation_infer import Text2Punc
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
from funasr.datasets.preprocessor import CommonPreprocessor
from funasr.punctuation.text_preprocessor import split_words, split_to_mini_sentence
from funasr.punctuation.text_preprocessor import split_to_mini_sentence
header_colors = '\033[95m'
end_colors = '\033[0m'
@@ -236,6 +237,8 @@
        predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], predictor_outs[2], predictor_outs[3]
        pre_token_length = pre_token_length.round().long()
        if torch.max(pre_token_length) < 1:
            return []
        decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
        decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
@@ -604,7 +607,7 @@
                    results = speech2text(**batch)
                    if len(results) < 1:
                        hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
                        results = [[" ", ["<space>"], [2], 0, 1, 6]] * nbest
                        results = [[" ", ["sil"], [2], 0, 1, 6]] * nbest
                    time_end = time.time()
                    forward_time = time_end - time_beg
                    lfr_factor = results[0][-1]
@@ -678,102 +681,6 @@
        logging.info("decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".
                     format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor+1e-6)))
        return asr_result_list
    return _forward
def Text2Punc(
    train_config: Optional[str],
    model_file: Optional[str],
    device: str = "cpu",
    dtype: str = "float32",
):
    # 2. Build Model
    model, train_args = PunctuationTask.build_model_from_file(
        train_config, model_file, device)
    # Wrape model to make model.nll() data-parallel
    wrapped_model = ForwardAdaptor(model, "inference")
    wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
    # logging.info(f"Model:\n{model}")
    punc_list = train_args.punc_list
    period = 0
    for i in range(len(punc_list)):
        if punc_list[i] == ",":
            punc_list[i] = ","
        elif punc_list[i] == "?":
            punc_list[i] = "?"
        elif punc_list[i] == "。":
            period = i
    preprocessor = CommonPreprocessor(
        train=False,
        token_type="word",
        token_list=train_args.token_list,
        bpemodel=train_args.bpemodel,
        text_cleaner=train_args.cleaner,
        g2p_type=train_args.g2p,
        text_name="text",
        non_linguistic_symbols=train_args.non_linguistic_symbols,
    )
    print("start decoding!!!")
    def _forward(words, split_size = 20):
        cache_sent = []
        mini_sentences = split_to_mini_sentence(words, split_size)
        new_mini_sentence = ""
        new_mini_sentence_punc = []
        cache_pop_trigger_limit = 200
        for mini_sentence_i in range(len(mini_sentences)):
            mini_sentence = mini_sentences[mini_sentence_i]
            mini_sentence = cache_sent + mini_sentence
            data = {"text": " ".join(mini_sentence)}
            batch = preprocessor(data=data, uid="12938712838719")
            batch["text_lengths"] = torch.from_numpy(np.array([len(batch["text"])], dtype='int32'))
            batch["text"] = torch.from_numpy(batch["text"])
            # Extend one dimension to fake a batch dim.
            batch["text"] = torch.unsqueeze(batch["text"], 0)
            batch = to_device(batch, device)
            y, _ = wrapped_model(**batch)
            _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
            punctuations = indices
            if indices.size()[0] != 1:
                punctuations = torch.squeeze(indices)
            assert punctuations.size()[0] == len(mini_sentence)
            # Search for the last Period/QuestionMark as cache
            if mini_sentence_i < len(mini_sentences) - 1:
                sentenceEnd = -1
                last_comma_index = -1
                for i in range(len(punctuations) - 2, 1, -1):
                    if punc_list[punctuations[i]] == "。" or punc_list[punctuations[i]] == "?":
                        sentenceEnd = i
                        break
                    if last_comma_index < 0 and punc_list[punctuations[i]] == ",":
                        last_comma_index = i
                if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
                    # The sentence it too long, cut off at a comma.
                    sentenceEnd = last_comma_index
                    punctuations[sentenceEnd] = period
                cache_sent = mini_sentence[sentenceEnd + 1:]
                mini_sentence = mini_sentence[0:sentenceEnd + 1]
                punctuations = punctuations[0:sentenceEnd + 1]
            # if len(punctuations) == 0:
            #    continue
            punctuations_np = punctuations.cpu().numpy()
            new_mini_sentence_punc += [int(x) for x in punctuations_np]
            words_with_punc = []
            for i in range(len(mini_sentence)):
                if i > 0:
                    if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
                        mini_sentence[i] = " " + mini_sentence[i]
                words_with_punc.append(mini_sentence[i])
                if punc_list[punctuations[i]] != "_":
                    words_with_punc.append(punc_list[punctuations[i]])
            new_mini_sentence += "".join(words_with_punc)
        return new_mini_sentence, new_mini_sentence_punc
    return _forward
def get_parser():
funasr/bin/asr_inference_uniasr.py
@@ -391,7 +391,7 @@
#         except TooShortUttError as e:
#             logging.warning(f"Utterance {keys} {e}")
#             hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
#             results = [[" ", ["<space>"], [2], hyp]] * nbest
#             results = [[" ", ["sil"], [2], hyp]] * nbest
#
#         # Only supporting batch_size==1
#         key = keys[0]
@@ -618,7 +618,7 @@
            except TooShortUttError as e:
                logging.warning(f"Utterance {keys} {e}")
                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
                results = [[" ", ["<space>"], [2], hyp]] * nbest
                results = [[" ", ["sil"], [2], hyp]] * nbest
    
            # Only supporting batch_size==1
            key = keys[0]
funasr/bin/punc_inference_launch.py
@@ -59,26 +59,18 @@
    )
    group = parser.add_argument_group("Input data related")
    group.add_argument(
        "--data_path_and_name_and_type",
        type=str2triple_str,
        action="append",
        required=False
    )
    group.add_argument(
        "--raw_inputs",
        type=str,
        required=False
    )
    group.add_argument("--data_path_and_name_and_type", type=str2triple_str, action="append", required=False)
    group.add_argument("--raw_inputs", type=str, required=False)
    group.add_argument("--key_file", type=str_or_none)
    group.add_argument("--cache", type=list, required=False)
    group.add_argument("--param_dict", type=dict, required=False)
    group = parser.add_argument_group("The model configuration related")
    group.add_argument("--train_config", type=str)
    group.add_argument("--model_file", type=str)
    group.add_argument("--mode", type=str, default="punc")
    return parser
def inference_launch(mode, **kwargs):
    if mode == "punc":
        from funasr.bin.punctuation_infer import inference_modelscope
funasr/bin/punctuation_infer.py
@@ -3,33 +3,141 @@
import logging
from pathlib import Path
import sys
import os
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
from typing import Any
from typing import List
import numpy as np
import torch
from torch.nn.parallel import data_parallel
from typeguard import check_argument_types
from funasr.datasets.preprocessor import CommonPreprocessor
from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
from funasr.utils.cli_utils import get_commandline_args
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.tasks.punctuation import PunctuationTask
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.types import float_or_none
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.punctuation.text_preprocessor import split_words, split_to_mini_sentence
from funasr.punctuation.text_preprocessor import split_to_mini_sentence
class Text2Punc:
    def __init__(
        self,
        train_config: Optional[str],
        model_file: Optional[str],
        device: str = "cpu",
        dtype: str = "float32",
    ):
        #  Build Model
        model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
        self.device = device
        # Wrape model to make model.nll() data-parallel
        self.wrapped_model = ForwardAdaptor(model, "inference")
        self.wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
        # logging.info(f"Model:\n{model}")
        self.punc_list = train_args.punc_list
        self.period = 0
        for i in range(len(self.punc_list)):
            if self.punc_list[i] == ",":
                self.punc_list[i] = ","
            elif self.punc_list[i] == "?":
                self.punc_list[i] = "?"
            elif self.punc_list[i] == "。":
                self.period = i
        self.preprocessor = CodeMixTokenizerCommonPreprocessor(
            train=False,
            token_type=train_args.token_type,
            token_list=train_args.token_list,
            bpemodel=train_args.bpemodel,
            text_cleaner=train_args.cleaner,
            g2p_type=train_args.g2p,
            text_name="text",
            non_linguistic_symbols=train_args.non_linguistic_symbols,
        )
        print("start decoding!!!")
    @torch.no_grad()
    def __call__(self, text: Union[list, str], split_size=20):
        data = {"text": text}
        result = self.preprocessor(data=data, uid="12938712838719")
        split_text = self.preprocessor.pop_split_text_data(result)
        mini_sentences = split_to_mini_sentence(split_text, split_size)
        mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
        assert len(mini_sentences) == len(mini_sentences_id)
        cache_sent = []
        cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
        new_mini_sentence = ""
        new_mini_sentence_punc = []
        cache_pop_trigger_limit = 200
        for mini_sentence_i in range(len(mini_sentences)):
            mini_sentence = mini_sentences[mini_sentence_i]
            mini_sentence_id = mini_sentences_id[mini_sentence_i]
            mini_sentence = cache_sent + mini_sentence
            mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
            data = {
                "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
                "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
            }
            data = to_device(data, self.device)
            y, _ = self.wrapped_model(**data)
            _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
            punctuations = indices
            if indices.size()[0] != 1:
                punctuations = torch.squeeze(indices)
            assert punctuations.size()[0] == len(mini_sentence)
            # Search for the last Period/QuestionMark as cache
            if mini_sentence_i < len(mini_sentences) - 1:
                sentenceEnd = -1
                last_comma_index = -1
                for i in range(len(punctuations) - 2, 1, -1):
                    if self.punc_list[punctuations[i]] == "。" or self.punc_list[punctuations[i]] == "?":
                        sentenceEnd = i
                        break
                    if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",":
                        last_comma_index = i
                if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
                    # The sentence it too long, cut off at a comma.
                    sentenceEnd = last_comma_index
                    punctuations[sentenceEnd] = self.period
                cache_sent = mini_sentence[sentenceEnd + 1:]
                cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
                mini_sentence = mini_sentence[0:sentenceEnd + 1]
                punctuations = punctuations[0:sentenceEnd + 1]
            # if len(punctuations) == 0:
            #    continue
            punctuations_np = punctuations.cpu().numpy()
            new_mini_sentence_punc += [int(x) for x in punctuations_np]
            words_with_punc = []
            for i in range(len(mini_sentence)):
                if i > 0:
                    if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
                        mini_sentence[i] = " " + mini_sentence[i]
                words_with_punc.append(mini_sentence[i])
                if self.punc_list[punctuations[i]] != "_":
                    words_with_punc.append(self.punc_list[punctuations[i]])
            new_mini_sentence += "".join(words_with_punc)
            # Add Period for the end of the sentence
            new_mini_sentence_out = new_mini_sentence
            new_mini_sentence_punc_out = new_mini_sentence_punc
            if mini_sentence_i == len(mini_sentences) - 1:
                if new_mini_sentence[-1] == "," or new_mini_sentence[-1] == "、":
                    new_mini_sentence_out = new_mini_sentence[:-1] + "。"
                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
                elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?":
                    new_mini_sentence_out = new_mini_sentence + "。"
                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
        return new_mini_sentence_out, new_mini_sentence_punc_out
def inference(
@@ -45,12 +153,12 @@
    key_file: Optional[str] = None,
    data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
    raw_inputs: Union[List[Any], bytes, str] = None,
    cache: List[Any] = None,
    param_dict: dict = None,
    **kwargs,
):
    inference_pipeline = inference_modelscope(
        output_dir=output_dir,
        raw_inputs=raw_inputs,
        batch_size=batch_size,
        dtype=dtype,
        ngpu=ngpu,
@@ -60,6 +168,7 @@
        key_file=key_file,
        train_config=train_config,
        model_file=model_file,
        param_dict=param_dict,
        **kwargs,
    )
    return inference_pipeline(data_path_and_name_and_type, raw_inputs)
@@ -76,6 +185,7 @@
    train_config: Optional[str],
    model_file: Optional[str],
    output_dir: Optional[str] = None,
    param_dict: dict = None,
    **kwargs,
):
    assert check_argument_types()
@@ -91,41 +201,14 @@
    # 1. Set random-seed
    set_all_random_seed(seed)
    # 2. Build Model
    model, train_args = PunctuationTask.build_model_from_file(
        train_config, model_file, device)
    # Wrape model to make model.nll() data-parallel
    wrapped_model = ForwardAdaptor(model, "inference")
    wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
    logging.info(f"Model:\n{model}")
    punc_list = train_args.punc_list
    period = 0
    for i in range(len(punc_list)):
        if punc_list[i] == ",":
            punc_list[i] = ","
        elif punc_list[i] == "?":
            punc_list[i] = "?"
        elif punc_list[i] == "。":
            period = i
    preprocessor = CommonPreprocessor(
        train=False,
        token_type="word",
        token_list=train_args.token_list,
        bpemodel=train_args.bpemodel,
        text_cleaner=train_args.cleaner,
        g2p_type=train_args.g2p,
        text_name="text",
        non_linguistic_symbols=train_args.non_linguistic_symbols,
    )
    print("start decoding!!!")
    text2punc = Text2Punc(train_config, model_file, device)
    def _forward(
        data_path_and_name_and_type,
        raw_inputs: Union[List[Any], bytes, str] = None,
        output_dir_v2: Optional[str] = None,
        cache: List[Any] = None,
        param_dict: dict = None,
    ):
        results = []
        split_size = 20
@@ -133,77 +216,14 @@
        if raw_inputs != None:
            line = raw_inputs.strip()
            key = "demo"
            if line=="":
            if line == "":
                item = {'key': key, 'value': ""}
                results.append(item)
                return results
            cache_sent = []
            words = split_words(line)
            new_mini_sentence = ""
            new_mini_sentence_punc = ""
            cache_pop_trigger_limit = 200
            mini_sentences = split_to_mini_sentence(words, split_size)
            for mini_sentence_i in range(len(mini_sentences)):
                mini_sentence = mini_sentences[mini_sentence_i]
                mini_sentence = cache_sent + mini_sentence
                data = {"text": " ".join(mini_sentence)}
                batch = preprocessor(data=data, uid="12938712838719")
                batch["text_lengths"] = torch.from_numpy(
                    np.array([len(batch["text"])], dtype='int32'))
                batch["text"] = torch.from_numpy(batch["text"])
                # Extend one dimension to fake a batch dim.
                batch["text"] = torch.unsqueeze(batch["text"], 0)
                batch = to_device(batch, device)
                y, _ = wrapped_model(**batch)
                _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
                punctuations = indices
                if indices.size()[0] != 1:
                    punctuations = torch.squeeze(indices)
                assert punctuations.size()[0] == len(mini_sentence)
                # Search for the last Period/QuestionMark as cache
                if mini_sentence_i < len(mini_sentences)-1:
                    sentenceEnd = -1
                    last_comma_index = -1
                    for i in range(len(punctuations)-2,1,-1):
                        if punc_list[punctuations[i]] == "。" or punc_list[punctuations[i]] == "?":
                            sentenceEnd = i
                            break
                        if last_comma_index < 0 and punc_list[punctuations[i]] == ",":
                            last_comma_index = i
                    if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
                        # The sentence it too long, cut off at a comma.
                        sentenceEnd = last_comma_index
                        punctuations[sentenceEnd] = period
                    cache_sent = mini_sentence[sentenceEnd+1:]
                    mini_sentence = mini_sentence[0:sentenceEnd+1]
                    punctuations = punctuations[0:sentenceEnd+1]
                punctuations_np = punctuations.cpu().numpy()
                new_mini_sentence_punc += "".join([str(x) for x in punctuations_np])
                words_with_punc = []
                for i in range(len(mini_sentence)):
                    if i>0:
                        if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i-1][0].encode()) == 1:
                            mini_sentence[i] = " "+ mini_sentence[i]
                    words_with_punc.append(mini_sentence[i])
                    if punc_list[punctuations[i]] != "_":
                        words_with_punc.append(punc_list[punctuations[i]])
                new_mini_sentence += "".join(words_with_punc)
                # Add Period for the end of the sentence
                new_mini_sentence_out = new_mini_sentence
                new_mini_sentence_punc_out = new_mini_sentence_punc
                if mini_sentence_i == len(mini_sentences)-1:
                    if new_mini_sentence[-1]=="," or new_mini_sentence[-1]=="、":
                        new_mini_sentence_out = new_mini_sentence[:-1] + "。"
                        new_mini_sentence_punc_out  = new_mini_sentence_punc[:-1] + str(period)
                    elif new_mini_sentence[-1]!="。" and new_mini_sentence[-1]!="?":
                        new_mini_sentence_out=new_mini_sentence+"。"
                        new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + str(period)
                    item = {'key': key, 'value': new_mini_sentence_out}
                    results.append(item)
            result, _ = text2punc(line)
            item = {'key': key, 'value': result}
            results.append(item)
            print(results)
            return results
        for inference_text, _, _ in data_path_and_name_and_type:
@@ -216,72 +236,9 @@
                    key = segs[0]
                    if len(segs[1]) == 0:
                        continue
                    cache_sent = []
                    words = split_words(segs[1])
                    new_mini_sentence = ""
                    new_mini_sentence_punc = ""
                    cache_pop_trigger_limit = 200
                    mini_sentences = split_to_mini_sentence(words, split_size)
                    for mini_sentence_i in range(len(mini_sentences)):
                        mini_sentence = mini_sentences[mini_sentence_i]
                        mini_sentence = cache_sent + mini_sentence
                        data = {"text": " ".join(mini_sentence)}
                        batch = preprocessor(data=data, uid="12938712838719")
                        batch["text_lengths"] = torch.from_numpy(
                            np.array([len(batch["text"])], dtype='int32'))
                        batch["text"] = torch.from_numpy(batch["text"])
                        # Extend one dimension to fake a batch dim.
                        batch["text"] = torch.unsqueeze(batch["text"], 0)
                        batch = to_device(batch, device)
                        y, _ = wrapped_model(**batch)
                        _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
                        punctuations = indices
                        if indices.size()[0] != 1:
                            punctuations = torch.squeeze(indices)
                        assert punctuations.size()[0] == len(mini_sentence)
                        # Search for the last Period/QuestionMark as cache
                        if mini_sentence_i < len(mini_sentences)-1:
                            sentenceEnd = -1
                            last_comma_index = -1
                            for i in range(len(punctuations)-2,1,-1):
                                if punc_list[punctuations[i]] == "。" or punc_list[punctuations[i]] == "?":
                                    sentenceEnd = i
                                    break
                                if last_comma_index < 0 and punc_list[punctuations[i]] == ",":
                                    last_comma_index = i
                            if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
                                # The sentence it too long, cut off at a comma.
                                sentenceEnd = last_comma_index
                                punctuations[sentenceEnd] = period
                            cache_sent = mini_sentence[sentenceEnd+1:]
                            mini_sentence = mini_sentence[0:sentenceEnd+1]
                            punctuations = punctuations[0:sentenceEnd+1]
                        punctuations_np = punctuations.cpu().numpy()
                        new_mini_sentence_punc += "".join([str(x) for x in punctuations_np])
                        words_with_punc = []
                        for i in range(len(mini_sentence)):
                            if i>0:
                                if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i-1][0].encode()) == 1:
                                    mini_sentence[i] = " "+ mini_sentence[i]
                            words_with_punc.append(mini_sentence[i])
                            if punc_list[punctuations[i]] != "_":
                                words_with_punc.append(punc_list[punctuations[i]])
                        new_mini_sentence += "".join(words_with_punc)
                        # Add Period for the end of the sentence
                        new_mini_sentence_out = new_mini_sentence
                        new_mini_sentence_punc_out = new_mini_sentence_punc
                        if mini_sentence_i == len(mini_sentences)-1:
                            if new_mini_sentence[-1]=="," or new_mini_sentence[-1]=="、":
                                new_mini_sentence_out = new_mini_sentence[:-1] + "。"
                                new_mini_sentence_punc_out  = new_mini_sentence_punc[:-1] + str(period)
                            elif new_mini_sentence[-1]!="。" and new_mini_sentence[-1]!="?":
                                new_mini_sentence_out=new_mini_sentence+"。"
                                new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + str(period)
                            item = {'key': key, 'value': new_mini_sentence_out}
                            results.append(item)
                    result, _ = text2punc(segs[1])
                    item = {'key': key, 'value': result}
                    results.append(item)
        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
        if output_path != None:
            output_file_name = "infer.out"
@@ -293,6 +250,7 @@
                    value_out = item_i["value"]
                    fout.write(f"{key_out}\t{value_out}\n")
        return results
    return _forward
@@ -338,19 +296,11 @@
    )
    group = parser.add_argument_group("Input data related")
    group.add_argument(
        "--data_path_and_name_and_type",
        type=str2triple_str,
        action="append",
        required=False
    )
    group.add_argument(
        "--raw_inputs",
        type=str,
        required=False
    )
    group.add_argument("--data_path_and_name_and_type", type=str2triple_str, action="append", required=False)
    group.add_argument("--raw_inputs", type=str, required=False)
    group.add_argument("--cache", type=list, required=False)
    group.add_argument("--param_dict", type=dict, required=False)
    group.add_argument("--key_file", type=str_or_none)
    group = parser.add_argument_group("The model configuration related")
    group.add_argument("--train_config", type=str)
@@ -364,11 +314,9 @@
    parser = get_parser()
    args = parser.parse_args(cmd)
    kwargs = vars(args)
   # kwargs.pop("config", None)
    # kwargs.pop("config", None)
    inference(**kwargs)
if __name__ == "__main__":
    main()
funasr/punctuation/abs_model.py
@@ -23,7 +23,5 @@
    """
    @abstractmethod
    def forward(
        self, input: torch.Tensor, hidden: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
    def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        raise NotImplementedError
funasr/punctuation/espnet_model.py
@@ -13,6 +13,7 @@
class ESPnetPunctuationModel(AbsESPnetModel):
    def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0):
        assert check_argument_types()
        super().__init__()
@@ -43,8 +44,8 @@
        batch_size = text.size(0)
        # For data parallel
        if max_length is None:
            text = text[:, : text_lengths.max()]
            punc = punc[:, : text_lengths.max()]
            text = text[:, :text_lengths.max()]
            punc = punc[:, :text_lengths.max()]
        else:
            text = text[:, :max_length]
            punc = punc[:, :max_length]
@@ -63,9 +64,11 @@
        # 3. Calc negative log likelihood
        # nll: (BxL,)
        if self.training == False:
            _, indices = y.view(-1, y.shape[-1]).topk(1,dim=1)
            _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
            from sklearn.metrics import f1_score
            f1_score = f1_score(punc.view(-1).detach().cpu().numpy(), indices.squeeze(-1).detach().cpu().numpy(), average='micro')
            f1_score = f1_score(punc.view(-1).detach().cpu().numpy(),
                                indices.squeeze(-1).detach().cpu().numpy(),
                                average='micro')
            nll = torch.Tensor([f1_score]).repeat(text_lengths.sum())
            return nll, text_lengths
        else:
@@ -82,14 +85,12 @@
        nll = nll.view(batch_size, -1)
        return nll, text_lengths
    def batchify_nll(
        self,
        text: torch.Tensor,
        punc: torch.Tensor,
        text_lengths: torch.Tensor,
        punc_lengths: torch.Tensor,
        batch_size: int = 100
    ) -> Tuple[torch.Tensor, torch.Tensor]:
    def batchify_nll(self,
                     text: torch.Tensor,
                     punc: torch.Tensor,
                     text_lengths: torch.Tensor,
                     punc_lengths: torch.Tensor,
                     batch_size: int = 100) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute negative log likelihood(nll) from transformer language model
        To avoid OOM, this fuction seperate the input into batches.
@@ -117,9 +118,7 @@
                batch_punc = punc[start_idx:end_idx, :]
                batch_text_lengths = text_lengths[start_idx:end_idx]
                # batch_nll: [B * T]
                batch_nll, batch_x_lengths = self.nll(
                    batch_text, batch_punc, batch_text_lengths, max_length=max_length
                )
                batch_nll, batch_x_lengths = self.nll(batch_text, batch_punc, batch_text_lengths, max_length=max_length)
                nlls.append(batch_nll)
                x_lengths.append(batch_x_lengths)
                start_idx = end_idx
@@ -131,21 +130,19 @@
        assert x_lengths.size(0) == total_num
        return nll, x_lengths
    def forward(
        self, text: torch.Tensor, punc: torch.Tensor, text_lengths: torch.Tensor, punc_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
    def forward(self, text: torch.Tensor, punc: torch.Tensor, text_lengths: torch.Tensor,
                punc_lengths: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths)
        ntokens = y_lengths.sum()
        loss = nll.sum() / ntokens
        stats = dict(loss=loss.detach())
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
        return loss, stats, weight
    def collect_feats(
        self, text: torch.Tensor, punc: torch.Tensor, text_lengths: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
    def collect_feats(self, text: torch.Tensor, punc: torch.Tensor,
                      text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]:
        return {}
    def inference(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
funasr/punctuation/target_delay_transformer.py
@@ -14,6 +14,7 @@
class TargetDelayTransformer(AbsPunctuation):
    def __init__(
        self,
        vocab_size: int,
@@ -28,7 +29,7 @@
    ):
        super().__init__()
        if pos_enc == "sinusoidal":
#            pos_enc_class = PositionalEncoding
            #            pos_enc_class = PositionalEncoding
            pos_enc_class = SinusoidalPositionEncoder
        elif pos_enc is None:
@@ -47,16 +48,16 @@
            num_blocks=layer,
            dropout_rate=dropout_rate,
            input_layer="pe",
           # pos_enc_class=pos_enc_class,
            # pos_enc_class=pos_enc_class,
            padding_idx=0,
        )
        self.decoder = nn.Linear(att_unit, punc_size)
#    def _target_mask(self, ys_in_pad):
#        ys_mask = ys_in_pad != 0
#        m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
#        return ys_mask.unsqueeze(-2) & m
    def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
        """Compute loss value from buffer sequences.
@@ -67,14 +68,12 @@
        """
        x = self.embed(input)
       # mask = self._target_mask(input)
        # mask = self._target_mask(input)
        h, _, _ = self.encoder(x, text_lengths)
        y = self.decoder(h)
        return y, None
    def score(
        self, y: torch.Tensor, state: Any, x: torch.Tensor
    ) -> Tuple[torch.Tensor, Any]:
    def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
        """Score new token.
        Args:
@@ -89,16 +88,12 @@
        """
        y = y.unsqueeze(0)
        h, _, cache = self.encoder.forward_one_step(
            self.embed(y), self._target_mask(y), cache=state
        )
        h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
        h = self.decoder(h[:, -1])
        logp = h.log_softmax(dim=-1).squeeze(0)
        return logp, cache
    def batch_score(
        self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
    ) -> Tuple[torch.Tensor, List[Any]]:
    def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
        """Score new token batch.
        Args:
@@ -120,15 +115,10 @@
            batch_state = None
        else:
            # transpose state of [batch, layer] into [layer, batch]
            batch_state = [
                torch.stack([states[b][i] for b in range(n_batch)])
                for i in range(n_layers)
            ]
            batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
        # batch decoding
        h, _, states = self.encoder.forward_one_step(
            self.embed(ys), self._target_mask(ys), cache=batch_state
        )
        h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
        h = self.decoder(h[:, -1])
        logp = h.log_softmax(dim=-1)
funasr/punctuation/text_preprocessor.py
@@ -1,24 +1,3 @@
def split_words(text: str):
    words = []
    segs = text.split()
    for seg in segs:
        # There is no space in seg.
        current_word = ""
        for c in seg:
            if len(c.encode()) == 1:
                # This is an ASCII char.
                current_word += c
            else:
                # This is a Chinese char.
                if len(current_word) > 0:
                    words.append(current_word)
                    current_word = ""
                words.append(c)
        if len(current_word) > 0:
            words.append(current_word)
    return words
def split_to_mini_sentence(words: list, word_limit: int = 20):
    assert word_limit > 1
    if len(words) <= word_limit:
funasr/runtime/python/grpc/Readme.md
@@ -5,30 +5,52 @@
## Steps
Step 1) Prepare server environment (on server).
Step 1) Prepare server environment (on server).
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; Install modelscope and funasr with pip or with cuda-docker image.
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; Option 1: Install modelscope and funasr with [pip](https://github.com/alibaba-damo-academy/FunASR#installation)
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; Option 2: or install with cuda-docker image as:
```
# Optional, modelscope cuda docker is preferred.
CID=`docker run --network host -d -it --gpus '"device=0"' registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.3.0-py37-torch1.11.0-tf1.15.5-1.2.0`
echo $CID
docker exec -it $CID /bin/bash
cd /opt/conda/lib/python3.7/site-packages/funasr/runtime/python/grpc
```
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; Get funasr source code and get into grpc directory.
```
git clone https://github.com/alibaba-damo-academy/FunASR
cd FunASR/funasr/runtime/python/grpc/
```
Step 2) Generate protobuf file (for server and client).
Step 2) Optional, generate protobuf file (run on server, the two generated pb files are both used for server and client).
```
# Optional, paraformer_pb2.py and paraformer_pb2_grpc.py are already generated.
# Optional, Install dependency.
python -m pip install grpcio grpcio-tools
```
```
# paraformer_pb2.py and paraformer_pb2_grpc.py are already generated,
# regenerate it only when you make changes to ./proto/paraformer.proto file.
python -m grpc_tools.protoc  --proto_path=./proto -I ./proto    --python_out=. --grpc_python_out=./ ./proto/paraformer.proto
```
Step 3) Start grpc server (on server).
```
# Optional, Install dependency.
python -m pip install grpcio grpcio-tools
```
```
# Start server.
python grpc_main_server.py --port 10095
```
Step 4) Start grpc client (on client with microphone).
```
# Install dependency. Optional.
python -m pip install pyaudio webrtcvad
# Optional, Install dependency.
python -m pip install pyaudio webrtcvad grpcio grpcio-tools
```
```
# Start client.
@@ -41,7 +63,7 @@
## Reference
We borrow or refer to some code from:
We borrow from or refer to some code as:
1)https://github.com/wenet-e2e/wenet/tree/main/runtime/core/grpc