ai.rs

  1pub mod assistant;
  2mod assistant_settings;
  3mod streaming_diff;
  4
  5use anyhow::{anyhow, Result};
  6pub use assistant::AssistantPanel;
  7use assistant_settings::OpenAIModel;
  8use chrono::{DateTime, Local};
  9use collections::HashMap;
 10use fs::Fs;
 11use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
 12use gpui::{executor::Background, AppContext};
 13use isahc::{http::StatusCode, Request, RequestExt};
 14use regex::Regex;
 15use serde::{Deserialize, Serialize};
 16use std::{
 17    cmp::Reverse,
 18    ffi::OsStr,
 19    fmt::{self, Display},
 20    io,
 21    path::PathBuf,
 22    sync::Arc,
 23};
 24use util::paths::CONVERSATIONS_DIR;
 25
 26const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
 27
 28// Data types for chat completion requests
 29#[derive(Debug, Serialize)]
 30pub struct OpenAIRequest {
 31    model: String,
 32    messages: Vec<RequestMessage>,
 33    stream: bool,
 34}
 35
 36#[derive(
 37    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 38)]
 39struct MessageId(usize);
 40
 41#[derive(Clone, Debug, Serialize, Deserialize)]
 42struct MessageMetadata {
 43    role: Role,
 44    sent_at: DateTime<Local>,
 45    status: MessageStatus,
 46}
 47
 48#[derive(Clone, Debug, Serialize, Deserialize)]
 49enum MessageStatus {
 50    Pending,
 51    Done,
 52    Error(Arc<str>),
 53}
 54
 55#[derive(Serialize, Deserialize)]
 56struct SavedMessage {
 57    id: MessageId,
 58    start: usize,
 59}
 60
 61#[derive(Serialize, Deserialize)]
 62struct SavedConversation {
 63    zed: String,
 64    version: String,
 65    text: String,
 66    messages: Vec<SavedMessage>,
 67    message_metadata: HashMap<MessageId, MessageMetadata>,
 68    summary: String,
 69    model: OpenAIModel,
 70}
 71
 72impl SavedConversation {
 73    const VERSION: &'static str = "0.1.0";
 74}
 75
 76struct SavedConversationMetadata {
 77    title: String,
 78    path: PathBuf,
 79    mtime: chrono::DateTime<chrono::Local>,
 80}
 81
 82impl SavedConversationMetadata {
 83    pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
 84        fs.create_dir(&CONVERSATIONS_DIR).await?;
 85
 86        let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
 87        let mut conversations = Vec::<SavedConversationMetadata>::new();
 88        while let Some(path) = paths.next().await {
 89            let path = path?;
 90            if path.extension() != Some(OsStr::new("json")) {
 91                continue;
 92            }
 93
 94            let pattern = r" - \d+.zed.json$";
 95            let re = Regex::new(pattern).unwrap();
 96
 97            let metadata = fs.metadata(&path).await?;
 98            if let Some((file_name, metadata)) = path
 99                .file_name()
100                .and_then(|name| name.to_str())
101                .zip(metadata)
102            {
103                let title = re.replace(file_name, "");
104                conversations.push(Self {
105                    title: title.into_owned(),
106                    path,
107                    mtime: metadata.mtime.into(),
108                });
109            }
110        }
111        conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
112
113        Ok(conversations)
114    }
115}
116
117#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
118struct RequestMessage {
119    role: Role,
120    content: String,
121}
122
123#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
124pub struct ResponseMessage {
125    role: Option<Role>,
126    content: Option<String>,
127}
128
129#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
130#[serde(rename_all = "lowercase")]
131enum Role {
132    User,
133    Assistant,
134    System,
135}
136
137impl Role {
138    pub fn cycle(&mut self) {
139        *self = match self {
140            Role::User => Role::Assistant,
141            Role::Assistant => Role::System,
142            Role::System => Role::User,
143        }
144    }
145}
146
147impl Display for Role {
148    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
149        match self {
150            Role::User => write!(f, "User"),
151            Role::Assistant => write!(f, "Assistant"),
152            Role::System => write!(f, "System"),
153        }
154    }
155}
156
157#[derive(Deserialize, Debug)]
158pub struct OpenAIResponseStreamEvent {
159    pub id: Option<String>,
160    pub object: String,
161    pub created: u32,
162    pub model: String,
163    pub choices: Vec<ChatChoiceDelta>,
164    pub usage: Option<Usage>,
165}
166
167#[derive(Deserialize, Debug)]
168pub struct Usage {
169    pub prompt_tokens: u32,
170    pub completion_tokens: u32,
171    pub total_tokens: u32,
172}
173
174#[derive(Deserialize, Debug)]
175pub struct ChatChoiceDelta {
176    pub index: u32,
177    pub delta: ResponseMessage,
178    pub finish_reason: Option<String>,
179}
180
181#[derive(Deserialize, Debug)]
182struct OpenAIUsage {
183    prompt_tokens: u64,
184    completion_tokens: u64,
185    total_tokens: u64,
186}
187
188#[derive(Deserialize, Debug)]
189struct OpenAIChoice {
190    text: String,
191    index: u32,
192    logprobs: Option<serde_json::Value>,
193    finish_reason: Option<String>,
194}
195
196pub fn init(cx: &mut AppContext) {
197    assistant::init(cx);
198}
199
200pub async fn stream_completion(
201    api_key: String,
202    executor: Arc<Background>,
203    mut request: OpenAIRequest,
204) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
205    request.stream = true;
206
207    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
208
209    let json_data = serde_json::to_string(&request)?;
210    let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
211        .header("Content-Type", "application/json")
212        .header("Authorization", format!("Bearer {}", api_key))
213        .body(json_data)?
214        .send_async()
215        .await?;
216
217    let status = response.status();
218    if status == StatusCode::OK {
219        executor
220            .spawn(async move {
221                let mut lines = BufReader::new(response.body_mut()).lines();
222
223                fn parse_line(
224                    line: Result<String, io::Error>,
225                ) -> Result<Option<OpenAIResponseStreamEvent>> {
226                    if let Some(data) = line?.strip_prefix("data: ") {
227                        let event = serde_json::from_str(&data)?;
228                        Ok(Some(event))
229                    } else {
230                        Ok(None)
231                    }
232                }
233
234                while let Some(line) = lines.next().await {
235                    if let Some(event) = parse_line(line).transpose() {
236                        let done = event.as_ref().map_or(false, |event| {
237                            event
238                                .choices
239                                .last()
240                                .map_or(false, |choice| choice.finish_reason.is_some())
241                        });
242                        if tx.unbounded_send(event).is_err() {
243                            break;
244                        }
245
246                        if done {
247                            break;
248                        }
249                    }
250                }
251
252                anyhow::Ok(())
253            })
254            .detach();
255
256        Ok(rx)
257    } else {
258        let mut body = String::new();
259        response.body_mut().read_to_string(&mut body).await?;
260
261        #[derive(Deserialize)]
262        struct OpenAIResponse {
263            error: OpenAIError,
264        }
265
266        #[derive(Deserialize)]
267        struct OpenAIError {
268            message: String,
269        }
270
271        match serde_json::from_str::<OpenAIResponse>(&body) {
272            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
273                "Failed to connect to OpenAI API: {}",
274                response.error.message,
275            )),
276
277            _ => Err(anyhow!(
278                "Failed to connect to OpenAI API: {} {}",
279                response.status(),
280                body,
281            )),
282        }
283    }
284}
285
286#[cfg(test)]
287#[ctor::ctor]
288fn init_logger() {
289    if std::env::var("RUST_LOG").is_ok() {
290        env_logger::init();
291    }
292}