context_store.rs

  1use crate::{assistant_settings::OpenAiModel, MessageId, MessageMetadata};
  2use anyhow::{anyhow, Result};
  3use collections::HashMap;
  4use fs::Fs;
  5use futures::StreamExt;
  6use fuzzy::StringMatchCandidate;
  7use gpui::{AppContext, Model, ModelContext, Task};
  8use regex::Regex;
  9use serde::{Deserialize, Serialize};
 10use std::{cmp::Reverse, ffi::OsStr, path::PathBuf, sync::Arc, time::Duration};
 11use ui::Context;
 12use util::{paths::CONTEXTS_DIR, ResultExt, TryFutureExt};
 13
 14#[derive(Serialize, Deserialize)]
 15pub struct SavedMessage {
 16    pub id: MessageId,
 17    pub start: usize,
 18}
 19
 20#[derive(Serialize, Deserialize)]
 21pub struct SavedContext {
 22    pub id: Option<String>,
 23    pub zed: String,
 24    pub version: String,
 25    pub text: String,
 26    pub messages: Vec<SavedMessage>,
 27    pub message_metadata: HashMap<MessageId, MessageMetadata>,
 28    pub summary: String,
 29}
 30
 31impl SavedContext {
 32    pub const VERSION: &'static str = "0.2.0";
 33}
 34
 35#[derive(Serialize, Deserialize)]
 36struct SavedContextV0_1_0 {
 37    id: Option<String>,
 38    zed: String,
 39    version: String,
 40    text: String,
 41    messages: Vec<SavedMessage>,
 42    message_metadata: HashMap<MessageId, MessageMetadata>,
 43    summary: String,
 44    api_url: Option<String>,
 45    model: OpenAiModel,
 46}
 47
 48#[derive(Clone)]
 49pub struct SavedContextMetadata {
 50    pub title: String,
 51    pub path: PathBuf,
 52    pub mtime: chrono::DateTime<chrono::Local>,
 53}
 54
 55pub struct ContextStore {
 56    contexts_metadata: Vec<SavedContextMetadata>,
 57    fs: Arc<dyn Fs>,
 58    _watch_updates: Task<Option<()>>,
 59}
 60
 61impl ContextStore {
 62    pub fn new(fs: Arc<dyn Fs>, cx: &mut AppContext) -> Task<Result<Model<Self>>> {
 63        cx.spawn(|mut cx| async move {
 64            const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
 65            let (mut events, _) = fs.watch(&CONTEXTS_DIR, CONTEXT_WATCH_DURATION).await;
 66
 67            let this = cx.new_model(|cx: &mut ModelContext<Self>| Self {
 68                contexts_metadata: Vec::new(),
 69                fs,
 70                _watch_updates: cx.spawn(|this, mut cx| {
 71                    async move {
 72                        while events.next().await.is_some() {
 73                            this.update(&mut cx, |this, cx| this.reload(cx))?
 74                                .await
 75                                .log_err();
 76                        }
 77                        anyhow::Ok(())
 78                    }
 79                    .log_err()
 80                }),
 81            })?;
 82            this.update(&mut cx, |this, cx| this.reload(cx))?
 83                .await
 84                .log_err();
 85            Ok(this)
 86        })
 87    }
 88
 89    pub fn load(&self, path: PathBuf, cx: &AppContext) -> Task<Result<SavedContext>> {
 90        let fs = self.fs.clone();
 91        cx.background_executor().spawn(async move {
 92            let saved_context = fs.load(&path).await?;
 93            let saved_context_json = serde_json::from_str::<serde_json::Value>(&saved_context)?;
 94            match saved_context_json
 95                .get("version")
 96                .ok_or_else(|| anyhow!("version not found"))?
 97            {
 98                serde_json::Value::String(version) => match version.as_str() {
 99                    SavedContext::VERSION => {
100                        Ok(serde_json::from_value::<SavedContext>(saved_context_json)?)
101                    }
102                    "0.1.0" => {
103                        let saved_context =
104                            serde_json::from_value::<SavedContextV0_1_0>(saved_context_json)?;
105                        Ok(SavedContext {
106                            id: saved_context.id,
107                            zed: saved_context.zed,
108                            version: saved_context.version,
109                            text: saved_context.text,
110                            messages: saved_context.messages,
111                            message_metadata: saved_context.message_metadata,
112                            summary: saved_context.summary,
113                        })
114                    }
115                    _ => Err(anyhow!("unrecognized saved context version: {}", version)),
116                },
117                _ => Err(anyhow!("version not found on saved context")),
118            }
119        })
120    }
121
122    pub fn search(&self, query: String, cx: &AppContext) -> Task<Vec<SavedContextMetadata>> {
123        let metadata = self.contexts_metadata.clone();
124        let executor = cx.background_executor().clone();
125        cx.background_executor().spawn(async move {
126            if query.is_empty() {
127                metadata
128            } else {
129                let candidates = metadata
130                    .iter()
131                    .enumerate()
132                    .map(|(id, metadata)| StringMatchCandidate::new(id, metadata.title.clone()))
133                    .collect::<Vec<_>>();
134                let matches = fuzzy::match_strings(
135                    &candidates,
136                    &query,
137                    false,
138                    100,
139                    &Default::default(),
140                    executor,
141                )
142                .await;
143
144                matches
145                    .into_iter()
146                    .map(|mat| metadata[mat.candidate_id].clone())
147                    .collect()
148            }
149        })
150    }
151
152    fn reload(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
153        let fs = self.fs.clone();
154        cx.spawn(|this, mut cx| async move {
155            fs.create_dir(&CONTEXTS_DIR).await?;
156
157            let mut paths = fs.read_dir(&CONTEXTS_DIR).await?;
158            let mut contexts = Vec::<SavedContextMetadata>::new();
159            while let Some(path) = paths.next().await {
160                let path = path?;
161                if path.extension() != Some(OsStr::new("json")) {
162                    continue;
163                }
164
165                let pattern = r" - \d+.zed.json$";
166                let re = Regex::new(pattern).unwrap();
167
168                let metadata = fs.metadata(&path).await?;
169                if let Some((file_name, metadata)) = path
170                    .file_name()
171                    .and_then(|name| name.to_str())
172                    .zip(metadata)
173                {
174                    // This is used to filter out contexts saved by the new assistant.
175                    if !re.is_match(file_name) {
176                        continue;
177                    }
178
179                    if let Some(title) = re.replace(file_name, "").lines().next() {
180                        contexts.push(SavedContextMetadata {
181                            title: title.to_string(),
182                            path,
183                            mtime: metadata.mtime.into(),
184                        });
185                    }
186                }
187            }
188            contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
189
190            this.update(&mut cx, |this, cx| {
191                this.contexts_metadata = contexts;
192                cx.notify();
193            })
194        })
195    }
196}