context_store.rs

  1use crate::{assistant_settings::OpenAiModel, MessageId, MessageMetadata};
  2use anyhow::{anyhow, Result};
  3use assistant_slash_command::SlashCommandOutputSection;
  4use collections::HashMap;
  5use fs::Fs;
  6use futures::StreamExt;
  7use fuzzy::StringMatchCandidate;
  8use gpui::{AppContext, Model, ModelContext, Task};
  9use paths::contexts_dir;
 10use regex::Regex;
 11use serde::{Deserialize, Serialize};
 12use std::{cmp::Reverse, ffi::OsStr, path::PathBuf, sync::Arc, time::Duration};
 13use ui::Context;
 14use util::{ResultExt, TryFutureExt};
 15
 16#[derive(Serialize, Deserialize)]
 17pub struct SavedMessage {
 18    pub id: MessageId,
 19    pub start: usize,
 20}
 21
 22#[derive(Serialize, Deserialize)]
 23pub struct SavedContext {
 24    pub id: Option<String>,
 25    pub zed: String,
 26    pub version: String,
 27    pub text: String,
 28    pub messages: Vec<SavedMessage>,
 29    pub message_metadata: HashMap<MessageId, MessageMetadata>,
 30    pub summary: String,
 31    pub slash_command_output_sections: Vec<SlashCommandOutputSection<usize>>,
 32}
 33
 34impl SavedContext {
 35    pub const VERSION: &'static str = "0.3.0";
 36}
 37
 38#[derive(Serialize, Deserialize)]
 39pub struct SavedContextV0_2_0 {
 40    pub id: Option<String>,
 41    pub zed: String,
 42    pub version: String,
 43    pub text: String,
 44    pub messages: Vec<SavedMessage>,
 45    pub message_metadata: HashMap<MessageId, MessageMetadata>,
 46    pub summary: String,
 47}
 48
 49#[derive(Serialize, Deserialize)]
 50struct SavedContextV0_1_0 {
 51    id: Option<String>,
 52    zed: String,
 53    version: String,
 54    text: String,
 55    messages: Vec<SavedMessage>,
 56    message_metadata: HashMap<MessageId, MessageMetadata>,
 57    summary: String,
 58    api_url: Option<String>,
 59    model: OpenAiModel,
 60}
 61
 62#[derive(Clone)]
 63pub struct SavedContextMetadata {
 64    pub title: String,
 65    pub path: PathBuf,
 66    pub mtime: chrono::DateTime<chrono::Local>,
 67}
 68
 69pub struct ContextStore {
 70    contexts_metadata: Vec<SavedContextMetadata>,
 71    fs: Arc<dyn Fs>,
 72    _watch_updates: Task<Option<()>>,
 73}
 74
 75impl ContextStore {
 76    pub fn new(fs: Arc<dyn Fs>, cx: &mut AppContext) -> Task<Result<Model<Self>>> {
 77        cx.spawn(|mut cx| async move {
 78            const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
 79            let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
 80
 81            let this = cx.new_model(|cx: &mut ModelContext<Self>| Self {
 82                contexts_metadata: Vec::new(),
 83                fs,
 84                _watch_updates: cx.spawn(|this, mut cx| {
 85                    async move {
 86                        while events.next().await.is_some() {
 87                            this.update(&mut cx, |this, cx| this.reload(cx))?
 88                                .await
 89                                .log_err();
 90                        }
 91                        anyhow::Ok(())
 92                    }
 93                    .log_err()
 94                }),
 95            })?;
 96            this.update(&mut cx, |this, cx| this.reload(cx))?
 97                .await
 98                .log_err();
 99            Ok(this)
100        })
101    }
102
103    pub fn load(&self, path: PathBuf, cx: &AppContext) -> Task<Result<SavedContext>> {
104        let fs = self.fs.clone();
105        cx.background_executor().spawn(async move {
106            let saved_context = fs.load(&path).await?;
107            let saved_context_json = serde_json::from_str::<serde_json::Value>(&saved_context)?;
108            match saved_context_json
109                .get("version")
110                .ok_or_else(|| anyhow!("version not found"))?
111            {
112                serde_json::Value::String(version) => match version.as_str() {
113                    SavedContext::VERSION => {
114                        Ok(serde_json::from_value::<SavedContext>(saved_context_json)?)
115                    }
116                    "0.2.0" => {
117                        let saved_context =
118                            serde_json::from_value::<SavedContextV0_2_0>(saved_context_json)?;
119                        Ok(SavedContext {
120                            id: saved_context.id,
121                            zed: saved_context.zed,
122                            version: saved_context.version,
123                            text: saved_context.text,
124                            messages: saved_context.messages,
125                            message_metadata: saved_context.message_metadata,
126                            summary: saved_context.summary,
127                            slash_command_output_sections: Vec::new(),
128                        })
129                    }
130                    "0.1.0" => {
131                        let saved_context =
132                            serde_json::from_value::<SavedContextV0_1_0>(saved_context_json)?;
133                        Ok(SavedContext {
134                            id: saved_context.id,
135                            zed: saved_context.zed,
136                            version: saved_context.version,
137                            text: saved_context.text,
138                            messages: saved_context.messages,
139                            message_metadata: saved_context.message_metadata,
140                            summary: saved_context.summary,
141                            slash_command_output_sections: Vec::new(),
142                        })
143                    }
144                    _ => Err(anyhow!("unrecognized saved context version: {}", version)),
145                },
146                _ => Err(anyhow!("version not found on saved context")),
147            }
148        })
149    }
150
151    pub fn search(&self, query: String, cx: &AppContext) -> Task<Vec<SavedContextMetadata>> {
152        let metadata = self.contexts_metadata.clone();
153        let executor = cx.background_executor().clone();
154        cx.background_executor().spawn(async move {
155            if query.is_empty() {
156                metadata
157            } else {
158                let candidates = metadata
159                    .iter()
160                    .enumerate()
161                    .map(|(id, metadata)| StringMatchCandidate::new(id, metadata.title.clone()))
162                    .collect::<Vec<_>>();
163                let matches = fuzzy::match_strings(
164                    &candidates,
165                    &query,
166                    false,
167                    100,
168                    &Default::default(),
169                    executor,
170                )
171                .await;
172
173                matches
174                    .into_iter()
175                    .map(|mat| metadata[mat.candidate_id].clone())
176                    .collect()
177            }
178        })
179    }
180
181    fn reload(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
182        let fs = self.fs.clone();
183        cx.spawn(|this, mut cx| async move {
184            fs.create_dir(contexts_dir()).await?;
185
186            let mut paths = fs.read_dir(contexts_dir()).await?;
187            let mut contexts = Vec::<SavedContextMetadata>::new();
188            while let Some(path) = paths.next().await {
189                let path = path?;
190                if path.extension() != Some(OsStr::new("json")) {
191                    continue;
192                }
193
194                let pattern = r" - \d+.zed.json$";
195                let re = Regex::new(pattern).unwrap();
196
197                let metadata = fs.metadata(&path).await?;
198                if let Some((file_name, metadata)) = path
199                    .file_name()
200                    .and_then(|name| name.to_str())
201                    .zip(metadata)
202                {
203                    // This is used to filter out contexts saved by the new assistant.
204                    if !re.is_match(file_name) {
205                        continue;
206                    }
207
208                    if let Some(title) = re.replace(file_name, "").lines().next() {
209                        contexts.push(SavedContextMetadata {
210                            title: title.to_string(),
211                            path,
212                            mtime: metadata.mtime.into(),
213                        });
214                    }
215                }
216            }
217            contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
218
219            this.update(&mut cx, |this, cx| {
220                this.contexts_metadata = contexts;
221                cx.notify();
222            })
223        })
224    }
225}