ai.rs

  1pub mod assistant;
  2mod assistant_settings;
  3
  4use anyhow::Result;
  5pub use assistant::AssistantPanel;
  6use chrono::{DateTime, Local};
  7use collections::HashMap;
  8use fs::Fs;
  9use futures::StreamExt;
 10use gpui::AppContext;
 11use regex::Regex;
 12use serde::{Deserialize, Serialize};
 13use std::{
 14    cmp::Reverse,
 15    ffi::OsStr,
 16    fmt::{self, Display},
 17    path::PathBuf,
 18    sync::Arc,
 19};
 20use util::paths::CONVERSATIONS_DIR;
 21
 22// Data types for chat completion requests
 23#[derive(Debug, Serialize)]
 24struct OpenAIRequest {
 25    model: String,
 26    messages: Vec<RequestMessage>,
 27    stream: bool,
 28}
 29
 30#[derive(
 31    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 32)]
 33struct MessageId(usize);
 34
 35#[derive(Clone, Debug, Serialize, Deserialize)]
 36struct MessageMetadata {
 37    role: Role,
 38    sent_at: DateTime<Local>,
 39    status: MessageStatus,
 40}
 41
 42#[derive(Clone, Debug, Serialize, Deserialize)]
 43enum MessageStatus {
 44    Pending,
 45    Done,
 46    Error(Arc<str>),
 47}
 48
 49#[derive(Serialize, Deserialize)]
 50struct SavedMessage {
 51    id: MessageId,
 52    start: usize,
 53}
 54
 55#[derive(Serialize, Deserialize)]
 56struct SavedConversation {
 57    zed: String,
 58    version: String,
 59    text: String,
 60    messages: Vec<SavedMessage>,
 61    message_metadata: HashMap<MessageId, MessageMetadata>,
 62    summary: String,
 63    model: String,
 64}
 65
 66impl SavedConversation {
 67    const VERSION: &'static str = "0.1.0";
 68}
 69
 70struct SavedConversationMetadata {
 71    title: String,
 72    path: PathBuf,
 73    mtime: chrono::DateTime<chrono::Local>,
 74}
 75
 76impl SavedConversationMetadata {
 77    pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
 78        fs.create_dir(&CONVERSATIONS_DIR).await?;
 79
 80        let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
 81        let mut conversations = Vec::<SavedConversationMetadata>::new();
 82        while let Some(path) = paths.next().await {
 83            let path = path?;
 84            if path.extension() != Some(OsStr::new("json")) {
 85                continue;
 86            }
 87
 88            let pattern = r" - \d+.zed.json$";
 89            let re = Regex::new(pattern).unwrap();
 90
 91            let metadata = fs.metadata(&path).await?;
 92            if let Some((file_name, metadata)) = path
 93                .file_name()
 94                .and_then(|name| name.to_str())
 95                .zip(metadata)
 96            {
 97                let title = re.replace(file_name, "");
 98                conversations.push(Self {
 99                    title: title.into_owned(),
100                    path,
101                    mtime: metadata.mtime.into(),
102                });
103            }
104        }
105        conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
106
107        Ok(conversations)
108    }
109}
110
111#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
112struct RequestMessage {
113    role: Role,
114    content: String,
115}
116
117#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
118struct ResponseMessage {
119    role: Option<Role>,
120    content: Option<String>,
121}
122
123#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
124#[serde(rename_all = "lowercase")]
125enum Role {
126    User,
127    Assistant,
128    System,
129}
130
131impl Role {
132    pub fn cycle(&mut self) {
133        *self = match self {
134            Role::User => Role::Assistant,
135            Role::Assistant => Role::System,
136            Role::System => Role::User,
137        }
138    }
139}
140
141impl Display for Role {
142    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
143        match self {
144            Role::User => write!(f, "User"),
145            Role::Assistant => write!(f, "Assistant"),
146            Role::System => write!(f, "System"),
147        }
148    }
149}
150
151#[derive(Deserialize, Debug)]
152struct OpenAIResponseStreamEvent {
153    pub id: Option<String>,
154    pub object: String,
155    pub created: u32,
156    pub model: String,
157    pub choices: Vec<ChatChoiceDelta>,
158    pub usage: Option<Usage>,
159}
160
161#[derive(Deserialize, Debug)]
162struct Usage {
163    pub prompt_tokens: u32,
164    pub completion_tokens: u32,
165    pub total_tokens: u32,
166}
167
168#[derive(Deserialize, Debug)]
169struct ChatChoiceDelta {
170    pub index: u32,
171    pub delta: ResponseMessage,
172    pub finish_reason: Option<String>,
173}
174
175#[derive(Deserialize, Debug)]
176struct OpenAIUsage {
177    prompt_tokens: u64,
178    completion_tokens: u64,
179    total_tokens: u64,
180}
181
182#[derive(Deserialize, Debug)]
183struct OpenAIChoice {
184    text: String,
185    index: u32,
186    logprobs: Option<serde_json::Value>,
187    finish_reason: Option<String>,
188}
189
190pub fn init(cx: &mut AppContext) {
191    assistant::init(cx);
192}