completion.rs

  1use anyhow::{anyhow, Result};
  2use futures::{
  3    future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
  4    Stream, StreamExt,
  5};
  6use gpui::executor::Background;
  7use isahc::{http::StatusCode, Request, RequestExt};
  8use serde::{Deserialize, Serialize};
  9use std::{
 10    fmt::{self, Display},
 11    io,
 12    sync::Arc,
 13};
 14
 15use crate::{
 16    completion::{CompletionProvider, CompletionRequest},
 17    models::LanguageModel,
 18};
 19
 20use super::OpenAILanguageModel;
 21
 22pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
 23
 24#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 25#[serde(rename_all = "lowercase")]
 26pub enum Role {
 27    User,
 28    Assistant,
 29    System,
 30}
 31
 32impl Role {
 33    pub fn cycle(&mut self) {
 34        *self = match self {
 35            Role::User => Role::Assistant,
 36            Role::Assistant => Role::System,
 37            Role::System => Role::User,
 38        }
 39    }
 40}
 41
 42impl Display for Role {
 43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 44        match self {
 45            Role::User => write!(f, "User"),
 46            Role::Assistant => write!(f, "Assistant"),
 47            Role::System => write!(f, "System"),
 48        }
 49    }
 50}
 51
 52#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 53pub struct RequestMessage {
 54    pub role: Role,
 55    pub content: String,
 56}
 57
 58#[derive(Debug, Default, Serialize)]
 59pub struct OpenAIRequest {
 60    pub model: String,
 61    pub messages: Vec<RequestMessage>,
 62    pub stream: bool,
 63    pub stop: Vec<String>,
 64    pub temperature: f32,
 65}
 66
 67impl CompletionRequest for OpenAIRequest {
 68    fn data(&self) -> serde_json::Result<String> {
 69        serde_json::to_string(self)
 70    }
 71}
 72
 73#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 74pub struct ResponseMessage {
 75    pub role: Option<Role>,
 76    pub content: Option<String>,
 77}
 78
 79#[derive(Deserialize, Debug)]
 80pub struct OpenAIUsage {
 81    pub prompt_tokens: u32,
 82    pub completion_tokens: u32,
 83    pub total_tokens: u32,
 84}
 85
 86#[derive(Deserialize, Debug)]
 87pub struct ChatChoiceDelta {
 88    pub index: u32,
 89    pub delta: ResponseMessage,
 90    pub finish_reason: Option<String>,
 91}
 92
 93#[derive(Deserialize, Debug)]
 94pub struct OpenAIResponseStreamEvent {
 95    pub id: Option<String>,
 96    pub object: String,
 97    pub created: u32,
 98    pub model: String,
 99    pub choices: Vec<ChatChoiceDelta>,
100    pub usage: Option<OpenAIUsage>,
101}
102
103pub async fn stream_completion(
104    api_key: String,
105    executor: Arc<Background>,
106    request: Box<dyn CompletionRequest>,
107) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
108    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
109
110    let json_data = request.data()?;
111    let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
112        .header("Content-Type", "application/json")
113        .header("Authorization", format!("Bearer {}", api_key))
114        .body(json_data)?
115        .send_async()
116        .await?;
117
118    let status = response.status();
119    if status == StatusCode::OK {
120        executor
121            .spawn(async move {
122                let mut lines = BufReader::new(response.body_mut()).lines();
123
124                fn parse_line(
125                    line: Result<String, io::Error>,
126                ) -> Result<Option<OpenAIResponseStreamEvent>> {
127                    if let Some(data) = line?.strip_prefix("data: ") {
128                        let event = serde_json::from_str(&data)?;
129                        Ok(Some(event))
130                    } else {
131                        Ok(None)
132                    }
133                }
134
135                while let Some(line) = lines.next().await {
136                    if let Some(event) = parse_line(line).transpose() {
137                        let done = event.as_ref().map_or(false, |event| {
138                            event
139                                .choices
140                                .last()
141                                .map_or(false, |choice| choice.finish_reason.is_some())
142                        });
143                        if tx.unbounded_send(event).is_err() {
144                            break;
145                        }
146
147                        if done {
148                            break;
149                        }
150                    }
151                }
152
153                anyhow::Ok(())
154            })
155            .detach();
156
157        Ok(rx)
158    } else {
159        let mut body = String::new();
160        response.body_mut().read_to_string(&mut body).await?;
161
162        #[derive(Deserialize)]
163        struct OpenAIResponse {
164            error: OpenAIError,
165        }
166
167        #[derive(Deserialize)]
168        struct OpenAIError {
169            message: String,
170        }
171
172        match serde_json::from_str::<OpenAIResponse>(&body) {
173            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
174                "Failed to connect to OpenAI API: {}",
175                response.error.message,
176            )),
177
178            _ => Err(anyhow!(
179                "Failed to connect to OpenAI API: {} {}",
180                response.status(),
181                body,
182            )),
183        }
184    }
185}
186
187pub struct OpenAICompletionProvider {
188    model: OpenAILanguageModel,
189    api_key: String,
190    executor: Arc<Background>,
191}
192
193impl OpenAICompletionProvider {
194    pub fn new(model_name: &str, api_key: String, executor: Arc<Background>) -> Self {
195        let model = OpenAILanguageModel::load(model_name);
196        Self {
197            model,
198            api_key,
199            executor,
200        }
201    }
202}
203
204impl CompletionProvider for OpenAICompletionProvider {
205    fn base_model(&self) -> Box<dyn LanguageModel> {
206        let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
207        model
208    }
209    fn complete(
210        &self,
211        prompt: Box<dyn CompletionRequest>,
212    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
213        let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
214        async move {
215            let response = request.await?;
216            let stream = response
217                .filter_map(|response| async move {
218                    match response {
219                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
220                        Err(error) => Some(Err(error)),
221                    }
222                })
223                .boxed();
224            Ok(stream)
225        }
226        .boxed()
227    }
228}