completion.rs

  1use anyhow::{anyhow, Result};
  2use futures::{
  3    future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
  4    Stream, StreamExt,
  5};
  6use gpui::{AppContext, BackgroundExecutor};
  7use isahc::{http::StatusCode, Request, RequestExt};
  8use parking_lot::RwLock;
  9use serde::{Deserialize, Serialize};
 10use std::{
 11    env,
 12    fmt::{self, Display},
 13    io,
 14    sync::Arc,
 15};
 16use util::ResultExt;
 17
 18use crate::{
 19    auth::{CredentialProvider, ProviderCredential},
 20    completion::{CompletionProvider, CompletionRequest},
 21    models::LanguageModel,
 22};
 23
 24use crate::providers::open_ai::{OpenAiLanguageModel, OPEN_AI_API_URL};
 25
 26#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 27#[serde(rename_all = "lowercase")]
 28pub enum Role {
 29    User,
 30    Assistant,
 31    System,
 32}
 33
 34impl Role {
 35    pub fn cycle(&mut self) {
 36        *self = match self {
 37            Role::User => Role::Assistant,
 38            Role::Assistant => Role::System,
 39            Role::System => Role::User,
 40        }
 41    }
 42}
 43
 44impl Display for Role {
 45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 46        match self {
 47            Role::User => write!(f, "User"),
 48            Role::Assistant => write!(f, "Assistant"),
 49            Role::System => write!(f, "System"),
 50        }
 51    }
 52}
 53
 54#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 55pub struct RequestMessage {
 56    pub role: Role,
 57    pub content: String,
 58}
 59
 60#[derive(Debug, Default, Serialize)]
 61pub struct OpenAiRequest {
 62    pub model: String,
 63    pub messages: Vec<RequestMessage>,
 64    pub stream: bool,
 65    pub stop: Vec<String>,
 66    pub temperature: f32,
 67}
 68
 69impl CompletionRequest for OpenAiRequest {
 70    fn data(&self) -> serde_json::Result<String> {
 71        serde_json::to_string(self)
 72    }
 73}
 74
 75#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 76pub struct ResponseMessage {
 77    pub role: Option<Role>,
 78    pub content: Option<String>,
 79}
 80
 81#[derive(Deserialize, Debug)]
 82pub struct OpenAiUsage {
 83    pub prompt_tokens: u32,
 84    pub completion_tokens: u32,
 85    pub total_tokens: u32,
 86}
 87
 88#[derive(Deserialize, Debug)]
 89pub struct ChatChoiceDelta {
 90    pub index: u32,
 91    pub delta: ResponseMessage,
 92    pub finish_reason: Option<String>,
 93}
 94
 95#[derive(Deserialize, Debug)]
 96pub struct OpenAiResponseStreamEvent {
 97    pub id: Option<String>,
 98    pub object: String,
 99    pub created: u32,
100    pub model: String,
101    pub choices: Vec<ChatChoiceDelta>,
102    pub usage: Option<OpenAiUsage>,
103}
104
105pub async fn stream_completion(
106    api_url: String,
107    credential: ProviderCredential,
108    executor: BackgroundExecutor,
109    request: Box<dyn CompletionRequest>,
110) -> Result<impl Stream<Item = Result<OpenAiResponseStreamEvent>>> {
111    let api_key = match credential {
112        ProviderCredential::Credentials { api_key } => api_key,
113        _ => {
114            return Err(anyhow!("no credentials provider for completion"));
115        }
116    };
117
118    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAiResponseStreamEvent>>();
119
120    let json_data = request.data()?;
121    let mut response = Request::post(format!("{api_url}/chat/completions"))
122        .header("Content-Type", "application/json")
123        .header("Authorization", format!("Bearer {}", api_key))
124        .body(json_data)?
125        .send_async()
126        .await?;
127
128    let status = response.status();
129    if status == StatusCode::OK {
130        executor
131            .spawn(async move {
132                let mut lines = BufReader::new(response.body_mut()).lines();
133
134                fn parse_line(
135                    line: Result<String, io::Error>,
136                ) -> Result<Option<OpenAiResponseStreamEvent>> {
137                    if let Some(data) = line?.strip_prefix("data: ") {
138                        let event = serde_json::from_str(data)?;
139                        Ok(Some(event))
140                    } else {
141                        Ok(None)
142                    }
143                }
144
145                while let Some(line) = lines.next().await {
146                    if let Some(event) = parse_line(line).transpose() {
147                        let done = event.as_ref().map_or(false, |event| {
148                            event
149                                .choices
150                                .last()
151                                .map_or(false, |choice| choice.finish_reason.is_some())
152                        });
153                        if tx.unbounded_send(event).is_err() {
154                            break;
155                        }
156
157                        if done {
158                            break;
159                        }
160                    }
161                }
162
163                anyhow::Ok(())
164            })
165            .detach();
166
167        Ok(rx)
168    } else {
169        let mut body = String::new();
170        response.body_mut().read_to_string(&mut body).await?;
171
172        #[derive(Deserialize)]
173        struct OpenAiResponse {
174            error: OpenAiError,
175        }
176
177        #[derive(Deserialize)]
178        struct OpenAiError {
179            message: String,
180        }
181
182        match serde_json::from_str::<OpenAiResponse>(&body) {
183            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
184                "Failed to connect to OpenAI API: {}",
185                response.error.message,
186            )),
187
188            _ => Err(anyhow!(
189                "Failed to connect to OpenAI API: {} {}",
190                response.status(),
191                body,
192            )),
193        }
194    }
195}
196
197#[derive(Clone)]
198pub struct OpenAiCompletionProvider {
199    api_url: String,
200    model: OpenAiLanguageModel,
201    credential: Arc<RwLock<ProviderCredential>>,
202    executor: BackgroundExecutor,
203}
204
205impl OpenAiCompletionProvider {
206    pub async fn new(api_url: String, model_name: String, executor: BackgroundExecutor) -> Self {
207        let model = executor
208            .spawn(async move { OpenAiLanguageModel::load(&model_name) })
209            .await;
210        let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
211        Self {
212            api_url,
213            model,
214            credential,
215            executor,
216        }
217    }
218}
219
220impl CredentialProvider for OpenAiCompletionProvider {
221    fn has_credentials(&self) -> bool {
222        match *self.credential.read() {
223            ProviderCredential::Credentials { .. } => true,
224            _ => false,
225        }
226    }
227
228    fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
229        let existing_credential = self.credential.read().clone();
230        let retrieved_credential = match existing_credential {
231            ProviderCredential::Credentials { .. } => {
232                return async move { existing_credential }.boxed()
233            }
234            _ => {
235                if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
236                    async move { ProviderCredential::Credentials { api_key } }.boxed()
237                } else {
238                    let credentials = cx.read_credentials(OPEN_AI_API_URL);
239                    async move {
240                        if let Some(Some((_, api_key))) = credentials.await.log_err() {
241                            if let Some(api_key) = String::from_utf8(api_key).log_err() {
242                                ProviderCredential::Credentials { api_key }
243                            } else {
244                                ProviderCredential::NoCredentials
245                            }
246                        } else {
247                            ProviderCredential::NoCredentials
248                        }
249                    }
250                    .boxed()
251                }
252            }
253        };
254
255        async move {
256            let retrieved_credential = retrieved_credential.await;
257            *self.credential.write() = retrieved_credential.clone();
258            retrieved_credential
259        }
260        .boxed()
261    }
262
263    fn save_credentials(
264        &self,
265        cx: &mut AppContext,
266        credential: ProviderCredential,
267    ) -> BoxFuture<()> {
268        *self.credential.write() = credential.clone();
269        let credential = credential.clone();
270        let write_credentials = match credential {
271            ProviderCredential::Credentials { api_key } => {
272                Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes()))
273            }
274            _ => None,
275        };
276
277        async move {
278            if let Some(write_credentials) = write_credentials {
279                write_credentials.await.log_err();
280            }
281        }
282        .boxed()
283    }
284
285    fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
286        *self.credential.write() = ProviderCredential::NoCredentials;
287        let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL);
288        async move {
289            delete_credentials.await.log_err();
290        }
291        .boxed()
292    }
293}
294
295impl CompletionProvider for OpenAiCompletionProvider {
296    fn base_model(&self) -> Box<dyn LanguageModel> {
297        let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
298        model
299    }
300    fn complete(
301        &self,
302        prompt: Box<dyn CompletionRequest>,
303    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
304        // Currently the CompletionRequest for OpenAI, includes a 'model' parameter
305        // This means that the model is determined by the CompletionRequest and not the CompletionProvider,
306        // which is currently model based, due to the language model.
307        // At some point in the future we should rectify this.
308        let credential = self.credential.read().clone();
309        let api_url = self.api_url.clone();
310        let request = stream_completion(api_url, credential, self.executor.clone(), prompt);
311        async move {
312            let response = request.await?;
313            let stream = response
314                .filter_map(|response| async move {
315                    match response {
316                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
317                        Err(error) => Some(Err(error)),
318                    }
319                })
320                .boxed();
321            Ok(stream)
322        }
323        .boxed()
324    }
325    fn box_clone(&self) -> Box<dyn CompletionProvider> {
326        Box::new((*self).clone())
327    }
328}