1use anyhow::{anyhow, Result};
2use futures::{
3 future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
4 Stream, StreamExt,
5};
6use gpui::{AppContext, BackgroundExecutor};
7use isahc::{http::StatusCode, Request, RequestExt};
8use parking_lot::RwLock;
9use serde::{Deserialize, Serialize};
10use std::{
11 env,
12 fmt::{self, Display},
13 io,
14 sync::Arc,
15};
16use util::ResultExt;
17
18use crate::{
19 auth::{CredentialProvider, ProviderCredential},
20 completion::{CompletionProvider, CompletionRequest},
21 models::LanguageModel,
22};
23
24use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
25
26#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
27#[serde(rename_all = "lowercase")]
28pub enum Role {
29 User,
30 Assistant,
31 System,
32}
33
34impl Role {
35 pub fn cycle(&mut self) {
36 *self = match self {
37 Role::User => Role::Assistant,
38 Role::Assistant => Role::System,
39 Role::System => Role::User,
40 }
41 }
42}
43
44impl Display for Role {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
46 match self {
47 Role::User => write!(f, "User"),
48 Role::Assistant => write!(f, "Assistant"),
49 Role::System => write!(f, "System"),
50 }
51 }
52}
53
54#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
55pub struct RequestMessage {
56 pub role: Role,
57 pub content: String,
58}
59
60#[derive(Debug, Default, Serialize)]
61pub struct OpenAIRequest {
62 pub model: String,
63 pub messages: Vec<RequestMessage>,
64 pub stream: bool,
65 pub stop: Vec<String>,
66 pub temperature: f32,
67}
68
69impl CompletionRequest for OpenAIRequest {
70 fn data(&self) -> serde_json::Result<String> {
71 serde_json::to_string(self)
72 }
73}
74
75#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
76pub struct ResponseMessage {
77 pub role: Option<Role>,
78 pub content: Option<String>,
79}
80
81#[derive(Deserialize, Debug)]
82pub struct OpenAIUsage {
83 pub prompt_tokens: u32,
84 pub completion_tokens: u32,
85 pub total_tokens: u32,
86}
87
88#[derive(Deserialize, Debug)]
89pub struct ChatChoiceDelta {
90 pub index: u32,
91 pub delta: ResponseMessage,
92 pub finish_reason: Option<String>,
93}
94
95#[derive(Deserialize, Debug)]
96pub struct OpenAIResponseStreamEvent {
97 pub id: Option<String>,
98 pub object: String,
99 pub created: u32,
100 pub model: String,
101 pub choices: Vec<ChatChoiceDelta>,
102 pub usage: Option<OpenAIUsage>,
103}
104
105pub async fn stream_completion(
106 credential: ProviderCredential,
107 executor: BackgroundExecutor,
108 request: Box<dyn CompletionRequest>,
109) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
110 let api_key = match credential {
111 ProviderCredential::Credentials { api_key } => api_key,
112 _ => {
113 return Err(anyhow!("no credentials provider for completion"));
114 }
115 };
116
117 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
118
119 let json_data = request.data()?;
120 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
121 .header("Content-Type", "application/json")
122 .header("Authorization", format!("Bearer {}", api_key))
123 .body(json_data)?
124 .send_async()
125 .await?;
126
127 let status = response.status();
128 if status == StatusCode::OK {
129 executor
130 .spawn(async move {
131 let mut lines = BufReader::new(response.body_mut()).lines();
132
133 fn parse_line(
134 line: Result<String, io::Error>,
135 ) -> Result<Option<OpenAIResponseStreamEvent>> {
136 if let Some(data) = line?.strip_prefix("data: ") {
137 let event = serde_json::from_str(data)?;
138 Ok(Some(event))
139 } else {
140 Ok(None)
141 }
142 }
143
144 while let Some(line) = lines.next().await {
145 if let Some(event) = parse_line(line).transpose() {
146 let done = event.as_ref().map_or(false, |event| {
147 event
148 .choices
149 .last()
150 .map_or(false, |choice| choice.finish_reason.is_some())
151 });
152 if tx.unbounded_send(event).is_err() {
153 break;
154 }
155
156 if done {
157 break;
158 }
159 }
160 }
161
162 anyhow::Ok(())
163 })
164 .detach();
165
166 Ok(rx)
167 } else {
168 let mut body = String::new();
169 response.body_mut().read_to_string(&mut body).await?;
170
171 #[derive(Deserialize)]
172 struct OpenAIResponse {
173 error: OpenAIError,
174 }
175
176 #[derive(Deserialize)]
177 struct OpenAIError {
178 message: String,
179 }
180
181 match serde_json::from_str::<OpenAIResponse>(&body) {
182 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
183 "Failed to connect to OpenAI API: {}",
184 response.error.message,
185 )),
186
187 _ => Err(anyhow!(
188 "Failed to connect to OpenAI API: {} {}",
189 response.status(),
190 body,
191 )),
192 }
193 }
194}
195
196#[derive(Clone)]
197pub struct OpenAICompletionProvider {
198 model: OpenAILanguageModel,
199 credential: Arc<RwLock<ProviderCredential>>,
200 executor: BackgroundExecutor,
201}
202
203impl OpenAICompletionProvider {
204 pub async fn new(model_name: String, executor: BackgroundExecutor) -> Self {
205 let model = executor
206 .spawn(async move { OpenAILanguageModel::load(&model_name) })
207 .await;
208 let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
209 Self {
210 model,
211 credential,
212 executor,
213 }
214 }
215}
216
217impl CredentialProvider for OpenAICompletionProvider {
218 fn has_credentials(&self) -> bool {
219 match *self.credential.read() {
220 ProviderCredential::Credentials { .. } => true,
221 _ => false,
222 }
223 }
224
225 fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
226 let existing_credential = self.credential.read().clone();
227 let retrieved_credential = match existing_credential {
228 ProviderCredential::Credentials { .. } => {
229 return async move { existing_credential }.boxed()
230 }
231 _ => {
232 if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
233 async move { ProviderCredential::Credentials { api_key } }.boxed()
234 } else {
235 let credentials = cx.read_credentials(OPENAI_API_URL);
236 async move {
237 if let Some(Some((_, api_key))) = credentials.await.log_err() {
238 if let Some(api_key) = String::from_utf8(api_key).log_err() {
239 ProviderCredential::Credentials { api_key }
240 } else {
241 ProviderCredential::NoCredentials
242 }
243 } else {
244 ProviderCredential::NoCredentials
245 }
246 }
247 .boxed()
248 }
249 }
250 };
251
252 async move {
253 let retrieved_credential = retrieved_credential.await;
254 *self.credential.write() = retrieved_credential.clone();
255 retrieved_credential
256 }
257 .boxed()
258 }
259
260 fn save_credentials(
261 &self,
262 cx: &mut AppContext,
263 credential: ProviderCredential,
264 ) -> BoxFuture<()> {
265 *self.credential.write() = credential.clone();
266 let credential = credential.clone();
267 let write_credentials = match credential {
268 ProviderCredential::Credentials { api_key } => {
269 Some(cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()))
270 }
271 _ => None,
272 };
273
274 async move {
275 if let Some(write_credentials) = write_credentials {
276 write_credentials.await.log_err();
277 }
278 }
279 .boxed()
280 }
281
282 fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
283 *self.credential.write() = ProviderCredential::NoCredentials;
284 let delete_credentials = cx.delete_credentials(OPENAI_API_URL);
285 async move {
286 delete_credentials.await.log_err();
287 }
288 .boxed()
289 }
290}
291
292impl CompletionProvider for OpenAICompletionProvider {
293 fn base_model(&self) -> Box<dyn LanguageModel> {
294 let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
295 model
296 }
297 fn complete(
298 &self,
299 prompt: Box<dyn CompletionRequest>,
300 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
301 // Currently the CompletionRequest for OpenAI, includes a 'model' parameter
302 // This means that the model is determined by the CompletionRequest and not the CompletionProvider,
303 // which is currently model based, due to the language model.
304 // At some point in the future we should rectify this.
305 let credential = self.credential.read().clone();
306 let request = stream_completion(credential, self.executor.clone(), prompt);
307 async move {
308 let response = request.await?;
309 let stream = response
310 .filter_map(|response| async move {
311 match response {
312 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
313 Err(error) => Some(Err(error)),
314 }
315 })
316 .boxed();
317 Ok(stream)
318 }
319 .boxed()
320 }
321 fn box_clone(&self) -> Box<dyn CompletionProvider> {
322 Box::new((*self).clone())
323 }
324}