1use anyhow::{anyhow, Result};
2use chrono::{DateTime, Utc};
3use collections::HashMap;
4use futures::future::{self, BoxFuture, Shared};
5use futures::FutureExt as _;
6use fuzzy::StringMatchCandidate;
7use gpui::{App, BackgroundExecutor, Global, ReadGlobal, SharedString, Task};
8use heed::{
9 types::{SerdeBincode, SerdeJson, Str},
10 Database, RoTxn,
11};
12use parking_lot::RwLock;
13use rope::Rope;
14use serde::{Deserialize, Serialize};
15use std::{
16 cmp::Reverse,
17 future::Future,
18 path::PathBuf,
19 sync::{atomic::AtomicBool, Arc},
20};
21use text::LineEnding;
22use util::ResultExt;
23use uuid::Uuid;
24
25/// Init starts loading the PromptStore in the background and assigns
26/// a shared future to a global.
27pub fn init(cx: &mut App) {
28 let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb");
29 let prompt_store_future = PromptStore::new(db_path, cx.background_executor().clone())
30 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
31 .boxed()
32 .shared();
33 cx.set_global(GlobalPromptStore(prompt_store_future))
34}
35
36#[derive(Clone, Debug, Serialize, Deserialize)]
37pub struct PromptMetadata {
38 pub id: PromptId,
39 pub title: Option<SharedString>,
40 pub default: bool,
41 pub saved_at: DateTime<Utc>,
42}
43
44#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
45#[serde(tag = "kind")]
46pub enum PromptId {
47 User { uuid: Uuid },
48 EditWorkflow,
49}
50
51impl PromptId {
52 pub fn new() -> PromptId {
53 PromptId::User {
54 uuid: Uuid::new_v4(),
55 }
56 }
57
58 pub fn is_built_in(&self) -> bool {
59 !matches!(self, PromptId::User { .. })
60 }
61}
62
63pub struct PromptStore {
64 executor: BackgroundExecutor,
65 env: heed::Env,
66 metadata_cache: RwLock<MetadataCache>,
67 metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
68 bodies: Database<SerdeJson<PromptId>, Str>,
69}
70
71#[derive(Default)]
72struct MetadataCache {
73 metadata: Vec<PromptMetadata>,
74 metadata_by_id: HashMap<PromptId, PromptMetadata>,
75}
76
77impl MetadataCache {
78 fn from_db(
79 db: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
80 txn: &RoTxn,
81 ) -> Result<Self> {
82 let mut cache = MetadataCache::default();
83 for result in db.iter(txn)? {
84 let (prompt_id, metadata) = result?;
85 cache.metadata.push(metadata.clone());
86 cache.metadata_by_id.insert(prompt_id, metadata);
87 }
88 cache.sort();
89 Ok(cache)
90 }
91
92 fn insert(&mut self, metadata: PromptMetadata) {
93 self.metadata_by_id.insert(metadata.id, metadata.clone());
94 if let Some(old_metadata) = self.metadata.iter_mut().find(|m| m.id == metadata.id) {
95 *old_metadata = metadata;
96 } else {
97 self.metadata.push(metadata);
98 }
99 self.sort();
100 }
101
102 fn remove(&mut self, id: PromptId) {
103 self.metadata.retain(|metadata| metadata.id != id);
104 self.metadata_by_id.remove(&id);
105 }
106
107 fn sort(&mut self) {
108 self.metadata.sort_unstable_by(|a, b| {
109 a.title
110 .cmp(&b.title)
111 .then_with(|| b.saved_at.cmp(&a.saved_at))
112 });
113 }
114}
115
116impl PromptStore {
117 pub fn global(cx: &App) -> impl Future<Output = Result<Arc<Self>>> {
118 let store = GlobalPromptStore::global(cx).0.clone();
119 async move { store.await.map_err(|err| anyhow!(err)) }
120 }
121
122 pub fn new(db_path: PathBuf, executor: BackgroundExecutor) -> Task<Result<Self>> {
123 executor.spawn({
124 let executor = executor.clone();
125 async move {
126 std::fs::create_dir_all(&db_path)?;
127
128 let db_env = unsafe {
129 heed::EnvOpenOptions::new()
130 .map_size(1024 * 1024 * 1024) // 1GB
131 .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
132 .open(db_path)?
133 };
134
135 let mut txn = db_env.write_txn()?;
136 let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
137 let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
138
139 // Remove edit workflow prompt, as we decided to opt into it using
140 // a slash command instead.
141 metadata.delete(&mut txn, &PromptId::EditWorkflow).ok();
142 bodies.delete(&mut txn, &PromptId::EditWorkflow).ok();
143
144 txn.commit()?;
145
146 Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
147
148 let txn = db_env.read_txn()?;
149 let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
150 txn.commit()?;
151
152 Ok(PromptStore {
153 executor,
154 env: db_env,
155 metadata_cache: RwLock::new(metadata_cache),
156 metadata,
157 bodies,
158 })
159 }
160 })
161 }
162
163 fn upgrade_dbs(
164 env: &heed::Env,
165 metadata_db: heed::Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
166 bodies_db: heed::Database<SerdeJson<PromptId>, Str>,
167 ) -> Result<()> {
168 #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
169 pub struct PromptIdV1(Uuid);
170
171 #[derive(Clone, Debug, Serialize, Deserialize)]
172 pub struct PromptMetadataV1 {
173 pub id: PromptIdV1,
174 pub title: Option<SharedString>,
175 pub default: bool,
176 pub saved_at: DateTime<Utc>,
177 }
178
179 let mut txn = env.write_txn()?;
180 let Some(bodies_v1_db) = env
181 .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<String>>(
182 &txn,
183 Some("bodies"),
184 )?
185 else {
186 return Ok(());
187 };
188 let mut bodies_v1 = bodies_v1_db
189 .iter(&txn)?
190 .collect::<heed::Result<HashMap<_, _>>>()?;
191
192 let Some(metadata_v1_db) = env
193 .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<PromptMetadataV1>>(
194 &txn,
195 Some("metadata"),
196 )?
197 else {
198 return Ok(());
199 };
200 let metadata_v1 = metadata_v1_db
201 .iter(&txn)?
202 .collect::<heed::Result<HashMap<_, _>>>()?;
203
204 for (prompt_id_v1, metadata_v1) in metadata_v1 {
205 let prompt_id_v2 = PromptId::User {
206 uuid: prompt_id_v1.0,
207 };
208 let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
209 continue;
210 };
211
212 if metadata_db
213 .get(&txn, &prompt_id_v2)?
214 .map_or(true, |metadata_v2| {
215 metadata_v1.saved_at > metadata_v2.saved_at
216 })
217 {
218 metadata_db.put(
219 &mut txn,
220 &prompt_id_v2,
221 &PromptMetadata {
222 id: prompt_id_v2,
223 title: metadata_v1.title.clone(),
224 default: metadata_v1.default,
225 saved_at: metadata_v1.saved_at,
226 },
227 )?;
228 bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?;
229 }
230 }
231
232 txn.commit()?;
233
234 Ok(())
235 }
236
237 pub fn load(&self, id: PromptId) -> Task<Result<String>> {
238 let env = self.env.clone();
239 let bodies = self.bodies;
240 self.executor.spawn(async move {
241 let txn = env.read_txn()?;
242 let mut prompt = bodies
243 .get(&txn, &id)?
244 .ok_or_else(|| anyhow!("prompt not found"))?
245 .into();
246 LineEnding::normalize(&mut prompt);
247 Ok(prompt)
248 })
249 }
250
251 pub fn default_prompt_metadata(&self) -> Vec<PromptMetadata> {
252 return self
253 .metadata_cache
254 .read()
255 .metadata
256 .iter()
257 .filter(|metadata| metadata.default)
258 .cloned()
259 .collect::<Vec<_>>();
260 }
261
262 pub fn delete(&self, id: PromptId) -> Task<Result<()>> {
263 self.metadata_cache.write().remove(id);
264
265 let db_connection = self.env.clone();
266 let bodies = self.bodies;
267 let metadata = self.metadata;
268
269 self.executor.spawn(async move {
270 let mut txn = db_connection.write_txn()?;
271
272 metadata.delete(&mut txn, &id)?;
273 bodies.delete(&mut txn, &id)?;
274
275 txn.commit()?;
276 Ok(())
277 })
278 }
279
280 /// Returns the number of prompts in the store.
281 pub fn prompt_count(&self) -> usize {
282 self.metadata_cache.read().metadata.len()
283 }
284
285 pub fn metadata(&self, id: PromptId) -> Option<PromptMetadata> {
286 self.metadata_cache.read().metadata_by_id.get(&id).cloned()
287 }
288
289 pub fn first(&self) -> Option<PromptMetadata> {
290 self.metadata_cache.read().metadata.first().cloned()
291 }
292
293 pub fn id_for_title(&self, title: &str) -> Option<PromptId> {
294 let metadata_cache = self.metadata_cache.read();
295 let metadata = metadata_cache
296 .metadata
297 .iter()
298 .find(|metadata| metadata.title.as_ref().map(|title| &***title) == Some(title))?;
299 Some(metadata.id)
300 }
301
302 pub fn search(&self, query: String) -> Task<Vec<PromptMetadata>> {
303 let cached_metadata = self.metadata_cache.read().metadata.clone();
304 let executor = self.executor.clone();
305 self.executor.spawn(async move {
306 let mut matches = if query.is_empty() {
307 cached_metadata
308 } else {
309 let candidates = cached_metadata
310 .iter()
311 .enumerate()
312 .filter_map(|(ix, metadata)| {
313 Some(StringMatchCandidate::new(ix, metadata.title.as_ref()?))
314 })
315 .collect::<Vec<_>>();
316 let matches = fuzzy::match_strings(
317 &candidates,
318 &query,
319 false,
320 100,
321 &AtomicBool::default(),
322 executor,
323 )
324 .await;
325 matches
326 .into_iter()
327 .map(|mat| cached_metadata[mat.candidate_id].clone())
328 .collect()
329 };
330 matches.sort_by_key(|metadata| Reverse(metadata.default));
331 matches
332 })
333 }
334
335 pub fn save(
336 &self,
337 id: PromptId,
338 title: Option<SharedString>,
339 default: bool,
340 body: Rope,
341 ) -> Task<Result<()>> {
342 if id.is_built_in() {
343 return Task::ready(Err(anyhow!("built-in prompts cannot be saved")));
344 }
345
346 let prompt_metadata = PromptMetadata {
347 id,
348 title,
349 default,
350 saved_at: Utc::now(),
351 };
352 self.metadata_cache.write().insert(prompt_metadata.clone());
353
354 let db_connection = self.env.clone();
355 let bodies = self.bodies;
356 let metadata = self.metadata;
357
358 self.executor.spawn(async move {
359 let mut txn = db_connection.write_txn()?;
360
361 metadata.put(&mut txn, &id, &prompt_metadata)?;
362 bodies.put(&mut txn, &id, &body.to_string())?;
363
364 txn.commit()?;
365
366 Ok(())
367 })
368 }
369
370 pub fn save_metadata(
371 &self,
372 id: PromptId,
373 mut title: Option<SharedString>,
374 default: bool,
375 ) -> Task<Result<()>> {
376 let mut cache = self.metadata_cache.write();
377
378 if id.is_built_in() {
379 title = cache
380 .metadata_by_id
381 .get(&id)
382 .and_then(|metadata| metadata.title.clone());
383 }
384
385 let prompt_metadata = PromptMetadata {
386 id,
387 title,
388 default,
389 saved_at: Utc::now(),
390 };
391
392 cache.insert(prompt_metadata.clone());
393
394 let db_connection = self.env.clone();
395 let metadata = self.metadata;
396
397 self.executor.spawn(async move {
398 let mut txn = db_connection.write_txn()?;
399 metadata.put(&mut txn, &id, &prompt_metadata)?;
400 txn.commit()?;
401
402 Ok(())
403 })
404 }
405}
406
407/// Wraps a shared future to a prompt store so it can be assigned as a context global.
408pub struct GlobalPromptStore(
409 Shared<BoxFuture<'static, Result<Arc<PromptStore>, Arc<anyhow::Error>>>>,
410);
411
412impl Global for GlobalPromptStore {}