ai.rs

  1pub mod assistant;
  2mod assistant_settings;
  3
  4use anyhow::Result;
  5pub use assistant::AssistantPanel;
  6use chrono::{DateTime, Local};
  7use collections::HashMap;
  8use fs::Fs;
  9use futures::StreamExt;
 10use gpui::AppContext;
 11use regex::Regex;
 12use serde::{Deserialize, Serialize};
 13use std::{
 14    cmp::Reverse,
 15    fmt::{self, Display},
 16    path::PathBuf,
 17    sync::Arc,
 18};
 19use util::paths::CONVERSATIONS_DIR;
 20
 21// Data types for chat completion requests
 22#[derive(Debug, Serialize)]
 23struct OpenAIRequest {
 24    model: String,
 25    messages: Vec<RequestMessage>,
 26    stream: bool,
 27}
 28
 29#[derive(
 30    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 31)]
 32struct MessageId(usize);
 33
 34#[derive(Clone, Debug, Serialize, Deserialize)]
 35struct MessageMetadata {
 36    role: Role,
 37    sent_at: DateTime<Local>,
 38    status: MessageStatus,
 39}
 40
 41#[derive(Clone, Debug, Serialize, Deserialize)]
 42enum MessageStatus {
 43    Pending,
 44    Done,
 45    Error(Arc<str>),
 46}
 47
 48#[derive(Serialize, Deserialize)]
 49struct SavedMessage {
 50    id: MessageId,
 51    start: usize,
 52}
 53
 54#[derive(Serialize, Deserialize)]
 55struct SavedConversation {
 56    zed: String,
 57    version: String,
 58    text: String,
 59    messages: Vec<SavedMessage>,
 60    message_metadata: HashMap<MessageId, MessageMetadata>,
 61    summary: String,
 62    model: String,
 63}
 64
 65impl SavedConversation {
 66    const VERSION: &'static str = "0.1.0";
 67}
 68
 69struct SavedConversationMetadata {
 70    title: String,
 71    path: PathBuf,
 72    mtime: chrono::DateTime<chrono::Local>,
 73}
 74
 75impl SavedConversationMetadata {
 76    pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
 77        fs.create_dir(&CONVERSATIONS_DIR).await?;
 78
 79        let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
 80        let mut conversations = Vec::<SavedConversationMetadata>::new();
 81        while let Some(path) = paths.next().await {
 82            let path = path?;
 83
 84            let pattern = r" - \d+.zed.json$";
 85            let re = Regex::new(pattern).unwrap();
 86
 87            let metadata = fs.metadata(&path).await?;
 88            if let Some((file_name, metadata)) = path
 89                .file_name()
 90                .and_then(|name| name.to_str())
 91                .zip(metadata)
 92            {
 93                let title = re.replace(file_name, "");
 94                conversations.push(Self {
 95                    title: title.into_owned(),
 96                    path,
 97                    mtime: metadata.mtime.into(),
 98                });
 99            }
100        }
101        conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
102
103        Ok(conversations)
104    }
105}
106
107#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
108struct RequestMessage {
109    role: Role,
110    content: String,
111}
112
113#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
114struct ResponseMessage {
115    role: Option<Role>,
116    content: Option<String>,
117}
118
119#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
120#[serde(rename_all = "lowercase")]
121enum Role {
122    User,
123    Assistant,
124    System,
125}
126
127impl Role {
128    pub fn cycle(&mut self) {
129        *self = match self {
130            Role::User => Role::Assistant,
131            Role::Assistant => Role::System,
132            Role::System => Role::User,
133        }
134    }
135}
136
137impl Display for Role {
138    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
139        match self {
140            Role::User => write!(f, "User"),
141            Role::Assistant => write!(f, "Assistant"),
142            Role::System => write!(f, "System"),
143        }
144    }
145}
146
147#[derive(Deserialize, Debug)]
148struct OpenAIResponseStreamEvent {
149    pub id: Option<String>,
150    pub object: String,
151    pub created: u32,
152    pub model: String,
153    pub choices: Vec<ChatChoiceDelta>,
154    pub usage: Option<Usage>,
155}
156
157#[derive(Deserialize, Debug)]
158struct Usage {
159    pub prompt_tokens: u32,
160    pub completion_tokens: u32,
161    pub total_tokens: u32,
162}
163
164#[derive(Deserialize, Debug)]
165struct ChatChoiceDelta {
166    pub index: u32,
167    pub delta: ResponseMessage,
168    pub finish_reason: Option<String>,
169}
170
171#[derive(Deserialize, Debug)]
172struct OpenAIUsage {
173    prompt_tokens: u64,
174    completion_tokens: u64,
175    total_tokens: u64,
176}
177
178#[derive(Deserialize, Debug)]
179struct OpenAIChoice {
180    text: String,
181    index: u32,
182    logprobs: Option<serde_json::Value>,
183    finish_reason: Option<String>,
184}
185
186pub fn init(cx: &mut AppContext) {
187    assistant::init(cx);
188}