completion.rs

  1use std::{
  2    env,
  3    fmt::{self, Display},
  4    io,
  5    sync::Arc,
  6};
  7
  8use anyhow::{anyhow, Result};
  9use futures::{
 10    future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
 11    Stream, StreamExt,
 12};
 13use gpui::{AppContext, BackgroundExecutor};
 14use isahc::{http::StatusCode, Request, RequestExt};
 15use parking_lot::RwLock;
 16use schemars::JsonSchema;
 17use serde::{Deserialize, Serialize};
 18use util::ResultExt;
 19
 20use crate::providers::open_ai::{OpenAiLanguageModel, OPEN_AI_API_URL};
 21use crate::{
 22    auth::{CredentialProvider, ProviderCredential},
 23    completion::{CompletionProvider, CompletionRequest},
 24    models::LanguageModel,
 25};
 26
 27#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 28#[serde(rename_all = "lowercase")]
 29pub enum Role {
 30    User,
 31    Assistant,
 32    System,
 33}
 34
 35impl Role {
 36    pub fn cycle(&mut self) {
 37        *self = match self {
 38            Role::User => Role::Assistant,
 39            Role::Assistant => Role::System,
 40            Role::System => Role::User,
 41        }
 42    }
 43}
 44
 45impl Display for Role {
 46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 47        match self {
 48            Role::User => write!(f, "User"),
 49            Role::Assistant => write!(f, "Assistant"),
 50            Role::System => write!(f, "System"),
 51        }
 52    }
 53}
 54
 55#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 56pub struct RequestMessage {
 57    pub role: Role,
 58    pub content: String,
 59}
 60
 61#[derive(Debug, Default, Serialize)]
 62pub struct OpenAiRequest {
 63    pub model: String,
 64    pub messages: Vec<RequestMessage>,
 65    pub stream: bool,
 66    pub stop: Vec<String>,
 67    pub temperature: f32,
 68}
 69
 70impl CompletionRequest for OpenAiRequest {
 71    fn data(&self) -> serde_json::Result<String> {
 72        serde_json::to_string(self)
 73    }
 74}
 75
 76#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 77pub struct ResponseMessage {
 78    pub role: Option<Role>,
 79    pub content: Option<String>,
 80}
 81
 82#[derive(Deserialize, Debug)]
 83pub struct OpenAiUsage {
 84    pub prompt_tokens: u32,
 85    pub completion_tokens: u32,
 86    pub total_tokens: u32,
 87}
 88
 89#[derive(Deserialize, Debug)]
 90pub struct ChatChoiceDelta {
 91    pub index: u32,
 92    pub delta: ResponseMessage,
 93    pub finish_reason: Option<String>,
 94}
 95
 96#[derive(Deserialize, Debug)]
 97pub struct OpenAiResponseStreamEvent {
 98    pub id: Option<String>,
 99    pub object: String,
100    pub created: u32,
101    pub model: String,
102    pub choices: Vec<ChatChoiceDelta>,
103    pub usage: Option<OpenAiUsage>,
104}
105
106async fn stream_completion(
107    api_url: String,
108    kind: OpenAiCompletionProviderKind,
109    credential: ProviderCredential,
110    executor: BackgroundExecutor,
111    request: Box<dyn CompletionRequest>,
112) -> Result<impl Stream<Item = Result<OpenAiResponseStreamEvent>>> {
113    let api_key = match credential {
114        ProviderCredential::Credentials { api_key } => api_key,
115        _ => {
116            return Err(anyhow!("no credentials provider for completion"));
117        }
118    };
119
120    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAiResponseStreamEvent>>();
121
122    let (auth_header_name, auth_header_value) = kind.auth_header(api_key);
123    let json_data = request.data()?;
124    let mut response = Request::post(kind.completions_endpoint_url(&api_url))
125        .header("Content-Type", "application/json")
126        .header(auth_header_name, auth_header_value)
127        .body(json_data)?
128        .send_async()
129        .await?;
130
131    let status = response.status();
132    if status == StatusCode::OK {
133        executor
134            .spawn(async move {
135                let mut lines = BufReader::new(response.body_mut()).lines();
136
137                fn parse_line(
138                    line: Result<String, io::Error>,
139                ) -> Result<Option<OpenAiResponseStreamEvent>> {
140                    if let Some(data) = line?.strip_prefix("data: ") {
141                        let event = serde_json::from_str(data)?;
142                        Ok(Some(event))
143                    } else {
144                        Ok(None)
145                    }
146                }
147
148                while let Some(line) = lines.next().await {
149                    if let Some(event) = parse_line(line).transpose() {
150                        let done = event.as_ref().map_or(false, |event| {
151                            event
152                                .choices
153                                .last()
154                                .map_or(false, |choice| choice.finish_reason.is_some())
155                        });
156                        if tx.unbounded_send(event).is_err() {
157                            break;
158                        }
159
160                        if done {
161                            break;
162                        }
163                    }
164                }
165
166                anyhow::Ok(())
167            })
168            .detach();
169
170        Ok(rx)
171    } else {
172        let mut body = String::new();
173        response.body_mut().read_to_string(&mut body).await?;
174
175        #[derive(Deserialize)]
176        struct OpenAiResponse {
177            error: OpenAiError,
178        }
179
180        #[derive(Deserialize)]
181        struct OpenAiError {
182            message: String,
183        }
184
185        match serde_json::from_str::<OpenAiResponse>(&body) {
186            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
187                "Failed to connect to OpenAI API: {}",
188                response.error.message,
189            )),
190
191            _ => Err(anyhow!(
192                "Failed to connect to OpenAI API: {} {}",
193                response.status(),
194                body,
195            )),
196        }
197    }
198}
199
200#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
201pub enum AzureOpenAiApiVersion {
202    /// Retiring April 2, 2024.
203    #[serde(rename = "2023-03-15-preview")]
204    V2023_03_15Preview,
205    #[serde(rename = "2023-05-15")]
206    V2023_05_15,
207    /// Retiring April 2, 2024.
208    #[serde(rename = "2023-06-01-preview")]
209    V2023_06_01Preview,
210    /// Retiring April 2, 2024.
211    #[serde(rename = "2023-07-01-preview")]
212    V2023_07_01Preview,
213    /// Retiring April 2, 2024.
214    #[serde(rename = "2023-08-01-preview")]
215    V2023_08_01Preview,
216    /// Retiring April 2, 2024.
217    #[serde(rename = "2023-09-01-preview")]
218    V2023_09_01Preview,
219    #[serde(rename = "2023-12-01-preview")]
220    V2023_12_01Preview,
221    #[serde(rename = "2024-02-15-preview")]
222    V2024_02_15Preview,
223}
224
225impl fmt::Display for AzureOpenAiApiVersion {
226    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
227        write!(
228            f,
229            "{}",
230            match self {
231                Self::V2023_03_15Preview => "2023-03-15-preview",
232                Self::V2023_05_15 => "2023-05-15",
233                Self::V2023_06_01Preview => "2023-06-01-preview",
234                Self::V2023_07_01Preview => "2023-07-01-preview",
235                Self::V2023_08_01Preview => "2023-08-01-preview",
236                Self::V2023_09_01Preview => "2023-09-01-preview",
237                Self::V2023_12_01Preview => "2023-12-01-preview",
238                Self::V2024_02_15Preview => "2024-02-15-preview",
239            }
240        )
241    }
242}
243
244#[derive(Clone)]
245pub enum OpenAiCompletionProviderKind {
246    OpenAi,
247    AzureOpenAi {
248        deployment_id: String,
249        api_version: AzureOpenAiApiVersion,
250    },
251}
252
253impl OpenAiCompletionProviderKind {
254    /// Returns the chat completion endpoint URL for this [`OpenAiCompletionProviderKind`].
255    fn completions_endpoint_url(&self, api_url: &str) -> String {
256        match self {
257            Self::OpenAi => {
258                // https://platform.openai.com/docs/api-reference/chat/create
259                format!("{api_url}/chat/completions")
260            }
261            Self::AzureOpenAi {
262                deployment_id,
263                api_version,
264            } => {
265                // https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
266                format!("{api_url}/openai/deployments/{deployment_id}/chat/completions?api-version={api_version}")
267            }
268        }
269    }
270
271    /// Returns the authentication header for this [`OpenAiCompletionProviderKind`].
272    fn auth_header(&self, api_key: String) -> (&'static str, String) {
273        match self {
274            Self::OpenAi => ("Authorization", format!("Bearer {api_key}")),
275            Self::AzureOpenAi { .. } => ("Api-Key", api_key),
276        }
277    }
278}
279
280#[derive(Clone)]
281pub struct OpenAiCompletionProvider {
282    api_url: String,
283    kind: OpenAiCompletionProviderKind,
284    model: OpenAiLanguageModel,
285    credential: Arc<RwLock<ProviderCredential>>,
286    executor: BackgroundExecutor,
287}
288
289impl OpenAiCompletionProvider {
290    pub async fn new(
291        api_url: String,
292        kind: OpenAiCompletionProviderKind,
293        model_name: String,
294        executor: BackgroundExecutor,
295    ) -> Self {
296        let model = executor
297            .spawn(async move { OpenAiLanguageModel::load(&model_name) })
298            .await;
299        let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
300        Self {
301            api_url,
302            kind,
303            model,
304            credential,
305            executor,
306        }
307    }
308}
309
310impl CredentialProvider for OpenAiCompletionProvider {
311    fn has_credentials(&self) -> bool {
312        match *self.credential.read() {
313            ProviderCredential::Credentials { .. } => true,
314            _ => false,
315        }
316    }
317
318    fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
319        let existing_credential = self.credential.read().clone();
320        let retrieved_credential = match existing_credential {
321            ProviderCredential::Credentials { .. } => {
322                return async move { existing_credential }.boxed()
323            }
324            _ => {
325                if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
326                    async move { ProviderCredential::Credentials { api_key } }.boxed()
327                } else {
328                    let credentials = cx.read_credentials(OPEN_AI_API_URL);
329                    async move {
330                        if let Some(Some((_, api_key))) = credentials.await.log_err() {
331                            if let Some(api_key) = String::from_utf8(api_key).log_err() {
332                                ProviderCredential::Credentials { api_key }
333                            } else {
334                                ProviderCredential::NoCredentials
335                            }
336                        } else {
337                            ProviderCredential::NoCredentials
338                        }
339                    }
340                    .boxed()
341                }
342            }
343        };
344
345        async move {
346            let retrieved_credential = retrieved_credential.await;
347            *self.credential.write() = retrieved_credential.clone();
348            retrieved_credential
349        }
350        .boxed()
351    }
352
353    fn save_credentials(
354        &self,
355        cx: &mut AppContext,
356        credential: ProviderCredential,
357    ) -> BoxFuture<()> {
358        *self.credential.write() = credential.clone();
359        let credential = credential.clone();
360        let write_credentials = match credential {
361            ProviderCredential::Credentials { api_key } => {
362                Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes()))
363            }
364            _ => None,
365        };
366
367        async move {
368            if let Some(write_credentials) = write_credentials {
369                write_credentials.await.log_err();
370            }
371        }
372        .boxed()
373    }
374
375    fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
376        *self.credential.write() = ProviderCredential::NoCredentials;
377        let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL);
378        async move {
379            delete_credentials.await.log_err();
380        }
381        .boxed()
382    }
383}
384
385impl CompletionProvider for OpenAiCompletionProvider {
386    fn base_model(&self) -> Box<dyn LanguageModel> {
387        let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
388        model
389    }
390
391    fn complete(
392        &self,
393        prompt: Box<dyn CompletionRequest>,
394    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
395        // Currently the CompletionRequest for OpenAI, includes a 'model' parameter
396        // This means that the model is determined by the CompletionRequest and not the CompletionProvider,
397        // which is currently model based, due to the language model.
398        // At some point in the future we should rectify this.
399        let credential = self.credential.read().clone();
400        let api_url = self.api_url.clone();
401        let kind = self.kind.clone();
402        let request = stream_completion(api_url, kind, credential, self.executor.clone(), prompt);
403        async move {
404            let response = request.await?;
405            let stream = response
406                .filter_map(|response| async move {
407                    match response {
408                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
409                        Err(error) => Some(Err(error)),
410                    }
411                })
412                .boxed();
413            Ok(stream)
414        }
415        .boxed()
416    }
417
418    fn box_clone(&self) -> Box<dyn CompletionProvider> {
419        Box::new((*self).clone())
420    }
421}