ai.rs

  1pub mod assistant;
  2mod assistant_settings;
  3mod streaming_diff;
  4
  5use anyhow::{anyhow, Result};
  6pub use assistant::AssistantPanel;
  7use chrono::{DateTime, Local};
  8use collections::HashMap;
  9use fs::Fs;
 10use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
 11use gpui::{executor::Background, AppContext};
 12use isahc::{http::StatusCode, Request, RequestExt};
 13use regex::Regex;
 14use serde::{Deserialize, Serialize};
 15use std::{
 16    cmp::Reverse,
 17    ffi::OsStr,
 18    fmt::{self, Display},
 19    io,
 20    path::PathBuf,
 21    sync::Arc,
 22};
 23use util::paths::CONVERSATIONS_DIR;
 24
 25const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
 26
 27// Data types for chat completion requests
 28#[derive(Debug, Serialize)]
 29pub struct OpenAIRequest {
 30    model: String,
 31    messages: Vec<RequestMessage>,
 32    stream: bool,
 33}
 34
 35#[derive(
 36    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 37)]
 38struct MessageId(usize);
 39
 40#[derive(Clone, Debug, Serialize, Deserialize)]
 41struct MessageMetadata {
 42    role: Role,
 43    sent_at: DateTime<Local>,
 44    status: MessageStatus,
 45}
 46
 47#[derive(Clone, Debug, Serialize, Deserialize)]
 48enum MessageStatus {
 49    Pending,
 50    Done,
 51    Error(Arc<str>),
 52}
 53
 54#[derive(Serialize, Deserialize)]
 55struct SavedMessage {
 56    id: MessageId,
 57    start: usize,
 58}
 59
 60#[derive(Serialize, Deserialize)]
 61struct SavedConversation {
 62    zed: String,
 63    version: String,
 64    text: String,
 65    messages: Vec<SavedMessage>,
 66    message_metadata: HashMap<MessageId, MessageMetadata>,
 67    summary: String,
 68    model: String,
 69}
 70
 71impl SavedConversation {
 72    const VERSION: &'static str = "0.1.0";
 73}
 74
 75struct SavedConversationMetadata {
 76    title: String,
 77    path: PathBuf,
 78    mtime: chrono::DateTime<chrono::Local>,
 79}
 80
 81impl SavedConversationMetadata {
 82    pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
 83        fs.create_dir(&CONVERSATIONS_DIR).await?;
 84
 85        let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
 86        let mut conversations = Vec::<SavedConversationMetadata>::new();
 87        while let Some(path) = paths.next().await {
 88            let path = path?;
 89            if path.extension() != Some(OsStr::new("json")) {
 90                continue;
 91            }
 92
 93            let pattern = r" - \d+.zed.json$";
 94            let re = Regex::new(pattern).unwrap();
 95
 96            let metadata = fs.metadata(&path).await?;
 97            if let Some((file_name, metadata)) = path
 98                .file_name()
 99                .and_then(|name| name.to_str())
100                .zip(metadata)
101            {
102                let title = re.replace(file_name, "");
103                conversations.push(Self {
104                    title: title.into_owned(),
105                    path,
106                    mtime: metadata.mtime.into(),
107                });
108            }
109        }
110        conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
111
112        Ok(conversations)
113    }
114}
115
116#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
117struct RequestMessage {
118    role: Role,
119    content: String,
120}
121
122#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
123pub struct ResponseMessage {
124    role: Option<Role>,
125    content: Option<String>,
126}
127
128#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
129#[serde(rename_all = "lowercase")]
130enum Role {
131    User,
132    Assistant,
133    System,
134}
135
136impl Role {
137    pub fn cycle(&mut self) {
138        *self = match self {
139            Role::User => Role::Assistant,
140            Role::Assistant => Role::System,
141            Role::System => Role::User,
142        }
143    }
144}
145
146impl Display for Role {
147    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
148        match self {
149            Role::User => write!(f, "User"),
150            Role::Assistant => write!(f, "Assistant"),
151            Role::System => write!(f, "System"),
152        }
153    }
154}
155
156#[derive(Deserialize, Debug)]
157pub struct OpenAIResponseStreamEvent {
158    pub id: Option<String>,
159    pub object: String,
160    pub created: u32,
161    pub model: String,
162    pub choices: Vec<ChatChoiceDelta>,
163    pub usage: Option<Usage>,
164}
165
166#[derive(Deserialize, Debug)]
167pub struct Usage {
168    pub prompt_tokens: u32,
169    pub completion_tokens: u32,
170    pub total_tokens: u32,
171}
172
173#[derive(Deserialize, Debug)]
174pub struct ChatChoiceDelta {
175    pub index: u32,
176    pub delta: ResponseMessage,
177    pub finish_reason: Option<String>,
178}
179
180#[derive(Deserialize, Debug)]
181struct OpenAIUsage {
182    prompt_tokens: u64,
183    completion_tokens: u64,
184    total_tokens: u64,
185}
186
187#[derive(Deserialize, Debug)]
188struct OpenAIChoice {
189    text: String,
190    index: u32,
191    logprobs: Option<serde_json::Value>,
192    finish_reason: Option<String>,
193}
194
195pub fn init(cx: &mut AppContext) {
196    assistant::init(cx);
197}
198
199pub async fn stream_completion(
200    api_key: String,
201    executor: Arc<Background>,
202    mut request: OpenAIRequest,
203) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
204    request.stream = true;
205
206    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
207
208    let json_data = serde_json::to_string(&request)?;
209    let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
210        .header("Content-Type", "application/json")
211        .header("Authorization", format!("Bearer {}", api_key))
212        .body(json_data)?
213        .send_async()
214        .await?;
215
216    let status = response.status();
217    if status == StatusCode::OK {
218        executor
219            .spawn(async move {
220                let mut lines = BufReader::new(response.body_mut()).lines();
221
222                fn parse_line(
223                    line: Result<String, io::Error>,
224                ) -> Result<Option<OpenAIResponseStreamEvent>> {
225                    if let Some(data) = line?.strip_prefix("data: ") {
226                        let event = serde_json::from_str(&data)?;
227                        Ok(Some(event))
228                    } else {
229                        Ok(None)
230                    }
231                }
232
233                while let Some(line) = lines.next().await {
234                    if let Some(event) = parse_line(line).transpose() {
235                        let done = event.as_ref().map_or(false, |event| {
236                            event
237                                .choices
238                                .last()
239                                .map_or(false, |choice| choice.finish_reason.is_some())
240                        });
241                        if tx.unbounded_send(event).is_err() {
242                            break;
243                        }
244
245                        if done {
246                            break;
247                        }
248                    }
249                }
250
251                anyhow::Ok(())
252            })
253            .detach();
254
255        Ok(rx)
256    } else {
257        let mut body = String::new();
258        response.body_mut().read_to_string(&mut body).await?;
259
260        #[derive(Deserialize)]
261        struct OpenAIResponse {
262            error: OpenAIError,
263        }
264
265        #[derive(Deserialize)]
266        struct OpenAIError {
267            message: String,
268        }
269
270        match serde_json::from_str::<OpenAIResponse>(&body) {
271            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
272                "Failed to connect to OpenAI API: {}",
273                response.error.message,
274            )),
275
276            _ => Err(anyhow!(
277                "Failed to connect to OpenAI API: {} {}",
278                response.status(),
279                body,
280            )),
281        }
282    }
283}
284
285#[cfg(test)]
286#[ctor::ctor]
287fn init_logger() {
288    if std::env::var("RUST_LOG").is_ok() {
289        env_logger::init();
290    }
291}