• java模拟GPT流式问答


    流式请求gpt并且流式推送相关前端页面

    1)java流式获取gpt答案

    1、读取文件流的方式

    使用post请求数据,由于gpt是eventsource的方式返回数据,所以格式是data:,需要手动替换一下值

    1. /**
    2. org.apache.http.client.methods
    3. **/
    4. @SneakyThrows
    5. private void chatStream(List messagesBOList) {
    6. CloseableHttpClient httpclient = HttpClients.createDefault();
    7. HttpPost httpPost = new HttpPost("https://api.openai.com/v1/chat/completions");
    8. httpPost.setHeader("Authorization","xxxxxxxxxxxx");
    9. httpPost.setHeader("Content-Type","application/json; charset=UTF-8");
    10. ChatParamBO build = ChatParamBO.builder()
    11. .temperature(0.7)
    12. .model("gpt-3.5-turbo")
    13. .messages(messagesBOList)
    14. .stream(true)
    15. .build();
    16. System.out.println(JsonUtils.toJson(build));
    17. httpPost.setEntity(new StringEntity(JsonUtils.toJson(build),"utf-8"));
    18. CloseableHttpResponse response = httpclient.execute(httpPost);
    19. try {
    20. HttpEntity entity = response.getEntity();
    21. if (entity != null) {
    22. InputStream inputStream = entity.getContent();
    23. BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream));
    24. String line;
    25. while ((line = reader.readLine()) != null) {
    26. // 处理 event stream 数据
    27. try {
    28. // System.out.println(line);
    29. ChatResultBO chatResultBO = JsonUtils.toObject(line.replace("data:", ""), ChatResultBO.class);
    30. String content = chatResultBO.getChoices().get(0).getDelta().getContent();
    31. log.info(content);
    32. // System.out.println(chatResultBO.getChoices().get(0).getMessage().getContent());
    33. } catch (Exception e) {
    34. // e.printStackTrace();
    35. }
    36. }
    37. }
    38. } finally {
    39. response.close();
    40. }
    41. }

    2、sse链接的方式获取数据

    用到了okhttp

    需要先引用相关maven:

    1. <dependency>
    2. <groupId>com.squareup.okhttp3groupId>
    3. <artifactId>okhttpartifactId>
    4. dependency>
    5. <dependency>
    6. <groupId>com.squareup.okhttp3groupId>
    7. <artifactId>okhttp-sseartifactId>
    8. dependency>
    1. // 定义see接口
    2. Request request = new Request.Builder().url("https://api.openai.com/v1/chat/completions")
    3. .header("Authorization","xxx")
    4. .post(okhttp3.RequestBody.create(okhttp3.MediaType.parse("application/json; charset=utf-8"),param.toJSONString()))
    5. .build();
    6. OkHttpClient okHttpClient = new OkHttpClient.Builder()
    7. .connectTimeout(10, TimeUnit.MINUTES)
    8. .readTimeout(10, TimeUnit.MINUTES)//这边需要将超时显示设置长一点,不然刚连上就断开,之前以为调用方式错误被坑了半天
    9. .build();
    10. // 实例化EventSource,注册EventSource监听器
    11. RealEventSource realEventSource = new RealEventSource(request, new EventSourceListener() {
    12. @Override
    13. public void onOpen(EventSource eventSource, Response response) {
    14. log.info("onOpen");
    15. }
    16. @SneakyThrows
    17. @Override
    18. public void onEvent(EventSource eventSource, String id, String type, String data) {
    19. // log.info("onEvent");
    20. log.info(data);//请求到的数据
    21. }
    22. @Override
    23. public void onClosed(EventSource eventSource) {
    24. log.info("onClosed");
    25. // emitter.complete();
    26. }
    27. @Override
    28. public void onFailure(EventSource eventSource, Throwable t, Response response) {
    29. log.info("onFailure,t={},response={}",t,response);//这边可以监听并重新打开
    30. // emitter.complete();
    31. }
    32. });
    33. realEventSource.connect(okHttpClient);//真正开始请求的一步

    2)流式推送答案

    方法一:通过订阅式SSE/WebSocket

    原理是先建立链接,然后不断发消息就可以

    1、websocket

    创建相关配置:

    1. import javax.websocket.Session;
    2. import lombok.Data;
    3. /**
    4. * @description WebSocket客户端连接
    5. */
    6. @Data
    7. public class WebSocketClient {
    8. // 与某个客户端的连接会话,需要通过它来给客户端发送数据
    9. private Session session;
    10. //连接的uri
    11. private String uri;
    12. }
    1. import org.springframework.context.annotation.Bean;
    2. import org.springframework.context.annotation.Configuration;
    3. import org.springframework.web.socket.server.standard.ServerEndpointExporter;
    4. @Configuration
    5. public class WebSocketConfig {
    6. @Bean
    7. public ServerEndpointExporter serverEndpointExporter() {
    8. return new ServerEndpointExporter();
    9. }
    10. }
    配置相关service
    1. @Slf4j
    2. @Component
    3. @ServerEndpoint("/websocket/chat/{chatId}")
    4. public class ChatWebsocketService {
    5. static final ConcurrentHashMap> webSocketClientMap= new ConcurrentHashMap<>();
    6. private String chatId;
    7. /**
    8. * 连接建立成功时触发,绑定参数
    9. * @param session 与某个客户端的连接会话,需要通过它来给客户端发送数据
    10. * @param chatId 商户ID
    11. */
    12. @OnOpen
    13. public void onOpen(Session session, @PathParam("chatId") String chatId){
    14. WebSocketClient client = new WebSocketClient();
    15. client.setSession(session);
    16. client.setUri(session.getRequestURI().toString());
    17. List webSocketClientList = webSocketClientMap.get(chatId);
    18. if(webSocketClientList == null){
    19. webSocketClientList = new ArrayList<>();
    20. }
    21. webSocketClientList.add(client);
    22. webSocketClientMap.put(chatId, webSocketClientList);
    23. this.chatId = chatId;
    24. }
    25. /**
    26. * 收到客户端消息后调用的方法
    27. *
    28. * @param message 客户端发送过来的消息
    29. */
    30. @OnMessage
    31. public void onMessage(String message) {
    32. log.info("chatId = {},message = {}",chatId,message);
    33. // 回复消息
    34. this.chatStream(BaseUtil.newList(ChatParamMessagesBO.builder().content(message).role("user").build()));
    35. // this.sendMessage(chatId,message+"233");
    36. }
    37. /**
    38. * 连接关闭时触发,注意不能向客户端发送消息了
    39. * @param chatId
    40. */
    41. @OnClose
    42. public void onClose(@PathParam("chatId") String chatId){
    43. webSocketClientMap.remove(chatId);
    44. }
    45. /**
    46. * 通信发生错误时触发
    47. * @param session
    48. * @param error
    49. */
    50. @OnError
    51. public void onError(Session session, Throwable error) {
    52. System.out.println("发生错误");
    53. error.printStackTrace();
    54. }
    55. /**
    56. * 向客户端发送消息
    57. * @param chatId
    58. * @param message
    59. */
    60. public void sendMessage(String chatId,String message){
    61. try {
    62. List webSocketClientList = webSocketClientMap.get(chatId);
    63. if(webSocketClientList!=null){
    64. for(WebSocketClient webSocketServer:webSocketClientList){
    65. webSocketServer.getSession().getBasicRemote().sendText(message);
    66. }
    67. }
    68. } catch (IOException e) {
    69. e.printStackTrace();
    70. throw new RuntimeException(e.getMessage());
    71. }
    72. }
    73. /**
    74. * 流式调用查询gpt
    75. * @param messagesBOList
    76. * @throws IOException
    77. */
    78. @SneakyThrows
    79. private void chatStream(List messagesBOList) {
    80. // TODO 和GPT的访问请求
    81. }
    82. }
    测试,postman建立链接

    2、SSE

    本质也是基于订阅推送方式

    前端:
    1. html>
    2. <html lang="en">
    3. <head>
    4. <meta charset="UTF-8">
    5. <title>SseEmittertitle>
    6. head>
    7. <body>
    8. <button onclick="closeSse()">关闭连接button>
    9. <div id="message">div>
    10. body>
    11. <script>
    12. let source = null;
    13. // 用时间戳模拟登录用户
    14. //const id = new Date().getTime();
    15. const id = '7829083B42464C5B9C445A087E873C7D';
    16. if (window.EventSource) {
    17. // 建立连接
    18. source = new EventSource('http://172.28.54.27:8902/api/sse/connect?conversationId=' + id);
    19. setMessageInnerHTML("连接用户=" + id);
    20. /**
    21. * 连接一旦建立,就会触发open事件
    22. * 另一种写法:source.onopen = function (event) {}
    23. */
    24. source.addEventListener('open', function(e) {
    25. setMessageInnerHTML("建立连接。。。");
    26. }, false);
    27. /**
    28. * 客户端收到服务器发来的数据
    29. * 另一种写法:source.onmessage = function (event) {}
    30. */
    31. source.addEventListener('message', function(e) {
    32. //console.log(e);
    33. setMessageInnerHTML(e.data);
    34. });
    35. source.addEventListener("close", function (event) {
    36. // 在这里处理关闭事件
    37. console.log("Server closed the connection");
    38. // 可以选择关闭EventSource连接
    39. source.close();
    40. });
    41. /**
    42. * 如果发生通信错误(比如连接中断),就会触发error事件
    43. * 或者:
    44. * 另一种写法:source.onerror = function (event) {}
    45. */
    46. source.addEventListener('error', function(e) {
    47. console.log(e);
    48. if (e.readyState === EventSource.CLOSED) {
    49. setMessageInnerHTML("连接关闭");
    50. } else {
    51. console.log(e);
    52. }
    53. }, false);
    54. } else {
    55. setMessageInnerHTML("你的浏览器不支持SSE");
    56. }
    57. // 监听窗口关闭事件,主动去关闭sse连接,如果服务端设置永不过期,浏览器关闭后手动清理服务端数据
    58. window.onbeforeunload = function() {
    59. //closeSse();
    60. };
    61. // 关闭Sse连接
    62. function closeSse() {
    63. source.close();
    64. const httpRequest = new XMLHttpRequest();
    65. httpRequest.open('GET', 'http://172.28.54.27:8902/api/sse/disconnection?conversationId=' + id, true);
    66. httpRequest.send();
    67. console.log("close");
    68. }
    69. // 将消息显示在网页上
    70. function setMessageInnerHTML(innerHTML) {
    71. document.getElementById('message').innerHTML += innerHTML + '
      '
      ;
    72. }
    73. script>
    74. html>
    后端:
    controller
    1. import org.springframework.cloud.context.config.annotation.RefreshScope;
    2. import org.springframework.validation.annotation.Validated;
    3. import org.springframework.web.bind.annotation.GetMapping;
    4. import org.springframework.web.bind.annotation.RequestMapping;
    5. import org.springframework.web.bind.annotation.RestController;
    6. import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
    7. import java.util.Set;
    8. import java.util.function.Consumer;
    9. import javax.annotation.Resource;
    10. import lombok.SneakyThrows;
    11. import lombok.extern.slf4j.Slf4j;
    12. @Validated
    13. @RestController
    14. @RequestMapping("/api/sse")
    15. @Slf4j
    16. @RefreshScope // 会监听变化实时变化值
    17. public class SseController {
    18. @Resource
    19. private SseBizService sseBizService;
    20. /**
    21. * 创建用户连接并返回 SseEmitter
    22. *
    23. * @param conversationId 用户ID
    24. * @return SseEmitter
    25. */
    26. @SneakyThrows
    27. @GetMapping(value = "/connect", produces = "text/event-stream; charset=utf-8")
    28. public SseEmitter connect(String conversationId) {
    29. // 设置超时时间,0表示不过期。默认30秒,超过时间未完成会抛出异常:AsyncRequestTimeoutException
    30. SseEmitter sseEmitter = new SseEmitter(0L);
    31. // 注册回调
    32. sseEmitter.onCompletion(completionCallBack(conversationId));
    33. sseEmitter.onError(errorCallBack(conversationId));
    34. sseEmitter.onTimeout(timeoutCallBack(conversationId));
    35. log.info("创建新的sse连接,当前用户:{}", conversationId);
    36. sseBizService.addConnect(conversationId,sseEmitter);
    37. sseBizService.sendMsg(conversationId,"链接成功");
    38. // sseCache.get(conversationId).send(SseEmitter.event().reconnectTime(10000).data("链接成功"),MediaType.TEXT_EVENT_STREAM);
    39. return sseEmitter;
    40. }
    41. /**
    42. * 给指定用户发送信息 -- 单播
    43. */
    44. @GetMapping(value = "/send", produces = "text/event-stream; charset=utf-8")
    45. public void sendMessage(String conversationId, String msg) {
    46. sseBizService.sendMsg(conversationId,msg);
    47. }
    48. /**
    49. * 移除用户连接
    50. */
    51. @GetMapping(value = "/disconnection", produces = "text/event-stream; charset=utf-8")
    52. public void removeUser(String conversationId) {
    53. log.info("移除用户:{}", conversationId);
    54. sseBizService.deleteConnect(conversationId);
    55. }
    56. /**
    57. * 向多人发布消息 -- 组播
    58. * @param groupId 开头标识
    59. * @param message 消息内容
    60. */
    61. public void groupSendMessage(String groupId, String message) {
    62. /* if (!BaseUtil.isNullOrEmpty(sseCache)) {
    63. *//*Set ids = sseEmitterMap.keySet().stream().filter(m -> m.startsWith(groupId)).collect(Collectors.toSet());
    64. batchSendMessage(message, ids);*//*
    65. sseCache.forEach((k, v) -> {
    66. try {
    67. if (k.startsWith(groupId)) {
    68. v.send(message, MediaType.APPLICATION_JSON);
    69. }
    70. } catch (IOException e) {
    71. log.error("用户[{}]推送异常:{}", k, e.getMessage());
    72. removeUser(k);
    73. }
    74. });
    75. }*/
    76. }
    77. /**
    78. * 群发所有人 -- 广播
    79. */
    80. public void batchSendMessage(String message) {
    81. /*sseCache.forEach((k, v) -> {
    82. try {
    83. v.send(message, MediaType.APPLICATION_JSON);
    84. } catch (IOException e) {
    85. log.error("用户[{}]推送异常:{}", k, e.getMessage());
    86. removeUser(k);
    87. }
    88. });*/
    89. }
    90. /**
    91. * 群发消息
    92. */
    93. public void batchSendMessage(String message, Set ids) {
    94. ids.forEach(userId -> sendMessage(userId, message));
    95. }
    96. /**
    97. * 获取当前连接信息
    98. */
    99. // public List getIds() {
    100. // return new ArrayList<>(sseCache.keySet());
    101. // }
    102. /**
    103. * 获取当前连接数量
    104. */
    105. // public int getUserCount() {
    106. // return count.intValue();
    107. // }
    108. private Runnable completionCallBack(String userId) {
    109. return () -> {
    110. log.info("结束连接:{}", userId);
    111. removeUser(userId);
    112. };
    113. }
    114. private Runnable timeoutCallBack(String userId) {
    115. return () -> {
    116. log.info("连接超时:{}", userId);
    117. removeUser(userId);
    118. };
    119. }
    120. private Consumer errorCallBack(String userId) {
    121. return throwable -> {
    122. log.info("连接异常:{}", userId);
    123. removeUser(userId);
    124. };
    125. }
    126. }
    service
    1. import org.springframework.cloud.context.config.annotation.RefreshScope;
    2. import org.springframework.http.MediaType;
    3. import org.springframework.stereotype.Component;
    4. import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
    5. import java.util.Map;
    6. import java.util.concurrent.ConcurrentHashMap;
    7. import java.util.concurrent.atomic.AtomicInteger;
    8. import lombok.SneakyThrows;
    9. import lombok.extern.slf4j.Slf4j;
    10. @Component
    11. @Slf4j
    12. @RefreshScope // 会监听变化实时变化值
    13. public class SseBizService {
    14. /**
    15. *
    16. * 当前连接数
    17. */
    18. private AtomicInteger count = new AtomicInteger(0);
    19. /**
    20. * 使用map对象,便于根据userId来获取对应的SseEmitter,或者放redis里面
    21. */
    22. private Map sseCache = new ConcurrentHashMap<>();
    23. /**
    24. * 添加用户
    25. * @author pengbin
    26. * @date 2023/9/11 11:37
    27. * @param
    28. * @return
    29. */
    30. public void addConnect(String id,SseEmitter sseEmitter){
    31. sseCache.put(id, sseEmitter);
    32. // 数量+1
    33. count.getAndIncrement();
    34. }
    35. /**
    36. * 删除用户
    37. * @author pengbin
    38. * @date 2023/9/11 11:37
    39. * @param
    40. * @return
    41. */
    42. public void deleteConnect(String id){
    43. sseCache.remove(id);
    44. // 数量+1
    45. count.getAndDecrement();
    46. }
    47. /**
    48. * 发送消息
    49. * @author pengbin
    50. * @date 2023/9/11 11:38
    51. * @param
    52. * @return
    53. */
    54. @SneakyThrows
    55. public void sendMsg(String id, String msg){
    56. if(sseCache.containsKey(id)){
    57. sseCache.get(id).send(msg, MediaType.TEXT_EVENT_STREAM);
    58. }
    59. }
    60. }

    方法二:SSE建立eventSource,使用完成后即刻销毁

    前端:在接收到结束标识后立即销毁

    1. /**
    2. * 客户端收到服务器发来的数据
    3. * 另一种写法:source.onmessage = function (event) {}
    4. */
    5. source.addEventListener('message', function(e) {
    6. //console.log(e);
    7. setMessageInnerHTML(e.data);
    8. if(e.data == '[DONE]'){
    9. source.close();
    10. }
    11. });

    后端:
     

    1. @SneakyThrows
    2. @GetMapping(value = "/stream/sse", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
    3. public SseEmitter completionsStream(@RequestParam String conversationId){
    4. //
    5. List messagesBOList =new ArrayList();
    6. // 获取内容信息
    7. ChatParamBO build = ChatParamBO.builder()
    8. .temperature(0.7)
    9. .stream(true)
    10. .model("xxxx")
    11. .messages(messagesBOList)
    12. .build();
    13. SseEmitter emitter = new SseEmitter();
    14. // 定义see接口
    15. Request request = new Request.Builder().url("xxx")
    16. .header("Authorization","xxxx")
    17. .post(okhttp3.RequestBody.create(okhttp3.MediaType.parse("application/json; charset=utf-8"),JsonUtils.toJson(build)))
    18. .build();
    19. OkHttpClient okHttpClient = new OkHttpClient.Builder()
    20. .connectTimeout(10, TimeUnit.MINUTES)
    21. .readTimeout(10, TimeUnit.MINUTES)//这边需要将超时显示设置长一点,不然刚连上就断开,之前以为调用方式错误被坑了半天
    22. .build();
    23. StringBuffer sb = new StringBuffer("");
    24. // 实例化EventSource,注册EventSource监听器
    25. RealEventSource realEventSource = null;
    26. realEventSource = new RealEventSource(request, new EventSourceListener() {
    27. @Override
    28. public void onOpen(EventSource eventSource, Response response) {
    29. log.info("onOpen");
    30. }
    31. @SneakyThrows
    32. @Override
    33. public void onEvent(EventSource eventSource, String id, String type, String data) {
    34. log.info(data);//请求到的数据
    35. try {
    36. ChatResultBO chatResultBO = JsonUtils.toObject(data.replace("data:", ""), ChatResultBO.class);
    37. String content = chatResultBO.getChoices().get(0).getDelta().getContent();
    38. sb.append(content);
    39. emitter.send(SseEmitter.event().data(JsonUtils.toJson(ChatContentBO.builder().content(content).build())));
    40. } catch (Exception e) {
    41. // e.printStackTrace();
    42. }
    43. if("[DONE]".equals(data)){
    44. emitter.send(SseEmitter.event().data(data));
    45. emitter.complete();
    46. log.info("result={}",sb);
    47. }
    48. }
    49. @Override
    50. public void onClosed(EventSource eventSource) {
    51. log.info("onClosed,eventSource={}",eventSource);//这边可以监听并重新打开
    52. // emitter.complete();
    53. }
    54. @Override
    55. public void onFailure(EventSource eventSource, Throwable t, Response response) {
    56. log.info("onFailure,t={},response={}",t,response);//这边可以监听并重新打开
    57. // emitter.complete();
    58. }
    59. });
    60. realEventSource.connect(okHttpClient);//真正开始请求的一步
    61. return emitter;
    62. }

    3)踩坑

    ngnix配置:

    后端配置需要添加:

     #gpt支持流式处理
      proxy_buffering off;

    1. location / {
    2. proxy_pass http://backend;
    3. proxy_redirect default;
    4. proxy_connect_timeout 90;
    5. proxy_read_timeout 90;
    6. proxy_send_timeout 90;
    7. #gpt支持流式处理
    8. proxy_buffering off;
    9. #root html;
    10. #root /opt/project/;
    11. index index.html index.htm;
    12. client_max_body_size 1024m;
    13. #设置正确的外网ip
    14. proxy_set_header Host $host;
    15. proxy_set_header X-Real-IP $remote_addr;
    16. proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
    17. }

  • 相关阅读:
    动漫主题dreamweaver作业静态HTML网页设计——仿京东(海贼王)版本
    NPDP怎么报名?考试难度大吗?
    网安学习-应急响应3
    千年TGS服务器日志报错如何解决
    【100个 Unity实用技能】☀️ | Unity 将秒数转化为00:00:00时间格式
    4种方法教你如何查看java对象所占内存大小
    【机器学习】面试题:LSTM长短期记忆网络的理解?LSTM是怎么解决梯度消失的问题的?还有哪些其它的解决梯度消失或梯度爆炸的方法?
    redis的持久化
    【如何】求 矩阵 的 若尔当标准形J,求 初等因子
    桌面运维命令
  • 原文地址:https://blog.csdn.net/pengbin790000/article/details/133684395