Dev gzf exp (#1663)
* rwkv 5
* rwkv v4
* rwkv v4
* rwkv
* rwkv
* update
* resume from step
* resume from step
* resume from step
* resume from step
* resume from step
* resume from step
* resume from step
* resume from step
* resume from step
* resume from step
* resume from step
* resume from step
* resume from step
* resume from step
* resume from step
* resume from step
* resume from step
* resume from step
* resume from step
| | |
| | | import logging |
| | | |
| | | import torch |
| | | import random |
| | | |
| | |
| | | self.float_pad_value = float_pad_value |
| | | self.sos = kwargs.get("sos", "<|startoftranscript|>") |
| | | self.eos = kwargs.get("eos", "<|endoftext|>") |
| | | self.batch_size = kwargs.get("batch_size") |
| | | self.batch_type = kwargs.get("batch_type") |
| | | |
| | | def get_source_len(self, index): |
| | | item = self.index_ds[index] |
| | |
| | | outputs[key] = torch.nn.utils.rnn.pad_sequence( |
| | | data_list, batch_first=True, padding_value=pad_value |
| | | ) |
| | | |
| | | if self.batch_type != "example": |
| | | b, t, _ = outputs["speech"].shape |
| | | if b * t > self.batch_size: |
| | | beg = torch.randint(0, 2, ()).item() |
| | | logging.info( |
| | | f"Warning, b * t: {b * t} > {self.batch_size}, drop half data 1st, beg:{beg}" |
| | | ) |
| | | for key, data_list in outputs.items(): |
| | | outputs[key] = outputs[key][beg : beg + b : 2] |
| | | |
| | | b, t, _ = outputs["speech"].shape |
| | | if b * t > self.batch_size: |
| | | beg = torch.randint(0, 2, ()).item() |
| | | logging.info( |
| | | f"Warning, b * t: {b * t} > {self.batch_size}, drop half data 2nd, beg:{beg}" |
| | | ) |
| | | for key, data_list in outputs.items(): |
| | | outputs[key] = outputs[key][beg : beg + b : 2] |
| | | |
| | | b, t, _ = outputs["speech"].shape |
| | | if b * t > self.batch_size: |
| | | beg = torch.randint(0, 2, ()).item() |
| | | logging.info( |
| | | f"Warning, b * t: {b * t} > {self.batch_size}, drop half data 3th, beg:{beg}" |
| | | ) |
| | | for key, data_list in outputs.items(): |
| | | outputs[key] = outputs[key][beg : beg + b : 2] |
| | | return outputs |
| | |
| | | from funasr.models.sense_voice.rwkv_v6 import RWKV_Tmix_x060 as RWKV_Tmix |
| | | # self.attn = RWKVLayer(args=args, layer_id=layer_id) |
| | | self.self_attn = RWKV_Tmix(args, layer_id=layer_id) |
| | | if args.get("datatype", "bf16") == "bf16": |
| | | self.self_attn.to(torch.bfloat16) |
| | | # self.norm1.to(torch.bfloat16) |
| | | |
| | | self.args = args |
| | | self.ln0 = None |
| | | if self.layer_id == 0 and not args.get("ln0", True): |
| | |
| | | nn.init.orthogonal_(self.self_attn.gate.weight, gain=0.1) |
| | | nn.init.zeros_(self.self_attn.output.weight) |
| | | |
| | | if args.get("datatype", "bf16") == "bf16": |
| | | self.self_attn.to(torch.bfloat16) |
| | | # self.norm1.to(torch.bfloat16) |
| | | |
| | | def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None): |
| | | """Compute decoded features. |
| | | |
| | |
| | | import logging |
| | | from dataclasses import dataclass |
| | | from typing import Dict |
| | | from typing import Iterable, Optional |