completion.rs

  1use anyhow::{anyhow, Result};
  2use futures::{
  3    future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
  4    Stream, StreamExt,
  5};
  6use gpui::executor::Background;
  7use isahc::{http::StatusCode, Request, RequestExt};
  8use serde::{Deserialize, Serialize};
  9use std::{
 10    fmt::{self, Display},
 11    io,
 12    sync::Arc,
 13};
 14
 15use crate::completion::{CompletionProvider, CompletionRequest};
 16
 17pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
 18
 19#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 20#[serde(rename_all = "lowercase")]
 21pub enum Role {
 22    User,
 23    Assistant,
 24    System,
 25}
 26
 27impl Role {
 28    pub fn cycle(&mut self) {
 29        *self = match self {
 30            Role::User => Role::Assistant,
 31            Role::Assistant => Role::System,
 32            Role::System => Role::User,
 33        }
 34    }
 35}
 36
 37impl Display for Role {
 38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 39        match self {
 40            Role::User => write!(f, "User"),
 41            Role::Assistant => write!(f, "Assistant"),
 42            Role::System => write!(f, "System"),
 43        }
 44    }
 45}
 46
 47#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 48pub struct RequestMessage {
 49    pub role: Role,
 50    pub content: String,
 51}
 52
 53#[derive(Debug, Default, Serialize)]
 54pub struct OpenAIRequest {
 55    pub model: String,
 56    pub messages: Vec<RequestMessage>,
 57    pub stream: bool,
 58    pub stop: Vec<String>,
 59    pub temperature: f32,
 60}
 61
 62impl CompletionRequest for OpenAIRequest {
 63    fn data(&self) -> serde_json::Result<String> {
 64        serde_json::to_string(self)
 65    }
 66}
 67
 68#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 69pub struct ResponseMessage {
 70    pub role: Option<Role>,
 71    pub content: Option<String>,
 72}
 73
 74#[derive(Deserialize, Debug)]
 75pub struct OpenAIUsage {
 76    pub prompt_tokens: u32,
 77    pub completion_tokens: u32,
 78    pub total_tokens: u32,
 79}
 80
 81#[derive(Deserialize, Debug)]
 82pub struct ChatChoiceDelta {
 83    pub index: u32,
 84    pub delta: ResponseMessage,
 85    pub finish_reason: Option<String>,
 86}
 87
 88#[derive(Deserialize, Debug)]
 89pub struct OpenAIResponseStreamEvent {
 90    pub id: Option<String>,
 91    pub object: String,
 92    pub created: u32,
 93    pub model: String,
 94    pub choices: Vec<ChatChoiceDelta>,
 95    pub usage: Option<OpenAIUsage>,
 96}
 97
 98pub async fn stream_completion(
 99    api_key: String,
100    executor: Arc<Background>,
101    request: Box<dyn CompletionRequest>,
102) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
103    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
104
105    let json_data = request.data()?;
106    let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
107        .header("Content-Type", "application/json")
108        .header("Authorization", format!("Bearer {}", api_key))
109        .body(json_data)?
110        .send_async()
111        .await?;
112
113    let status = response.status();
114    if status == StatusCode::OK {
115        executor
116            .spawn(async move {
117                let mut lines = BufReader::new(response.body_mut()).lines();
118
119                fn parse_line(
120                    line: Result<String, io::Error>,
121                ) -> Result<Option<OpenAIResponseStreamEvent>> {
122                    if let Some(data) = line?.strip_prefix("data: ") {
123                        let event = serde_json::from_str(&data)?;
124                        Ok(Some(event))
125                    } else {
126                        Ok(None)
127                    }
128                }
129
130                while let Some(line) = lines.next().await {
131                    if let Some(event) = parse_line(line).transpose() {
132                        let done = event.as_ref().map_or(false, |event| {
133                            event
134                                .choices
135                                .last()
136                                .map_or(false, |choice| choice.finish_reason.is_some())
137                        });
138                        if tx.unbounded_send(event).is_err() {
139                            break;
140                        }
141
142                        if done {
143                            break;
144                        }
145                    }
146                }
147
148                anyhow::Ok(())
149            })
150            .detach();
151
152        Ok(rx)
153    } else {
154        let mut body = String::new();
155        response.body_mut().read_to_string(&mut body).await?;
156
157        #[derive(Deserialize)]
158        struct OpenAIResponse {
159            error: OpenAIError,
160        }
161
162        #[derive(Deserialize)]
163        struct OpenAIError {
164            message: String,
165        }
166
167        match serde_json::from_str::<OpenAIResponse>(&body) {
168            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
169                "Failed to connect to OpenAI API: {}",
170                response.error.message,
171            )),
172
173            _ => Err(anyhow!(
174                "Failed to connect to OpenAI API: {} {}",
175                response.status(),
176                body,
177            )),
178        }
179    }
180}
181
182pub struct OpenAICompletionProvider {
183    api_key: String,
184    executor: Arc<Background>,
185}
186
187impl OpenAICompletionProvider {
188    pub fn new(api_key: String, executor: Arc<Background>) -> Self {
189        Self { api_key, executor }
190    }
191}
192
193impl CompletionProvider for OpenAICompletionProvider {
194    fn complete(
195        &self,
196        prompt: Box<dyn CompletionRequest>,
197    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
198        let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
199        async move {
200            let response = request.await?;
201            let stream = response
202                .filter_map(|response| async move {
203                    match response {
204                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
205                        Err(error) => Some(Err(error)),
206                    }
207                })
208                .boxed();
209            Ok(stream)
210        }
211        .boxed()
212    }
213}