热爱技术,追求卓越
不断求索,精益求精

用java访问本都部署的FunASR实现语音识别,可自由切换到阿里百炼

我们在之前的文章《Docker部署FunASR,简单测试》讲到了关于语音识别模型FunASR的部署。

今天,我们就来看看spring boot项目如何访问本地部署的FunASR。

修改官方开源项目提供的FunasrWsClient

官方的FunasrWsClient中使用的wavPath,我们将相关的注释掉,增加了MultipartFile类型的audio属性:

/**识别的音频流*/
 private MultipartFile audio;

目的是为了实现接收前端http请求传过来的音频流。

同时还加入了如下回调处理callBack:

/**识别回调*/
private ResultCallback<RecognitionResult> callBack;

这个callBack主要是在识别成功、识别错误的情况调用。完整源代码如下:

import java.io.*;
import java.net.URI;
import java.util.Map;
import org.java_websocket.client.WebSocketClient;
import org.java_websocket.drafts.Draft;
import org.java_websocket.handshake.ServerHandshake;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.multipart.MultipartFile;

import com.alibaba.dashscope.audio.asr.recognition.RecognitionResult;
import com.alibaba.dashscope.audio.asr.recognition.timestamp.Sentence;
import com.alibaba.dashscope.common.ResultCallback;
import com.alibaba.fastjson2.JSON;
import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject;

import java.util.regex.Matcher;
import java.util.regex.Pattern;
/** This example demonstrates how to connect to websocket server. */
public class FunasrWsClient extends WebSocketClient {

  public class RecWavThread extends Thread {
    private FunasrWsClient funasrClient;
    private MultipartFile audio;

    public RecWavThread(FunasrWsClient funasrClient, MultipartFile audio) {
      this.funasrClient = funasrClient;
      this.audio = audio;
    }

    public void run() {
      this.funasrClient.recWav(this.audio);
    }
  }

  private static final Logger logger = LoggerFactory.getLogger(FunasrWsClient.class);

  public FunasrWsClient(URI serverUri, Draft draft, MultipartFile audio, ResultCallback<RecognitionResult> callBack) {
    super(serverUri, draft);
    this.audio = audio;
    this.callBack = callBack;
  }

  public FunasrWsClient(URI serverURI, MultipartFile audio, ResultCallback<RecognitionResult> callBack) {
	super(serverURI);
    this.audio = audio;
    this.callBack = callBack;
  }

  public FunasrWsClient(URI serverUri, Map<String, String> httpHeaders, MultipartFile audio, ResultCallback<RecognitionResult> callBack) {
    super(serverUri, httpHeaders);
    this.audio = audio;
    this.callBack = callBack;
  }

  public void getSslContext(String keyfile, String certfile) {
    // TODO
    return;
  }

  // send json at first time
  public void sendJson(
      String mode, String strChunkSize, int chunkInterval, String wavName, boolean isSpeaking,String suffix) {
    try {

      JSONObject obj = new JSONObject();
      obj.put("mode", mode);
      JSONArray array = new JSONArray();
      String[] chunkList = strChunkSize.split(",");
      for (int i = 0; i < chunkList.length; i++) {
        array.add(Integer.valueOf(chunkList[i].trim()));
      }

      obj.put("chunk_size", array);
      obj.put("chunk_interval", chunkInterval);
      obj.put("wav_name", wavName);

	  if(FunasrWsClient.hotwords.trim().length()>0)
	  {
		  String regex = "\\d+";
		  JSONObject jsonitems = new JSONObject();
		  String[] items=FunasrWsClient.hotwords.trim().split(" ");
          Pattern pattern = Pattern.compile(regex);
          String tmpWords="";
		  for(int i=0;i<items.length;i++)
		  {

			  Matcher matcher = pattern.matcher(items[i]);

			  if (matcher.matches()) {

				jsonitems.put(tmpWords.trim(), items[i].trim());
				tmpWords="";
				continue;
			  }
			  tmpWords=tmpWords+items[i]+" ";

		  }



		  obj.put("hotwords", jsonitems.toString());
	  }

	  if(suffix.equals("wav")){
	      suffix="pcm";
	  }
	  obj.put("wav_format", suffix);
      if (isSpeaking) {
        obj.put("is_speaking", true);
      } else {
        obj.put("is_speaking", false);
      }
      logger.info("sendJson: " + obj);
      // return;

      send(obj.toString());

      return;
    } catch (Exception e) {
      e.printStackTrace();
    }
  }

