completion.rs

  1use anyhow::{anyhow, Result};
  2use futures::{
  3    future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
  4    Stream, StreamExt,
  5};
  6use gpui::{executor::Background, AppContext};
  7use isahc::{http::StatusCode, Request, RequestExt};
  8use parking_lot::RwLock;
  9use serde::{Deserialize, Serialize};
 10use std::{
 11    env,
 12    fmt::{self, Display},
 13    io,
 14    sync::Arc,
 15};
 16use util::ResultExt;
 17
 18use crate::{
 19    auth::{CredentialProvider, ProviderCredential},
 20    completion::{CompletionProvider, CompletionRequest},
 21    models::LanguageModel,
 22};
 23
 24use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
 25
 26#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 27#[serde(rename_all = "lowercase")]
 28pub enum Role {
 29    User,
 30    Assistant,
 31    System,
 32}
 33
 34impl Role {
 35    pub fn cycle(&mut self) {
 36        *self = match self {
 37            Role::User => Role::Assistant,
 38            Role::Assistant => Role::System,
 39            Role::System => Role::User,
 40        }
 41    }
 42}
 43
 44impl Display for Role {
 45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 46        match self {
 47            Role::User => write!(f, "User"),
 48            Role::Assistant => write!(f, "Assistant"),
 49            Role::System => write!(f, "System"),
 50        }
 51    }
 52}
 53
 54#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 55pub struct RequestMessage {
 56    pub role: Role,
 57    pub content: String,
 58}
 59
 60#[derive(Debug, Default, Serialize)]
 61pub struct OpenAIRequest {
 62    pub model: String,
 63    pub messages: Vec<RequestMessage>,
 64    pub stream: bool,
 65    pub stop: Vec<String>,
 66    pub temperature: f32,
 67}
 68
 69impl CompletionRequest for OpenAIRequest {
 70    fn data(&self) -> serde_json::Result<String> {
 71        serde_json::to_string(self)
 72    }
 73}
 74
 75#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 76pub struct ResponseMessage {
 77    pub role: Option<Role>,
 78    pub content: Option<String>,
 79}
 80
 81#[derive(Deserialize, Debug)]
 82pub struct OpenAIUsage {
 83    pub prompt_tokens: u32,
 84    pub completion_tokens: u32,
 85    pub total_tokens: u32,
 86}
 87
 88#[derive(Deserialize, Debug)]
 89pub struct ChatChoiceDelta {
 90    pub index: u32,
 91    pub delta: ResponseMessage,
 92    pub finish_reason: Option<String>,
 93}
 94
 95#[derive(Deserialize, Debug)]
 96pub struct OpenAIResponseStreamEvent {
 97    pub id: Option<String>,
 98    pub object: String,
 99    pub created: u32,
100    pub model: String,
101    pub choices: Vec<ChatChoiceDelta>,
102    pub usage: Option<OpenAIUsage>,
103}
104
105pub async fn stream_completion(
106    credential: ProviderCredential,
107    executor: Arc<Background>,
108    request: Box<dyn CompletionRequest>,
109) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
110    let api_key = match credential {
111        ProviderCredential::Credentials { api_key } => api_key,
112        _ => {
113            return Err(anyhow!("no credentials provider for completion"));
114        }
115    };
116
117    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
118
119    let json_data = request.data()?;
120    let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
121        .header("Content-Type", "application/json")
122        .header("Authorization", format!("Bearer {}", api_key))
123        .body(json_data)?
124        .send_async()
125        .await?;
126
127    let status = response.status();
128    if status == StatusCode::OK {
129        executor
130            .spawn(async move {
131                let mut lines = BufReader::new(response.body_mut()).lines();
132
133                fn parse_line(
134                    line: Result<String, io::Error>,
135                ) -> Result<Option<OpenAIResponseStreamEvent>> {
136                    if let Some(data) = line?.strip_prefix("data: ") {
137                        let event = serde_json::from_str(&data)?;
138                        Ok(Some(event))
139                    } else {
140                        Ok(None)
141                    }
142                }
143
144                while let Some(line) = lines.next().await {
145                    if let Some(event) = parse_line(line).transpose() {
146                        let done = event.as_ref().map_or(false, |event| {
147                            event
148                                .choices
149                                .last()
150                                .map_or(false, |choice| choice.finish_reason.is_some())
151                        });
152                        if tx.unbounded_send(event).is_err() {
153                            break;
154                        }
155
156                        if done {
157                            break;
158                        }
159                    }
160                }
161
162                anyhow::Ok(())
163            })
164            .detach();
165
166        Ok(rx)
167    } else {
168        let mut body = String::new();
169        response.body_mut().read_to_string(&mut body).await?;
170
171        #[derive(Deserialize)]
172        struct OpenAIResponse {
173            error: OpenAIError,
174        }
175
176        #[derive(Deserialize)]
177        struct OpenAIError {
178            message: String,
179        }
180
181        match serde_json::from_str::<OpenAIResponse>(&body) {
182            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
183                "Failed to connect to OpenAI API: {}",
184                response.error.message,
185            )),
186
187            _ => Err(anyhow!(
188                "Failed to connect to OpenAI API: {} {}",
189                response.status(),
190                body,
191            )),
192        }
193    }
194}
195
196#[derive(Clone)]
197pub struct OpenAICompletionProvider {
198    model: OpenAILanguageModel,
199    credential: Arc<RwLock<ProviderCredential>>,
200    executor: Arc<Background>,
201}
202
203impl OpenAICompletionProvider {
204    pub fn new(model_name: &str, executor: Arc<Background>) -> Self {
205        let model = OpenAILanguageModel::load(model_name);
206        let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
207        Self {
208            model,
209            credential,
210            executor,
211        }
212    }
213}
214
215impl CredentialProvider for OpenAICompletionProvider {
216    fn has_credentials(&self) -> bool {
217        match *self.credential.read() {
218            ProviderCredential::Credentials { .. } => true,
219            _ => false,
220        }
221    }
222    fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
223        let mut credential = self.credential.write();
224        match *credential {
225            ProviderCredential::Credentials { .. } => {
226                return credential.clone();
227            }
228            _ => {
229                if let Ok(api_key) = env::var("OPENAI_API_KEY") {
230                    *credential = ProviderCredential::Credentials { api_key };
231                } else if let Some((_, api_key)) = cx
232                    .platform()
233                    .read_credentials(OPENAI_API_URL)
234                    .log_err()
235                    .flatten()
236                {
237                    if let Some(api_key) = String::from_utf8(api_key).log_err() {
238                        *credential = ProviderCredential::Credentials { api_key };
239                    }
240                } else {
241                };
242            }
243        }
244
245        credential.clone()
246    }
247
248    fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
249        match credential.clone() {
250            ProviderCredential::Credentials { api_key } => {
251                cx.platform()
252                    .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
253                    .log_err();
254            }
255            _ => {}
256        }
257
258        *self.credential.write() = credential;
259    }
260    fn delete_credentials(&self, cx: &AppContext) {
261        cx.platform().delete_credentials(OPENAI_API_URL).log_err();
262        *self.credential.write() = ProviderCredential::NoCredentials;
263    }
264}
265
266impl CompletionProvider for OpenAICompletionProvider {
267    fn base_model(&self) -> Box<dyn LanguageModel> {
268        let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
269        model
270    }
271    fn complete(
272        &self,
273        prompt: Box<dyn CompletionRequest>,
274    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
275        // Currently the CompletionRequest for OpenAI, includes a 'model' parameter
276        // This means that the model is determined by the CompletionRequest and not the CompletionProvider,
277        // which is currently model based, due to the langauge model.
278        // At some point in the future we should rectify this.
279        let credential = self.credential.read().clone();
280        let request = stream_completion(credential, self.executor.clone(), prompt);
281        async move {
282            let response = request.await?;
283            let stream = response
284                .filter_map(|response| async move {
285                    match response {
286                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
287                        Err(error) => Some(Err(error)),
288                    }
289                })
290                .boxed();
291            Ok(stream)
292        }
293        .boxed()
294    }
295    fn box_clone(&self) -> Box<dyn CompletionProvider> {
296        Box::new((*self).clone())
297    }
298}