ai.rs

  1pub mod assistant;
  2mod assistant_settings;
  3
  4use anyhow::Result;
  5pub use assistant::AssistantPanel;
  6use assistant_settings::OpenAIModel;
  7use chrono::{DateTime, Local};
  8use collections::HashMap;
  9use fs::Fs;
 10use futures::StreamExt;
 11use gpui::AppContext;
 12use regex::Regex;
 13use serde::{Deserialize, Serialize};
 14use std::{
 15    cmp::Reverse,
 16    ffi::OsStr,
 17    fmt::{self, Display},
 18    path::PathBuf,
 19    sync::Arc,
 20};
 21use util::paths::CONVERSATIONS_DIR;
 22
 23// Data types for chat completion requests
 24#[derive(Debug, Serialize)]
 25struct OpenAIRequest {
 26    model: String,
 27    messages: Vec<RequestMessage>,
 28    stream: bool,
 29}
 30
 31#[derive(
 32    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 33)]
 34struct MessageId(usize);
 35
 36#[derive(Clone, Debug, Serialize, Deserialize)]
 37struct MessageMetadata {
 38    role: Role,
 39    sent_at: DateTime<Local>,
 40    status: MessageStatus,
 41}
 42
 43#[derive(Clone, Debug, Serialize, Deserialize)]
 44enum MessageStatus {
 45    Pending,
 46    Done,
 47    Error(Arc<str>),
 48}
 49
 50#[derive(Serialize, Deserialize)]
 51struct SavedMessage {
 52    id: MessageId,
 53    start: usize,
 54}
 55
 56#[derive(Serialize, Deserialize)]
 57struct SavedConversation {
 58    zed: String,
 59    version: String,
 60    text: String,
 61    messages: Vec<SavedMessage>,
 62    message_metadata: HashMap<MessageId, MessageMetadata>,
 63    summary: String,
 64    model: OpenAIModel,
 65}
 66
 67impl SavedConversation {
 68    const VERSION: &'static str = "0.1.0";
 69}
 70
 71struct SavedConversationMetadata {
 72    title: String,
 73    path: PathBuf,
 74    mtime: chrono::DateTime<chrono::Local>,
 75}
 76
 77impl SavedConversationMetadata {
 78    pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
 79        fs.create_dir(&CONVERSATIONS_DIR).await?;
 80
 81        let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
 82        let mut conversations = Vec::<SavedConversationMetadata>::new();
 83        while let Some(path) = paths.next().await {
 84            let path = path?;
 85            if path.extension() != Some(OsStr::new("json")) {
 86                continue;
 87            }
 88
 89            let pattern = r" - \d+.zed.json$";
 90            let re = Regex::new(pattern).unwrap();
 91
 92            let metadata = fs.metadata(&path).await?;
 93            if let Some((file_name, metadata)) = path
 94                .file_name()
 95                .and_then(|name| name.to_str())
 96                .zip(metadata)
 97            {
 98                let title = re.replace(file_name, "");
 99                conversations.push(Self {
100                    title: title.into_owned(),
101                    path,
102                    mtime: metadata.mtime.into(),
103                });
104            }
105        }
106        conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
107
108        Ok(conversations)
109    }
110}
111
112#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
113struct RequestMessage {
114    role: Role,
115    content: String,
116}
117
118#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
119struct ResponseMessage {
120    role: Option<Role>,
121    content: Option<String>,
122}
123
124#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
125#[serde(rename_all = "lowercase")]
126enum Role {
127    User,
128    Assistant,
129    System,
130}
131
132impl Role {
133    pub fn cycle(&mut self) {
134        *self = match self {
135            Role::User => Role::Assistant,
136            Role::Assistant => Role::System,
137            Role::System => Role::User,
138        }
139    }
140}
141
142impl Display for Role {
143    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
144        match self {
145            Role::User => write!(f, "User"),
146            Role::Assistant => write!(f, "Assistant"),
147            Role::System => write!(f, "System"),
148        }
149    }
150}
151
152#[derive(Deserialize, Debug)]
153struct OpenAIResponseStreamEvent {
154    pub id: Option<String>,
155    pub object: String,
156    pub created: u32,
157    pub model: String,
158    pub choices: Vec<ChatChoiceDelta>,
159    pub usage: Option<Usage>,
160}
161
162#[derive(Deserialize, Debug)]
163struct Usage {
164    pub prompt_tokens: u32,
165    pub completion_tokens: u32,
166    pub total_tokens: u32,
167}
168
169#[derive(Deserialize, Debug)]
170struct ChatChoiceDelta {
171    pub index: u32,
172    pub delta: ResponseMessage,
173    pub finish_reason: Option<String>,
174}
175
176#[derive(Deserialize, Debug)]
177struct OpenAIUsage {
178    prompt_tokens: u64,
179    completion_tokens: u64,
180    total_tokens: u64,
181}
182
183#[derive(Deserialize, Debug)]
184struct OpenAIChoice {
185    text: String,
186    index: u32,
187    logprobs: Option<serde_json::Value>,
188    finish_reason: Option<String>,
189}
190
191pub fn init(cx: &mut AppContext) {
192    assistant::init(cx);
193}