1mod prompts;
2
3use anyhow::{Result, anyhow};
4use chrono::{DateTime, Utc};
5use collections::HashMap;
6use futures::FutureExt as _;
7use futures::future::{self, BoxFuture, Shared};
8use fuzzy::StringMatchCandidate;
9use gpui::{App, BackgroundExecutor, Global, ReadGlobal, SharedString, Task};
10use heed::{
11 Database, RoTxn,
12 types::{SerdeBincode, SerdeJson, Str},
13};
14use parking_lot::RwLock;
15pub use prompts::*;
16use rope::Rope;
17use serde::{Deserialize, Serialize};
18use std::{
19 cmp::Reverse,
20 future::Future,
21 path::PathBuf,
22 sync::{Arc, atomic::AtomicBool},
23};
24use text::LineEnding;
25use util::ResultExt;
26use uuid::Uuid;
27
28/// Init starts loading the PromptStore in the background and assigns
29/// a shared future to a global.
30pub fn init(cx: &mut App) {
31 let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb");
32 let prompt_store_future = PromptStore::new(db_path, cx.background_executor().clone())
33 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
34 .boxed()
35 .shared();
36 cx.set_global(GlobalPromptStore(prompt_store_future))
37}
38
39#[derive(Clone, Debug, Serialize, Deserialize)]
40pub struct PromptMetadata {
41 pub id: PromptId,
42 pub title: Option<SharedString>,
43 pub default: bool,
44 pub saved_at: DateTime<Utc>,
45}
46
47#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
48#[serde(tag = "kind")]
49pub enum PromptId {
50 User { uuid: Uuid },
51 EditWorkflow,
52}
53
54impl PromptId {
55 pub fn new() -> PromptId {
56 PromptId::User {
57 uuid: Uuid::new_v4(),
58 }
59 }
60
61 pub fn is_built_in(&self) -> bool {
62 !matches!(self, PromptId::User { .. })
63 }
64}
65
66pub struct PromptStore {
67 executor: BackgroundExecutor,
68 env: heed::Env,
69 metadata_cache: RwLock<MetadataCache>,
70 metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
71 bodies: Database<SerdeJson<PromptId>, Str>,
72}
73
74#[derive(Default)]
75struct MetadataCache {
76 metadata: Vec<PromptMetadata>,
77 metadata_by_id: HashMap<PromptId, PromptMetadata>,
78}
79
80impl MetadataCache {
81 fn from_db(
82 db: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
83 txn: &RoTxn,
84 ) -> Result<Self> {
85 let mut cache = MetadataCache::default();
86 for result in db.iter(txn)? {
87 let (prompt_id, metadata) = result?;
88 cache.metadata.push(metadata.clone());
89 cache.metadata_by_id.insert(prompt_id, metadata);
90 }
91 cache.sort();
92 Ok(cache)
93 }
94
95 fn insert(&mut self, metadata: PromptMetadata) {
96 self.metadata_by_id.insert(metadata.id, metadata.clone());
97 if let Some(old_metadata) = self.metadata.iter_mut().find(|m| m.id == metadata.id) {
98 *old_metadata = metadata;
99 } else {
100 self.metadata.push(metadata);
101 }
102 self.sort();
103 }
104
105 fn remove(&mut self, id: PromptId) {
106 self.metadata.retain(|metadata| metadata.id != id);
107 self.metadata_by_id.remove(&id);
108 }
109
110 fn sort(&mut self) {
111 self.metadata.sort_unstable_by(|a, b| {
112 a.title
113 .cmp(&b.title)
114 .then_with(|| b.saved_at.cmp(&a.saved_at))
115 });
116 }
117}
118
119impl PromptStore {
120 pub fn global(cx: &App) -> impl Future<Output = Result<Arc<Self>>> + use<> {
121 let store = GlobalPromptStore::global(cx).0.clone();
122 async move { store.await.map_err(|err| anyhow!(err)) }
123 }
124
125 pub fn new(db_path: PathBuf, executor: BackgroundExecutor) -> Task<Result<Self>> {
126 executor.spawn({
127 let executor = executor.clone();
128 async move {
129 std::fs::create_dir_all(&db_path)?;
130
131 let db_env = unsafe {
132 heed::EnvOpenOptions::new()
133 .map_size(1024 * 1024 * 1024) // 1GB
134 .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
135 .open(db_path)?
136 };
137
138 let mut txn = db_env.write_txn()?;
139 let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
140 let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
141
142 // Remove edit workflow prompt, as we decided to opt into it using
143 // a slash command instead.
144 metadata.delete(&mut txn, &PromptId::EditWorkflow).ok();
145 bodies.delete(&mut txn, &PromptId::EditWorkflow).ok();
146
147 txn.commit()?;
148
149 Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
150
151 let txn = db_env.read_txn()?;
152 let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
153 txn.commit()?;
154
155 Ok(PromptStore {
156 executor,
157 env: db_env,
158 metadata_cache: RwLock::new(metadata_cache),
159 metadata,
160 bodies,
161 })
162 }
163 })
164 }
165
166 fn upgrade_dbs(
167 env: &heed::Env,
168 metadata_db: heed::Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
169 bodies_db: heed::Database<SerdeJson<PromptId>, Str>,
170 ) -> Result<()> {
171 #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
172 pub struct PromptIdV1(Uuid);
173
174 #[derive(Clone, Debug, Serialize, Deserialize)]
175 pub struct PromptMetadataV1 {
176 pub id: PromptIdV1,
177 pub title: Option<SharedString>,
178 pub default: bool,
179 pub saved_at: DateTime<Utc>,
180 }
181
182 let mut txn = env.write_txn()?;
183 let Some(bodies_v1_db) = env
184 .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<String>>(
185 &txn,
186 Some("bodies"),
187 )?
188 else {
189 return Ok(());
190 };
191 let mut bodies_v1 = bodies_v1_db
192 .iter(&txn)?
193 .collect::<heed::Result<HashMap<_, _>>>()?;
194
195 let Some(metadata_v1_db) = env
196 .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<PromptMetadataV1>>(
197 &txn,
198 Some("metadata"),
199 )?
200 else {
201 return Ok(());
202 };
203 let metadata_v1 = metadata_v1_db
204 .iter(&txn)?
205 .collect::<heed::Result<HashMap<_, _>>>()?;
206
207 for (prompt_id_v1, metadata_v1) in metadata_v1 {
208 let prompt_id_v2 = PromptId::User {
209 uuid: prompt_id_v1.0,
210 };
211 let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
212 continue;
213 };
214
215 if metadata_db
216 .get(&txn, &prompt_id_v2)?
217 .map_or(true, |metadata_v2| {
218 metadata_v1.saved_at > metadata_v2.saved_at
219 })
220 {
221 metadata_db.put(
222 &mut txn,
223 &prompt_id_v2,
224 &PromptMetadata {
225 id: prompt_id_v2,
226 title: metadata_v1.title.clone(),
227 default: metadata_v1.default,
228 saved_at: metadata_v1.saved_at,
229 },
230 )?;
231 bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?;
232 }
233 }
234
235 txn.commit()?;
236
237 Ok(())
238 }
239
240 pub fn load(&self, id: PromptId) -> Task<Result<String>> {
241 let env = self.env.clone();
242 let bodies = self.bodies;
243 self.executor.spawn(async move {
244 let txn = env.read_txn()?;
245 let mut prompt = bodies
246 .get(&txn, &id)?
247 .ok_or_else(|| anyhow!("prompt not found"))?
248 .into();
249 LineEnding::normalize(&mut prompt);
250 Ok(prompt)
251 })
252 }
253
254 pub fn default_prompt_metadata(&self) -> Vec<PromptMetadata> {
255 return self
256 .metadata_cache
257 .read()
258 .metadata
259 .iter()
260 .filter(|metadata| metadata.default)
261 .cloned()
262 .collect::<Vec<_>>();
263 }
264
265 pub fn delete(&self, id: PromptId) -> Task<Result<()>> {
266 self.metadata_cache.write().remove(id);
267
268 let db_connection = self.env.clone();
269 let bodies = self.bodies;
270 let metadata = self.metadata;
271
272 self.executor.spawn(async move {
273 let mut txn = db_connection.write_txn()?;
274
275 metadata.delete(&mut txn, &id)?;
276 bodies.delete(&mut txn, &id)?;
277
278 txn.commit()?;
279 Ok(())
280 })
281 }
282
283 /// Returns the number of prompts in the store.
284 pub fn prompt_count(&self) -> usize {
285 self.metadata_cache.read().metadata.len()
286 }
287
288 pub fn metadata(&self, id: PromptId) -> Option<PromptMetadata> {
289 self.metadata_cache.read().metadata_by_id.get(&id).cloned()
290 }
291
292 pub fn first(&self) -> Option<PromptMetadata> {
293 self.metadata_cache.read().metadata.first().cloned()
294 }
295
296 pub fn id_for_title(&self, title: &str) -> Option<PromptId> {
297 let metadata_cache = self.metadata_cache.read();
298 let metadata = metadata_cache
299 .metadata
300 .iter()
301 .find(|metadata| metadata.title.as_ref().map(|title| &***title) == Some(title))?;
302 Some(metadata.id)
303 }
304
305 pub fn search(&self, query: String) -> Task<Vec<PromptMetadata>> {
306 let cached_metadata = self.metadata_cache.read().metadata.clone();
307 let executor = self.executor.clone();
308 self.executor.spawn(async move {
309 let mut matches = if query.is_empty() {
310 cached_metadata
311 } else {
312 let candidates = cached_metadata
313 .iter()
314 .enumerate()
315 .filter_map(|(ix, metadata)| {
316 Some(StringMatchCandidate::new(ix, metadata.title.as_ref()?))
317 })
318 .collect::<Vec<_>>();
319 let matches = fuzzy::match_strings(
320 &candidates,
321 &query,
322 false,
323 100,
324 &AtomicBool::default(),
325 executor,
326 )
327 .await;
328 matches
329 .into_iter()
330 .map(|mat| cached_metadata[mat.candidate_id].clone())
331 .collect()
332 };
333 matches.sort_by_key(|metadata| Reverse(metadata.default));
334 matches
335 })
336 }
337
338 pub fn save(
339 &self,
340 id: PromptId,
341 title: Option<SharedString>,
342 default: bool,
343 body: Rope,
344 ) -> Task<Result<()>> {
345 if id.is_built_in() {
346 return Task::ready(Err(anyhow!("built-in prompts cannot be saved")));
347 }
348
349 let prompt_metadata = PromptMetadata {
350 id,
351 title,
352 default,
353 saved_at: Utc::now(),
354 };
355 self.metadata_cache.write().insert(prompt_metadata.clone());
356
357 let db_connection = self.env.clone();
358 let bodies = self.bodies;
359 let metadata = self.metadata;
360
361 self.executor.spawn(async move {
362 let mut txn = db_connection.write_txn()?;
363
364 metadata.put(&mut txn, &id, &prompt_metadata)?;
365 bodies.put(&mut txn, &id, &body.to_string())?;
366
367 txn.commit()?;
368
369 Ok(())
370 })
371 }
372
373 pub fn save_metadata(
374 &self,
375 id: PromptId,
376 mut title: Option<SharedString>,
377 default: bool,
378 ) -> Task<Result<()>> {
379 let mut cache = self.metadata_cache.write();
380
381 if id.is_built_in() {
382 title = cache
383 .metadata_by_id
384 .get(&id)
385 .and_then(|metadata| metadata.title.clone());
386 }
387
388 let prompt_metadata = PromptMetadata {
389 id,
390 title,
391 default,
392 saved_at: Utc::now(),
393 };
394
395 cache.insert(prompt_metadata.clone());
396
397 let db_connection = self.env.clone();
398 let metadata = self.metadata;
399
400 self.executor.spawn(async move {
401 let mut txn = db_connection.write_txn()?;
402 metadata.put(&mut txn, &id, &prompt_metadata)?;
403 txn.commit()?;
404
405 Ok(())
406 })
407 }
408}
409
410/// Wraps a shared future to a prompt store so it can be assigned as a context global.
411pub struct GlobalPromptStore(
412 Shared<BoxFuture<'static, Result<Arc<PromptStore>, Arc<anyhow::Error>>>>,
413);
414
415impl Global for GlobalPromptStore {}