prompt_store.rs

  1mod prompts;
  2
  3use anyhow::{Result, anyhow};
  4use chrono::{DateTime, Utc};
  5use collections::HashMap;
  6use futures::FutureExt as _;
  7use futures::future::Shared;
  8use fuzzy::StringMatchCandidate;
  9use gpui::{
 10    App, AppContext, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Task,
 11};
 12use heed::{
 13    Database, RoTxn,
 14    types::{SerdeBincode, SerdeJson, Str},
 15};
 16use parking_lot::RwLock;
 17pub use prompts::*;
 18use rope::Rope;
 19use serde::{Deserialize, Serialize};
 20use std::{
 21    cmp::Reverse,
 22    future::Future,
 23    path::PathBuf,
 24    sync::{Arc, atomic::AtomicBool},
 25};
 26use text::LineEnding;
 27use util::ResultExt;
 28use uuid::Uuid;
 29
 30/// Init starts loading the PromptStore in the background and assigns
 31/// a shared future to a global.
 32pub fn init(cx: &mut App) {
 33    let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb");
 34    let prompt_store_task = PromptStore::new(db_path, cx);
 35    let prompt_store_entity_task = cx
 36        .spawn(async move |cx| {
 37            prompt_store_task
 38                .await
 39                .and_then(|prompt_store| cx.new(|_cx| prompt_store))
 40                .map_err(Arc::new)
 41        })
 42        .shared();
 43    cx.set_global(GlobalPromptStore(prompt_store_entity_task))
 44}
 45
 46#[derive(Clone, Debug, Serialize, Deserialize)]
 47pub struct PromptMetadata {
 48    pub id: PromptId,
 49    pub title: Option<SharedString>,
 50    pub default: bool,
 51    pub saved_at: DateTime<Utc>,
 52}
 53
 54#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
 55#[serde(tag = "kind")]
 56pub enum PromptId {
 57    User { uuid: UserPromptId },
 58    EditWorkflow,
 59}
 60
 61impl PromptId {
 62    pub fn new() -> PromptId {
 63        PromptId::User {
 64            uuid: UserPromptId::new(),
 65        }
 66    }
 67
 68    pub fn is_built_in(&self) -> bool {
 69        !matches!(self, PromptId::User { .. })
 70    }
 71}
 72
 73#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
 74#[serde(transparent)]
 75pub struct UserPromptId(pub Uuid);
 76
 77impl UserPromptId {
 78    pub fn new() -> UserPromptId {
 79        UserPromptId(Uuid::new_v4())
 80    }
 81}
 82
 83impl From<Uuid> for UserPromptId {
 84    fn from(uuid: Uuid) -> Self {
 85        UserPromptId(uuid)
 86    }
 87}
 88
 89pub struct PromptStore {
 90    env: heed::Env,
 91    metadata_cache: RwLock<MetadataCache>,
 92    metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
 93    bodies: Database<SerdeJson<PromptId>, Str>,
 94}
 95
 96pub struct PromptsUpdatedEvent;
 97
 98impl EventEmitter<PromptsUpdatedEvent> for PromptStore {}
 99
