1use anyhow::{anyhow, Result};
2use futures::{
3 future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
4 Stream, StreamExt,
5};
6use gpui::executor::Background;
7use isahc::{http::StatusCode, Request, RequestExt};
8use serde::{Deserialize, Serialize};
9use std::{
10 fmt::{self, Display},
11 io,
12 sync::Arc,
13};
14
15pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
16
17#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
18#[serde(rename_all = "lowercase")]
19pub enum Role {
20 User,
21 Assistant,
22 System,
23}
24
25impl Role {
26 pub fn cycle(&mut self) {
27 *self = match self {
28 Role::User => Role::Assistant,
29 Role::Assistant => Role::System,
30 Role::System => Role::User,
31 }
32 }
33}
34
35impl Display for Role {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
37 match self {
38 Role::User => write!(f, "User"),
39 Role::Assistant => write!(f, "Assistant"),
40 Role::System => write!(f, "System"),
41 }
42 }
43}
44
45#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
46pub struct RequestMessage {
47 pub role: Role,
48 pub content: String,
49}
50
51#[derive(Debug, Default, Serialize)]
52pub struct OpenAIRequest {
53 pub model: String,
54 pub messages: Vec<RequestMessage>,
55 pub stream: bool,
56}
57
58#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
59pub struct ResponseMessage {
60 pub role: Option<Role>,
61 pub content: Option<String>,
62}
63
64#[derive(Deserialize, Debug)]
65pub struct OpenAIUsage {
66 pub prompt_tokens: u32,
67 pub completion_tokens: u32,
68 pub total_tokens: u32,
69}
70
71#[derive(Deserialize, Debug)]
72pub struct ChatChoiceDelta {
73 pub index: u32,
74 pub delta: ResponseMessage,
75 pub finish_reason: Option<String>,
76}
77
78#[derive(Deserialize, Debug)]
79pub struct OpenAIResponseStreamEvent {
80 pub id: Option<String>,
81 pub object: String,
82 pub created: u32,
83 pub model: String,
84 pub choices: Vec<ChatChoiceDelta>,
85 pub usage: Option<OpenAIUsage>,
86}
87
88pub async fn stream_completion(
89 api_key: String,
90 executor: Arc<Background>,
91 mut request: OpenAIRequest,
92) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
93 request.stream = true;
94
95 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
96
97 let json_data = serde_json::to_string(&request)?;
98 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
99 .header("Content-Type", "application/json")
100 .header("Authorization", format!("Bearer {}", api_key))
101 .body(json_data)?
102 .send_async()
103 .await?;
104
105 let status = response.status();
106 if status == StatusCode::OK {
107 executor
108 .spawn(async move {
109 let mut lines = BufReader::new(response.body_mut()).lines();
110
111 fn parse_line(
112 line: Result<String, io::Error>,
113 ) -> Result<Option<OpenAIResponseStreamEvent>> {
114 if let Some(data) = line?.strip_prefix("data: ") {
115 let event = serde_json::from_str(&data)?;
116 Ok(Some(event))
117 } else {
118 Ok(None)
119 }
120 }
121
122 while let Some(line) = lines.next().await {
123 if let Some(event) = parse_line(line).transpose() {
124 let done = event.as_ref().map_or(false, |event| {
125 event
126 .choices
127 .last()
128 .map_or(false, |choice| choice.finish_reason.is_some())
129 });
130 if tx.unbounded_send(event).is_err() {
131 break;
132 }
133
134 if done {
135 break;
136 }
137 }
138 }
139
140 anyhow::Ok(())
141 })
142 .detach();
143
144 Ok(rx)
145 } else {
146 let mut body = String::new();
147 response.body_mut().read_to_string(&mut body).await?;
148
149 #[derive(Deserialize)]
150 struct OpenAIResponse {
151 error: OpenAIError,
152 }
153
154 #[derive(Deserialize)]
155 struct OpenAIError {
156 message: String,
157 }
158
159 match serde_json::from_str::<OpenAIResponse>(&body) {
160 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
161 "Failed to connect to OpenAI API: {}",
162 response.error.message,
163 )),
164
165 _ => Err(anyhow!(
166 "Failed to connect to OpenAI API: {} {}",
167 response.status(),
168 body,
169 )),
170 }
171 }
172}
173
174pub trait CompletionProvider {
175 fn complete(
176 &self,
177 prompt: OpenAIRequest,
178 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
179}
180
181pub struct OpenAICompletionProvider {
182 api_key: String,
183 executor: Arc<Background>,
184}
185
186impl OpenAICompletionProvider {
187 pub fn new(api_key: String, executor: Arc<Background>) -> Self {
188 Self { api_key, executor }
189 }
190}
191
192impl CompletionProvider for OpenAICompletionProvider {
193 fn complete(
194 &self,
195 prompt: OpenAIRequest,
196 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
197 let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
198 async move {
199 let response = request.await?;
200 let stream = response
201 .filter_map(|response| async move {
202 match response {
203 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
204 Err(error) => Some(Err(error)),
205 }
206 })
207 .boxed();
208 Ok(stream)
209 }
210 .boxed()
211 }
212}