From fc68b5ffe453235294a561737d8e84bb6c1689a4 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 25 四月 2024 21:43:47 +0800
Subject: [PATCH] Dev gzf exp (#1661)
---
funasr/models/conformer_rwkv/decoder.py | 46 +
funasr/models/sense_voice/decoder.py | 37 +
wandb/run-20240425_211446-lkqptn01/logs/debug.log | 29 +
wandb/run-20240425_211446-lkqptn01/run-lkqptn01.wandb | 0
funasr/datasets/audio_datasets/espnet_samplers.py | 9
funasr/models/sense_voice/rwkv_v4.py | 412 +++++++++++++++
funasr/models/sense_voice/rwkv_v5.py | 597 ++++++++++++++++++++++
funasr/models/sense_voice/rwkv_v6.py | 19
funasr/train_utils/trainer.py | 55 +
wandb/debug.log | 1
examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py | 3
funasr/datasets/audio_datasets/index_ds.py | 4
funasr/datasets/dataloader_entry.py | 20
funasr/train_utils/model_summary.py | 2
funasr/models/sense_voice/cuda/wkv_cuda.cu | 125 ++++
wandb/run-20240425_211446-lkqptn01/logs/debug-internal.log | 146 +++++
funasr/bin/train.py | 24
wandb/run-20240425_211446-lkqptn01/files/config.yaml | 26
wandb/latest-run | 1
funasr/models/sense_voice/cuda/wkv_op.cpp | 21
wandb/debug-internal.log | 1
21 files changed, 1,527 insertions(+), 51 deletions(-)
diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
index b21348f..0f30a37 100644
--- a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
+++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
@@ -9,6 +9,9 @@
model = AutoModel(model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch")
+mm = model.model
+for p in mm.parameters():
+ print(f"{p.numel()}")
res = model.generate(input=wav_file)
print(res)
# [[beg1, end1], [beg2, end2], .., [begN, endN]]
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index eb1611a..448e464 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -99,7 +99,7 @@
if freeze_param is not None:
if "," in freeze_param:
freeze_param = eval(freeze_param)
- if isinstance(freeze_param, Sequence):
+ if not isinstance(freeze_param, Sequence):
freeze_param = (freeze_param,)
logging.info("freeze_param is not None: %s", freeze_param)
for t in freeze_param:
@@ -107,8 +107,9 @@
if k.startswith(t + ".") or k == t:
logging.info(f"Setting {k}.requires_grad = False")
p.requires_grad = False
+ if local_rank == 0:
+ logging.info(f"{model_summary(model)}")
- logging.info(f"model info: {model_summary(model)}")
if use_ddp:
model = model.cuda(local_rank)
model = DDP(
@@ -145,8 +146,6 @@
else:
model = model.to(device=kwargs.get("device", "cuda"))
- if local_rank == 0:
- logging.info(f"{model}")
kwargs["device"] = next(model.parameters()).device
# optim
@@ -182,7 +181,12 @@
scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None
scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
- trainer.resume_checkpoint(model=model, optim=optim, scheduler=scheduler, scaler=scaler)
+ trainer.resume_checkpoint(
+ model=model,
+ optim=optim,
+ scheduler=scheduler,
+ scaler=scaler,
+ )
tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard")
os.makedirs(tensorboard_dir, exist_ok=True)
@@ -197,8 +201,11 @@
for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
time1 = time.perf_counter()
- for data_split_i in range(dataloader.data_split_num):
- dataloader_tr, dataloader_val = dataloader.build_iter(epoch, data_split_i=data_split_i)
+ for data_split_i in range(trainer.start_data_split_i, dataloader.data_split_num):
+ dataloader_tr, dataloader_val = dataloader.build_iter(
+ epoch, data_split_i=data_split_i, start_step=trainer.start_step
+ )
+ trainer.start_step = 0
trainer.train_epoch(
model=model,
optim=optim,
@@ -211,9 +218,8 @@
data_split_i=data_split_i,
data_split_num=dataloader.data_split_num,
)
-
- torch.cuda.empty_cache()
+ torch.cuda.empty_cache()
trainer.validate_epoch(
model=model, dataloader_val=dataloader_val, epoch=epoch, writer=writer
diff --git a/funasr/datasets/audio_datasets/espnet_samplers.py b/funasr/datasets/audio_datasets/espnet_samplers.py
index 3f14d09..cb30a28 100644
--- a/funasr/datasets/audio_datasets/espnet_samplers.py
+++ b/funasr/datasets/audio_datasets/espnet_samplers.py
@@ -41,6 +41,7 @@
drop_last=False,
is_training: bool = True,
sort_size: int = 1024,
+ start_step: int = 0,
**kwargs,
):
@@ -70,7 +71,9 @@
self.max_token_length = kwargs.get("max_token_length", 2048)
self.min_token_length = kwargs.get("min_token_length", 0)
self.length_scale_source = kwargs.get("length_scale_source", 1.0)
-
+ self.start_step = 0
+ if self.start_step > 0:
+ logging.info(f"Warning, start_step > 0, dataloader start from step: {self.start_step}")
# super().__init__(dataset, num_replicas=num_replicas, rank=rank,
# shuffle=shuffle, drop_last=drop_last)
@@ -92,7 +95,7 @@
max_len_in_batch = 0 # Tracks the max sample length within the current batch
for idx in sorted_indices:
-
+
# original_sample_length = self.dataset.get_source_len(idx)
# if (
# original_sample_length < self.min_token_length
@@ -142,7 +145,7 @@
# Allocate the batches to the current rank
start_idx = self.rank * batches_per_rank
end_idx = start_idx + batches_per_rank
- rank_batches = buffer_batches[start_idx:end_idx]
+ rank_batches = buffer_batches[start_idx + self.start_step : end_idx]
# Return an iterator over the batches for the current rank
return iter(rank_batches)
diff --git a/funasr/datasets/audio_datasets/index_ds.py b/funasr/datasets/audio_datasets/index_ds.py
index da008b4..70581e8 100644
--- a/funasr/datasets/audio_datasets/index_ds.py
+++ b/funasr/datasets/audio_datasets/index_ds.py
@@ -104,10 +104,10 @@
or target_len > self.max_target_length
):
continue
-
+
if (source_len + target_len) > self.max_token_length:
continue
-
+
contents_i = {
"source": source,
"prompt": prompt,
diff --git a/funasr/datasets/dataloader_entry.py b/funasr/datasets/dataloader_entry.py
index 9595805..925b1d3 100644
--- a/funasr/datasets/dataloader_entry.py
+++ b/funasr/datasets/dataloader_entry.py
@@ -14,14 +14,14 @@
frontend=frontend,
tokenizer=tokenizer,
is_training=True,
- **kwargs.get("dataset_conf")
+ **kwargs.get("dataset_conf"),
)
dataset_val = dataset_class(
kwargs.get("valid_data_set_list"),
frontend=frontend,
tokenizer=tokenizer,
is_training=False,
- **kwargs.get("dataset_conf")
+ **kwargs.get("dataset_conf"),
)
# dataloader
@@ -55,14 +55,14 @@
frontend=frontend,
tokenizer=tokenizer,
is_training=True,
- **kwargs.get("dataset_conf")
+ **kwargs.get("dataset_conf"),
)
dataset_val = dataset_class(
kwargs.get("valid_data_set_list"),
frontend=frontend,
tokenizer=tokenizer,
is_training=False,
- **kwargs.get("dataset_conf")
+ **kwargs.get("dataset_conf"),
)
self.dataset_tr = dataset_tr
@@ -76,7 +76,7 @@
self.tokenizer = tokenizer
self.kwargs = kwargs
- def build_iter(self, epoch=0, data_split_i=0, **kwargs):
+ def build_iter(self, epoch=0, data_split_i=0, start_step=0, **kwargs):
# reload dataset slice
if self.data_split_num > 1:
@@ -87,7 +87,7 @@
tokenizer=self.tokenizer,
is_training=True,
**self.kwargs.get("dataset_conf"),
- data_split_i=data_split_i
+ data_split_i=data_split_i,
)
# dataloader
@@ -95,7 +95,9 @@
batch_sampler_val = None
if batch_sampler is not None:
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
- batch_sampler = batch_sampler_class(self.dataset_tr, **self.kwargs.get("dataset_conf"))
+ batch_sampler = batch_sampler_class(
+ self.dataset_tr, start_step=start_step, **self.kwargs.get("dataset_conf")
+ )
batch_sampler_val = batch_sampler_class(
self.dataset_val, is_training=False, **self.kwargs.get("dataset_conf")
)
@@ -121,14 +123,14 @@
frontend=frontend,
tokenizer=tokenizer,
is_training=True,
- **kwargs.get("dataset_conf")
+ **kwargs.get("dataset_conf"),
)
dataset_val = dataset_class(
kwargs.get("valid_data_set_list"),
frontend=frontend,
tokenizer=tokenizer,
is_training=False,
- **kwargs.get("dataset_conf")
+ **kwargs.get("dataset_conf"),
)
return dataset_tr, dataset_val
diff --git a/funasr/models/conformer_rwkv/decoder.py b/funasr/models/conformer_rwkv/decoder.py
index 90e56e5..5e2ac12 100644
--- a/funasr/models/conformer_rwkv/decoder.py
+++ b/funasr/models/conformer_rwkv/decoder.py
@@ -29,6 +29,11 @@
from funasr.register import tables
+class LayerNorm(nn.LayerNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+
class DecoderLayer(nn.Module):
"""Single decoder layer module.
@@ -54,7 +59,7 @@
def __init__(
self,
size,
- self_attn,
+ # self_attn,
src_attn,
feed_forward,
dropout_rate,
@@ -62,11 +67,12 @@
concat_after=False,
layer_id=None,
args={},
+ **kwargs,
):
"""Construct an DecoderLayer object."""
super(DecoderLayer, self).__init__()
self.size = size
- self.self_attn = self_attn.to(torch.bfloat16)
+ # self.self_attn = self_attn.to(torch.bfloat16)
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
@@ -79,6 +85,22 @@
self.concat_linear1 = nn.Linear(size + size, size)
self.concat_linear2 = nn.Linear(size + size, size)
self.layer_id = layer_id
+
+ if args.get("version", "v4") == "v4":
+ from funasr.models.sense_voice.rwkv_v4 import RWKVLayer
+ from funasr.models.sense_voice.rwkv_v4 import RWKV_TimeMix as RWKV_Tmix
+ elif args.get("version", "v5") == "v5":
+ from funasr.models.sense_voice.rwkv_v5 import RWKVLayer
+ from funasr.models.sense_voice.rwkv_v5 import RWKV_Tmix_x052 as RWKV_Tmix
+ else:
+ from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
+ from funasr.models.sense_voice.rwkv_v6 import RWKV_Tmix_x060 as RWKV_Tmix
+ # self.attn = RWKVLayer(args=args, layer_id=layer_id)
+ self.self_attn = RWKV_Tmix(args, layer_id=layer_id)
+ if args.get("datatype", "bf16") == "bf16":
+ self.self_attn.to(torch.bfloat16)
+ # self.norm1.to(torch.bfloat16)
+ self.args = args
self.ln0 = None
if self.layer_id == 0 and not args.get("ln0", True):
self.ln0 = LayerNorm(args.n_embd)
@@ -93,7 +115,15 @@
print("init_rwkv")
scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
nn.init.constant_(self.norm1.weight, scale)
- nn.init.constant_(self.self_attn.ln2.weight, scale)
+ # nn.init.constant_(self.self_attn.ln2.weight, scale)
+
+ if args.get("init_rwkv", True):
+ print("init_rwkv")
+ nn.init.orthogonal_(self.self_attn.receptance.weight, gain=1)
+ nn.init.orthogonal_(self.self_attn.key.weight, gain=0.1)
+ nn.init.orthogonal_(self.self_attn.value.weight, gain=1)
+ nn.init.orthogonal_(self.self_attn.gate.weight, gain=0.1)
+ nn.init.zeros_(self.self_attn.output.weight)
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
"""Compute decoded features.
@@ -117,6 +147,8 @@
if self.layer_id == 0 and self.ln0 is not None:
tgt = self.ln0(tgt)
+ if self.args.get("datatype", "bf16") == "bf16":
+ tgt = tgt.bfloat16()
residual = tgt
tgt = self.norm1(tgt)
@@ -132,7 +164,8 @@
x = residual + self.dropout(self.self_attn(tgt, mask=tgt_q_mask))
x = x[:, -1, :]
-
+ if self.args.get("datatype", "bf16") == "bf16":
+ x = x.to(torch.float32)
# x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
residual = x
@@ -370,17 +403,16 @@
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
- from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
+ # from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
rwkv_cfg = kwargs.get("rwkv_cfg", {})
args = OmegaConf.create(rwkv_cfg)
- # self.attn = RWKVLayer(args=args, layer_id=layer_id)
+
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
- RWKVLayer(args=args, layer_id=lnum),
MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
diff --git a/funasr/models/sense_voice/cuda/wkv_cuda.cu b/funasr/models/sense_voice/cuda/wkv_cuda.cu
new file mode 100644
index 0000000..6acd0f3
--- /dev/null
+++ b/funasr/models/sense_voice/cuda/wkv_cuda.cu
@@ -0,0 +1,125 @@
+#include <stdio.h>
+#include <assert.h>
+
+#define MIN_VALUE (-1e38)
+
+template <typename F>
+__global__ void kernel_forward(const int B, const int T, const int C,
+ const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
+ F *__restrict__ const _y) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int _b = idx / C;
+ const int _c = idx % C;
+ const int _offset = _b * T * C + _c;
+
+ F u = _u[_c];
+ F w = _w[_c];
+ const F *__restrict__ const k = _k + _offset;
+ const F *__restrict__ const v = _v + _offset;
+ F *__restrict__ const y = _y + _offset;
+
+ F p = 0, q = 0, o = MIN_VALUE;
+ // p and q are running sums divided by exp(o) (to avoid overflows)
+ for (int i = 0; i < T; i++) {
+ const int ii = i * C;
+
+ F no = max(o, u + k[ii]);
+ F A = exp(o - no);
+ F B = exp(u + k[ii] - no);
+ y[ii] = (A * p + B * v[ii]) / (A * q + B);
+
+ no = max(w + o, k[ii]);
+ A = exp(w + o - no);
+ B = exp(k[ii] - no);
+ p = A * p + B * v[ii];
+ q = A * q + B;
+ o = no;
+ }
+}
+
+template <typename F>
+__global__ void kernel_backward(const int B, const int T, const int C,
+ const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy,
+ F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int _b = idx / C;
+ const int _c = idx % C;
+ const int _offset = _b * T * C + _c;
+
+ F u = _u[_c];
+ F w = _w[_c];
+ const F *__restrict__ const k = _k + _offset;
+ const F *__restrict__ const v = _v + _offset;
+ const F *__restrict__ const gy = _gy + _offset;
+
+ F *__restrict__ const gk = _gk + _offset;
+ F *__restrict__ const gv = _gv + _offset;
+
+ F y[Tmax], z[Tmax], zexp[Tmax];
+
+ F gw = 0, gu = 0;
+ F p = 0, q = 0;
+ F dpdw = 0, dqdw = 0;
+ F o = MIN_VALUE;
+ for (int i = 0; i < T; i++) {
+ const int ii = i * C;
+ F no = max(o, k[ii] + u);
+ F A = exp(o - no);
+ F B = exp(k[ii] + u - no);
+
+ F num = A * p + B * v[ii];
+ F iden = 1 / (A * q + B);
+
+ y[i] = num * iden;
+ z[i] = iden;
+ zexp[i] = k[ii] + u - no;
+
+ gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A;
+ gu += gy[ii] * (v[ii] - y[i]) * B * iden;
+
+ no = max(w + o, k[ii]);
+ A = exp(w + o - no);
+ B = exp(k[ii] - no);
+ dpdw = A * (p + dpdw);
+ dqdw = A * (q + dqdw);
+ p = A * p + B * v[ii];
+ q = A * q + B;
+ o = no;
+ }
+
+ F gp = 0, gq = 0;
+ o = MIN_VALUE;
+ for (int i = T - 1; i >= 0; i--) {
+ const int ii = i * C;
+ F A = gy[ii] * z[i] * exp(zexp[i]);
+ F B = exp(k[ii] + o);
+ gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq);
+ gv[ii] = A + B * gp;
+
+ F no = max(w + o, zexp[i] - k[ii] - u);
+ A = exp(w + o - no);
+ B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no);
+ gp = A * gp + B;
+ gq = A * gq - B * y[i];
+ o = no;
+ }
+
+ // Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass
+ const int _offsetBC = _b * C + _c;
+ _gw[_offsetBC] += gw * _w[_c];
+ _gu[_offsetBC] += gu;
+}
+
+void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
+ assert(B * C % threadsPerBlock.x == 0);
+ dim3 numBlocks(B * C / threadsPerBlock.x);
+ kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
+}
+
+void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) {
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
+ assert(B * C % threadsPerBlock.x == 0);
+ dim3 numBlocks(B * C / threadsPerBlock.x);
+ kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv);
+}
diff --git a/funasr/models/sense_voice/cuda/wkv_op.cpp b/funasr/models/sense_voice/cuda/wkv_op.cpp
new file mode 100644
index 0000000..efe56d8
--- /dev/null
+++ b/funasr/models/sense_voice/cuda/wkv_op.cpp
@@ -0,0 +1,21 @@
+#include <torch/extension.h>
+
+void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
+void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv);
+
+void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
+ cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
+}
+void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
+ cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("forward", &forward, "wkv forward");
+ m.def("backward", &backward, "wkv backward");
+}
+
+TORCH_LIBRARY(wkv, m) {
+ m.def("forward", forward);
+ m.def("backward", backward);
+}
diff --git a/funasr/models/sense_voice/decoder.py b/funasr/models/sense_voice/decoder.py
index 5d34ff3..133508f 100644
--- a/funasr/models/sense_voice/decoder.py
+++ b/funasr/models/sense_voice/decoder.py
@@ -156,7 +156,6 @@
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
-from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
from omegaconf import OmegaConf
@@ -168,9 +167,25 @@
rwkv_cfg = kwargs.get("rwkv_cfg", {})
args = OmegaConf.create(rwkv_cfg)
- self.attn = RWKVLayer(args=args, layer_id=layer_id)
- if args.get("datatype", "bf16") == "bf16":
- self.attn.to(torch.bfloat16)
+ if args.get("version", "v4") == "v4":
+ from funasr.models.sense_voice.rwkv_v4 import RWKVLayer
+ from funasr.models.sense_voice.rwkv_v4 import RWKV_TimeMix as RWKV_Tmix
+ elif args.get("version", "v5") == "v5":
+ from funasr.models.sense_voice.rwkv_v5 import RWKVLayer
+ from funasr.models.sense_voice.rwkv_v5 import RWKV_Tmix_x052 as RWKV_Tmix
+ else:
+ from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
+ from funasr.models.sense_voice.rwkv_v6 import RWKV_Tmix_x060 as RWKV_Tmix
+ # self.att = RWKVLayer(args=args, layer_id=layer_id)
+ self.att = RWKV_Tmix(args, layer_id=layer_id)
+
+ if args.get("init_rwkv", True):
+ print("init_rwkv")
+ nn.init.orthogonal_(self.att.receptance.weight, gain=1)
+ nn.init.orthogonal_(self.att.key.weight, gain=0.1)
+ nn.init.orthogonal_(self.att.value.weight, gain=1)
+ nn.init.orthogonal_(self.att.gate.weight, gain=0.1)
+ nn.init.zeros_(self.att.output.weight)
self.ln0 = None
if layer_id == 0 and not args.get("ln0", True):
@@ -180,6 +195,7 @@
layer_id = 0
scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
nn.init.constant_(self.ln0.weight, scale)
+
self.layer_id = layer_id
self.args = args
@@ -191,6 +207,11 @@
print("init_rwkv")
scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
nn.init.constant_(self.ln1.weight, scale)
+
+ if args.get("datatype", "bf16") == "bf16":
+ self.att.to(torch.bfloat16)
+ # if self.ln1 is not None:
+ # self.ln1.to(torch.bfloat16)
self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
@@ -213,10 +234,14 @@
if self.layer_id == 0 and self.ln0 is not None:
x = self.ln0(x)
+ if self.args.get("datatype", "bf16") == "bf16":
+ x = x.bfloat16()
if self.ln1 is None:
- x = x + self.attn(x, mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
+ x = x + self.att(x, mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
else:
- x = x + self.attn(self.ln1(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
+ x = x + self.att(self.ln1(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
+ if self.args.get("datatype", "bf16") == "bf16":
+ x = x.to(torch.float32)
if self.cross_attn:
x = (
diff --git a/funasr/models/sense_voice/rwkv_v4.py b/funasr/models/sense_voice/rwkv_v4.py
new file mode 100644
index 0000000..c154ac0
--- /dev/null
+++ b/funasr/models/sense_voice/rwkv_v4.py
@@ -0,0 +1,412 @@
+########################################################################################################
+# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
+########################################################################################################
+
+import os, math, gc, importlib
+import torch
+
+# torch._C._jit_set_profiling_executor(True)
+# torch._C._jit_set_profiling_mode(True)
+import torch.nn as nn
+from torch.nn import functional as F
+
+
+def __nop(ob):
+ return ob
+
+
+MyModule = nn.Module
+MyFunction = __nop
+if "RWKV_JIT_ON" in os.environ and os.environ["RWKV_JIT_ON"] == "1":
+ MyModule = torch.jit.ScriptModule
+ MyFunction = torch.jit.script_method
+
+########################################################################################################
+# CUDA Kernel
+########################################################################################################
+
+wkv_cuda = None
+
+
+def load_rwkv_kernel(
+ HEAD_SIZE: int = 64,
+ RWKV_CTXLEN: int = 512,
+ T_MAX: int = 512,
+):
+ from torch.utils.cpp_extension import load
+
+ global wkv_cuda
+
+ if wkv_cuda is not None:
+ return
+
+ absolute_file_path = os.path.abspath(__file__)
+ cur_dir = os.path.dirname(absolute_file_path)
+ wkv_cuda = load(
+ name="wkv",
+ sources=[f"{cur_dir}/cuda/wkv_op.cpp", f"{cur_dir}/cuda/wkv_cuda.cu"],
+ verbose=True,
+ extra_cuda_cflags=[
+ "-res-usage",
+ "--maxrregcount 60",
+ "--use_fast_math",
+ "-O3",
+ "-Xptxas -O3",
+ f"-DTmax={T_MAX}",
+ ],
+ )
+
+
+class WKV(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, B, T, C, w, u, k, v):
+ ctx.B = B
+ ctx.T = T
+ ctx.C = C
+ # assert T <= T_MAX
+ assert B * C % min(C, 1024) == 0
+ if "32" in os.environ["RWKV_FLOAT_MODE"]:
+ w = -torch.exp(w.contiguous())
+ u = u.contiguous()
+ k = k.contiguous()
+ v = v.contiguous()
+ else:
+ w = -torch.exp(w.float().contiguous())
+ u = u.float().contiguous()
+ k = k.float().contiguous()
+ v = v.float().contiguous()
+ ctx.save_for_backward(w, u, k, v)
+ y = torch.empty((B, T, C), device="cuda", memory_format=torch.contiguous_format)
+ wkv_cuda.forward(B, T, C, w, u, k, v, y)
+ if "32" in os.environ["RWKV_FLOAT_MODE"]:
+ return y
+ elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
+ return y.half()
+ elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
+ return y.bfloat16()
+
+ @staticmethod
+ def backward(ctx, gy):
+ B = ctx.B
+ T = ctx.T
+ C = ctx.C
+ assert T <= T_MAX
+ assert B * C % min(C, 1024) == 0
+ w, u, k, v = ctx.saved_tensors
+ gw = torch.zeros((B, C), device="cuda").contiguous()
+ gu = torch.zeros((B, C), device="cuda").contiguous()
+ gk = torch.zeros((B, T, C), device="cuda").contiguous()
+ gv = torch.zeros((B, T, C), device="cuda").contiguous()
+ if "32" in os.environ["RWKV_FLOAT_MODE"]:
+ wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
+ else:
+ wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
+ gw = torch.sum(gw, dim=0)
+ gu = torch.sum(gu, dim=0)
+ if "32" in os.environ["RWKV_FLOAT_MODE"]:
+ return (None, None, None, gw, gu, gk, gv)
+ elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
+ return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
+ elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
+ return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
+
+
+def RUN_CUDA(B, T, C, w, u, k, v):
+ return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
+
+
+class RWKV_TimeMix(torch.jit.ScriptModule):
+ def __init__(self, config, layer_id):
+ super().__init__()
+ load_rwkv_kernel()
+ self.layer_id = layer_id
+ self.ctx_len = config.ctx_len
+ self.n_embd = config.n_embd
+
+ attn_sz = config.n_embd
+
+ with torch.no_grad(): # fancy init
+ ratio_0_to_1 = layer_id / (config.n_layer - 1) # 0 to 1
+ ratio_1_to_almost0 = 1.0 - (layer_id / config.n_layer) # 1 to ~0
+
+ # fancy time_decay
+ decay_speed = torch.ones(attn_sz)
+ for h in range(attn_sz):
+ decay_speed[h] = -5 + 8 * (h / (attn_sz - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
+ self.time_decay = nn.Parameter(decay_speed)
+ # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
+
+ # fancy time_first
+ zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(attn_sz)]) * 0.5
+ self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag)
+
+ # fancy time_mix
+ x = torch.ones(1, 1, config.n_embd)
+ for i in range(config.n_embd):
+ x[0, 0, i] = i / config.n_embd
+ self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
+ self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
+ self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
+
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
+
+ self.key = nn.Linear(config.n_embd, attn_sz, bias=False)
+ self.value = nn.Linear(config.n_embd, attn_sz, bias=False)
+ self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False)
+
+ self.output = nn.Linear(attn_sz, config.n_embd, bias=False)
+
+ self.key.scale_init = 0
+ self.receptance.scale_init = 0
+ self.output.scale_init = 0
+
+ @torch.jit.script_method
+ def jit_func(self, x):
+
+ # Mix x with the previous timestep to produce xk, xv, xr
+ xx = self.time_shift(x)
+ xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
+ xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
+ xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
+
+ # Use xk, xv, xr to produce k, v, r
+ k = self.key(xk)
+ v = self.value(xv)
+ r = self.receptance(xr)
+ sr = torch.sigmoid(r)
+
+ return sr, k, v
+
+ def forward(self, x):
+ B, T, C = x.size() # x = (Batch,Time,Channel)
+
+ sr, k, v = self.jit_func(x)
+
+ rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
+ rwkv = self.output(rwkv)
+ return rwkv
+
+
+class RWKV_ChannelMix(torch.jit.ScriptModule):
+ def __init__(self, config, layer_id):
+ super().__init__()
+ self.layer_id = layer_id
+
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
+
+ with torch.no_grad(): # fancy init of time_mix
+ ratio_1_to_almost0 = 1.0 - (layer_id / config.n_layer) # 1 to ~0
+
+ x = torch.ones(1, 1, config.n_embd)
+ for i in range(config.n_embd):
+ x[0, 0, i] = i / config.n_embd
+
+ self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
+ self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
+
+ hidden_sz = 4 * config.n_embd
+ self.key = nn.Linear(config.n_embd, hidden_sz, bias=False)
+ self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)
+ self.value = nn.Linear(hidden_sz, config.n_embd, bias=False)
+
+ self.value.scale_init = 0
+ self.receptance.scale_init = 0
+
+ @torch.jit.script_method
+ def forward(self, x):
+ xx = self.time_shift(x)
+ xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
+ xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
+
+ k = self.key(xk)
+ k = torch.square(torch.relu(k))
+ kv = self.value(k)
+
+ rkv = torch.sigmoid(self.receptance(xr)) * kv
+ return rkv
+
+
+# class Block(nn.Module):
+# def __init__(self, args, layer_id):
+# super().__init__()
+# self.args = args
+# self.layer_id = layer_id
+#
+# self.ln1 = nn.LayerNorm(args.n_embd)
+# self.ln2 = nn.LayerNorm(args.n_embd)
+#
+# if self.layer_id == 0:
+# self.ln0 = nn.LayerNorm(args.n_embd)
+#
+# self.att = RWKV_Tmix_x060(args, layer_id)
+#
+# self.ffn = RWKV_CMix_x060(args, layer_id)
+#
+# if args.dropout > 0:
+# self.drop0 = nn.Dropout(p=args.dropout)
+# self.drop1 = nn.Dropout(p=args.dropout)
+#
+# def forward(self, x, x_emb=None):
+# args = self.args
+# B, T, C = x.size()
+# if self.layer_id == 0:
+# x = self.ln0(x)
+#
+# if self.args.dropout == 0:
+# if self.layer_id == 0 and args.pre_ffn > 0:
+# x = x + self.ffnPre(self.ln1(x))
+# else:
+# x = x + self.att(self.ln1(x))
+# x = x + self.ffn(self.ln2(x))
+# else:
+# if self.layer_id == 0 and args.pre_ffn > 0:
+# x = self.drop0(x + self.ffnPre(self.ln1(x)))
+# else:
+# x = self.drop0(x + self.att(self.ln1(x)))
+# x = self.drop1(x + self.ffn(self.ln2(x)))
+#
+# return x
+
+
+class RWKVLayer(nn.Module):
+ def __init__(self, args, layer_id):
+ super().__init__()
+ self.args = args
+ self.layer_id = layer_id
+ if args.dim_ffn is None:
+ args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32)
+ self.ln0 = None
+ if self.layer_id == 0 and args.get("ln0", True):
+ self.ln0 = nn.LayerNorm(args.n_embd)
+
+ self.ln1 = None
+ if args.get("ln1", True):
+ self.ln1 = nn.LayerNorm(args.n_embd)
+ self.ln2 = nn.LayerNorm(args.n_embd)
+
+ self.att = RWKV_TimeMix(args, layer_id)
+
+ self.ffn = RWKV_ChannelMix(args, layer_id)
+
+ if args.dropout > 0:
+ self.drop0 = nn.Dropout(p=args.dropout)
+ self.drop1 = nn.Dropout(p=args.dropout)
+
+ # init
+ if args.get("init_rwkv", True):
+ print("init_rwkv")
+ nn.init.orthogonal_(self.att.receptance.weight, gain=1)
+ nn.init.orthogonal_(self.att.key.weight, gain=0.1)
+ nn.init.orthogonal_(self.att.value.weight, gain=1)
+ nn.init.orthogonal_(self.att.gate.weight, gain=0.1)
+ nn.init.zeros_(self.att.output.weight)
+
+ nn.init.orthogonal_(self.ffn.key.weight, gain=1)
+ nn.init.zeros_(self.ffn.value.weight)
+ nn.init.zeros_(self.ffn.receptance.weight)
+ scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
+ nn.init.constant_(self.ln2.weight, scale)
+ if self.ln0 is not None:
+ nn.init.constant_(self.ln0.weight, scale)
+ if self.ln1 is not None:
+ nn.init.constant_(self.ln1.weight, scale)
+
+ def forward(self, x, x_emb=None, mask=None, **kwargs):
+
+ args = self.args
+ if args.get("datatype", "bf16") == "bf16":
+ x = x.bfloat16()
+ B, T, C = x.size()
+ if self.layer_id == 0 and self.ln0 is not None:
+ x = self.ln0(x)
+
+ if self.args.dropout == 0:
+ if self.ln1 is None:
+ x = x + self.att(x)
+ else:
+ x = x + self.att(self.ln1(x))
+ x = x + self.ffn(self.ln2(x))
+ else:
+ if self.ln1 is None:
+ x = self.drop0(x + self.att(x))
+ else:
+ x = self.drop0(x + self.att(self.ln1(x)))
+ x = self.drop1(x + self.ffn(self.ln2(x)))
+
+ if args.get("datatype", "bf16") == "bf16":
+ x = x.to(torch.float32)
+ return x
+
+
+class RWKV(nn.Module):
+ def __init__(self, args):
+ super().__init__()
+ self.args = args
+ if not hasattr(args, "dim_att"):
+ args.dim_att = args.n_embd
+ if not hasattr(args, "dim_ffn"):
+ if "-f4" in os.environ["RWKV_MY_TESTING"]:
+ args.dim_ffn = int((args.n_embd * 4) // 32 * 32)
+ else:
+ args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32) # default = 3.5x emb size
+ if not hasattr(args, "tiny_att_layer"):
+ args.tiny_att_layer = -1
+ if not hasattr(args, "tiny_att_dim"):
+ args.tiny_att_dim = -1
+ assert args.n_embd % 32 == 0
+ assert args.dim_att % 32 == 0
+ assert args.dim_ffn % 32 == 0
+
+ self.emb = nn.Embedding(args.vocab_size, args.n_embd)
+
+ self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
+
+ self.ln_out = nn.LayerNorm(args.n_embd)
+ self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
+
+ if args.dropout > 0:
+ self.drop0 = nn.Dropout(p=args.dropout)
+
+ def forward(self, idx):
+ args = self.args
+ B, T = idx.size()
+ assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
+
+ x = self.emb(idx)
+ x_emb = x
+
+ if args.dropout > 0:
+ x = self.drop0(x)
+ if args.tiny_att_dim > 0:
+ for block in self.blocks:
+ if args.grad_cp == 1:
+ x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
+ else:
+ x = block(x, x_emb)
+ else:
+ for block in self.blocks:
+ if args.grad_cp == 1:
+ x = deepspeed.checkpointing.checkpoint(block, x)
+ else:
+ x = block(x)
+
+ x = self.ln_out(x)
+
+ if args.head_qk > 0:
+ q = self.head_q(x)[:, :T, :]
+ k = self.head_k(x)[:, :T, :]
+ c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
+ c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
+
+ if "32" in os.environ["RWKV_FLOAT_MODE"]:
+ c = c @ F.one_hot(idx, num_classes=args.vocab_size)
+ elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
+ c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
+ elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
+ c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
+
+ x = self.head(x) + c
+ else:
+ x = self.head(x)
+
+ return x
diff --git a/funasr/models/sense_voice/rwkv_v5.py b/funasr/models/sense_voice/rwkv_v5.py
new file mode 100644
index 0000000..f19ca79
--- /dev/null
+++ b/funasr/models/sense_voice/rwkv_v5.py
@@ -0,0 +1,597 @@
+########################################################################################################
+# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
+########################################################################################################
+
+import os, math, gc, importlib
+import torch
+
+# torch._C._jit_set_profiling_executor(True)
+# torch._C._jit_set_profiling_mode(True)
+import torch.nn as nn
+from torch.nn import functional as F
+
+
+def __nop(ob):
+ return ob
+
+
+MyModule = nn.Module
+MyFunction = __nop
+if "RWKV_JIT_ON" in os.environ and os.environ["RWKV_JIT_ON"] == "1":
+ MyModule = torch.jit.ScriptModule
+ MyFunction = torch.jit.script_method
+
+########################################################################################################
+# CUDA Kernel
+########################################################################################################
+
+wkv5_cuda = None
+
+
+def load_rwkv_kernel(
+ HEAD_SIZE: int = 64,
+ RWKV_CTXLEN: int = 512,
+):
+ from torch.utils.cpp_extension import load
+
+ global wkv5_cuda
+
+ if wkv5_cuda is not None:
+ return
+
+ absolute_file_path = os.path.abspath(__file__)
+ cur_dir = os.path.dirname(absolute_file_path)
+
+ wkv5_cuda = load(
+ name="wkv5",
+ sources=[f"{cur_dir}/cuda/wkv5_op.cpp", f"{cur_dir}/cuda/wkv5_cuda.cu"],
+ verbose=True,
+ extra_cuda_cflags=[
+ "-res-usage",
+ "--use_fast_math",
+ "-O3",
+ "-Xptxas -O3",
+ "--extra-device-vectorization",
+ f"-D_N_={HEAD_SIZE}",
+ ],
+ )
+
+
+# dtype = torch.float
+dtype = torch.bfloat16
+
+
+class WKV_5(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, B, T, C, H, r, k, v, w, u):
+ with torch.no_grad():
+ # assert r.dtype == torch.bfloat16
+ # assert k.dtype == torch.bfloat16
+ # assert v.dtype == torch.bfloat16
+ # assert w.dtype == torch.bfloat16
+ # assert u.dtype == torch.bfloat16
+ # assert HEAD_SIZE == C // H
+ ctx.B = B
+ ctx.T = T
+ ctx.C = C
+ ctx.H = H
+ assert r.is_contiguous()
+ assert k.is_contiguous()
+ assert v.is_contiguous()
+ assert w.is_contiguous()
+ assert u.is_contiguous()
+ ew = (-torch.exp(w.float())).contiguous()
+ eew = (torch.exp(ew)).contiguous()
+ ctx.save_for_backward(r, k, v, eew, ew, u)
+ y = torch.empty(
+ (B, T, C),
+ device=r.device,
+ dtype=torch.bfloat16,
+ memory_format=torch.contiguous_format,
+ ) # .uniform_(-1, 1)
+ wkv5_cuda.forward(B, T, C, H, r, k, v, eew, u, y)
+ return y
+
+ @staticmethod
+ def backward(ctx, gy):
+ with torch.no_grad():
+ assert gy.dtype == torch.bfloat16
+ B = ctx.B
+ T = ctx.T
+ C = ctx.C
+ H = ctx.H
+ assert gy.is_contiguous()
+ r, k, v, eew, ew, u = ctx.saved_tensors
+ gr = torch.empty(
+ (B, T, C),
+ device=gy.device,
+ requires_grad=False,
+ dtype=torch.bfloat16,
+ memory_format=torch.contiguous_format,
+ ) # .uniform_(-1, 1)
+ gk = torch.empty(
+ (B, T, C),
+ device=gy.device,
+ requires_grad=False,
+ dtype=torch.bfloat16,
+ memory_format=torch.contiguous_format,
+ ) # .uniform_(-1, 1)
+ gv = torch.empty(
+ (B, T, C),
+ device=gy.device,
+ requires_grad=False,
+ dtype=torch.bfloat16,
+ memory_format=torch.contiguous_format,
+ ) # .uniform_(-1, 1)
+ gw = torch.empty(
+ (B, C),
+ device=gy.device,
+ requires_grad=False,
+ dtype=torch.bfloat16,
+ memory_format=torch.contiguous_format,
+ ) # .uniform_(-1, 1)
+ gu = torch.empty(
+ (B, C),
+ device=gy.device,
+ requires_grad=False,
+ dtype=torch.bfloat16,
+ memory_format=torch.contiguous_format,
+ ) # .uniform_(-1, 1)
+ wkv5_cuda.backward(B, T, C, H, r, k, v, eew, ew, u, gy, gr, gk, gv, gw, gu)
+ gw = torch.sum(gw, 0).view(H, C // H)
+ gu = torch.sum(gu, 0).view(H, C // H)
+ return (None, None, None, None, gr, gk, gv, gw, gu)
+
+
+def RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w, u):
+ return WKV_5.apply(B, T, C, H, r, k, v, w, u)
+
+
+class RWKV_Tmix_x052(MyModule):
+ def __init__(self, args, layer_id):
+ super().__init__()
+ self.args = args
+ load_rwkv_kernel(args.head_size_a, args.ctx_len)
+ self.layer_id = layer_id
+
+ self.head_size = args.head_size_a
+ # assert HEAD_SIZE == self.head_size # change HEAD_SIZE to match args.head_size_a
+ self.n_head = args.dim_att // self.head_size
+ assert args.dim_att % self.n_head == 0
+ self.head_size_divisor = args.head_size_divisor
+
+ with torch.no_grad():
+ ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
+ ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
+ ddd = torch.ones(1, 1, args.n_embd)
+ for i in range(args.n_embd):
+ ddd[0, 0, i] = i / args.n_embd
+
+ # fancy time_mix
+ self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
+ self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
+ self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
+ self.time_mix_g = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
+
+ # fancy time_decay
+ decay_speed = torch.ones(args.dim_att)
+ for n in range(args.dim_att):
+ decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
+ self.time_decay = nn.Parameter(decay_speed.reshape(self.n_head, self.head_size))
+ # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
+
+ tmp = torch.zeros(args.dim_att)
+ for n in range(args.dim_att):
+ zigzag = ((n + 1) % 3 - 1) * 0.1
+ tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
+
+ self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
+
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
+ self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
+ self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
+
+ self.value = nn.Linear(args.n_embd, args.dim_att, bias=False)
+ self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
+ self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False)
+ self.ln_x = nn.GroupNorm(self.n_head, args.dim_att)
+
+ @MyFunction
+ def jit_func(self, x):
+ B, T, C = x.size()
+
+ xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
+ xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
+ xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
+ xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
+ xg = x * self.time_mix_g + xx * (1 - self.time_mix_g)
+
+ r = self.receptance(xr)
+ k = self.key(xk)
+ v = self.value(xv)
+ g = F.silu(self.gate(xg))
+
+ return r, k, v, g
+
+ @MyFunction
+ def jit_func_2(self, x, g):
+ B, T, C = x.size()
+ x = x.view(B * T, C)
+
+ x = self.ln_x(x / self.head_size_divisor).view(B, T, C)
+ x = self.output(x * g)
+ return x
+
+ def forward(self, x, **kwargs):
+ B, T, C = x.size()
+ H = self.n_head
+
+ r, k, v, g = self.jit_func(x)
+
+ x = RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w=self.time_decay, u=self.time_faaaa)
+
+ return self.jit_func_2(x, g)
+
+
+#
+# class RWKV_Tmix_x060(MyModule):
+# def __init__(self, args, layer_id):
+# super().__init__()
+# self.args = args
+#
+# load_rwkv_kernel(args.head_size_a, args.ctx_len)
+#
+# self.layer_id = layer_id
+#
+# self.head_size = args.head_size_a
+# self.n_head = args.dim_att // self.head_size
+# assert args.dim_att % self.n_head == 0
+#
+# with torch.no_grad():
+# ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
+# ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
+# ddd = torch.ones(1, 1, args.n_embd)
+# for i in range(args.n_embd):
+# ddd[0, 0, i] = i / args.n_embd
+#
+# # fancy time_mix
+# self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
+# self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
+# self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
+# self.time_maa_v = nn.Parameter(
+# 1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
+# )
+# self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0))
+# self.time_maa_g = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0))
+#
+# D_MIX_LORA = 32 # generate TIME_MIX for w,k,v,r,g
+# self.time_maa_w1 = nn.Parameter(torch.zeros(args.n_embd, D_MIX_LORA * 5))
+# self.time_maa_w2 = nn.Parameter(
+# torch.zeros(5, D_MIX_LORA, args.n_embd).uniform_(-0.01, 0.01)
+# )
+#
+# # fancy time_decay
+# decay_speed = torch.ones(args.dim_att)
+# for n in range(args.dim_att):
+# decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
+# self.time_decay = nn.Parameter(decay_speed.reshape(1, 1, args.dim_att))
+#
+# D_DECAY_LORA = 64
+# self.time_decay_w1 = nn.Parameter(torch.zeros(args.n_embd, D_DECAY_LORA))
+# self.time_decay_w2 = nn.Parameter(
+# torch.zeros(D_DECAY_LORA, args.dim_att).uniform_(-0.01, 0.01)
+# )
+#
+# tmp = torch.zeros(args.dim_att)
+# for n in range(args.dim_att):
+# zigzag = ((n + 1) % 3 - 1) * 0.1
+# tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
+#
+# self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
+#
+# self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
+# self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
+# self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
+#
+# self.value = nn.Linear(args.n_embd, args.dim_att, bias=False)
+# self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
+# self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False)
+# self.ln_x = nn.GroupNorm(
+# self.n_head, args.dim_att, eps=(1e-5) * (args.head_size_divisor**2)
+# )
+#
+# @MyFunction
+# def jit_func(self, x):
+# B, T, C = x.size()
+#
+# xx = self.time_shift(x) - x
+#
+# xxx = x + xx * self.time_maa_x
+# xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1)
+# xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
+# mw, mk, mv, mr, mg = xxx.unbind(dim=0)
+#
+# xw = x + xx * (self.time_maa_w + mw)
+# xk = x + xx * (self.time_maa_k + mk)
+# xv = x + xx * (self.time_maa_v + mv)
+# xr = x + xx * (self.time_maa_r + mr)
+# xg = x + xx * (self.time_maa_g + mg)
+#
+# r = self.receptance(xr)
+# k = self.key(xk)
+# v = self.value(xv)
+# g = F.silu(self.gate(xg))
+#
+# ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2
+# w = self.time_decay + ww
+#
+# return r, k, v, g, w
+#
+# @MyFunction
+# def jit_func_2(self, x, g):
+# B, T, C = x.size()
+# x = x.view(B * T, C)
+#
+# x = self.ln_x(x).view(B, T, C)
+# x = self.output(x * g)
+# return x
+#
+# def forward(self, x):
+# B, T, C = x.size()
+# H = self.n_head
+#
+# r, k, v, g, w = self.jit_func(x)
+# x = RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u=self.time_faaaa)
+#
+# return self.jit_func_2(x, g)
+
+
+class RWKV_CMix_x052(MyModule):
+ def __init__(self, args, layer_id):
+ super().__init__()
+ self.args = args
+ self.layer_id = layer_id
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
+
+ with torch.no_grad(): # fancy init of time_mix
+ ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
+ ddd = torch.ones(1, 1, args.n_embd)
+ for i in range(args.n_embd):
+ ddd[0, 0, i] = i / args.n_embd
+ self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
+ self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
+
+ self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
+ self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
+ self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
+
+ @MyFunction
+ def forward(self, x):
+ xx = self.time_shift(x)
+ xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
+ xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
+ k = self.key(xk)
+ k = torch.relu(k) ** 2
+ kv = self.value(k)
+ return torch.sigmoid(self.receptance(xr)) * kv
+
+
+# class RWKV_CMix_x060(MyModule):
+# def __init__(self, args, layer_id):
+# super().__init__()
+# self.args = args
+# self.layer_id = layer_id
+# self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
+#
+# with torch.no_grad(): # fancy init of time_mix
+# ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
+# ddd = torch.ones(1, 1, args.n_embd)
+# for i in range(args.n_embd):
+# ddd[0, 0, i] = i / args.n_embd
+# self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
+# self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
+#
+# self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
+# self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
+# self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
+#
+# @MyFunction
+# def forward(self, x):
+# xx = self.time_shift(x) - x
+# xk = x + xx * self.time_maa_k
+# xr = x + xx * self.time_maa_r
+#
+# k = self.key(xk)
+# k = torch.relu(k) ** 2
+# kv = self.value(k)
+# return torch.sigmoid(self.receptance(xr)) * kv
+
+
+# class Block(nn.Module):
+# def __init__(self, args, layer_id):
+# super().__init__()
+# self.args = args
+# self.layer_id = layer_id
+#
+# self.ln1 = nn.LayerNorm(args.n_embd)
+# self.ln2 = nn.LayerNorm(args.n_embd)
+#
+# if self.layer_id == 0:
+# self.ln0 = nn.LayerNorm(args.n_embd)
+#
+# self.att = RWKV_Tmix_x060(args, layer_id)
+#
+# self.ffn = RWKV_CMix_x060(args, layer_id)
+#
+# if args.dropout > 0:
+# self.drop0 = nn.Dropout(p=args.dropout)
+# self.drop1 = nn.Dropout(p=args.dropout)
+#
+# def forward(self, x, x_emb=None):
+# args = self.args
+# B, T, C = x.size()
+# if self.layer_id == 0:
+# x = self.ln0(x)
+#
+# if self.args.dropout == 0:
+# if self.layer_id == 0 and args.pre_ffn > 0:
+# x = x + self.ffnPre(self.ln1(x))
+# else:
+# x = x + self.att(self.ln1(x))
+# x = x + self.ffn(self.ln2(x))
+# else:
+# if self.layer_id == 0 and args.pre_ffn > 0:
+# x = self.drop0(x + self.ffnPre(self.ln1(x)))
+# else:
+# x = self.drop0(x + self.att(self.ln1(x)))
+# x = self.drop1(x + self.ffn(self.ln2(x)))
+#
+# return x
+
+
+class RWKVLayer(nn.Module):
+ def __init__(self, args, layer_id):
+ super().__init__()
+ self.args = args
+ self.layer_id = layer_id
+ if args.dim_ffn is None:
+ args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32)
+ self.ln0 = None
+ if self.layer_id == 0 and args.get("ln0", True):
+ self.ln0 = nn.LayerNorm(args.n_embd)
+
+ self.ln1 = None
+ if args.get("ln1", True):
+ self.ln1 = nn.LayerNorm(args.n_embd)
+
+ self.att = RWKV_Tmix_x052(args, layer_id)
+ self.ln2 = None
+ self.ffn = None
+ if args.get("use_rwkv_ffn", True):
+ self.ln2 = nn.LayerNorm(args.n_embd)
+ self.ffn = RWKV_CMix_x052(args, layer_id)
+
+ if args.dropout > 0:
+ self.drop0 = nn.Dropout(p=args.dropout)
+ self.drop1 = nn.Dropout(p=args.dropout)
+
+ # init
+ if args.get("init_rwkv", True):
+ print("init_rwkv")
+ nn.init.orthogonal_(self.att.receptance.weight, gain=1)
+ nn.init.orthogonal_(self.att.key.weight, gain=0.1)
+ nn.init.orthogonal_(self.att.value.weight, gain=1)
+ nn.init.orthogonal_(self.att.gate.weight, gain=0.1)
+ nn.init.zeros_(self.att.output.weight)
+
+ nn.init.orthogonal_(self.ffn.key.weight, gain=1)
+ nn.init.zeros_(self.ffn.value.weight)
+ nn.init.zeros_(self.ffn.receptance.weight)
+ scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
+ nn.init.constant_(self.ln2.weight, scale)
+ if self.ln0 is not None:
+ nn.init.constant_(self.ln0.weight, scale)
+ if self.ln1 is not None:
+ nn.init.constant_(self.ln1.weight, scale)
+
+ def forward(self, x, x_emb=None, mask=None, **kwargs):
+
+ args = self.args
+ if args.get("datatype", "bf16") == "bf16":
+ x = x.bfloat16()
+ B, T, C = x.size()
+ if self.layer_id == 0 and self.ln0 is not None:
+ x = self.ln0(x)
+
+ if self.args.dropout == 0:
+ if self.ln1 is None:
+ x = x + self.att(x)
+ else:
+ x = x + self.att(self.ln1(x))
+ if self.ffn is not None:
+ x = x + self.ffn(self.ln2(x))
+ else:
+ if self.ln1 is None:
+ x = self.drop0(x + self.att(x))
+ else:
+ x = self.drop0(x + self.att(self.ln1(x)))
+ if self.ffn is not None:
+ x = self.drop1(x + self.ffn(self.ln2(x)))
+
+ if args.get("datatype", "bf16") == "bf16":
+ x = x.to(torch.float32)
+ return x
+
+
+# class RWKV(nn.Module):
+# def __init__(self, args):
+# super().__init__()
+# self.args = args
+# if not hasattr(args, "dim_att"):
+# args.dim_att = args.n_embd
+# if not hasattr(args, "dim_ffn"):
+# if "-f4" in os.environ["RWKV_MY_TESTING"]:
+# args.dim_ffn = int((args.n_embd * 4) // 32 * 32)
+# else:
+# args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32) # default = 3.5x emb size
+# if not hasattr(args, "tiny_att_layer"):
+# args.tiny_att_layer = -1
+# if not hasattr(args, "tiny_att_dim"):
+# args.tiny_att_dim = -1
+# assert args.n_embd % 32 == 0
+# assert args.dim_att % 32 == 0
+# assert args.dim_ffn % 32 == 0
+#
+# self.emb = nn.Embedding(args.vocab_size, args.n_embd)
+#
+# self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
+#
+# self.ln_out = nn.LayerNorm(args.n_embd)
+# self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
+#
+# if args.dropout > 0:
+# self.drop0 = nn.Dropout(p=args.dropout)
+#
+# def forward(self, idx):
+# args = self.args
+# B, T = idx.size()
+# assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
+#
+# x = self.emb(idx)
+# x_emb = x
+#
+# if args.dropout > 0:
+# x = self.drop0(x)
+# if args.tiny_att_dim > 0:
+# for block in self.blocks:
+# if args.grad_cp == 1:
+# x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
+# else:
+# x = block(x, x_emb)
+# else:
+# for block in self.blocks:
+# if args.grad_cp == 1:
+# x = deepspeed.checkpointing.checkpoint(block, x)
+# else:
+# x = block(x)
+#
+# x = self.ln_out(x)
+#
+# if args.head_qk > 0:
+# q = self.head_q(x)[:, :T, :]
+# k = self.head_k(x)[:, :T, :]
+# c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
+# c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
+#
+# if "32" in os.environ["RWKV_FLOAT_MODE"]:
+# c = c @ F.one_hot(idx, num_classes=args.vocab_size)
+# elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
+# c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
+# elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
+# c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
+#
+# x = self.head(x) + c
+# else:
+# x = self.head(x)
+#
+# return x
diff --git a/funasr/models/sense_voice/rwkv_v6.py b/funasr/models/sense_voice/rwkv_v6.py
index 36269a1..b91d47a 100644
--- a/funasr/models/sense_voice/rwkv_v6.py
+++ b/funasr/models/sense_voice/rwkv_v6.py
@@ -244,7 +244,7 @@
x = self.output(x * g)
return x
- def forward(self, x):
+ def forward(self, x, **kwargs):
B, T, C = x.size()
H = self.n_head
@@ -341,11 +341,14 @@
self.ln1 = None
if args.get("ln1", True):
self.ln1 = nn.LayerNorm(args.n_embd)
- self.ln2 = nn.LayerNorm(args.n_embd)
self.att = RWKV_Tmix_x060(args, layer_id)
- self.ffn = RWKV_CMix_x060(args, layer_id)
+ self.ln2 = None
+ self.ffn = None
+ if args.get("use_rwkv_ffn", True):
+ self.ln2 = nn.LayerNorm(args.n_embd)
+ self.ffn = RWKV_CMix_x060(args, layer_id)
if args.dropout > 0:
self.drop0 = nn.Dropout(p=args.dropout)
@@ -364,11 +367,13 @@
nn.init.zeros_(self.ffn.value.weight)
nn.init.zeros_(self.ffn.receptance.weight)
scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
- nn.init.constant_(self.ln2.weight, scale)
+
if self.ln0 is not None:
nn.init.constant_(self.ln0.weight, scale)
if self.ln1 is not None:
nn.init.constant_(self.ln1.weight, scale)
+ if self.ln2 is not None:
+ nn.init.constant_(self.ln2.weight, scale)
def forward(self, x, x_emb=None, mask=None, **kwargs):
@@ -384,13 +389,15 @@
x = x + self.att(x)
else:
x = x + self.att(self.ln1(x))
- x = x + self.ffn(self.ln2(x))
+ if self.ffn is not None:
+ x = x + self.ffn(self.ln2(x))
else:
if self.ln1 is None:
x = self.drop0(x + self.att(x))
else:
x = self.drop0(x + self.att(self.ln1(x)))
- x = self.drop1(x + self.ffn(self.ln2(x)))
+ if self.ffn is not None:
+ x = self.drop1(x + self.ffn(self.ln2(x)))
if args.get("datatype", "bf16") == "bf16":
x = x.to(torch.float32)
diff --git a/funasr/train_utils/model_summary.py b/funasr/train_utils/model_summary.py
index 2aef88a..4e92a33 100644
--- a/funasr/train_utils/model_summary.py
+++ b/funasr/train_utils/model_summary.py
@@ -47,6 +47,8 @@
def model_summary(model: torch.nn.Module) -> str:
message = "Model structure:\n"
message += str(model)
+ # for p in model.parameters():
+ # print(f"{p.numel()}")
tot_params = sum(p.numel() for p in model.parameters())
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
percent_trainable = "{:.1f}".format(num_params * 100.0 / tot_params)
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 8f20ba4..66f8778 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -15,6 +15,11 @@
from funasr.train_utils.average_nbest_models import average_checkpoints
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+try:
+ import wandb
+except:
+ wandb = None
+
@contextmanager
def maybe_autocast(enabled):
@@ -107,9 +112,22 @@
self.best_step_or_epoch = ""
self.val_acc_step_or_eoch = {}
self.val_loss_step_or_eoch = {}
-
- self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
+ self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
+ self.start_data_split_i = 0
+ self.start_step = 0
+ self.use_wandb = kwargs.get("use_wandb", False)
+ if self.use_wandb:
+ wandb.login(key=kwargs.get("wandb_token"))
+ wandb.init(
+ config=kwargs,
+ project=kwargs.get("wandb_project", "my_project"),
+ entity=kwargs.get("wandb_team", "my_team"),
+ name=kwargs.get("wandb_exp_name", "my_exp"),
+ dir=output_dir,
+ job_type="training",
+ reinit=True,
+ )
def save_checkpoint(
self,
@@ -142,6 +160,7 @@
"val_loss_step_or_eoch": self.val_loss_step_or_eoch,
"best_step_or_epoch": self.best_step_or_epoch,
"avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
+ "step": step,
}
if hasattr(model, "module"):
state["state_dict"] = model.module.state_dict()
@@ -268,6 +287,13 @@
self.best_step_or_epoch = (
checkpoint["best_step_or_epoch"] if "best_step_or_epoch" in checkpoint else ""
)
+ self.start_data_split_i = (
+ checkpoint["start_data_split_i"] if "start_data_split_i" in checkpoint else 0
+ )
+ self.batch_total = checkpoint["batch_total"] if "batch_total" in checkpoint else 0
+ self.start_step = checkpoint["step"] if "step" in checkpoint else 0
+ self.start_step = 0 if self.start_step is None else self.start_step
+
model.to(self.device)
print(f"Checkpoint loaded successfully from '{ckpt}'")
else:
@@ -327,7 +353,7 @@
time2 = time.perf_counter()
with maybe_autocast(self.use_fp16):
retval = model(**batch)
-
+
if (
self.reset_gpu_cache
and (torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024) > 70
@@ -598,7 +624,7 @@
acc_avg_epoch = getattr(self, f"{tag}_acc_avg")
description = (
f"{tag}, "
- f"rank: {self.local_rank}, "
+ f"rank: {self.rank}, "
f"epoch: {epoch}/{self.max_epoch}, "
f"data_slice: {data_split_i}/{data_split_num}, "
f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
@@ -613,18 +639,29 @@
)
logging.info(description)
+ description_dict = {
+ f"rank{self.rank}_loss/{tag}": loss,
+ f"rank{self.rank}_lr/{tag}": lr,
+ }
+
if writer is not None:
- writer.add_scalar(f"rank{self.local_rank}_loss/{tag}", loss, self.batch_total)
- writer.add_scalar(f"rank{self.local_rank}_lr/{tag}", lr, self.batch_total)
- writer.add_scalar(f"rank{self.local_rank}_lr/{tag}", lr, self.batch_total)
+ writer.add_scalar(f"rank{self.rank}_loss/{tag}", loss, self.batch_total)
+ writer.add_scalar(f"rank{self.rank}_lr/{tag}", lr, self.batch_total)
for key, var in stats.items():
writer.add_scalar(
- f"stats_rank{self.local_rank}_{key}/{tag}", var.item(), self.batch_total
+ f"stats_rank{self.rank}_{key}/{tag}", var.item(), self.batch_total
)
+ description_dict[f"stats_rank{self.rank}_{key}/{tag}"] = var.item()
for key, var in speed_stats.items():
writer.add_scalar(
- f"stats_rank{self.local_rank}_{key}/{tag}", eval(var), self.batch_total
+ f"stats_rank{self.rank}_{key}/{tag}", eval(var), self.batch_total
)
+ description_dict[f"stats_rank{self.rank}_{key}/{tag}"] = eval(var)
+ if self.use_wandb and wandb is not None:
+ wandb.log(
+ description_dict,
+ setp=self.batch_total,
+ )
def close(self, writer=None):
diff --git a/wandb/debug-internal.log b/wandb/debug-internal.log
new file mode 120000
index 0000000..e3e213d
--- /dev/null
+++ b/wandb/debug-internal.log
@@ -0,0 +1 @@
+run-20240425_211446-lkqptn01/logs/debug-internal.log
\ No newline at end of file
diff --git a/wandb/debug.log b/wandb/debug.log
new file mode 120000
index 0000000..826b402
--- /dev/null
+++ b/wandb/debug.log
@@ -0,0 +1 @@
+run-20240425_211446-lkqptn01/logs/debug.log
\ No newline at end of file
diff --git a/wandb/latest-run b/wandb/latest-run
new file mode 120000
index 0000000..7e88449
--- /dev/null
+++ b/wandb/latest-run
@@ -0,0 +1 @@
+run-20240425_211446-lkqptn01
\ No newline at end of file
diff --git a/wandb/run-20240425_211446-lkqptn01/files/config.yaml b/wandb/run-20240425_211446-lkqptn01/files/config.yaml
new file mode 100644
index 0000000..b131e2e
--- /dev/null
+++ b/wandb/run-20240425_211446-lkqptn01/files/config.yaml
@@ -0,0 +1,26 @@
+wandb_version: 1
+
+a:
+ desc: null
+ value: 1
+_wandb:
+ desc: null
+ value:
+ python_version: 3.8.15
+ cli_version: 0.16.6
+ is_jupyter_run: false
+ is_kaggle_kernel: false
+ start_time: 1714050886.0
+ t:
+ 1:
+ - 55
+ - 105
+ 3:
+ - 13
+ - 16
+ - 23
+ 4: 3.8.15
+ 5: 0.16.6
+ 8:
+ - 5
+ 13: darwin-arm64
diff --git a/wandb/run-20240425_211446-lkqptn01/logs/debug-internal.log b/wandb/run-20240425_211446-lkqptn01/logs/debug-internal.log
new file mode 100644
index 0000000..4d63ace
--- /dev/null
+++ b/wandb/run-20240425_211446-lkqptn01/logs/debug-internal.log
@@ -0,0 +1,146 @@
+2024-04-25 21:14:46,163 INFO StreamThr :90024 [internal.py:wandb_internal():86] W&B internal server running at pid: 90024, started at: 2024-04-25 21:14:46.156788
+2024-04-25 21:14:46,164 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status
+2024-04-25 21:14:46,172 INFO WriterThread:90024 [datastore.py:open_for_write():87] open: ./wandb/run-20240425_211446-lkqptn01/run-lkqptn01.wandb
+2024-04-25 21:14:46,173 DEBUG SenderThread:90024 [sender.py:send():379] send: header
+2024-04-25 21:14:46,284 DEBUG SenderThread:90024 [sender.py:send():379] send: run
+2024-04-25 21:14:46,828 ERROR SenderThread:90024 [internal_api.py:execute():373] 403 response executing GraphQL.
+2024-04-25 21:14:46,828 ERROR SenderThread:90024 [internal_api.py:execute():374] {"errors":[{"message":"permission denied","path":["upsertBucket"],"extensions":{"code":"PERMISSION_ERROR"}}],"data":{"upsertBucket":null}}
+2024-04-25 21:14:46,828 ERROR SenderThread:90024 [sender.py:send_run():971] It appears that you do not have permission to access the requested resource. Please reach out to the project owner to grant you access. If you have the correct permissions, verify that there are no issues with your networking setup.(Error 403: Forbidden)
+Traceback (most recent call last):
+ File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/lib/retry.py", line 131, in __call__
+ result = self._call_fn(*args, **kwargs)
+ File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/internal/internal_api.py", line 369, in execute
+ return self.client.execute(*args, **kwargs) # type: ignore
+ File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/client.py", line 52, in execute
+ result = self._get_result(document, *args, **kwargs)
+ File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/client.py", line 60, in _get_result
+ return self.transport.execute(document, *args, **kwargs)
+ File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/lib/gql_request.py", line 59, in execute
+ request.raise_for_status()
+ File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/requests/models.py", line 1021, in raise_for_status
+ raise HTTPError(http_error_msg, response=self)
+requests.exceptions.HTTPError: 403 Client Error: Forbidden for url: https://api.wandb.ai/graphql
+
+During handling of the above exception, another exception occurred:
+
+Traceback (most recent call last):
+ File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/internal/sender.py", line 969, in send_run
+ server_run = self._init_run(run, config_value_dict)
+ File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/internal/sender.py", line 1014, in _init_run
+ server_run, inserted, server_messages = self._api.upsert_run(
+ File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/apis/normalize.py", line 73, in wrapper
+ raise err
+ File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/apis/normalize.py", line 41, in wrapper
+ return func(*args, **kwargs)
+ File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/internal/internal_api.py", line 2217, in upsert_run
+ response = self.gql(
+ File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/internal/internal_api.py", line 341, in gql
+ ret = self._retry_gql(
+ File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/lib/retry.py", line 147, in __call__
+ retry_timedelta_triggered = check_retry_fn(e)
+ File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/util.py", line 968, in check_retry_fn
+ return fallback_retry_fn(e)
+ File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/util.py", line 910, in no_retry_auth
+ raise CommError(
+wandb.errors.CommError: It appears that you do not have permission to access the requested resource. Please reach out to the project owner to grant you access. If you have the correct permissions, verify that there are no issues with your networking setup.(Error 403: Forbidden)
+2024-04-25 21:14:51,848 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:14:56,865 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:15:01,882 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:15:06,910 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:15:11,930 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:15:16,954 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:15:21,968 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:15:26,992 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:15:32,007 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:15:37,026 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:15:42,047 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:15:47,060 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:15:52,088 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:15:57,107 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:16:02,125 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:16:07,154 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:16:12,184 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:16:17,193 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:16:22,218 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:16:27,236 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:16:32,248 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:16:37,267 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:16:42,285 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:16:47,302 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:16:52,323 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:16:57,345 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:17:02,359 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:17:07,370 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:17:12,391 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:17:17,410 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:17:22,427 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:17:27,449 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:17:32,472 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:17:37,492 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:17:42,514 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:17:47,529 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:17:52,547 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:17:57,565 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:18:02,584 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:18:07,605 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:18:12,630 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:18:17,649 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:18:22,670 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:18:27,689 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:18:32,706 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:18:37,721 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:18:42,738 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:18:47,760 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:18:52,776 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:18:57,797 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:19:02,816 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:19:07,831 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:19:12,849 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:19:17,886 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:19:22,903 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:19:27,914 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:19:32,935 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:19:37,956 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:19:42,977 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:19:47,997 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:19:53,013 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:19:58,037 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:20:03,054 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:20:08,074 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:20:13,095 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:20:18,115 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:20:23,136 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:20:28,158 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:20:33,175 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:20:38,195 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:20:43,216 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:20:48,237 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:20:53,254 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:20:58,276 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:21:03,293 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:21:08,306 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:21:13,328 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:21:18,349 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:21:23,372 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:21:28,393 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:21:33,418 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:21:38,428 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:21:43,452 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:21:48,475 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:21:53,493 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:21:58,512 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:22:03,530 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:22:08,546 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:22:13,563 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:22:18,586 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:22:23,604 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:22:28,621 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:22:33,641 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:22:38,688 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:22:43,704 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:22:48,722 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:22:53,742 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:22:58,756 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:23:03,775 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:23:08,798 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:23:13,815 DEBUG HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
diff --git a/wandb/run-20240425_211446-lkqptn01/logs/debug.log b/wandb/run-20240425_211446-lkqptn01/logs/debug.log
new file mode 100644
index 0000000..0f4231c
--- /dev/null
+++ b/wandb/run-20240425_211446-lkqptn01/logs/debug.log
@@ -0,0 +1,29 @@
+2024-04-25 21:14:46,147 INFO MainThread:89997 [wandb_setup.py:_flush():76] Current SDK version is 0.16.6
+2024-04-25 21:14:46,147 INFO MainThread:89997 [wandb_setup.py:_flush():76] Configure stats pid to 89997
+2024-04-25 21:14:46,147 INFO MainThread:89997 [wandb_setup.py:_flush():76] Loading settings from /Users/zhifu/.config/wandb/settings
+2024-04-25 21:14:46,147 INFO MainThread:89997 [wandb_setup.py:_flush():76] Loading settings from /Users/zhifu/funasr1.0/wandb/settings
+2024-04-25 21:14:46,147 INFO MainThread:89997 [wandb_setup.py:_flush():76] Loading settings from environment variables: {}
+2024-04-25 21:14:46,147 WARNING MainThread:89997 [wandb_setup.py:_flush():76] Could not find program at <input>
+2024-04-25 21:14:46,147 INFO MainThread:89997 [wandb_setup.py:_flush():76] Inferring run settings from compute environment: {'program_relpath': None, 'program': '<input>'}
+2024-04-25 21:14:46,147 INFO MainThread:89997 [wandb_setup.py:_flush():76] Applying login settings: {'api_key': '***REDACTED***'}
+2024-04-25 21:14:46,147 INFO MainThread:89997 [wandb_setup.py:_flush():76] Applying login settings: {'api_key': '***REDACTED***'}
+2024-04-25 21:14:46,147 INFO MainThread:89997 [wandb_setup.py:_flush():76] Applying login settings: {}
+2024-04-25 21:14:46,147 INFO MainThread:89997 [wandb_init.py:_log_setup():521] Logging user logs to ./wandb/run-20240425_211446-lkqptn01/logs/debug.log
+2024-04-25 21:14:46,147 INFO MainThread:89997 [wandb_init.py:_log_setup():522] Logging internal logs to ./wandb/run-20240425_211446-lkqptn01/logs/debug-internal.log
+2024-04-25 21:14:46,147 INFO MainThread:89997 [wandb_init.py:init():561] calling init triggers
+2024-04-25 21:14:46,148 INFO MainThread:89997 [wandb_init.py:init():568] wandb.init called with sweep_config: {}
+config: {'a': 1}
+2024-04-25 21:14:46,148 INFO MainThread:89997 [wandb_init.py:init():611] starting backend
+2024-04-25 21:14:46,148 INFO MainThread:89997 [wandb_init.py:init():615] setting up manager
+2024-04-25 21:14:46,153 INFO MainThread:89997 [backend.py:_multiprocessing_setup():105] multiprocessing start_methods=spawn,fork,forkserver, using: spawn
+2024-04-25 21:14:46,156 INFO MainThread:89997 [wandb_init.py:init():623] backend started and connected
+2024-04-25 21:14:46,171 INFO MainThread:89997 [wandb_init.py:init():715] updated telemetry
+2024-04-25 21:14:46,282 INFO MainThread:89997 [wandb_init.py:init():748] communicating run to backend with 90.0 second timeout
+2024-04-25 21:14:46,836 ERROR MainThread:89997 [wandb_init.py:init():774] encountered error: It appears that you do not have permission to access the requested resource. Please reach out to the project owner to grant you access. If you have the correct permissions, verify that there are no issues with your networking setup.(Error 403: Forbidden)
+2024-04-25 21:14:48,147 ERROR MainThread:89997 [wandb_init.py:init():1199] It appears that you do not have permission to access the requested resource. Please reach out to the project owner to grant you access. If you have the correct permissions, verify that there are no issues with your networking setup.(Error 403: Forbidden)
+Traceback (most recent call last):
+ File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/wandb_init.py", line 1181, in init
+ run = wi.init()
+ File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/wandb_init.py", line 780, in init
+ raise error
+wandb.errors.CommError: It appears that you do not have permission to access the requested resource. Please reach out to the project owner to grant you access. If you have the correct permissions, verify that there are no issues with your networking setup.(Error 403: Forbidden)
diff --git a/wandb/run-20240425_211446-lkqptn01/run-lkqptn01.wandb b/wandb/run-20240425_211446-lkqptn01/run-lkqptn01.wandb
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/wandb/run-20240425_211446-lkqptn01/run-lkqptn01.wandb
--
Gitblit v1.9.1