ai.rs

  1pub mod assistant;
  2mod assistant_settings;
  3mod codegen;
  4mod streaming_diff;
  5
  6use anyhow::{anyhow, Result};
  7pub use assistant::AssistantPanel;
  8use assistant_settings::OpenAIModel;
  9use chrono::{DateTime, Local};
 10use collections::HashMap;
 11use fs::Fs;
 12use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
 13use gpui::{executor::Background, AppContext};
 14use isahc::{http::StatusCode, Request, RequestExt};
 15use regex::Regex;
 16use serde::{Deserialize, Serialize};
 17use std::{
 18    cmp::Reverse,
 19    ffi::OsStr,
 20    fmt::{self, Display},
 21    io,
 22    path::PathBuf,
 23    sync::Arc,
 24};
 25use util::paths::CONVERSATIONS_DIR;
 26
 27const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
 28
 29// Data types for chat completion requests
 30#[derive(Debug, Default, Serialize)]
 31pub struct OpenAIRequest {
 32    model: String,
 33    messages: Vec<RequestMessage>,
 34    stream: bool,
 35}
 36
 37#[derive(
 38    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 39)]
 40struct MessageId(usize);
 41
 42#[derive(Clone, Debug, Serialize, Deserialize)]
 43struct MessageMetadata {
 44    role: Role,
 45    sent_at: DateTime<Local>,
 46    status: MessageStatus,
 47}
 48
 49#[derive(Clone, Debug, Serialize, Deserialize)]
 50enum MessageStatus {
 51    Pending,
 52    Done,
 53    Error(Arc<str>),
 54}
 55
 56#[derive(Serialize, Deserialize)]
 57struct SavedMessage {
 58    id: MessageId,
 59    start: usize,
 60}
 61
 62#[derive(Serialize, Deserialize)]
 63struct SavedConversation {
 64    id: Option<String>,
 65    zed: String,
 66    version: String,
 67    text: String,
 68    messages: Vec<SavedMessage>,
 69    message_metadata: HashMap<MessageId, MessageMetadata>,
 70    summary: String,
 71    model: OpenAIModel,
 72}
 73
 74impl SavedConversation {
 75    const VERSION: &'static str = "0.1.0";
 76}
 77
 78struct SavedConversationMetadata {
 79    title: String,
 80    path: PathBuf,
 81    mtime: chrono::DateTime<chrono::Local>,
 82}
 83
 84impl SavedConversationMetadata {
 85    pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
 86        fs.create_dir(&CONVERSATIONS_DIR).await?;
 87
 88        let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
 89        let mut conversations = Vec::<SavedConversationMetadata>::new();
 90        while let Some(path) = paths.next().await {
 91            let path = path?;
 92            if path.extension() != Some(OsStr::new("json")) {
 93                continue;
 94            }
 95
 96            let pattern = r" - \d+.zed.json$";
 97            let re = Regex::new(pattern).unwrap();
 98
 99            let metadata = fs.metadata(&path).await?;
100            if let Some((file_name, metadata)) = path
101                .file_name()
102                .and_then(|name| name.to_str())
103                .zip(metadata)
104            {
105                let title = re.replace(file_name, "");
106                conversations.push(Self {
107                    title: title.into_owned(),
108                    path,
109                    mtime: metadata.mtime.into(),
110                });
111            }
112        }
113        conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
114
115        Ok(conversations)
116    }
117}
118
119#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
120struct RequestMessage {
121    role: Role,
122    content: String,
123}
124
125#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
126pub struct ResponseMessage {
127    role: Option<Role>,
128    content: Option<String>,
129}
130
131#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
132#[serde(rename_all = "lowercase")]
133enum Role {
134    User,
135    Assistant,
136    System,
137}
138
139impl Role {
140    pub fn cycle(&mut self) {
141        *self = match self {
142            Role::User => Role::Assistant,
143            Role::Assistant => Role::System,
144            Role::System => Role::User,
145        }
146    }
147}
148
149impl Display for Role {
150    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
151        match self {
152            Role::User => write!(f, "User"),
153            Role::Assistant => write!(f, "Assistant"),
154            Role::System => write!(f, "System"),
155        }
156    }
157}
158
159#[derive(Deserialize, Debug)]
160pub struct OpenAIResponseStreamEvent {
161    pub id: Option<String>,
162    pub object: String,
163    pub created: u32,
164    pub model: String,
165    pub choices: Vec<ChatChoiceDelta>,
166    pub usage: Option<Usage>,
167}
168
169#[derive(Deserialize, Debug)]
170pub struct Usage {
171    pub prompt_tokens: u32,
172    pub completion_tokens: u32,
173    pub total_tokens: u32,
174}
175
176#[derive(Deserialize, Debug)]
177pub struct ChatChoiceDelta {
178    pub index: u32,
179    pub delta: ResponseMessage,
180    pub finish_reason: Option<String>,
181}
182
183#[derive(Deserialize, Debug)]
184struct OpenAIUsage {
185    prompt_tokens: u64,
186    completion_tokens: u64,
187    total_tokens: u64,
188}
189
190#[derive(Deserialize, Debug)]
191struct OpenAIChoice {
192    text: String,
193    index: u32,
194    logprobs: Option<serde_json::Value>,
195    finish_reason: Option<String>,
196}
197
198pub fn init(cx: &mut AppContext) {
199    assistant::init(cx);
200}
201
202pub async fn stream_completion(
203    api_key: String,
204    executor: Arc<Background>,
205    mut request: OpenAIRequest,
206) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
207    request.stream = true;
208
209    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
210
211    let json_data = serde_json::to_string(&request)?;
212    let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
213        .header("Content-Type", "application/json")
214        .header("Authorization", format!("Bearer {}", api_key))
215        .body(json_data)?
216        .send_async()
217        .await?;
218
219    let status = response.status();
220    if status == StatusCode::OK {
221        executor
222            .spawn(async move {
223                let mut lines = BufReader::new(response.body_mut()).lines();
224
225                fn parse_line(
226                    line: Result<String, io::Error>,
227                ) -> Result<Option<OpenAIResponseStreamEvent>> {
228                    if let Some(data) = line?.strip_prefix("data: ") {
229                        let event = serde_json::from_str(&data)?;
230                        Ok(Some(event))
231                    } else {
232                        Ok(None)
233                    }
234                }
235
236                while let Some(line) = lines.next().await {
237                    if let Some(event) = parse_line(line).transpose() {
238                        let done = event.as_ref().map_or(false, |event| {
239                            event
240                                .choices
241                                .last()
242                                .map_or(false, |choice| choice.finish_reason.is_some())
243                        });
244                        if tx.unbounded_send(event).is_err() {
245                            break;
246                        }
247
248                        if done {
249                            break;
250                        }
251                    }
252                }
253
254                anyhow::Ok(())
255            })
256            .detach();
257
258        Ok(rx)
259    } else {
260        let mut body = String::new();
261        response.body_mut().read_to_string(&mut body).await?;
262
263        #[derive(Deserialize)]
264        struct OpenAIResponse {
265            error: OpenAIError,
266        }
267
268        #[derive(Deserialize)]
269        struct OpenAIError {
270            message: String,
271        }
272
273        match serde_json::from_str::<OpenAIResponse>(&body) {
274            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
275                "Failed to connect to OpenAI API: {}",
276                response.error.message,
277            )),
278
279            _ => Err(anyhow!(
280                "Failed to connect to OpenAI API: {} {}",
281                response.status(),
282                body,
283            )),
284        }
285    }
286}
287
288#[cfg(test)]
289#[ctor::ctor]
290fn init_logger() {
291    if std::env::var("RUST_LOG").is_ok() {
292        env_logger::init();
293    }
294}