ai.rs

  1pub mod assistant;
  2mod assistant_settings;
  3
  4use anyhow::Result;
  5pub use assistant::AssistantPanel;
  6use chrono::{DateTime, Local};
  7use collections::HashMap;
  8use fs::Fs;
  9use futures::StreamExt;
 10use gpui::AppContext;
 11use regex::Regex;
 12use serde::{Deserialize, Serialize};
 13use std::{
 14    fmt::{self, Display},
 15    path::PathBuf,
 16    sync::Arc,
 17};
 18use util::paths::CONVERSATIONS_DIR;
 19
 20// Data types for chat completion requests
 21#[derive(Debug, Serialize)]
 22struct OpenAIRequest {
 23    model: String,
 24    messages: Vec<RequestMessage>,
 25    stream: bool,
 26}
 27
 28#[derive(
 29    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 30)]
 31struct MessageId(usize);
 32
 33#[derive(Clone, Debug, Serialize, Deserialize)]
 34struct MessageMetadata {
 35    role: Role,
 36    sent_at: DateTime<Local>,
 37    status: MessageStatus,
 38}
 39
 40#[derive(Clone, Debug, Serialize, Deserialize)]
 41enum MessageStatus {
 42    Pending,
 43    Done,
 44    Error(Arc<str>),
 45}
 46
 47#[derive(Serialize, Deserialize)]
 48struct SavedMessage {
 49    id: MessageId,
 50    start: usize,
 51}
 52
 53#[derive(Serialize, Deserialize)]
 54struct SavedConversation {
 55    zed: String,
 56    version: String,
 57    text: String,
 58    messages: Vec<SavedMessage>,
 59    message_metadata: HashMap<MessageId, MessageMetadata>,
 60    summary: String,
 61    model: String,
 62}
 63
 64impl SavedConversation {
 65    const VERSION: &'static str = "0.1.0";
 66}
 67
 68struct SavedConversationMetadata {
 69    title: String,
 70    path: PathBuf,
 71    mtime: chrono::DateTime<chrono::Local>,
 72}
 73
 74impl SavedConversationMetadata {
 75    pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
 76        fs.create_dir(&CONVERSATIONS_DIR).await?;
 77
 78        let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
 79        let mut conversations = Vec::<SavedConversationMetadata>::new();
 80        while let Some(path) = paths.next().await {
 81            let path = path?;
 82
 83            let pattern = r" - \d+.zed.json$";
 84            let re = Regex::new(pattern).unwrap();
 85
 86            let metadata = fs.metadata(&path).await?;
 87            if let Some((file_name, metadata)) = path
 88                .file_name()
 89                .and_then(|name| name.to_str())
 90                .zip(metadata)
 91            {
 92                let title = re.replace(file_name, "");
 93                conversations.push(Self {
 94                    title: title.into_owned(),
 95                    path,
 96                    mtime: metadata.mtime.into(),
 97                });
 98            }
 99        }
100
101        Ok(conversations)
102    }
103}
104
105#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
106struct RequestMessage {
107    role: Role,
108    content: String,
109}
110
111#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
112struct ResponseMessage {
113    role: Option<Role>,
114    content: Option<String>,
115}
116
117#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
118#[serde(rename_all = "lowercase")]
119enum Role {
120    User,
121    Assistant,
122    System,
123}
124
125impl Role {
126    pub fn cycle(&mut self) {
127        *self = match self {
128            Role::User => Role::Assistant,
129            Role::Assistant => Role::System,
130            Role::System => Role::User,
131        }
132    }
133}
134
135impl Display for Role {
136    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
137        match self {
138            Role::User => write!(f, "User"),
139            Role::Assistant => write!(f, "Assistant"),
140            Role::System => write!(f, "System"),
141        }
142    }
143}
144
145#[derive(Deserialize, Debug)]
146struct OpenAIResponseStreamEvent {
147    pub id: Option<String>,
148    pub object: String,
149    pub created: u32,
150    pub model: String,
151    pub choices: Vec<ChatChoiceDelta>,
152    pub usage: Option<Usage>,
153}
154
155#[derive(Deserialize, Debug)]
156struct Usage {
157    pub prompt_tokens: u32,
158    pub completion_tokens: u32,
159    pub total_tokens: u32,
160}
161
162#[derive(Deserialize, Debug)]
163struct ChatChoiceDelta {
164    pub index: u32,
165    pub delta: ResponseMessage,
166    pub finish_reason: Option<String>,
167}
168
169#[derive(Deserialize, Debug)]
170struct OpenAIUsage {
171    prompt_tokens: u64,
172    completion_tokens: u64,
173    total_tokens: u64,
174}
175
176#[derive(Deserialize, Debug)]
177struct OpenAIChoice {
178    text: String,
179    index: u32,
180    logprobs: Option<serde_json::Value>,
181    finish_reason: Option<String>,
182}
183
184pub fn init(cx: &mut AppContext) {
185    assistant::init(cx);
186}