  // send json at end of wav
  public void sendEof() {
    try {
      JSONObject obj = new JSONObject();

      obj.put("is_speaking", false);

      logger.info("sendEof: " + obj);
      // return;

      send(obj.toString());
      iseof = true;
      return;
    } catch (Exception e) {
      e.printStackTrace();
    }
  }

  // function for rec wav file
  public void recWav(MultipartFile audio) {
	//String fileName=FunasrWsClient.wavPath;
	//String suffix=fileName.split("\\.")[fileName.split("\\.").length-1];
    sendJson(mode, strChunkSize, chunkInterval, audio.getName(), true, "wav");
    //File file = new File(FunasrWsClient.wavPath);

    int chunkSize = sendChunkSize;
    byte[] bytes = new byte[chunkSize];

    int readSize = 0;
    try (InputStream fis = audio.getInputStream()) {
//      if (FunasrWsClient.wavPath.endsWith(".wav")) {
//        fis.read(bytes, 0, 44); //skip first 44 wav header
//      }
      fis.read(bytes, 0, 44); //skip first 44 wav header
      
      readSize = fis.read(bytes, 0, chunkSize);
      while (readSize > 0) {
        // send when it is chunk size
        if (readSize == chunkSize) {
          send(bytes); // send buf to server

        } else {
          // send when at last or not is chunk size
          byte[] tmpBytes = new byte[readSize];
          for (int i = 0; i < readSize; i++) {
            tmpBytes[i] = bytes[i];
          }
          send(tmpBytes);
        }
        // if not in offline mode, we simulate online stream by sleep
        if (!mode.equals("offline")) {
          Thread.sleep(Integer.valueOf(chunkSize / 32));
        }

        readSize = fis.read(bytes, 0, chunkSize);
      }

      if (!mode.equals("offline")) {
        // if not offline, we send eof and wait for 3 seconds to close
        Thread.sleep(2000);
        sendEof();
        Thread.sleep(3000);
        close();
      } else {
        // if offline, just send eof
        sendEof();
      }

    } catch (Exception e) {
      e.printStackTrace();
    }
  }

  @Override
  public void onOpen(ServerHandshake handshakedata) {

    RecWavThread thread = new RecWavThread(this, this.audio);
    thread.start();
  }

  @Override
  public void onMessage(String message) {
    JSONObject jsonObject = new JSONObject();
    //JSONParser jsonParser = new JSONParser();
    logger.info("received: " + message);
    try {
    	
      jsonObject = JSON.parseObject(message);
      logger.info("text: " + jsonObject.get("text"));
      if(callBack != null) {
    	  RecognitionResult result = new RecognitionResult();
    	  Sentence st = new Sentence();
    	  st.setSentenceEnd(true);
    	  st.setText((String)jsonObject.get("text"));
    	  result.setSentence(st);
    	  callBack.onEvent(result);
      }
	  if(jsonObject.containsKey("timestamp"))
	  {
		  logger.info("timestamp: " + jsonObject.get("timestamp"));
	  }
    } catch (Exception e) {
      e.printStackTrace();
    }
    if (iseof && mode.equals("offline") && !jsonObject.containsKey("is_final")) {
      close();
    }
	 
    if (iseof && mode.equals("offline") && jsonObject.containsKey("is_final") && jsonObject.get("is_final").equals("false")) {
      close();
    }
  }

  @Override
  public void onClose(int code, String reason, boolean remote) {

    logger.info(
        "Connection closed by "
            + (remote ? "remote peer" : "us")
            + " Code: "
            + code
            + " Reason: "
            + reason);
    if(callBack != null) {
    	callBack.onComplete();
    }
  }

