prompt_store.rs

  1mod prompts;
  2
  3use anyhow::{Context as _, Result, anyhow};
  4use chrono::{DateTime, Utc};
  5use collections::HashMap;
  6use futures::FutureExt as _;
  7use futures::future::Shared;
  8use fuzzy::StringMatchCandidate;
  9use gpui::{
 10    App, AppContext, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Task,
 11};
 12use heed::{
 13    Database, RoTxn,
 14    types::{SerdeBincode, SerdeJson, Str},
 15};
 16use parking_lot::RwLock;
 17pub use prompts::*;
 18use rope::Rope;
 19use serde::{Deserialize, Serialize};
 20use std::{
 21    cmp::Reverse,
 22    future::Future,
 23    path::PathBuf,
 24    sync::{Arc, atomic::AtomicBool},
 25};
 26use text::LineEnding;
 27use util::ResultExt;
 28use uuid::Uuid;
 29
 30/// Init starts loading the PromptStore in the background and assigns
 31/// a shared future to a global.
 32pub fn init(cx: &mut App) {
 33    let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb");
 34    let prompt_store_task = PromptStore::new(db_path, cx);
 35    let prompt_store_entity_task = cx
 36        .spawn(async move |cx| {
 37            prompt_store_task
 38                .await
 39                .and_then(|prompt_store| cx.new(|_cx| prompt_store))
 40                .map_err(Arc::new)
 41        })
 42        .shared();
 43    cx.set_global(GlobalPromptStore(prompt_store_entity_task))
 44}
 45
 46#[derive(Clone, Debug, Serialize, Deserialize)]
 47pub struct PromptMetadata {
 48    pub id: PromptId,
 49    pub title: Option<SharedString>,
 50    pub default: bool,
 51    pub saved_at: DateTime<Utc>,
 52}
 53
 54#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
 55#[serde(tag = "kind")]
 56pub enum PromptId {
 57    User { uuid: UserPromptId },
 58    EditWorkflow,
 59}
 60
 61impl PromptId {
 62    pub fn new() -> PromptId {
 63        UserPromptId::new().into()
 64    }
 65
 66    pub fn is_built_in(&self) -> bool {
 67        !matches!(self, PromptId::User { .. })
 68    }
 69}
 70
 71impl From<UserPromptId> for PromptId {
 72    fn from(uuid: UserPromptId) -> Self {
 73        PromptId::User { uuid }
 74    }
 75}
 76
 77#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
 78#[serde(transparent)]
 79pub struct UserPromptId(pub Uuid);
 80
 81impl UserPromptId {
 82    pub fn new() -> UserPromptId {
 83        UserPromptId(Uuid::new_v4())
 84    }
 85}
 86
 87impl From<Uuid> for UserPromptId {
 88    fn from(uuid: Uuid) -> Self {
 89        UserPromptId(uuid)
 90    }
 91}
 92
 93impl std::fmt::Display for PromptId {
 94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 95        match self {
 96            PromptId::User { uuid } => write!(f, "{}", uuid.0),
 97            PromptId::EditWorkflow => write!(f, "Edit workflow"),
 98        }
 99    }
