1use crate::{assistant_settings::OpenAiModel, MessageId, MessageMetadata};
2use anyhow::{anyhow, Result};
3use assistant_slash_command::SlashCommandOutputSection;
4use collections::HashMap;
5use fs::Fs;
6use futures::StreamExt;
7use fuzzy::StringMatchCandidate;
8use gpui::{AppContext, Model, ModelContext, Task};
9use paths::contexts_dir;
10use regex::Regex;
11use serde::{Deserialize, Serialize};
12use std::{cmp::Reverse, ffi::OsStr, path::PathBuf, sync::Arc, time::Duration};
13use ui::Context;
14use util::{ResultExt, TryFutureExt};
15
16#[derive(Serialize, Deserialize)]
17pub struct SavedMessage {
18 pub id: MessageId,
19 pub start: usize,
20}
21
22#[derive(Serialize, Deserialize)]
23pub struct SavedContext {
24 pub id: Option<String>,
25 pub zed: String,
26 pub version: String,
27 pub text: String,
28 pub messages: Vec<SavedMessage>,
29 pub message_metadata: HashMap<MessageId, MessageMetadata>,
30 pub summary: String,
31 pub slash_command_output_sections: Vec<SlashCommandOutputSection<usize>>,
32}
33
34impl SavedContext {
35 pub const VERSION: &'static str = "0.3.0";
36}
37
38#[derive(Serialize, Deserialize)]
39pub struct SavedContextV0_2_0 {
40 pub id: Option<String>,
41 pub zed: String,
42 pub version: String,
43 pub text: String,
44 pub messages: Vec<SavedMessage>,
45 pub message_metadata: HashMap<MessageId, MessageMetadata>,
46 pub summary: String,
47}
48
49#[derive(Serialize, Deserialize)]
50struct SavedContextV0_1_0 {
51 id: Option<String>,
52 zed: String,
53 version: String,
54 text: String,
55 messages: Vec<SavedMessage>,
56 message_metadata: HashMap<MessageId, MessageMetadata>,
57 summary: String,
58 api_url: Option<String>,
59 model: OpenAiModel,
60}
61
62#[derive(Clone)]
63pub struct SavedContextMetadata {
64 pub title: String,
65 pub path: PathBuf,
66 pub mtime: chrono::DateTime<chrono::Local>,
67}
68
69pub struct ContextStore {
70 contexts_metadata: Vec<SavedContextMetadata>,
71 fs: Arc<dyn Fs>,
72 _watch_updates: Task<Option<()>>,
73}
74
75impl ContextStore {
76 pub fn new(fs: Arc<dyn Fs>, cx: &mut AppContext) -> Task<Result<Model<Self>>> {
77 cx.spawn(|mut cx| async move {
78 const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
79 let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
80
81 let this = cx.new_model(|cx: &mut ModelContext<Self>| Self {
82 contexts_metadata: Vec::new(),
83 fs,
84 _watch_updates: cx.spawn(|this, mut cx| {
85 async move {
86 while events.next().await.is_some() {
87 this.update(&mut cx, |this, cx| this.reload(cx))?
88 .await
89 .log_err();
90 }
91 anyhow::Ok(())
92 }
93 .log_err()
94 }),
95 })?;
96 this.update(&mut cx, |this, cx| this.reload(cx))?
97 .await
98 .log_err();
99 Ok(this)
100 })
101 }
102
103 pub fn load(&self, path: PathBuf, cx: &AppContext) -> Task<Result<SavedContext>> {
104 let fs = self.fs.clone();
105 cx.background_executor().spawn(async move {
106 let saved_context = fs.load(&path).await?;
107 let saved_context_json = serde_json::from_str::<serde_json::Value>(&saved_context)?;
108 match saved_context_json
109 .get("version")
110 .ok_or_else(|| anyhow!("version not found"))?
111 {
112 serde_json::Value::String(version) => match version.as_str() {
113 SavedContext::VERSION => {
114 Ok(serde_json::from_value::<SavedContext>(saved_context_json)?)
115 }
116 "0.2.0" => {
117 let saved_context =
118 serde_json::from_value::<SavedContextV0_2_0>(saved_context_json)?;
119 Ok(SavedContext {
120 id: saved_context.id,
121 zed: saved_context.zed,
122 version: saved_context.version,
123 text: saved_context.text,
124 messages: saved_context.messages,
125 message_metadata: saved_context.message_metadata,
126 summary: saved_context.summary,
127 slash_command_output_sections: Vec::new(),
128 })
129 }
130 "0.1.0" => {
131 let saved_context =
132 serde_json::from_value::<SavedContextV0_1_0>(saved_context_json)?;
133 Ok(SavedContext {
134 id: saved_context.id,
135 zed: saved_context.zed,
136 version: saved_context.version,
137 text: saved_context.text,
138 messages: saved_context.messages,
139 message_metadata: saved_context.message_metadata,
140 summary: saved_context.summary,
141 slash_command_output_sections: Vec::new(),
142 })
143 }
144 _ => Err(anyhow!("unrecognized saved context version: {}", version)),
145 },
146 _ => Err(anyhow!("version not found on saved context")),
147 }
148 })
149 }
150
151 pub fn search(&self, query: String, cx: &AppContext) -> Task<Vec<SavedContextMetadata>> {
152 let metadata = self.contexts_metadata.clone();
153 let executor = cx.background_executor().clone();
154 cx.background_executor().spawn(async move {
155 if query.is_empty() {
156 metadata
157 } else {
158 let candidates = metadata
159 .iter()
160 .enumerate()
161 .map(|(id, metadata)| StringMatchCandidate::new(id, metadata.title.clone()))
162 .collect::<Vec<_>>();
163 let matches = fuzzy::match_strings(
164 &candidates,
165 &query,
166 false,
167 100,
168 &Default::default(),
169 executor,
170 )
171 .await;
172
173 matches
174 .into_iter()
175 .map(|mat| metadata[mat.candidate_id].clone())
176 .collect()
177 }
178 })
179 }
180
181 fn reload(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
182 let fs = self.fs.clone();
183 cx.spawn(|this, mut cx| async move {
184 fs.create_dir(contexts_dir()).await?;
185
186 let mut paths = fs.read_dir(contexts_dir()).await?;
187 let mut contexts = Vec::<SavedContextMetadata>::new();
188 while let Some(path) = paths.next().await {
189 let path = path?;
190 if path.extension() != Some(OsStr::new("json")) {
191 continue;
192 }
193
194 let pattern = r" - \d+.zed.json$";
195 let re = Regex::new(pattern).unwrap();
196
197 let metadata = fs.metadata(&path).await?;
198 if let Some((file_name, metadata)) = path
199 .file_name()
200 .and_then(|name| name.to_str())
201 .zip(metadata)
202 {
203 // This is used to filter out contexts saved by the new assistant.
204 if !re.is_match(file_name) {
205 continue;
206 }
207
208 if let Some(title) = re.replace(file_name, "").lines().next() {
209 contexts.push(SavedContextMetadata {
210 title: title.to_string(),
211 path,
212 mtime: metadata.mtime.into(),
213 });
214 }
215 }
216 }
217 contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
218
219 this.update(&mut cx, |this, cx| {
220 this.contexts_metadata = contexts;
221 cx.notify();
222 })
223 })
224 }
225}