  @Override
  public void onError(Exception ex) {
    logger.info("ex: " + ex);
    ex.printStackTrace();
    // if the error is fatal then onClose will be called additionally
    if(callBack != null) {
    	callBack.onError(ex);
    }
  }

  private boolean iseof = false;
  //public static String wavPath;
  static String mode = "offline";//使用离线,提高识别率
  static String strChunkSize = "5,10,5";
  static int chunkInterval = 10;
  static int sendChunkSize = 1920;
  static String hotwords="";
  static String fsthotwords="";
  /**识别的音频流*/
  private MultipartFile audio;
  /**识别回调*/
  private ResultCallback<RecognitionResult> callBack;
}

以上代码所需依赖:

<!-- 阿里巴巴语音 -->
<dependency>
    <groupId>com.alibaba</groupId>
    <artifactId>dashscope-sdk-java</artifactId>
    <version>2.22.6</version>
</dependency>
<dependency>
  <groupId>net.sourceforge.argparse4j</groupId>
  <artifactId>argparse4j</artifactId>
  <version>0.9.0</version>
</dependency>
<dependency>
    <groupId>org.java-websocket</groupId>
    <artifactId>Java-WebSocket</artifactId>
    <version>1.6.0</version>
</dependency>

调用方式:

/**
 * 使用本地部署的funasr识别
 * @param audio
 * @param callBack
 */
private void recognizeLocal(MultipartFile audio, ResultCallback<RecognitionResult> callBack) {
	try {
		//String wsAddress = "ws://127.0.0.1:10095";
		FunasrWsClient c = new FunasrWsClient(new URI(funasrWsUrl), audio, callBack);
		c.connect();
	} catch (Exception e) {
		log.error("本地Funasr识别异常", e);
		callBack.onError(e);
	}
}

本地部署与阿里百炼动态切换

若一直使用阿里百炼的测试,比较耗费Token,要是本地部署和第三方部署能动态切换就好了。我们写一个FunasrService,增加几个配置:

@Value("${funasr.apiKey:xxx}")//阿里百炼的apiKey配置
private String funasrApiKey;
@Value("${funasr.type:1}")//1本地,2阿里百炼
private String funasrType;
@Value("${funasr.wsUrl:ws}")//本地部署的websocket地址
private String funasrWsUrl;

http接口接收到请求后,调用audio2Text方法即可实现语音识别:

/**
 * 语音识别转文字
 * @param audio
 * @param callback
 */
public void audio2Text(MultipartFile audio, ResultCallback<RecognitionResult> callback) {
	if("1".equals(funasrType)) {//本地识别
		recognizeLocal(audio, callback);
	}else {//阿里百炼识别
		recognize(audio, callback);
	}
}

audio2Text方法中,callback是传入的回调函数。FunasrService完整代码如下:

import java.io.InputStream;
import java.net.URI;
import java.nio.ByteBuffer;
import java.util.function.Consumer;

import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;

import com.alibaba.dashscope.audio.asr.recognition.Recognition;
import com.alibaba.dashscope.audio.asr.recognition.RecognitionParam;
import com.alibaba.dashscope.audio.asr.recognition.RecognitionResult;
import com.alibaba.dashscope.common.ResultCallback;
import lombok.extern.slf4j.Slf4j;

@Slf4j
@Service
public class FunasrService {
	
	@Value("${funasr.apiKey:xxx}")
	private String funasrApiKey;
	@Value("${funasr.type:1}")//1本地,2阿里百炼
	private String funasrType;
	@Value("${funasr.wsUrl:ws}")
	private String funasrWsUrl;
	
	
	static void init() {
		//以下为北京地域url,若使用新加坡地域的模型,需将url替换为:wss://dashscope-intl.aliyuncs.com/api-ws/v1/inference
		com.alibaba.dashscope.utils.Constants.baseWebsocketApiUrl = 
				"wss://dashscope.aliyuncs.com/api-ws/v1/inference";
	}
	
