completion.rs

  1use anyhow::{anyhow, Result};
  2use futures::{
  3    future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
  4    Stream, StreamExt,
  5};
  6use gpui::executor::Background;
  7use isahc::{http::StatusCode, Request, RequestExt};
  8use serde::{Deserialize, Serialize};
  9use std::{
 10    fmt::{self, Display},
 11    io,
 12    sync::Arc,
 13};
 14
 15pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
 16
 17#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 18#[serde(rename_all = "lowercase")]
 19pub enum Role {
 20    User,
 21    Assistant,
 22    System,
 23}
 24
 25impl Role {
 26    pub fn cycle(&mut self) {
 27        *self = match self {
 28            Role::User => Role::Assistant,
 29            Role::Assistant => Role::System,
 30            Role::System => Role::User,
 31        }
 32    }
 33}
 34
 35impl Display for Role {
 36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 37        match self {
 38            Role::User => write!(f, "User"),
 39            Role::Assistant => write!(f, "Assistant"),
 40            Role::System => write!(f, "System"),
 41        }
 42    }
 43}
 44
 45#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 46pub struct RequestMessage {
 47    pub role: Role,
 48    pub content: String,
 49}
 50
 51#[derive(Debug, Default, Serialize)]
 52pub struct OpenAIRequest {
 53    pub model: String,
 54    pub messages: Vec<RequestMessage>,
 55    pub stream: bool,
 56}
 57
 58#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 59pub struct ResponseMessage {
 60    pub role: Option<Role>,
 61    pub content: Option<String>,
 62}
 63
 64#[derive(Deserialize, Debug)]
 65pub struct OpenAIUsage {
 66    pub prompt_tokens: u32,
 67    pub completion_tokens: u32,
 68    pub total_tokens: u32,
 69}
 70
 71#[derive(Deserialize, Debug)]
 72pub struct ChatChoiceDelta {
 73    pub index: u32,
 74    pub delta: ResponseMessage,
 75    pub finish_reason: Option<String>,
 76}
 77
 78#[derive(Deserialize, Debug)]
 79pub struct OpenAIResponseStreamEvent {
 80    pub id: Option<String>,
 81    pub object: String,
 82    pub created: u32,
 83    pub model: String,
 84    pub choices: Vec<ChatChoiceDelta>,
 85    pub usage: Option<OpenAIUsage>,
 86}
 87
 88pub async fn stream_completion(
 89    api_key: String,
 90    executor: Arc<Background>,
 91    mut request: OpenAIRequest,
 92) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
 93    request.stream = true;
 94
 95    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
 96
 97    let json_data = serde_json::to_string(&request)?;
 98    let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
 99        .header("Content-Type", "application/json")
100        .header("Authorization", format!("Bearer {}", api_key))
101        .body(json_data)?
102        .send_async()
103        .await?;
104
105    let status = response.status();
106    if status == StatusCode::OK {
107        executor
108            .spawn(async move {
109                let mut lines = BufReader::new(response.body_mut()).lines();
110
111                fn parse_line(
112                    line: Result<String, io::Error>,
113                ) -> Result<Option<OpenAIResponseStreamEvent>> {
114                    if let Some(data) = line?.strip_prefix("data: ") {
115                        let event = serde_json::from_str(&data)?;
116                        Ok(Some(event))
117                    } else {
118                        Ok(None)
119                    }
120                }
121
122                while let Some(line) = lines.next().await {
123                    if let Some(event) = parse_line(line).transpose() {
124                        let done = event.as_ref().map_or(false, |event| {
125                            event
126                                .choices
127                                .last()
128                                .map_or(false, |choice| choice.finish_reason.is_some())
129                        });
130                        if tx.unbounded_send(event).is_err() {
131                            break;
132                        }
133
134                        if done {
135                            break;
136                        }
137                    }
138                }
139
140                anyhow::Ok(())
141            })
142            .detach();
143
144        Ok(rx)
145    } else {
146        let mut body = String::new();
147        response.body_mut().read_to_string(&mut body).await?;
148
149        #[derive(Deserialize)]
150        struct OpenAIResponse {
151            error: OpenAIError,
152        }
153
154        #[derive(Deserialize)]
155        struct OpenAIError {
156            message: String,
157        }
158
159        match serde_json::from_str::<OpenAIResponse>(&body) {
160            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
161                "Failed to connect to OpenAI API: {}",
162                response.error.message,
163            )),
164
165            _ => Err(anyhow!(
166                "Failed to connect to OpenAI API: {} {}",
167                response.status(),
168                body,
169            )),
170        }
171    }
172}
173
174pub trait CompletionProvider {
175    fn complete(
176        &self,
177        prompt: OpenAIRequest,
178    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
179}
180
181pub struct OpenAICompletionProvider {
182    api_key: String,
183    executor: Arc<Background>,
184}
185
186impl OpenAICompletionProvider {
187    pub fn new(api_key: String, executor: Arc<Background>) -> Self {
188        Self { api_key, executor }
189    }
190}
191
192impl CompletionProvider for OpenAICompletionProvider {
193    fn complete(
194        &self,
195        prompt: OpenAIRequest,
196    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
197        let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
198        async move {
199            let response = request.await?;
200            let stream = response
201                .filter_map(|response| async move {
202                    match response {
203                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
204                        Err(error) => Some(Err(error)),
205                    }
206                })
207                .boxed();
208            Ok(stream)
209        }
210        .boxed()
211    }
212}