1use anyhow::{anyhow, Result};
2use futures::{
3 future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
4 Stream, StreamExt,
5};
6use gpui::executor::Background;
7use isahc::{http::StatusCode, Request, RequestExt};
8use serde::{Deserialize, Serialize};
9use std::{
10 fmt::{self, Display},
11 io,
12 sync::Arc,
13};
14
15use crate::completion::{CompletionProvider, CompletionRequest};
16
17pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
18
19#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
20#[serde(rename_all = "lowercase")]
21pub enum Role {
22 User,
23 Assistant,
24 System,
25}
26
27impl Role {
28 pub fn cycle(&mut self) {
29 *self = match self {
30 Role::User => Role::Assistant,
31 Role::Assistant => Role::System,
32 Role::System => Role::User,
33 }
34 }
35}
36
37impl Display for Role {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
39 match self {
40 Role::User => write!(f, "User"),
41 Role::Assistant => write!(f, "Assistant"),
42 Role::System => write!(f, "System"),
43 }
44 }
45}
46
47#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
48pub struct RequestMessage {
49 pub role: Role,
50 pub content: String,
51}
52
53#[derive(Debug, Default, Serialize)]
54pub struct OpenAIRequest {
55 pub model: String,
56 pub messages: Vec<RequestMessage>,
57 pub stream: bool,
58 pub stop: Vec<String>,
59 pub temperature: f32,
60}
61
62impl CompletionRequest for OpenAIRequest {
63 fn data(&self) -> serde_json::Result<String> {
64 serde_json::to_string(self)
65 }
66}
67
68#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
69pub struct ResponseMessage {
70 pub role: Option<Role>,
71 pub content: Option<String>,
72}
73
74#[derive(Deserialize, Debug)]
75pub struct OpenAIUsage {
76 pub prompt_tokens: u32,
77 pub completion_tokens: u32,
78 pub total_tokens: u32,
79}
80
81#[derive(Deserialize, Debug)]
82pub struct ChatChoiceDelta {
83 pub index: u32,
84 pub delta: ResponseMessage,
85 pub finish_reason: Option<String>,
86}
87
88#[derive(Deserialize, Debug)]
89pub struct OpenAIResponseStreamEvent {
90 pub id: Option<String>,
91 pub object: String,
92 pub created: u32,
93 pub model: String,
94 pub choices: Vec<ChatChoiceDelta>,
95 pub usage: Option<OpenAIUsage>,
96}
97
98pub async fn stream_completion(
99 api_key: String,
100 executor: Arc<Background>,
101 request: Box<dyn CompletionRequest>,
102) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
103 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
104
105 let json_data = request.data()?;
106 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
107 .header("Content-Type", "application/json")
108 .header("Authorization", format!("Bearer {}", api_key))
109 .body(json_data)?
110 .send_async()
111 .await?;
112
113 let status = response.status();
114 if status == StatusCode::OK {
115 executor
116 .spawn(async move {
117 let mut lines = BufReader::new(response.body_mut()).lines();
118
119 fn parse_line(
120 line: Result<String, io::Error>,
121 ) -> Result<Option<OpenAIResponseStreamEvent>> {
122 if let Some(data) = line?.strip_prefix("data: ") {
123 let event = serde_json::from_str(&data)?;
124 Ok(Some(event))
125 } else {
126 Ok(None)
127 }
128 }
129
130 while let Some(line) = lines.next().await {
131 if let Some(event) = parse_line(line).transpose() {
132 let done = event.as_ref().map_or(false, |event| {
133 event
134 .choices
135 .last()
136 .map_or(false, |choice| choice.finish_reason.is_some())
137 });
138 if tx.unbounded_send(event).is_err() {
139 break;
140 }
141
142 if done {
143 break;
144 }
145 }
146 }
147
148 anyhow::Ok(())
149 })
150 .detach();
151
152 Ok(rx)
153 } else {
154 let mut body = String::new();
155 response.body_mut().read_to_string(&mut body).await?;
156
157 #[derive(Deserialize)]
158 struct OpenAIResponse {
159 error: OpenAIError,
160 }
161
162 #[derive(Deserialize)]
163 struct OpenAIError {
164 message: String,
165 }
166
167 match serde_json::from_str::<OpenAIResponse>(&body) {
168 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
169 "Failed to connect to OpenAI API: {}",
170 response.error.message,
171 )),
172
173 _ => Err(anyhow!(
174 "Failed to connect to OpenAI API: {} {}",
175 response.status(),
176 body,
177 )),
178 }
179 }
180}
181
182pub struct OpenAICompletionProvider {
183 api_key: String,
184 executor: Arc<Background>,
185}
186
187impl OpenAICompletionProvider {
188 pub fn new(api_key: String, executor: Arc<Background>) -> Self {
189 Self { api_key, executor }
190 }
191}
192
193impl CompletionProvider for OpenAICompletionProvider {
194 fn complete(
195 &self,
196 prompt: Box<dyn CompletionRequest>,
197 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
198 let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
199 async move {
200 let response = request.await?;
201 let stream = response
202 .filter_map(|response| async move {
203 match response {
204 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
205 Err(error) => Some(Err(error)),
206 }
207 })
208 .boxed();
209 Ok(stream)
210 }
211 .boxed()
212 }
213}