	/**
	 * 语音识别转文字
	 * @param audio
	 * @param callback
	 */
	public void audio2Text(MultipartFile audio, ResultCallback<RecognitionResult> callback) {
		if("1".equals(funasrType)) {
			recognizeLocal(audio, callback);
		}else {
			recognize(audio, callback);
		}
	}
	
	
	/**
	 * 语音识别转文字
	 * @param audio 语音流
	 * @param onSuccess 识别成功处理逻辑
	 * @param onError 识别异常处理逻辑
	 */
	public void audio2Text(Long memoryId, MultipartFile audio, Consumer<String> onSuccess, Consumer<Exception> onError) {
		ResultCallback<RecognitionResult> callback = getCallBack(onSuccess, onError);
		if("1".equals(funasrType)) {
			recognizeLocal(audio, callback);
		}else {
			recognize(audio, callback);
		}
	}
	
	/**
	 * 使用阿里百炼的fun-asr-realtime识别
	 * @param audio
	 * @param callBack
	 */
	private void recognize(MultipartFile audio, ResultCallback<RecognitionResult> callBack) {
		RecognitionParam param = RecognitionParam.builder()
	            .model("fun-asr-realtime")
	            // 新加坡和北京地域的API Key不同。获取API Key:https://help.aliyun.com/zh/model-studio/get-api-key
	            // 若没有配置环境变量,请用百炼API Key将下行替换为:.apiKey("sk-xxx")
	            .apiKey(funasrApiKey)
	            .format("wav")
	            .sampleRate(16000)
	            .build();
		Recognition recognizer = new Recognition();
		
        try {
            recognizer.call(param, callBack);
            // Please replace the path with your audio file path
            //log.info("Input file is: " + audio.getName());
            // Read file and send audio by chunks
            InputStream fis = audio.getInputStream();
            // chunk size set to 1 seconds for 16KHz sample rate
            byte[] buffer = new byte[3200];
            int bytesRead;
            // Loop to read chunks of the file
            while ((bytesRead = fis.read(buffer)) != -1) {
                ByteBuffer byteBuffer;
                // Handle the last chunk which might be smaller than the buffer size
                //log.info("bytesRead: " + bytesRead);
                if (bytesRead < buffer.length) {
                    byteBuffer = ByteBuffer.wrap(buffer, 0, bytesRead);
                } else {
                    byteBuffer = ByteBuffer.wrap(buffer);
                }

                recognizer.sendAudioFrame(byteBuffer);
                buffer = new byte[3200];
                Thread.sleep(100);
            }
            recognizer.stop();
        } catch (Exception e) {
        	log.error("识别异常", e);
        	callBack.onError(e);
        } finally {
            // 任务结束后关闭 Websocket 连接
            recognizer.getDuplexApi().close(1000, "bye");
        }

        log.info("[Metric] requestId: "
                        + recognizer.getLastRequestId()
                        + ", first package delay ms: "
                        + recognizer.getFirstPackageDelay()
                        + ", last package delay ms: "
                        + recognizer.getLastPackageDelay());
	}
	
	/**
	 * 使用本地部署的funasr识别
	 * @param audio
	 * @param callBack
	 */
	private void recognizeLocal(MultipartFile audio, ResultCallback<RecognitionResult> callBack) {
		try {
			
			//String wsAddress = "ws://127.0.0.1:10095";
			FunasrWsClient c = new FunasrWsClient(new URI(funasrWsUrl), audio, callBack);
			c.connect();
		} catch (Exception e) {
			log.error("本地Funasr识别异常", e);
			callBack.onError(e);
		}
	}
	
	
	/**
	 * 获取识别回调
	 * @param onSuccess 识别成功后需要做的
	 * @param onError 识别错误时需要做的
	 * @return
	 */
	public ResultCallback<RecognitionResult> getCallBack(Consumer<String> onSuccess, Consumer<Exception> onError){
		ResultCallback<RecognitionResult> callback = new ResultCallback<RecognitionResult>() {
            @Override
            public void onEvent(RecognitionResult message) {
                if (message.isSentenceEnd()) {
                    log.info("语音识别结果:" + message.getSentence().getText());
                    onSuccess.accept(message.getSentence().getText());
                } else {
                	log.info("语音识别中间结果: " + message.getSentence().getText());
                }
            }

            @Override
            public void onComplete() {
            	log.info("识别完成");
            }

            @Override
            public void onError(Exception e) {
            	log.info("识别回调异常: " + e.getMessage());
            	onError.accept(e);
            }
        };
        return callback;
	}

}

