From d3d2fe73c08ee51d3a44d7ffb7b31eff32b60404 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 18 三月 2024 20:46:23 +0800
Subject: [PATCH] wav fronend
---
funasr/frontends/wav_frontend.py | 1 +
funasr/train_utils/trainer.py | 5 +++--
2 files changed, 4 insertions(+), 2 deletions(-)
diff --git a/funasr/frontends/wav_frontend.py b/funasr/frontends/wav_frontend.py
index c6e03e8..afa7421 100644
--- a/funasr/frontends/wav_frontend.py
+++ b/funasr/frontends/wav_frontend.py
@@ -75,6 +75,7 @@
LFR_outputs = torch.vstack(LFR_inputs)
return LFR_outputs.type(torch.float32)
+@tables.register("frontend_classes", "wav_frontend")
@tables.register("frontend_classes", "WavFrontend")
class WavFrontend(nn.Module):
"""Conventional frontend structure for ASR.
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index a00b3de..14abd6c 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -146,7 +146,7 @@
"""
ckpt = os.path.join(resume_path, "model.pt")
if os.path.isfile(ckpt):
- checkpoint = torch.load(ckpt)
+ checkpoint = torch.load(ckpt, map_location="cpu")
self.start_epoch = checkpoint['epoch'] + 1
# self.model.load_state_dict(checkpoint['state_dict'])
src_state = checkpoint['state_dict']
@@ -169,7 +169,8 @@
print(f"Checkpoint loaded successfully from '{ckpt}'")
else:
print(f"No checkpoint found at '{ckpt}', does not resume status!")
-
+
+ self.model.to(self.device)
if self.use_ddp or self.use_fsdp:
dist.barrier()
--
Gitblit v1.9.1