100#[derive(Default)]
101struct MetadataCache {
102    metadata: Vec<PromptMetadata>,
103    metadata_by_id: HashMap<PromptId, PromptMetadata>,
104}
105
106impl MetadataCache {
107    fn from_db(
108        db: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
109        txn: &RoTxn,
110    ) -> Result<Self> {
111        let mut cache = MetadataCache::default();
112        for result in db.iter(txn)? {
113            let (prompt_id, metadata) = result?;
114            cache.metadata.push(metadata.clone());
115            cache.metadata_by_id.insert(prompt_id, metadata);
116        }
117        cache.sort();
118        Ok(cache)
119    }
120
121    fn insert(&mut self, metadata: PromptMetadata) {
122        self.metadata_by_id.insert(metadata.id, metadata.clone());
123        if let Some(old_metadata) = self.metadata.iter_mut().find(|m| m.id == metadata.id) {
124            *old_metadata = metadata;
125        } else {
126            self.metadata.push(metadata);
127        }
128        self.sort();
129    }
130
131    fn remove(&mut self, id: PromptId) {
132        self.metadata.retain(|metadata| metadata.id != id);
133        self.metadata_by_id.remove(&id);
134    }
135
136    fn sort(&mut self) {
137        self.metadata.sort_unstable_by(|a, b| {
138            a.title
139                .cmp(&b.title)
140                .then_with(|| b.saved_at.cmp(&a.saved_at))
141        });
142    }
143}
144
145impl PromptStore {
146    pub fn global(cx: &App) -> impl Future<Output = Result<Entity<Self>>> + use<> {
147        let store = GlobalPromptStore::global(cx).0.clone();
148        async move { store.await.map_err(|err| anyhow!(err)) }
149    }
150
151    pub fn new(db_path: PathBuf, cx: &App) -> Task<Result<Self>> {
152        cx.background_spawn(async move {
153            std::fs::create_dir_all(&db_path)?;
154
155            let db_env = unsafe {
156                heed::EnvOpenOptions::new()
157                    .map_size(1024 * 1024 * 1024) // 1GB
158                    .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
159                    .open(db_path)?
160            };
161
162            let mut txn = db_env.write_txn()?;
163            let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
164            let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
165
166            // Remove edit workflow prompt, as we decided to opt into it using
167            // a slash command instead.
168            metadata.delete(&mut txn, &PromptId::EditWorkflow).ok();
169            bodies.delete(&mut txn, &PromptId::EditWorkflow).ok();
170
171            txn.commit()?;
172
173            Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
174
175            let txn = db_env.read_txn()?;
176            let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
177            txn.commit()?;
178
179            Ok(PromptStore {
180                env: db_env,
181                metadata_cache: RwLock::new(metadata_cache),
182                metadata,
183                bodies,
184            })
185        })
186    }
187
188    fn upgrade_dbs(
189        env: &heed::Env,
190        metadata_db: heed::Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
191        bodies_db: heed::Database<SerdeJson<PromptId>, Str>,
192    ) -> Result<()> {
193        #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
194        pub struct PromptIdV1(Uuid);
195
196        #[derive(Clone, Debug, Serialize, Deserialize)]
197        pub struct PromptMetadataV1 {
198            pub id: PromptIdV1,
199            pub title: Option<SharedString>,
200            pub default: bool,
201            pub saved_at: DateTime<Utc>,
202        }
203
204        let mut txn = env.write_txn()?;
205        let Some(bodies_v1_db) = env
206            .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<String>>(
207                &txn,
208                Some("bodies"),
209            )?
210        else {
211            return Ok(());
212        };
213        let mut bodies_v1 = bodies_v1_db
214            .iter(&txn)?
215            .collect::<heed::Result<HashMap<_, _>>>()?;
216
217        let Some(metadata_v1_db) = env
218            .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<PromptMetadataV1>>(
219                &txn,
220                Some("metadata"),
221            )?
222        else {
223            return Ok(());
224        };
225        let metadata_v1 = metadata_v1_db
226            .iter(&txn)?
227            .collect::<heed::Result<HashMap<_, _>>>()?;
228
229        for (prompt_id_v1, metadata_v1) in metadata_v1 {
230            let prompt_id_v2 = PromptId::User {
231                uuid: UserPromptId(prompt_id_v1.0),
232            };
233            let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
234                continue;
235            };
236
237            if metadata_db
238                .get(&txn, &prompt_id_v2)?
239                .map_or(true, |metadata_v2| {
240                    metadata_v1.saved_at > metadata_v2.saved_at
241                })
242            {
243                metadata_db.put(
244                    &mut txn,
245                    &prompt_id_v2,
246                    &PromptMetadata {
247                        id: prompt_id_v2,
248                        title: metadata_v1.title.clone(),
249                        default: metadata_v1.default,
250                        saved_at: metadata_v1.saved_at,
251                    },
252                )?;
253                bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?;
254            }
255        }
256
257        txn.commit()?;
258
259        Ok(())
260    }
261
262    pub fn load(&self, id: PromptId, cx: &App) -> Task<Result<String>> {
263        let env = self.env.clone();
264        let bodies = self.bodies;
265        cx.background_spawn(async move {
266            let txn = env.read_txn()?;
267            let mut prompt = bodies
268                .get(&txn, &id)?
269                .ok_or_else(|| anyhow!("prompt not found"))?
270                .into();
271            LineEnding::normalize(&mut prompt);
272            Ok(prompt)
273        })
274    }
275
276    pub fn all_prompt_metadata(&self) -> Vec<PromptMetadata> {
277        self.metadata_cache.read().metadata.clone()
278    }
279
280    pub fn default_prompt_metadata(&self) -> Vec<PromptMetadata> {
281        return self
282            .metadata_cache
283            .read()
284            .metadata
285            .iter()
286            .filter(|metadata| metadata.default)
287            .cloned()
288            .collect::<Vec<_>>();
289    }
290
291    pub fn delete(&self, id: PromptId, cx: &Context<Self>) -> Task<Result<()>> {
292        self.metadata_cache.write().remove(id);
293
294        let db_connection = self.env.clone();
295        let bodies = self.bodies;
296        let metadata = self.metadata;
297
298        let task = cx.background_spawn(async move {
299            let mut txn = db_connection.write_txn()?;
300
301            metadata.delete(&mut txn, &id)?;
302            bodies.delete(&mut txn, &id)?;
303
304            txn.commit()?;
305            anyhow::Ok(())
306        });
307
308        cx.spawn(async move |this, cx| {
309            task.await?;
310            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
311            anyhow::Ok(())
312        })
313    }
314
315    /// Returns the number of prompts in the store.
316    pub fn prompt_count(&self) -> usize {
317        self.metadata_cache.read().metadata.len()
318    }
319
320    pub fn metadata(&self, id: PromptId) -> Option<PromptMetadata> {
321        self.metadata_cache.read().metadata_by_id.get(&id).cloned()
322    }
323
324    pub fn first(&self) -> Option<PromptMetadata> {
325        self.metadata_cache.read().metadata.first().cloned()
326    }
327
328    pub fn id_for_title(&self, title: &str) -> Option<PromptId> {
329        let metadata_cache = self.metadata_cache.read();
330        let metadata = metadata_cache
331            .metadata
332            .iter()
333            .find(|metadata| metadata.title.as_ref().map(|title| &***title) == Some(title))?;
334        Some(metadata.id)
335    }
336
337    pub fn search(
338        &self,
339        query: String,
340        cancellation_flag: Arc<AtomicBool>,
341        cx: &App,
342    ) -> Task<Vec<PromptMetadata>> {
343        let cached_metadata = self.metadata_cache.read().metadata.clone();
344        let executor = cx.background_executor().clone();
345        cx.background_spawn(async move {
346            let mut matches = if query.is_empty() {
347                cached_metadata
348            } else {
349                let candidates = cached_metadata
350                    .iter()
351                    .enumerate()
352                    .filter_map(|(ix, metadata)| {
353                        Some(StringMatchCandidate::new(ix, metadata.title.as_ref()?))
354                    })
355                    .collect::<Vec<_>>();
356                let matches = fuzzy::match_strings(
357                    &candidates,
358                    &query,
359                    false,
360                    100,
361                    &cancellation_flag,
362                    executor,
363                )
364                .await;
365                matches
366                    .into_iter()
367                    .map(|mat| cached_metadata[mat.candidate_id].clone())
368                    .collect()
369            };
370            matches.sort_by_key(|metadata| Reverse(metadata.default));
371            matches
372        })
373    }
374
375    pub fn save(
376        &self,
377        id: PromptId,
378        title: Option<SharedString>,
379        default: bool,
380        body: Rope,
381        cx: &Context<Self>,
382    ) -> Task<Result<()>> {
383        if id.is_built_in() {
384            return Task::ready(Err(anyhow!("built-in prompts cannot be saved")));
385        }
386
387        let prompt_metadata = PromptMetadata {
388            id,
389            title,
390            default,
391            saved_at: Utc::now(),
392        };
393        self.metadata_cache.write().insert(prompt_metadata.clone());
394
395        let db_connection = self.env.clone();
396        let bodies = self.bodies;
397        let metadata = self.metadata;
398
399        let task = cx.background_spawn(async move {
400            let mut txn = db_connection.write_txn()?;
401
402            metadata.put(&mut txn, &id, &prompt_metadata)?;
403            bodies.put(&mut txn, &id, &body.to_string())?;
404
405            txn.commit()?;
406
407            anyhow::Ok(())
408        });
409
410        cx.spawn(async move |this, cx| {
411            task.await?;
412            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
413            anyhow::Ok(())
414        })
415    }
416
417    pub fn save_metadata(
418        &self,
419        id: PromptId,
420        mut title: Option<SharedString>,
421        default: bool,
422        cx: &Context<Self>,
423    ) -> Task<Result<()>> {
424        let mut cache = self.metadata_cache.write();
425
426        if id.is_built_in() {
427            title = cache
428                .metadata_by_id
429                .get(&id)
430                .and_then(|metadata| metadata.title.clone());
431        }
432
433        let prompt_metadata = PromptMetadata {
434            id,
435            title,
436            default,
437            saved_at: Utc::now(),
438        };
439
440        cache.insert(prompt_metadata.clone());
441
442        let db_connection = self.env.clone();
443        let metadata = self.metadata;
444
445        let task = cx.background_spawn(async move {
446            let mut txn = db_connection.write_txn()?;
447            metadata.put(&mut txn, &id, &prompt_metadata)?;
448            txn.commit()?;
449
450            anyhow::Ok(())
451        });
452
453        cx.spawn(async move |this, cx| {
454            task.await?;
455            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
456            anyhow::Ok(())
457        })
458    }
459}
460
461/// Wraps a shared future to a prompt store so it can be assigned as a context global.
462pub struct GlobalPromptStore(Shared<Task<Result<Entity<PromptStore>, Arc<anyhow::Error>>>>);
463
464impl Global for GlobalPromptStore {}