接收前端发送来的音频流

接收http请求发过来的音频的代码如下:

@Slf4j
@RestController
public class SpeechController {
	
	@Autowired
	private AgentMessageProcessor messageProcessor;
	@Autowired
	private FunasrService funAsrAudioService;
	@Autowired
	private ApplicationEventPublisher eventPublisher;
	
	@PostMapping(value = "/speech", consumes = MediaType.MULTIPART_FORM_DATA_VALUE)
	public Flux<String> speech(Long memoryId, @RequestParam("audio") MultipartFile audio) {
		log.info("收到语音转换请求:memoryId={},audio={}", memoryId, audio.getName());
		Flux<String> flux = messageProcessor.getMessageFlux(memoryId);
		//回调
		ResultCallback<RecognitionResult> callback = funAsrAudioService.getCallBack(
				m->eventPublisher.publishEvent(new AgentMessageEvent(this, memoryId, m, true)),
				e->{
					log.error("语音识别出现错误:", e);
					eventPublisher.publishEvent(new AgentMessageEvent(this, memoryId, "语音识别出现错误", true));
				});
		funAsrAudioService.audio2Text(audio, callback);
		
		return flux;
	}
}

这里的AgentMessageProcessor是智能体消息处理,主要是获取一个Flux实例用于实现流式对话。ApplicationEventPublisher用于在识别成功或识别出错的情况发送事件,在回调中会给客户端发送消息。

智能体消息处理,向前端发送流式消息

AgentMessageProcessor代码如下:

import org.springframework.stereotype.Component;

import com.beust.jcommander.internal.Maps;

import lombok.extern.slf4j.Slf4j;

import org.dromara.common.json.utils.JsonUtils;
import org.springframework.context.event.EventListener;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Sinks;

import java.util.Map;

@Slf4j
@Component
public class AgentMessageProcessor {
    
    private Map<Object, Sinks.Many<String>> sinkMap = Maps.newHashMap();
 
    /**
     * 监听消息事件
     * @param event
     */
    @EventListener
    public void handleMessageEvent(AgentMessageEvent event) {
    	log.info("接收到发送事件:{}", JsonUtils.toJsonString(event));
    	Sinks.Many<String> sink = sinkMap.get(event.getMemoryId());
        // 发送消息到Flux
        Sinks.EmitResult result = sink.tryEmitNext(event.getMessage());
        if (result.isFailure()) {
            log.error("发送消息失败: " + result);
            sink.emitComplete(null);
        }else if(event.isClose()) {
        	sink.emitComplete(null);
        }
    }
    
   
    /**
     * 根据memoryId获取Flux对象
     * @param memoryId
     * @return
     */
    public Flux<String> getMessageFlux(Object memoryId) {
    	Sinks.Many<String> sink = Sinks.many().multicast().onBackpressureBuffer();
    	sinkMap.put(memoryId, sink);
    	Flux<String> flux = sink.asFlux();
    	flux.subscribe();
    	flux.doOnCancel(()->remove(memoryId))
			.doOnError(error->{
				log.error("流错误: " + error.getMessage());
				remove(memoryId);
			})
			.doOnComplete(()->remove(memoryId));
        return flux;
    }
    
    
    /**
     * 发送消息
     * @param memoryId
     * @param message
     */
    public void sendMessage(Object memoryId, String message) {
    	Sinks.Many<String> sink = sinkMap.get(memoryId);
        sink.tryEmitNext(message);
    }
    
    public void remove(Object memoryId) {
    	sinkMap.remove(memoryId);
    }
}

这样一个完整的异步语音识别事件流就完成了。还可根据FunASR的部署模式(本地或第三方)自由切换。

赞(0)
未经允许不得转载:LoveCTO » 用java访问本都部署的FunASR实现语音识别,可自由切换到阿里百炼

热爱技术 追求卓越 精益求精