ai.rs

  1pub mod assistant;
  2mod assistant_settings;
  3mod codegen;
  4mod streaming_diff;
  5
  6use anyhow::{anyhow, Result};
  7pub use assistant::AssistantPanel;
  8use assistant_settings::OpenAIModel;
  9use chrono::{DateTime, Local};
 10use collections::HashMap;
 11use fs::Fs;
 12use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
 13use gpui::{executor::Background, AppContext};
 14use isahc::{http::StatusCode, Request, RequestExt};
 15use regex::Regex;
 16use serde::{Deserialize, Serialize};
 17use std::{
 18    cmp::Reverse,
 19    ffi::OsStr,
 20    fmt::{self, Display},
 21    io,
 22    path::PathBuf,
 23    sync::Arc,
 24};
 25use util::paths::CONVERSATIONS_DIR;
 26
 27const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
 28
 29// Data types for chat completion requests
 30#[derive(Debug, Default, Serialize)]
 31pub struct OpenAIRequest {
 32    model: String,
 33    messages: Vec<RequestMessage>,
 34    stream: bool,
 35}
 36
 37#[derive(
 38    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 39)]
 40struct MessageId(usize);
 41
 42#[derive(Clone, Debug, Serialize, Deserialize)]
 43struct MessageMetadata {
 44    role: Role,
 45    sent_at: DateTime<Local>,
 46    status: MessageStatus,
 47}
 48
 49#[derive(Clone, Debug, Serialize, Deserialize)]
 50enum MessageStatus {
 51    Pending,
 52    Done,
 53    Error(Arc<str>),
 54}
 55
 56#[derive(Serialize, Deserialize)]
 57struct SavedMessage {
 58    id: MessageId,
 59    start: usize,
 60}
 61
 62#[derive(Serialize, Deserialize)]
 63struct SavedConversation {
 64    zed: String,
 65    version: String,
 66    text: String,
 67    messages: Vec<SavedMessage>,
 68    message_metadata: HashMap<MessageId, MessageMetadata>,
 69    summary: String,
 70    model: OpenAIModel,
 71}
 72
 73impl SavedConversation {
 74    const VERSION: &'static str = "0.1.0";
 75}
 76
 77struct SavedConversationMetadata {
 78    title: String,
 79    path: PathBuf,
 80    mtime: chrono::DateTime<chrono::Local>,
 81}
 82
 83impl SavedConversationMetadata {
 84    pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
 85        fs.create_dir(&CONVERSATIONS_DIR).await?;
 86
 87        let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
 88        let mut conversations = Vec::<SavedConversationMetadata>::new();
 89        while let Some(path) = paths.next().await {
 90            let path = path?;
 91            if path.extension() != Some(OsStr::new("json")) {
 92                continue;
 93            }
 94
 95            let pattern = r" - \d+.zed.json$";
 96            let re = Regex::new(pattern).unwrap();
 97
 98            let metadata = fs.metadata(&path).await?;
 99            if let Some((file_name, metadata)) = path
100                .file_name()
101                .and_then(|name| name.to_str())
102                .zip(metadata)
103            {
104                let title = re.replace(file_name, "");
105                conversations.push(Self {
106                    title: title.into_owned(),
107                    path,
108                    mtime: metadata.mtime.into(),
109                });
110            }
111        }
112        conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
113
114        Ok(conversations)
115    }
116}
117
118#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
119struct RequestMessage {
120    role: Role,
121    content: String,
122}
123
124#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
125pub struct ResponseMessage {
126    role: Option<Role>,
127    content: Option<String>,
128}
129
130#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
131#[serde(rename_all = "lowercase")]
132enum Role {
133    User,
134    Assistant,
135    System,
136}
137
138impl Role {
139    pub fn cycle(&mut self) {
140        *self = match self {
141            Role::User => Role::Assistant,
142            Role::Assistant => Role::System,
143            Role::System => Role::User,
144        }
145    }
146}
147
148impl Display for Role {
149    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
150        match self {
151            Role::User => write!(f, "User"),
152            Role::Assistant => write!(f, "Assistant"),
153            Role::System => write!(f, "System"),
154        }
155    }
156}
157
158#[derive(Deserialize, Debug)]
159pub struct OpenAIResponseStreamEvent {
160    pub id: Option<String>,
161    pub object: String,
162    pub created: u32,
163    pub model: String,
164    pub choices: Vec<ChatChoiceDelta>,
165    pub usage: Option<Usage>,
166}
167
168#[derive(Deserialize, Debug)]
169pub struct Usage {
170    pub prompt_tokens: u32,
171    pub completion_tokens: u32,
172    pub total_tokens: u32,
173}
174
175#[derive(Deserialize, Debug)]
176pub struct ChatChoiceDelta {
177    pub index: u32,
178    pub delta: ResponseMessage,
179    pub finish_reason: Option<String>,
180}
181
182#[derive(Deserialize, Debug)]
183struct OpenAIUsage {
184    prompt_tokens: u64,
185    completion_tokens: u64,
186    total_tokens: u64,
187}
188
189#[derive(Deserialize, Debug)]
190struct OpenAIChoice {
191    text: String,
192    index: u32,
193    logprobs: Option<serde_json::Value>,
194    finish_reason: Option<String>,
195}
196
197pub fn init(cx: &mut AppContext) {
198    assistant::init(cx);
199}
200
201pub async fn stream_completion(
202    api_key: String,
203    executor: Arc<Background>,
204    mut request: OpenAIRequest,
205) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
206    request.stream = true;
207
208    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
209
210    let json_data = serde_json::to_string(&request)?;
211    let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
212        .header("Content-Type", "application/json")
213        .header("Authorization", format!("Bearer {}", api_key))
214        .body(json_data)?
215        .send_async()
216        .await?;
217
218    let status = response.status();
219    if status == StatusCode::OK {
220        executor
221            .spawn(async move {
222                let mut lines = BufReader::new(response.body_mut()).lines();
223
224                fn parse_line(
225                    line: Result<String, io::Error>,
226                ) -> Result<Option<OpenAIResponseStreamEvent>> {
227                    if let Some(data) = line?.strip_prefix("data: ") {
228                        let event = serde_json::from_str(&data)?;
229                        Ok(Some(event))
230                    } else {
231                        Ok(None)
232                    }
233                }
234
235                while let Some(line) = lines.next().await {
236                    if let Some(event) = parse_line(line).transpose() {
237                        let done = event.as_ref().map_or(false, |event| {
238                            event
239                                .choices
240                                .last()
241                                .map_or(false, |choice| choice.finish_reason.is_some())
242                        });
243                        if tx.unbounded_send(event).is_err() {
244                            break;
245                        }
246
247                        if done {
248                            break;
249                        }
250                    }
251                }
252
253                anyhow::Ok(())
254            })
255            .detach();
256
257        Ok(rx)
258    } else {
259        let mut body = String::new();
260        response.body_mut().read_to_string(&mut body).await?;
261
262        #[derive(Deserialize)]
263        struct OpenAIResponse {
264            error: OpenAIError,
265        }
266
267        #[derive(Deserialize)]
268        struct OpenAIError {
269            message: String,
270        }
271
272        match serde_json::from_str::<OpenAIResponse>(&body) {
273            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
274                "Failed to connect to OpenAI API: {}",
275                response.error.message,
276            )),
277
278            _ => Err(anyhow!(
279                "Failed to connect to OpenAI API: {} {}",
280                response.status(),
281                body,
282            )),
283        }
284    }
285}
286
287#[cfg(test)]
288#[ctor::ctor]
289fn init_logger() {
290    if std::env::var("RUST_LOG").is_ok() {
291        env_logger::init();
292    }
293}