1use anyhow::{anyhow, Result};
2use futures::{
3 future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
4 Stream, StreamExt,
5};
6use gpui::executor::Background;
7use isahc::{http::StatusCode, Request, RequestExt};
8use serde::{Deserialize, Serialize};
9use std::{
10 fmt::{self, Display},
11 io,
12 sync::Arc,
13};
14
15pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
16
17#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
18#[serde(rename_all = "lowercase")]
19pub enum Role {
20 User,
21 Assistant,
22 System,
23}
24
25impl Role {
26 pub fn cycle(&mut self) {
27 *self = match self {
28 Role::User => Role::Assistant,
29 Role::Assistant => Role::System,
30 Role::System => Role::User,
31 }
32 }
33}
34
35impl Display for Role {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
37 match self {
38 Role::User => write!(f, "User"),
39 Role::Assistant => write!(f, "Assistant"),
40 Role::System => write!(f, "System"),
41 }
42 }
43}
44
45#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
46pub struct RequestMessage {
47 pub role: Role,
48 pub content: String,
49}
50
51#[derive(Debug, Default, Serialize)]
52pub struct OpenAIRequest {
53 pub model: String,
54 pub messages: Vec<RequestMessage>,
55 pub stream: bool,
56 pub stop: Vec<String>,
57 pub temperature: f32,
58}
59
60#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
61pub struct ResponseMessage {
62 pub role: Option<Role>,
63 pub content: Option<String>,
64}
65
66#[derive(Deserialize, Debug)]
67pub struct OpenAIUsage {
68 pub prompt_tokens: u32,
69 pub completion_tokens: u32,
70 pub total_tokens: u32,
71}
72
73#[derive(Deserialize, Debug)]
74pub struct ChatChoiceDelta {
75 pub index: u32,
76 pub delta: ResponseMessage,
77 pub finish_reason: Option<String>,
78}
79
80#[derive(Deserialize, Debug)]
81pub struct OpenAIResponseStreamEvent {
82 pub id: Option<String>,
83 pub object: String,
84 pub created: u32,
85 pub model: String,
86 pub choices: Vec<ChatChoiceDelta>,
87 pub usage: Option<OpenAIUsage>,
88}
89
90pub async fn stream_completion(
91 api_key: String,
92 executor: Arc<Background>,
93 mut request: OpenAIRequest,
94) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
95 request.stream = true;
96
97 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
98
99 let json_data = serde_json::to_string(&request)?;
100 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
101 .header("Content-Type", "application/json")
102 .header("Authorization", format!("Bearer {}", api_key))
103 .body(json_data)?
104 .send_async()
105 .await?;
106
107 let status = response.status();
108 if status == StatusCode::OK {
109 executor
110 .spawn(async move {
111 let mut lines = BufReader::new(response.body_mut()).lines();
112
113 fn parse_line(
114 line: Result<String, io::Error>,
115 ) -> Result<Option<OpenAIResponseStreamEvent>> {
116 if let Some(data) = line?.strip_prefix("data: ") {
117 let event = serde_json::from_str(&data)?;
118 Ok(Some(event))
119 } else {
120 Ok(None)
121 }
122 }
123
124 while let Some(line) = lines.next().await {
125 if let Some(event) = parse_line(line).transpose() {
126 let done = event.as_ref().map_or(false, |event| {
127 event
128 .choices
129 .last()
130 .map_or(false, |choice| choice.finish_reason.is_some())
131 });
132 if tx.unbounded_send(event).is_err() {
133 break;
134 }
135
136 if done {
137 break;
138 }
139 }
140 }
141
142 anyhow::Ok(())
143 })
144 .detach();
145
146 Ok(rx)
147 } else {
148 let mut body = String::new();
149 response.body_mut().read_to_string(&mut body).await?;
150
151 #[derive(Deserialize)]
152 struct OpenAIResponse {
153 error: OpenAIError,
154 }
155
156 #[derive(Deserialize)]
157 struct OpenAIError {
158 message: String,
159 }
160
161 match serde_json::from_str::<OpenAIResponse>(&body) {
162 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
163 "Failed to connect to OpenAI API: {}",
164 response.error.message,
165 )),
166
167 _ => Err(anyhow!(
168 "Failed to connect to OpenAI API: {} {}",
169 response.status(),
170 body,
171 )),
172 }
173 }
174}
175
176pub trait CompletionProvider {
177 fn complete(
178 &self,
179 prompt: OpenAIRequest,
180 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
181}
182
183pub struct OpenAICompletionProvider {
184 api_key: String,
185 executor: Arc<Background>,
186}
187
188impl OpenAICompletionProvider {
189 pub fn new(api_key: String, executor: Arc<Background>) -> Self {
190 Self { api_key, executor }
191 }
192}
193
194impl CompletionProvider for OpenAICompletionProvider {
195 fn complete(
196 &self,
197 prompt: OpenAIRequest,
198 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
199 let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
200 async move {
201 let response = request.await?;
202 let stream = response
203 .filter_map(|response| async move {
204 match response {
205 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
206 Err(error) => Some(Err(error)),
207 }
208 })
209 .boxed();
210 Ok(stream)
211 }
212 .boxed()
213 }
214}