From e971e000ad582c767ae44c9650470899f5bb46d0 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 26 四月 2024 01:11:18 +0800
Subject: [PATCH] Dev gzf exp (#1663)
---
funasr/datasets/sense_voice_datasets/datasets.py | 32 ++++++++++++++++++++++++++++++++
funasr/models/sense_voice/model.py | 1 +
funasr/models/conformer_rwkv/decoder.py | 8 +++++---
3 files changed, 38 insertions(+), 3 deletions(-)
diff --git a/funasr/datasets/sense_voice_datasets/datasets.py b/funasr/datasets/sense_voice_datasets/datasets.py
index 5468ea6..4f14b35 100644
--- a/funasr/datasets/sense_voice_datasets/datasets.py
+++ b/funasr/datasets/sense_voice_datasets/datasets.py
@@ -1,3 +1,5 @@
+import logging
+
import torch
import random
@@ -46,6 +48,8 @@
self.float_pad_value = float_pad_value
self.sos = kwargs.get("sos", "<|startoftranscript|>")
self.eos = kwargs.get("eos", "<|endoftext|>")
+ self.batch_size = kwargs.get("batch_size")
+ self.batch_type = kwargs.get("batch_type")
def get_source_len(self, index):
item = self.index_ds[index]
@@ -124,4 +128,32 @@
outputs[key] = torch.nn.utils.rnn.pad_sequence(
data_list, batch_first=True, padding_value=pad_value
)
+
+ if self.batch_type != "example":
+ b, t, _ = outputs["speech"].shape
+ if b * t > self.batch_size:
+ beg = torch.randint(0, 2, ()).item()
+ logging.info(
+ f"Warning, b * t: {b * t} > {self.batch_size}, drop half data 1st, beg:{beg}"
+ )
+ for key, data_list in outputs.items():
+ outputs[key] = outputs[key][beg : beg + b : 2]
+
+ b, t, _ = outputs["speech"].shape
+ if b * t > self.batch_size:
+ beg = torch.randint(0, 2, ()).item()
+ logging.info(
+ f"Warning, b * t: {b * t} > {self.batch_size}, drop half data 2nd, beg:{beg}"
+ )
+ for key, data_list in outputs.items():
+ outputs[key] = outputs[key][beg : beg + b : 2]
+
+ b, t, _ = outputs["speech"].shape
+ if b * t > self.batch_size:
+ beg = torch.randint(0, 2, ()).item()
+ logging.info(
+ f"Warning, b * t: {b * t} > {self.batch_size}, drop half data 3th, beg:{beg}"
+ )
+ for key, data_list in outputs.items():
+ outputs[key] = outputs[key][beg : beg + b : 2]
return outputs
diff --git a/funasr/models/conformer_rwkv/decoder.py b/funasr/models/conformer_rwkv/decoder.py
index 5e2ac12..4c41049 100644
--- a/funasr/models/conformer_rwkv/decoder.py
+++ b/funasr/models/conformer_rwkv/decoder.py
@@ -97,9 +97,7 @@
from funasr.models.sense_voice.rwkv_v6 import RWKV_Tmix_x060 as RWKV_Tmix
# self.attn = RWKVLayer(args=args, layer_id=layer_id)
self.self_attn = RWKV_Tmix(args, layer_id=layer_id)
- if args.get("datatype", "bf16") == "bf16":
- self.self_attn.to(torch.bfloat16)
- # self.norm1.to(torch.bfloat16)
+
self.args = args
self.ln0 = None
if self.layer_id == 0 and not args.get("ln0", True):
@@ -125,6 +123,10 @@
nn.init.orthogonal_(self.self_attn.gate.weight, gain=0.1)
nn.init.zeros_(self.self_attn.output.weight)
+ if args.get("datatype", "bf16") == "bf16":
+ self.self_attn.to(torch.bfloat16)
+ # self.norm1.to(torch.bfloat16)
+
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
"""Compute decoded features.
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index 4cbb490..b731bb6 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -1,3 +1,4 @@
+import logging
from dataclasses import dataclass
from typing import Dict
from typing import Iterable, Optional
--
Gitblit v1.9.1