completion.rs

  1use anyhow::{anyhow, Result};
  2use futures::{
  3    future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
  4    Stream, StreamExt,
  5};
  6use gpui::{AppContext, BackgroundExecutor};
  7use isahc::{http::StatusCode, Request, RequestExt};
  8use parking_lot::RwLock;
  9use serde::{Deserialize, Serialize};
 10use std::{
 11    env,
 12    fmt::{self, Display},
 13    io,
 14    sync::Arc,
 15};
 16use util::ResultExt;
 17
 18use crate::{
 19    auth::{CredentialProvider, ProviderCredential},
 20    completion::{CompletionProvider, CompletionRequest},
 21    models::LanguageModel,
 22};
 23
 24use crate::providers::open_ai::{OpenAiLanguageModel, OPEN_AI_API_URL};
 25
 26#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 27#[serde(rename_all = "lowercase")]
 28pub enum Role {
 29    User,
 30    Assistant,
 31    System,
 32}
 33
 34impl Role {
 35    pub fn cycle(&mut self) {
 36        *self = match self {
 37            Role::User => Role::Assistant,
 38            Role::Assistant => Role::System,
 39            Role::System => Role::User,
 40        }
 41    }
 42}
 43
 44impl Display for Role {
 45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 46        match self {
 47            Role::User => write!(f, "User"),
 48            Role::Assistant => write!(f, "Assistant"),
 49            Role::System => write!(f, "System"),
 50        }
 51    }
 52}
 53
 54#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 55pub struct RequestMessage {
 56    pub role: Role,
 57    pub content: String,
 58}
 59
 60#[derive(Debug, Default, Serialize)]
 61pub struct OpenAiRequest {
 62    pub model: String,
 63    pub messages: Vec<RequestMessage>,
 64    pub stream: bool,
 65    pub stop: Vec<String>,
 66    pub temperature: f32,
 67}
 68
 69impl CompletionRequest for OpenAiRequest {
 70    fn data(&self) -> serde_json::Result<String> {
 71        serde_json::to_string(self)
 72    }
 73}
 74
 75#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 76pub struct ResponseMessage {
 77    pub role: Option<Role>,
 78    pub content: Option<String>,
 79}
 80
 81#[derive(Deserialize, Debug)]
 82pub struct OpenAiUsage {
 83    pub prompt_tokens: u32,
 84    pub completion_tokens: u32,
 85    pub total_tokens: u32,
 86}
 87
 88#[derive(Deserialize, Debug)]
 89pub struct ChatChoiceDelta {
 90    pub index: u32,
 91    pub delta: ResponseMessage,
 92    pub finish_reason: Option<String>,
 93}
 94
 95#[derive(Deserialize, Debug)]
 96pub struct OpenAiResponseStreamEvent {
 97    pub id: Option<String>,
 98    pub object: String,
 99    pub created: u32,
100    pub model: String,
101    pub choices: Vec<ChatChoiceDelta>,
102    pub usage: Option<OpenAiUsage>,
103}
104
105pub async fn stream_completion(
106    credential: ProviderCredential,
107    executor: BackgroundExecutor,
108    request: Box<dyn CompletionRequest>,
109) -> Result<impl Stream<Item = Result<OpenAiResponseStreamEvent>>> {
110    let api_key = match credential {
111        ProviderCredential::Credentials { api_key } => api_key,
112        _ => {
113            return Err(anyhow!("no credentials provider for completion"));
114        }
115    };
116
117    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAiResponseStreamEvent>>();
118
119    let json_data = request.data()?;
120    let mut response = Request::post(format!("{OPEN_AI_API_URL}/chat/completions"))
121        .header("Content-Type", "application/json")
122        .header("Authorization", format!("Bearer {}", api_key))
123        .body(json_data)?
124        .send_async()
125        .await?;
126
127    let status = response.status();
128    if status == StatusCode::OK {
129        executor
130            .spawn(async move {
131                let mut lines = BufReader::new(response.body_mut()).lines();
132
133                fn parse_line(
134                    line: Result<String, io::Error>,
135                ) -> Result<Option<OpenAiResponseStreamEvent>> {
136                    if let Some(data) = line?.strip_prefix("data: ") {
137                        let event = serde_json::from_str(data)?;
138                        Ok(Some(event))
139                    } else {
140                        Ok(None)
141                    }
142                }
143
144                while let Some(line) = lines.next().await {
145                    if let Some(event) = parse_line(line).transpose() {
146                        let done = event.as_ref().map_or(false, |event| {
147                            event
148                                .choices
149                                .last()
150                                .map_or(false, |choice| choice.finish_reason.is_some())
151                        });
152                        if tx.unbounded_send(event).is_err() {
153                            break;
154                        }
155
156                        if done {
157                            break;
158                        }
159                    }
160                }
161
162                anyhow::Ok(())
163            })
164            .detach();
165
166        Ok(rx)
167    } else {
168        let mut body = String::new();
169        response.body_mut().read_to_string(&mut body).await?;
170
171        #[derive(Deserialize)]
172        struct OpenAiResponse {
173            error: OpenAiError,
174        }
175
176        #[derive(Deserialize)]
177        struct OpenAiError {
178            message: String,
179        }
180
181        match serde_json::from_str::<OpenAiResponse>(&body) {
182            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
183                "Failed to connect to OpenAI API: {}",
184                response.error.message,
185            )),
186
187            _ => Err(anyhow!(
188                "Failed to connect to OpenAI API: {} {}",
189                response.status(),
190                body,
191            )),
192        }
193    }
194}
195
196#[derive(Clone)]
197pub struct OpenAiCompletionProvider {
198    model: OpenAiLanguageModel,
199    credential: Arc<RwLock<ProviderCredential>>,
200    executor: BackgroundExecutor,
201}
202
203impl OpenAiCompletionProvider {
204    pub async fn new(model_name: String, executor: BackgroundExecutor) -> Self {
205        let model = executor
206            .spawn(async move { OpenAiLanguageModel::load(&model_name) })
207            .await;
208        let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
209        Self {
210            model,
211            credential,
212            executor,
213        }
214    }
215}
216
217impl CredentialProvider for OpenAiCompletionProvider {
218    fn has_credentials(&self) -> bool {
219        match *self.credential.read() {
220            ProviderCredential::Credentials { .. } => true,
221            _ => false,
222        }
223    }
224
225    fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
226        let existing_credential = self.credential.read().clone();
227        let retrieved_credential = match existing_credential {
228            ProviderCredential::Credentials { .. } => {
229                return async move { existing_credential }.boxed()
230            }
231            _ => {
232                if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
233                    async move { ProviderCredential::Credentials { api_key } }.boxed()
234                } else {
235                    let credentials = cx.read_credentials(OPEN_AI_API_URL);
236                    async move {
237                        if let Some(Some((_, api_key))) = credentials.await.log_err() {
238                            if let Some(api_key) = String::from_utf8(api_key).log_err() {
239                                ProviderCredential::Credentials { api_key }
240                            } else {
241                                ProviderCredential::NoCredentials
242                            }
243                        } else {
244                            ProviderCredential::NoCredentials
245                        }
246                    }
247                    .boxed()
248                }
249            }
250        };
251
252        async move {
253            let retrieved_credential = retrieved_credential.await;
254            *self.credential.write() = retrieved_credential.clone();
255            retrieved_credential
256        }
257        .boxed()
258    }
259
260    fn save_credentials(
261        &self,
262        cx: &mut AppContext,
263        credential: ProviderCredential,
264    ) -> BoxFuture<()> {
265        *self.credential.write() = credential.clone();
266        let credential = credential.clone();
267        let write_credentials = match credential {
268            ProviderCredential::Credentials { api_key } => {
269                Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes()))
270            }
271            _ => None,
272        };
273
274        async move {
275            if let Some(write_credentials) = write_credentials {
276                write_credentials.await.log_err();
277            }
278        }
279        .boxed()
280    }
281
282    fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
283        *self.credential.write() = ProviderCredential::NoCredentials;
284        let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL);
285        async move {
286            delete_credentials.await.log_err();
287        }
288        .boxed()
289    }
290}
291
292impl CompletionProvider for OpenAiCompletionProvider {
293    fn base_model(&self) -> Box<dyn LanguageModel> {
294        let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
295        model
296    }
297    fn complete(
298        &self,
299        prompt: Box<dyn CompletionRequest>,
300    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
301        // Currently the CompletionRequest for OpenAI, includes a 'model' parameter
302        // This means that the model is determined by the CompletionRequest and not the CompletionProvider,
303        // which is currently model based, due to the language model.
304        // At some point in the future we should rectify this.
305        let credential = self.credential.read().clone();
306        let request = stream_completion(credential, self.executor.clone(), prompt);
307        async move {
308            let response = request.await?;
309            let stream = response
310                .filter_map(|response| async move {
311                    match response {
312                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
313                        Err(error) => Some(Err(error)),
314                    }
315                })
316                .boxed();
317            Ok(stream)
318        }
319        .boxed()
320    }
321    fn box_clone(&self) -> Box<dyn CompletionProvider> {
322        Box::new((*self).clone())
323    }
324}