100}
101
102pub struct PromptStore {
103    env: heed::Env,
104    metadata_cache: RwLock<MetadataCache>,
105    metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
106    bodies: Database<SerdeJson<PromptId>, Str>,
107}
108
109pub struct PromptsUpdatedEvent;
110
111impl EventEmitter<PromptsUpdatedEvent> for PromptStore {}
112
113#[derive(Default)]
114struct MetadataCache {
115    metadata: Vec<PromptMetadata>,
116    metadata_by_id: HashMap<PromptId, PromptMetadata>,
117}
118
119impl MetadataCache {
120    fn from_db(
121        db: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
122        txn: &RoTxn,
123    ) -> Result<Self> {
124        let mut cache = MetadataCache::default();
125        for result in db.iter(txn)? {
126            let (prompt_id, metadata) = result?;
127            cache.metadata.push(metadata.clone());
128            cache.metadata_by_id.insert(prompt_id, metadata);
129        }
130        cache.sort();
131        Ok(cache)
132    }
133
134    fn insert(&mut self, metadata: PromptMetadata) {
135        self.metadata_by_id.insert(metadata.id, metadata.clone());
136        if let Some(old_metadata) = self.metadata.iter_mut().find(|m| m.id == metadata.id) {
137            *old_metadata = metadata;
138        } else {
139            self.metadata.push(metadata);
140        }
141        self.sort();
142    }
143
144    fn remove(&mut self, id: PromptId) {
145        self.metadata.retain(|metadata| metadata.id != id);
146        self.metadata_by_id.remove(&id);
147    }
148
149    fn sort(&mut self) {
150        self.metadata.sort_unstable_by(|a, b| {
151            a.title
152                .cmp(&b.title)
153                .then_with(|| b.saved_at.cmp(&a.saved_at))
154        });
155    }
156}
157
158impl PromptStore {
159    pub fn global(cx: &App) -> impl Future<Output = Result<Entity<Self>>> + use<> {
160        let store = GlobalPromptStore::global(cx).0.clone();
161        async move { store.await.map_err(|err| anyhow!(err)) }
162    }
163
164    pub fn new(db_path: PathBuf, cx: &App) -> Task<Result<Self>> {
165        cx.background_spawn(async move {
166            std::fs::create_dir_all(&db_path)?;
167
168            let db_env = unsafe {
169                heed::EnvOpenOptions::new()
170                    .map_size(1024 * 1024 * 1024) // 1GB
171                    .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
172                    .open(db_path)?
173            };
174
175            let mut txn = db_env.write_txn()?;
176            let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
177            let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
178
179            // Remove edit workflow prompt, as we decided to opt into it using
180            // a slash command instead.
181            metadata.delete(&mut txn, &PromptId::EditWorkflow).ok();
182            bodies.delete(&mut txn, &PromptId::EditWorkflow).ok();
183
184            txn.commit()?;
185
186            Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
187
188            let txn = db_env.read_txn()?;
189            let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
190            txn.commit()?;
191
192            Ok(PromptStore {
193                env: db_env,
194                metadata_cache: RwLock::new(metadata_cache),
195                metadata,
196                bodies,
197            })
198        })
199    }
200
201    fn upgrade_dbs(
202        env: &heed::Env,
203        metadata_db: heed::Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
204        bodies_db: heed::Database<SerdeJson<PromptId>, Str>,
205    ) -> Result<()> {
206        #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
207        pub struct PromptIdV1(Uuid);
208
209        #[derive(Clone, Debug, Serialize, Deserialize)]
210        pub struct PromptMetadataV1 {
211            pub id: PromptIdV1,
212            pub title: Option<SharedString>,
213            pub default: bool,
214            pub saved_at: DateTime<Utc>,
215        }
216
217        let mut txn = env.write_txn()?;
218        let Some(bodies_v1_db) = env
219            .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<String>>(
220                &txn,
221                Some("bodies"),
222            )?
223        else {
224            return Ok(());
225        };
226        let mut bodies_v1 = bodies_v1_db
227            .iter(&txn)?
228            .collect::<heed::Result<HashMap<_, _>>>()?;
229
230        let Some(metadata_v1_db) = env
231            .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<PromptMetadataV1>>(
232                &txn,
233                Some("metadata"),
234            )?
235        else {
236            return Ok(());
237        };
238        let metadata_v1 = metadata_v1_db
239            .iter(&txn)?
240            .collect::<heed::Result<HashMap<_, _>>>()?;
241
242        for (prompt_id_v1, metadata_v1) in metadata_v1 {
243            let prompt_id_v2 = UserPromptId(prompt_id_v1.0).into();
244            let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
245                continue;
246            };
247
248            if metadata_db
249                .get(&txn, &prompt_id_v2)?
250                .map_or(true, |metadata_v2| {
251                    metadata_v1.saved_at > metadata_v2.saved_at
252                })
253            {
254                metadata_db.put(
255                    &mut txn,
256                    &prompt_id_v2,
257                    &PromptMetadata {
258                        id: prompt_id_v2,
259                        title: metadata_v1.title.clone(),
260                        default: metadata_v1.default,
261                        saved_at: metadata_v1.saved_at,
262                    },
263                )?;
264                bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?;
265            }
266        }
267
268        txn.commit()?;
269
270        Ok(())
271    }
272
273    pub fn load(&self, id: PromptId, cx: &App) -> Task<Result<String>> {
274        let env = self.env.clone();
275        let bodies = self.bodies;
276        cx.background_spawn(async move {
277            let txn = env.read_txn()?;
278            let mut prompt = bodies.get(&txn, &id)?.context("prompt not found")?.into();
279            LineEnding::normalize(&mut prompt);
280            Ok(prompt)
281        })
282    }
283
284    pub fn all_prompt_metadata(&self) -> Vec<PromptMetadata> {
285        self.metadata_cache.read().metadata.clone()
286    }
287
288    pub fn default_prompt_metadata(&self) -> Vec<PromptMetadata> {
289        return self
290            .metadata_cache
291            .read()
292            .metadata
293            .iter()
294            .filter(|metadata| metadata.default)
295            .cloned()
296            .collect::<Vec<_>>();
297    }
298
299    pub fn delete(&self, id: PromptId, cx: &Context<Self>) -> Task<Result<()>> {
300        self.metadata_cache.write().remove(id);
301
302        let db_connection = self.env.clone();
303        let bodies = self.bodies;
304        let metadata = self.metadata;
305
306        let task = cx.background_spawn(async move {
307            let mut txn = db_connection.write_txn()?;
308
309            metadata.delete(&mut txn, &id)?;
310            bodies.delete(&mut txn, &id)?;
311
312            txn.commit()?;
313            anyhow::Ok(())
314        });
315
316        cx.spawn(async move |this, cx| {
317            task.await?;
318            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
319            anyhow::Ok(())
320        })
321    }
322
323    /// Returns the number of prompts in the store.
324    pub fn prompt_count(&self) -> usize {
325        self.metadata_cache.read().metadata.len()
326    }
327
328    pub fn metadata(&self, id: PromptId) -> Option<PromptMetadata> {
329        self.metadata_cache.read().metadata_by_id.get(&id).cloned()
330    }
331
332    pub fn first(&self) -> Option<PromptMetadata> {
333        self.metadata_cache.read().metadata.first().cloned()
334    }
335
336    pub fn id_for_title(&self, title: &str) -> Option<PromptId> {
337        let metadata_cache = self.metadata_cache.read();
338        let metadata = metadata_cache
339            .metadata
340            .iter()
341            .find(|metadata| metadata.title.as_ref().map(|title| &***title) == Some(title))?;
342        Some(metadata.id)
343    }
344
345    pub fn search(
346        &self,
347        query: String,
348        cancellation_flag: Arc<AtomicBool>,
349        cx: &App,
350    ) -> Task<Vec<PromptMetadata>> {
351        let cached_metadata = self.metadata_cache.read().metadata.clone();
352        let executor = cx.background_executor().clone();
353        cx.background_spawn(async move {
354            let mut matches = if query.is_empty() {
355                cached_metadata
356            } else {
357                let candidates = cached_metadata
358                    .iter()
359                    .enumerate()
360                    .filter_map(|(ix, metadata)| {
361                        Some(StringMatchCandidate::new(ix, metadata.title.as_ref()?))
362                    })
363                    .collect::<Vec<_>>();
364                let matches = fuzzy::match_strings(
365                    &candidates,
366                    &query,
367                    false,
368                    true,
369                    100,
370                    &cancellation_flag,
371                    executor,
372                )
373                .await;
374                matches
375                    .into_iter()
376                    .map(|mat| cached_metadata[mat.candidate_id].clone())
377                    .collect()
378            };
379            matches.sort_by_key(|metadata| Reverse(metadata.default));
380            matches
381        })
382    }
383
384    pub fn save(
385        &self,
386        id: PromptId,
387        title: Option<SharedString>,
388        default: bool,
389        body: Rope,
390        cx: &Context<Self>,
391    ) -> Task<Result<()>> {
392        if id.is_built_in() {
393            return Task::ready(Err(anyhow!("built-in prompts cannot be saved")));
394        }
395
396        let prompt_metadata = PromptMetadata {
397            id,
398            title,
399            default,
400            saved_at: Utc::now(),
401        };
402        self.metadata_cache.write().insert(prompt_metadata.clone());
403
404        let db_connection = self.env.clone();
405        let bodies = self.bodies;
406        let metadata = self.metadata;
407
408        let task = cx.background_spawn(async move {
409            let mut txn = db_connection.write_txn()?;
410
411            metadata.put(&mut txn, &id, &prompt_metadata)?;
412            bodies.put(&mut txn, &id, &body.to_string())?;
413
414            txn.commit()?;
415
416            anyhow::Ok(())
417        });
418
419        cx.spawn(async move |this, cx| {
420            task.await?;
421            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
422            anyhow::Ok(())
423        })
424    }
425
426    pub fn save_metadata(
427        &self,
428        id: PromptId,
429        mut title: Option<SharedString>,
430        default: bool,
431        cx: &Context<Self>,
432    ) -> Task<Result<()>> {
433        let mut cache = self.metadata_cache.write();
434
435        if id.is_built_in() {
436            title = cache
437                .metadata_by_id
438                .get(&id)
439                .and_then(|metadata| metadata.title.clone());
440        }
441
442        let prompt_metadata = PromptMetadata {
443            id,
444            title,
445            default,
446            saved_at: Utc::now(),
447        };
448
449        cache.insert(prompt_metadata.clone());
450
451        let db_connection = self.env.clone();
452        let metadata = self.metadata;
453
454        let task = cx.background_spawn(async move {
455            let mut txn = db_connection.write_txn()?;
456            metadata.put(&mut txn, &id, &prompt_metadata)?;
457            txn.commit()?;
458
459            anyhow::Ok(())
460        });
461
462        cx.spawn(async move |this, cx| {
463            task.await?;
464            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
465            anyhow::Ok(())
466        })
467    }
468}
469
470/// Wraps a shared future to a prompt store so it can be assigned as a context global.
471pub struct GlobalPromptStore(Shared<Task<Result<Entity<PromptStore>, Arc<anyhow::Error>>>>);
472
473impl Global for GlobalPromptStore {}