completion.rs

  1use anyhow::{anyhow, Result};
  2use async_trait::async_trait;
  3use futures::{
  4    future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
  5    Stream, StreamExt,
  6};
  7use gpui2::{AppContext, Executor};
  8use isahc::{http::StatusCode, Request, RequestExt};
  9use parking_lot::RwLock;
 10use serde::{Deserialize, Serialize};
 11use std::{
 12    env,
 13    fmt::{self, Display},
 14    io,
 15    sync::Arc,
 16};
 17use util::ResultExt;
 18
 19use crate::{
 20    auth::{CredentialProvider, ProviderCredential},
 21    completion::{CompletionProvider, CompletionRequest},
 22    models::LanguageModel,
 23};
 24
 25use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
 26
 27#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 28#[serde(rename_all = "lowercase")]
 29pub enum Role {
 30    User,
 31    Assistant,
 32    System,
 33}
 34
 35impl Role {
 36    pub fn cycle(&mut self) {
 37        *self = match self {
 38            Role::User => Role::Assistant,
 39            Role::Assistant => Role::System,
 40            Role::System => Role::User,
 41        }
 42    }
 43}
 44
 45impl Display for Role {
 46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 47        match self {
 48            Role::User => write!(f, "User"),
 49            Role::Assistant => write!(f, "Assistant"),
 50            Role::System => write!(f, "System"),
 51        }
 52    }
 53}
 54
 55#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 56pub struct RequestMessage {
 57    pub role: Role,
 58    pub content: String,
 59}
 60
 61#[derive(Debug, Default, Serialize)]
 62pub struct OpenAIRequest {
 63    pub model: String,
 64    pub messages: Vec<RequestMessage>,
 65    pub stream: bool,
 66    pub stop: Vec<String>,
 67    pub temperature: f32,
 68}
 69
 70impl CompletionRequest for OpenAIRequest {
 71    fn data(&self) -> serde_json::Result<String> {
 72        serde_json::to_string(self)
 73    }
 74}
 75
 76#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 77pub struct ResponseMessage {
 78    pub role: Option<Role>,
 79    pub content: Option<String>,
 80}
 81
 82#[derive(Deserialize, Debug)]
 83pub struct OpenAIUsage {
 84    pub prompt_tokens: u32,
 85    pub completion_tokens: u32,
 86    pub total_tokens: u32,
 87}
 88
 89#[derive(Deserialize, Debug)]
 90pub struct ChatChoiceDelta {
 91    pub index: u32,
 92    pub delta: ResponseMessage,
 93    pub finish_reason: Option<String>,
 94}
 95
 96#[derive(Deserialize, Debug)]
 97pub struct OpenAIResponseStreamEvent {
 98    pub id: Option<String>,
 99    pub object: String,
100    pub created: u32,
101    pub model: String,
102    pub choices: Vec<ChatChoiceDelta>,
103    pub usage: Option<OpenAIUsage>,
104}
105
106pub async fn stream_completion(
107    credential: ProviderCredential,
108    executor: Arc<Executor>,
109    request: Box<dyn CompletionRequest>,
110) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
111    let api_key = match credential {
112        ProviderCredential::Credentials { api_key } => api_key,
113        _ => {
114            return Err(anyhow!("no credentials provider for completion"));
115        }
116    };
117
118    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
119
120    let json_data = request.data()?;
121    let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
122        .header("Content-Type", "application/json")
123        .header("Authorization", format!("Bearer {}", api_key))
124        .body(json_data)?
125        .send_async()
126        .await?;
127
128    let status = response.status();
129    if status == StatusCode::OK {
130        executor
131            .spawn(async move {
132                let mut lines = BufReader::new(response.body_mut()).lines();
133
134                fn parse_line(
135                    line: Result<String, io::Error>,
136                ) -> Result<Option<OpenAIResponseStreamEvent>> {
137                    if let Some(data) = line?.strip_prefix("data: ") {
138                        let event = serde_json::from_str(&data)?;
139                        Ok(Some(event))
140                    } else {
141                        Ok(None)
142                    }
143                }
144
145                while let Some(line) = lines.next().await {
146                    if let Some(event) = parse_line(line).transpose() {
147                        let done = event.as_ref().map_or(false, |event| {
148                            event
149                                .choices
150                                .last()
151                                .map_or(false, |choice| choice.finish_reason.is_some())
152                        });
153                        if tx.unbounded_send(event).is_err() {
154                            break;
155                        }
156
157                        if done {
158                            break;
159                        }
160                    }
161                }
162
163                anyhow::Ok(())
164            })
165            .detach();
166
167        Ok(rx)
168    } else {
169        let mut body = String::new();
170        response.body_mut().read_to_string(&mut body).await?;
171
172        #[derive(Deserialize)]
173        struct OpenAIResponse {
174            error: OpenAIError,
175        }
176
177        #[derive(Deserialize)]
178        struct OpenAIError {
179            message: String,
180        }
181
182        match serde_json::from_str::<OpenAIResponse>(&body) {
183            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
184                "Failed to connect to OpenAI API: {}",
185                response.error.message,
186            )),
187
188            _ => Err(anyhow!(
189                "Failed to connect to OpenAI API: {} {}",
190                response.status(),
191                body,
192            )),
193        }
194    }
195}
196
197#[derive(Clone)]
198pub struct OpenAICompletionProvider {
199    model: OpenAILanguageModel,
200    credential: Arc<RwLock<ProviderCredential>>,
201    executor: Arc<Executor>,
202}
203
204impl OpenAICompletionProvider {
205    pub fn new(model_name: &str, executor: Arc<Executor>) -> Self {
206        let model = OpenAILanguageModel::load(model_name);
207        let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
208        Self {
209            model,
210            credential,
211            executor,
212        }
213    }
214}
215
216#[async_trait]
217impl CredentialProvider for OpenAICompletionProvider {
218    fn has_credentials(&self) -> bool {
219        match *self.credential.read() {
220            ProviderCredential::Credentials { .. } => true,
221            _ => false,
222        }
223    }
224    async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
225        let existing_credential = self.credential.read().clone();
226
227        let retrieved_credential = cx
228            .run_on_main(move |cx| match existing_credential {
229                ProviderCredential::Credentials { .. } => {
230                    return existing_credential.clone();
231                }
232                _ => {
233                    if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
234                        return ProviderCredential::Credentials { api_key };
235                    }
236
237                    if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err()
238                    {
239                        if let Some(api_key) = String::from_utf8(api_key).log_err() {
240                            return ProviderCredential::Credentials { api_key };
241                        } else {
242                            return ProviderCredential::NoCredentials;
243                        }
244                    } else {
245                        return ProviderCredential::NoCredentials;
246                    }
247                }
248            })
249            .await;
250
251        *self.credential.write() = retrieved_credential.clone();
252        retrieved_credential
253    }
254
255    async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
256        *self.credential.write() = credential.clone();
257        let credential = credential.clone();
258        cx.run_on_main(move |cx| match credential {
259            ProviderCredential::Credentials { api_key } => {
260                cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
261                    .log_err();
262            }
263            _ => {}
264        })
265        .await;
266    }
267    async fn delete_credentials(&self, cx: &mut AppContext) {
268        cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err())
269            .await;
270        *self.credential.write() = ProviderCredential::NoCredentials;
271    }
272}
273
274impl CompletionProvider for OpenAICompletionProvider {
275    fn base_model(&self) -> Box<dyn LanguageModel> {
276        let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
277        model
278    }
279    fn complete(
280        &self,
281        prompt: Box<dyn CompletionRequest>,
282    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
283        // Currently the CompletionRequest for OpenAI, includes a 'model' parameter
284        // This means that the model is determined by the CompletionRequest and not the CompletionProvider,
285        // which is currently model based, due to the langauge model.
286        // At some point in the future we should rectify this.
287        let credential = self.credential.read().clone();
288        let request = stream_completion(credential, self.executor.clone(), prompt);
289        async move {
290            let response = request.await?;
291            let stream = response
292                .filter_map(|response| async move {
293                    match response {
294                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
295                        Err(error) => Some(Err(error)),
296                    }
297                })
298                .boxed();
299            Ok(stream)
300        }
301        .boxed()
302    }
303    fn box_clone(&self) -> Box<dyn CompletionProvider> {
304        Box::new((*self).clone())
305    }
306}