context_store.rs

  1use crate::{assistant_settings::OpenAiModel, MessageId, MessageMetadata};
  2use anyhow::{anyhow, Result};
  3use collections::HashMap;
  4use fs::Fs;
  5use futures::StreamExt;
  6use fuzzy::StringMatchCandidate;
  7use gpui::{AppContext, Model, ModelContext, Task};
  8use paths::CONTEXTS_DIR;
  9use regex::Regex;
 10use serde::{Deserialize, Serialize};
 11use std::{cmp::Reverse, ffi::OsStr, path::PathBuf, sync::Arc, time::Duration};
 12use ui::Context;
 13use util::{ResultExt, TryFutureExt};
 14
 15#[derive(Serialize, Deserialize)]
 16pub struct SavedMessage {
 17    pub id: MessageId,
 18    pub start: usize,
 19}
 20
 21#[derive(Serialize, Deserialize)]
 22pub struct SavedContext {
 23    pub id: Option<String>,
 24    pub zed: String,
 25    pub version: String,
 26    pub text: String,
 27    pub messages: Vec<SavedMessage>,
 28    pub message_metadata: HashMap<MessageId, MessageMetadata>,
 29    pub summary: String,
 30}
 31
 32impl SavedContext {
 33    pub const VERSION: &'static str = "0.2.0";
 34}
 35
 36#[derive(Serialize, Deserialize)]
 37struct SavedContextV0_1_0 {
 38    id: Option<String>,
 39    zed: String,
 40    version: String,
 41    text: String,
 42    messages: Vec<SavedMessage>,
 43    message_metadata: HashMap<MessageId, MessageMetadata>,
 44    summary: String,
 45    api_url: Option<String>,
 46    model: OpenAiModel,
 47}
 48
 49#[derive(Clone)]
 50pub struct SavedContextMetadata {
 51    pub title: String,
 52    pub path: PathBuf,
 53    pub mtime: chrono::DateTime<chrono::Local>,
 54}
 55
 56pub struct ContextStore {
 57    contexts_metadata: Vec<SavedContextMetadata>,
 58    fs: Arc<dyn Fs>,
 59    _watch_updates: Task<Option<()>>,
 60}
 61
 62impl ContextStore {
 63    pub fn new(fs: Arc<dyn Fs>, cx: &mut AppContext) -> Task<Result<Model<Self>>> {
 64        cx.spawn(|mut cx| async move {
 65            const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
 66            let (mut events, _) = fs.watch(&CONTEXTS_DIR, CONTEXT_WATCH_DURATION).await;
 67
 68            let this = cx.new_model(|cx: &mut ModelContext<Self>| Self {
 69                contexts_metadata: Vec::new(),
 70                fs,
 71                _watch_updates: cx.spawn(|this, mut cx| {
 72                    async move {
 73                        while events.next().await.is_some() {
 74                            this.update(&mut cx, |this, cx| this.reload(cx))?
 75                                .await
 76                                .log_err();
 77                        }
 78                        anyhow::Ok(())
 79                    }
 80                    .log_err()
 81                }),
 82            })?;
 83            this.update(&mut cx, |this, cx| this.reload(cx))?
 84                .await
 85                .log_err();
 86            Ok(this)
 87        })
 88    }
 89
 90    pub fn load(&self, path: PathBuf, cx: &AppContext) -> Task<Result<SavedContext>> {
 91        let fs = self.fs.clone();
 92        cx.background_executor().spawn(async move {
 93            let saved_context = fs.load(&path).await?;
 94            let saved_context_json = serde_json::from_str::<serde_json::Value>(&saved_context)?;
 95            match saved_context_json
 96                .get("version")
 97                .ok_or_else(|| anyhow!("version not found"))?
 98            {
 99                serde_json::Value::String(version) => match version.as_str() {
100                    SavedContext::VERSION => {
101                        Ok(serde_json::from_value::<SavedContext>(saved_context_json)?)
102                    }
103                    "0.1.0" => {
104                        let saved_context =
105                            serde_json::from_value::<SavedContextV0_1_0>(saved_context_json)?;
106                        Ok(SavedContext {
107                            id: saved_context.id,
108                            zed: saved_context.zed,
109                            version: saved_context.version,
110                            text: saved_context.text,
111                            messages: saved_context.messages,
112                            message_metadata: saved_context.message_metadata,
113                            summary: saved_context.summary,
114                        })
115                    }
116                    _ => Err(anyhow!("unrecognized saved context version: {}", version)),
117                },
118                _ => Err(anyhow!("version not found on saved context")),
119            }
120        })
121    }
122
123    pub fn search(&self, query: String, cx: &AppContext) -> Task<Vec<SavedContextMetadata>> {
124        let metadata = self.contexts_metadata.clone();
125        let executor = cx.background_executor().clone();
126        cx.background_executor().spawn(async move {
127            if query.is_empty() {
128                metadata
129            } else {
130                let candidates = metadata
131                    .iter()
132                    .enumerate()
133                    .map(|(id, metadata)| StringMatchCandidate::new(id, metadata.title.clone()))
134                    .collect::<Vec<_>>();
135                let matches = fuzzy::match_strings(
136                    &candidates,
137                    &query,
138                    false,
139                    100,
140                    &Default::default(),
141                    executor,
142                )
143                .await;
144
145                matches
146                    .into_iter()
147                    .map(|mat| metadata[mat.candidate_id].clone())
148                    .collect()
149            }
150        })
151    }
152
153    fn reload(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
154        let fs = self.fs.clone();
155        cx.spawn(|this, mut cx| async move {
156            fs.create_dir(&CONTEXTS_DIR).await?;
157
158            let mut paths = fs.read_dir(&CONTEXTS_DIR).await?;
159            let mut contexts = Vec::<SavedContextMetadata>::new();
160            while let Some(path) = paths.next().await {
161                let path = path?;
162                if path.extension() != Some(OsStr::new("json")) {
163                    continue;
164                }
165
166                let pattern = r" - \d+.zed.json$";
167                let re = Regex::new(pattern).unwrap();
168
169                let metadata = fs.metadata(&path).await?;
170                if let Some((file_name, metadata)) = path
171                    .file_name()
172                    .and_then(|name| name.to_str())
173                    .zip(metadata)
174                {
175                    // This is used to filter out contexts saved by the new assistant.
176                    if !re.is_match(file_name) {
177                        continue;
178                    }
179
180                    if let Some(title) = re.replace(file_name, "").lines().next() {
181                        contexts.push(SavedContextMetadata {
182                            title: title.to_string(),
183                            path,
184                            mtime: metadata.mtime.into(),
185                        });
186                    }
187                }
188            }
189            contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
190
191            this.update(&mut cx, |this, cx| {
192                this.contexts_metadata = contexts;
193                cx.notify();
194            })
195        })
196    }
197}