completion.rs

  1use anyhow::{anyhow, Result};
  2use futures::{
  3    future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
  4    Stream, StreamExt,
  5};
  6use gpui::executor::Background;
  7use isahc::{http::StatusCode, Request, RequestExt};
  8use serde::{Deserialize, Serialize};
  9use std::{
 10    fmt::{self, Display},
 11    io,
 12    sync::Arc,
 13};
 14
 15pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
 16
 17#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 18#[serde(rename_all = "lowercase")]
 19pub enum Role {
 20    User,
 21    Assistant,
 22    System,
 23}
 24
 25impl Role {
 26    pub fn cycle(&mut self) {
 27        *self = match self {
 28            Role::User => Role::Assistant,
 29            Role::Assistant => Role::System,
 30            Role::System => Role::User,
 31        }
 32    }
 33}
 34
 35impl Display for Role {
 36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 37        match self {
 38            Role::User => write!(f, "User"),
 39            Role::Assistant => write!(f, "Assistant"),
 40            Role::System => write!(f, "System"),
 41        }
 42    }
 43}
 44
 45#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 46pub struct RequestMessage {
 47    pub role: Role,
 48    pub content: String,
 49}
 50
 51#[derive(Debug, Default, Serialize)]
 52pub struct OpenAIRequest {
 53    pub model: String,
 54    pub messages: Vec<RequestMessage>,
 55    pub stream: bool,
 56    pub stop: Vec<String>,
 57    pub temperature: f32,
 58}
 59
 60#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 61pub struct ResponseMessage {
 62    pub role: Option<Role>,
 63    pub content: Option<String>,
 64}
 65
 66#[derive(Deserialize, Debug)]
 67pub struct OpenAIUsage {
 68    pub prompt_tokens: u32,
 69    pub completion_tokens: u32,
 70    pub total_tokens: u32,
 71}
 72
 73#[derive(Deserialize, Debug)]
 74pub struct ChatChoiceDelta {
 75    pub index: u32,
 76    pub delta: ResponseMessage,
 77    pub finish_reason: Option<String>,
 78}
 79
 80#[derive(Deserialize, Debug)]
 81pub struct OpenAIResponseStreamEvent {
 82    pub id: Option<String>,
 83    pub object: String,
 84    pub created: u32,
 85    pub model: String,
 86    pub choices: Vec<ChatChoiceDelta>,
 87    pub usage: Option<OpenAIUsage>,
 88}
 89
 90pub async fn stream_completion(
 91    api_key: String,
 92    executor: Arc<Background>,
 93    mut request: OpenAIRequest,
 94) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
 95    request.stream = true;
 96
 97    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
 98
 99    let json_data = serde_json::to_string(&request)?;
100    let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
101        .header("Content-Type", "application/json")
102        .header("Authorization", format!("Bearer {}", api_key))
103        .body(json_data)?
104        .send_async()
105        .await?;
106
107    let status = response.status();
108    if status == StatusCode::OK {
109        executor
110            .spawn(async move {
111                let mut lines = BufReader::new(response.body_mut()).lines();
112
113                fn parse_line(
114                    line: Result<String, io::Error>,
115                ) -> Result<Option<OpenAIResponseStreamEvent>> {
116                    if let Some(data) = line?.strip_prefix("data: ") {
117                        let event = serde_json::from_str(&data)?;
118                        Ok(Some(event))
119                    } else {
120                        Ok(None)
121                    }
122                }
123
124                while let Some(line) = lines.next().await {
125                    if let Some(event) = parse_line(line).transpose() {
126                        let done = event.as_ref().map_or(false, |event| {
127                            event
128                                .choices
129                                .last()
130                                .map_or(false, |choice| choice.finish_reason.is_some())
131                        });
132                        if tx.unbounded_send(event).is_err() {
133                            break;
134                        }
135
136                        if done {
137                            break;
138                        }
139                    }
140                }
141
142                anyhow::Ok(())
143            })
144            .detach();
145
146        Ok(rx)
147    } else {
148        let mut body = String::new();
149        response.body_mut().read_to_string(&mut body).await?;
150
151        #[derive(Deserialize)]
152        struct OpenAIResponse {
153            error: OpenAIError,
154        }
155
156        #[derive(Deserialize)]
157        struct OpenAIError {
158            message: String,
159        }
160
161        match serde_json::from_str::<OpenAIResponse>(&body) {
162            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
163                "Failed to connect to OpenAI API: {}",
164                response.error.message,
165            )),
166
167            _ => Err(anyhow!(
168                "Failed to connect to OpenAI API: {} {}",
169                response.status(),
170                body,
171            )),
172        }
173    }
174}
175
176pub trait CompletionProvider {
177    fn complete(
178        &self,
179        prompt: OpenAIRequest,
180    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
181}
182
183pub struct OpenAICompletionProvider {
184    api_key: String,
185    executor: Arc<Background>,
186}
187
188impl OpenAICompletionProvider {
189    pub fn new(api_key: String, executor: Arc<Background>) -> Self {
190        Self { api_key, executor }
191    }
192}
193
194impl CompletionProvider for OpenAICompletionProvider {
195    fn complete(
196        &self,
197        prompt: OpenAIRequest,
198    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
199        let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
200        async move {
201            let response = request.await?;
202            let stream = response
203                .filter_map(|response| async move {
204                    match response {
205                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
206                        Err(error) => Some(Err(error)),
207                    }
208                })
209                .boxed();
210            Ok(stream)
211        }
212        .boxed()
213    }
214}