zhifu gao
2024-04-26 e971e000ad582c767ae44c9650470899f5bb46d0
Dev gzf exp (#1663)

* rwkv 5

* rwkv v4

* rwkv v4

* rwkv

* rwkv

* update

* resume from step

* resume from step

* resume from step

* resume from step

* resume from step

* resume from step

* resume from step

* resume from step

* resume from step

* resume from step

* resume from step

* resume from step

* resume from step

* resume from step

* resume from step

* resume from step

* resume from step

* resume from step

* resume from step
3个文件已修改
41 ■■■■■ 已修改文件
funasr/datasets/sense_voice_datasets/datasets.py 32 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/conformer_rwkv/decoder.py 8 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/model.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/sense_voice_datasets/datasets.py
@@ -1,3 +1,5 @@
import logging
import torch
import random
@@ -46,6 +48,8 @@
        self.float_pad_value = float_pad_value
        self.sos = kwargs.get("sos", "<|startoftranscript|>")
        self.eos = kwargs.get("eos", "<|endoftext|>")
        self.batch_size = kwargs.get("batch_size")
        self.batch_type = kwargs.get("batch_type")
    def get_source_len(self, index):
        item = self.index_ds[index]
@@ -124,4 +128,32 @@
                outputs[key] = torch.nn.utils.rnn.pad_sequence(
                    data_list, batch_first=True, padding_value=pad_value
                )
        if self.batch_type != "example":
            b, t, _ = outputs["speech"].shape
            if b * t > self.batch_size:
                beg = torch.randint(0, 2, ()).item()
                logging.info(
                    f"Warning, b * t: {b * t} > {self.batch_size}, drop half data 1st, beg:{beg}"
                )
                for key, data_list in outputs.items():
                    outputs[key] = outputs[key][beg : beg + b : 2]
            b, t, _ = outputs["speech"].shape
            if b * t > self.batch_size:
                beg = torch.randint(0, 2, ()).item()
                logging.info(
                    f"Warning, b * t: {b * t} > {self.batch_size}, drop half data 2nd, beg:{beg}"
                )
                for key, data_list in outputs.items():
                    outputs[key] = outputs[key][beg : beg + b : 2]
            b, t, _ = outputs["speech"].shape
            if b * t > self.batch_size:
                beg = torch.randint(0, 2, ()).item()
                logging.info(
                    f"Warning, b * t: {b * t} > {self.batch_size}, drop half data 3th, beg:{beg}"
                )
                for key, data_list in outputs.items():
                    outputs[key] = outputs[key][beg : beg + b : 2]
        return outputs
funasr/models/conformer_rwkv/decoder.py
@@ -97,9 +97,7 @@
            from funasr.models.sense_voice.rwkv_v6 import RWKV_Tmix_x060 as RWKV_Tmix
        # self.attn = RWKVLayer(args=args, layer_id=layer_id)
        self.self_attn = RWKV_Tmix(args, layer_id=layer_id)
        if args.get("datatype", "bf16") == "bf16":
            self.self_attn.to(torch.bfloat16)
            # self.norm1.to(torch.bfloat16)
        self.args = args
        self.ln0 = None
        if self.layer_id == 0 and not args.get("ln0", True):
@@ -125,6 +123,10 @@
            nn.init.orthogonal_(self.self_attn.gate.weight, gain=0.1)
            nn.init.zeros_(self.self_attn.output.weight)
        if args.get("datatype", "bf16") == "bf16":
            self.self_attn.to(torch.bfloat16)
            # self.norm1.to(torch.bfloat16)
    def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
        """Compute decoded features.
funasr/models/sense_voice/model.py
@@ -1,3 +1,4 @@
import logging
from dataclasses import dataclass
from typing import Dict
from typing import Iterable, Optional