funasr/models/ct_transformer/model.py
@@ -287,9 +287,7 @@ # y, _ = self.wrapped_model(**data) y, _ = self.punc_forward(**data) _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1) punctuations = indices if indices.size()[0] != 1: punctuations = torch.squeeze(indices) punctuations = torch.squeeze(indices, dim=1) assert punctuations.size()[0] == len(mini_sentence) # Search for the last Period/QuestionMark as cache runtime/java/java_http2ws_src/http/src/Readme.md
New file @@ -0,0 +1,14 @@ dependencies { implementation("org.springframework.boot:spring-boot-starter-web") implementation("org.json:json:20240303") implementation("org.springframework.boot:spring-boot-starter-websocket") } 使用接口测试工具 form-data格式传入文件 返回测试成功即运行成功 默认访问路径: io路径: http://localhost:8081/recognition/testIO nio路径: http://localhost:8081/recognition/testNIO application.yml中可根据自身需要修改对应模型参数 runtime/java/java_http2ws_src/http/src/main/java/com/example/funasr_java_client/FunasrJavaClientApplication.java
New file @@ -0,0 +1,20 @@ package com.example.funasr_java_client; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; /** * * @author Virgil Qiu * @since 2024/04/24 * */ @SpringBootApplication public class FunasrJavaClientApplication { public static void main(String[] args) { SpringApplication.run(FunasrJavaClientApplication.class, args); } } runtime/java/java_http2ws_src/http/src/main/java/com/example/funasr_java_client/RecognitionController.java
New file @@ -0,0 +1,36 @@ package com.example.funasr_java_client.Servcvice; import org.springframework.http.ResponseEntity; import org.springframework.web.bind.annotation.*; import org.springframework.web.multipart.MultipartFile; import java.io.IOException; import java.util.concurrent.ExecutionException; /** * * @author Virgil Qiu * @since 2024/04/24 * */ @RestController @RequestMapping("/recognition") public class RecognitionController { private final RecognitionService recognitionService; public RecognitionController(RecognitionService recognitionService) { this.recognitionService = recognitionService; } @PostMapping("/testNIO") public String testIO(@RequestParam MultipartFile file) throws IOException, ExecutionException, InterruptedException { recognitionService.recognition(file); return "测试成功"; } @PostMapping("/testIO") public String testNIO(@RequestParam MultipartFile file) throws IOException, ExecutionException, InterruptedException { recognitionService.recognition(file); return "测试成功"; } } runtime/java/java_http2ws_src/http/src/main/java/com/example/funasr_java_client/Servcvice/RecognitionService.java
New file @@ -0,0 +1,19 @@ package com.example.funasr_java_client.Servcvice; import org.springframework.web.multipart.MultipartFile; import java.io.IOException; import java.util.concurrent.ExecutionException; /** * * @author Virgil Qiu * @since 2024/04/24 * */ public interface RecognitionService { Object recognition(MultipartFile file) throws IOException, ExecutionException, InterruptedException; } runtime/java/java_http2ws_src/http/src/main/java/com/example/funasr_java_client/Servcvice/RecognitionService2.java
New file @@ -0,0 +1,18 @@ package com.example.funasr_java_client.Servcvice; import org.springframework.web.multipart.MultipartFile; import java.io.IOException; import java.util.concurrent.ExecutionException; /** * * @author Virgil Qiu * @since 2024/04/24 * */ public interface RecognitionService2 { Object recognition(MultipartFile file) throws IOException, ExecutionException, InterruptedException; } runtime/java/java_http2ws_src/http/src/main/java/com/example/funasr_java_client/Servcvice/impl/RecognitionServiceImpl.java
New file @@ -0,0 +1,100 @@ package com.example.funasr_java_client.Servcvice.impl; import com.example.funasr_java_client.Servcvice.RecognitionService; import com.example.funasr_java_client.WebSocketClient; import org.json.JSONObject; import org.springframework.stereotype.Service; import org.springframework.web.multipart.MultipartFile; import org.springframework.web.socket.BinaryMessage; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.client.standard.StandardWebSocketClient; import java.io.File; import java.io.IOException; import java.net.URI; import java.nio.ByteBuffer; import java.nio.file.Files; import java.nio.file.Paths; import java.util.concurrent.ExecutionException; /** * * @author Virgil Qiu * @since 2024/04/24 * */ @Service public class RecognitionServiceImpl implements RecognitionService { @Value("${parameters.fileUrl}") private String fileUrl; @Value("${parameters.model}") private String model; @Value("${parameters.hotWords}") private String hotWords; @Value("${parameters.serverIpPort}") private String serverIpPort; @Override public Object recognition(MultipartFile file) throws IOException, ExecutionException, InterruptedException { if (file.isEmpty()) { return "0"; // 文件为空,返回特殊值 } String originalFilename = file.getOriginalFilename(); String[] parts = originalFilename.split("\\."); String prefix = (parts.length > 0) ? parts[0] : originalFilename; System.out.println(prefix); String localFilePath = fileUrl + prefix + ".pcm"; File localFile = new File(localFilePath); File destDir = localFile.getParentFile(); if (!destDir.exists() && !destDir.mkdirs()) { throw new IOException("Unable to create destination directory: " + localFilePath); } file.transferTo(localFile); WebSocketClient client = new WebSocketClient(); URI uri = URI.create(serverIpPort); StandardWebSocketClient standardWebSocketClient = new StandardWebSocketClient(); WebSocketSession webSocketSession = standardWebSocketClient.execute(client, null, uri).get(); JSONObject configJson = new JSONObject(); configJson.put("mode", model); configJson.put("wav_name", prefix); configJson.put("wav_format", "pcm"); // 文件格式为pcm configJson.put("is_speaking", true); configJson.put("hotwords", hotWords"); configJson.put("itn", true); // 发送配置参数与meta信息 webSocketSession.sendMessage(new TextMessage(configJson.toString())); byte[] audioData; try { audioData = Files.readAllBytes(Paths.get(localFilePath)); } catch (IOException e) { System.err.println("Error reading file: " + e.getMessage()); e.printStackTrace(); return "Error reading audio file"; // Return an appropriate error message or throw an exception } ByteBuffer audioByteBuffer = ByteBuffer.wrap(audioData); BinaryMessage binaryMessage = new BinaryMessage(audioByteBuffer); webSocketSession.sendMessage(binaryMessage); // 发送音频结束标志 JSONObject endMarkerJson = new JSONObject(); endMarkerJson.put("is_speaking", false); webSocketSession.sendMessage(new TextMessage(endMarkerJson.toString())); // TODO: 实现接收并处理服务端返回的识别结果 return "test"; } } runtime/java/java_http2ws_src/http/src/main/java/com/example/funasr_java_client/Servcvice/impl/RecognitionServiceImpl2.java
New file @@ -0,0 +1,112 @@ package com.example.funasr_java_client.Servcvice.impl; import com.example.funasr_java_client.Servcvice.RecognitionService; import com.example.funasr_java_client.Servcvice.RecognitionService2; import com.example.funasr_java_client.WebSocketClient; import org.json.JSONObject; import org.springframework.stereotype.Service; import org.springframework.web.multipart.MultipartFile; import org.springframework.web.socket.BinaryMessage; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.client.standard.StandardWebSocketClient; import java.io.ByteArrayOutputStream; import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.net.URI; import java.nio.ByteBuffer; import java.nio.file.Files; import java.nio.file.Paths; import java.util.concurrent.ExecutionException; /** * * @author Virgil Qiu * @since 2024/04/24 * */ @Service public class RecognitionServiceImpl2 implements RecognitionService2 { @Value("${parameters.fileUrl}") private String fileUrl; @Value("${parameters.model}") private String model; @Value("${parameters.hotWords}") private String hotWords; @Value("${parameters.serverIpPort}") private String serverIpPort; @Override public Object recognition(MultipartFile file) throws IOException, ExecutionException, InterruptedException { if (file.isEmpty()) { return "0"; // 文件为空,返回特殊值 } String originalFilename = file.getOriginalFilename(); String[] parts = originalFilename.split("\\."); String prefix = (parts.length > 0) ? parts[0] : originalFilename; System.out.println(prefix); String localFilePath = fileUrl + prefix + ".pcm"; File localFile = new File(localFilePath); File destDir = localFile.getParentFile(); if (!destDir.exists() && !destDir.mkdirs()) { throw new IOException("Unable to create destination directory: " + localFilePath); } file.transferTo(localFile); WebSocketClient client = new WebSocketClient(); URI uri = URI.create(serverIpPort); StandardWebSocketClient standardWebSocketClient = new StandardWebSocketClient(); WebSocketSession webSocketSession = standardWebSocketClient.execute(client, null, uri).get(); JSONObject configJson = new JSONObject(); configJson.put("mode", model); configJson.put("wav_name", prefix); configJson.put("wav_format", "pcm"); // 文件格式为pcm configJson.put("is_speaking", true); configJson.put("hotwords", hotWords); configJson.put("itn", true); // 发送配置参数与meta信息 webSocketSession.sendMessage(new TextMessage(configJson.toString())); try (FileInputStream fis = new FileInputStream(localFilePath)) { ByteArrayOutputStream baos = new ByteArrayOutputStream(); byte[] buffer = new byte[1024]; int bytesRead; while ((bytesRead = fis.read(buffer)) != -1) { baos.write(buffer, 0, bytesRead); } // 将所有读取的字节合并到一个字节数组中 byte[] completeData = baos.toByteArray(); // 使用字节数组创建BinaryMessage实例 BinaryMessage binaryMessage = new BinaryMessage(completeData); webSocketSession.sendMessage(binaryMessage); // 使用或发送binaryMessage... } catch (IOException e) { System.err.println("Error reading file: " + e.getMessage()); e.printStackTrace(); } // 发送音频结束标志 JSONObject endMarkerJson = new JSONObject(); endMarkerJson.put("is_speaking", false); webSocketSession.sendMessage(new TextMessage(endMarkerJson.toString())); // TODO: 实现接收并处理服务端返回的识别结果 return "test"; } } runtime/java/java_http2ws_src/http/src/main/java/com/example/funasr_java_client/WebSocketClient.java
New file @@ -0,0 +1,63 @@ package com.example.funasr_java_client; import org.springframework.stereotype.Component; import org.springframework.web.socket.*; /** * * @author Virgil Qiu * @since 2024/04/24 * */ @Component public class WebSocketClient implements WebSocketHandler { private WebSocketSession session; @Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { this.session = session; System.out.println("WebSocket connection established."); } @Override public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception { if (message instanceof TextMessage) { String receivedMessage = ((TextMessage) message).getPayload(); System.out.println("Received message from server: " + receivedMessage); // 在这里处理接收到的消息 } } @Override public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception { System.err.println("WebSocket transport error: " + exception.getMessage()); session.close(CloseStatus.SERVER_ERROR); } @Override public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception { System.out.println("WebSocket connection closed with status: " + status); } @Override public boolean supportsPartialMessages() { return false; } public void sendMessage(String message) { if (session != null && session.isOpen()) { try { session.sendMessage(new TextMessage(message)); System.out.println("Sent message to server: " + message); } catch (Exception e) { e.printStackTrace(); } } else { System.err.println("WebSocket session is not open. Cannot send message."); } } } runtime/java/java_http2ws_src/http/src/main/resources/application.properties
New file @@ -0,0 +1,2 @@ spring.application.name=funasr_java_client server.port=8081 runtime/java/java_http2ws_src/http/src/main/resources/application.yml
New file @@ -0,0 +1,21 @@ #/** # * # * @author Virgil Qiu # * @since 2024/04/24 # * # */ spring: application: name: java_http_client server: port: 8081 parameters: model: "offline" #离线模型为例 hotWords: "{\"自定义\":20,\"热词\":20,\"设置\":30}" fileUrl: "E:/EI/Audio" serverIpPort: "ws://your_funasr_ip:port"