prompt_store.rs

  1mod prompts;
  2
  3use anyhow::{Result, anyhow};
  4use chrono::{DateTime, Utc};
  5use collections::HashMap;
  6use futures::FutureExt as _;
  7use futures::future::Shared;
  8use fuzzy::StringMatchCandidate;
  9use gpui::{
 10    App, AppContext, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Task,
 11};
 12use heed::{
 13    Database, RoTxn,
 14    types::{SerdeBincode, SerdeJson, Str},
 15};
 16use parking_lot::RwLock;
 17pub use prompts::*;
 18use rope::Rope;
 19use serde::{Deserialize, Serialize};
 20use std::{
 21    cmp::Reverse,
 22    future::Future,
 23    path::PathBuf,
 24    sync::{Arc, atomic::AtomicBool},
 25};
 26use strum::{EnumIter, IntoEnumIterator as _};
 27use text::LineEnding;
 28use util::ResultExt;
 29use uuid::Uuid;
 30
 31/// Init starts loading the PromptStore in the background and assigns
 32/// a shared future to a global.
 33pub fn init(cx: &mut App) {
 34    let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb");
 35    let prompt_store_task = PromptStore::new(db_path, cx);
 36    let prompt_store_entity_task = cx
 37        .spawn(async move |cx| {
 38            prompt_store_task
 39                .await
 40                .and_then(|prompt_store| cx.new(|_cx| prompt_store))
 41                .map_err(Arc::new)
 42        })
 43        .shared();
 44    cx.set_global(GlobalPromptStore(prompt_store_entity_task))
 45}
 46
 47#[derive(Clone, Debug, Serialize, Deserialize)]
 48pub struct PromptMetadata {
 49    pub id: PromptId,
 50    pub title: Option<SharedString>,
 51    pub default: bool,
 52    pub saved_at: DateTime<Utc>,
 53}
 54
 55impl PromptMetadata {
 56    fn builtin(builtin: BuiltInPrompt) -> Self {
 57        Self {
 58            id: PromptId::BuiltIn(builtin),
 59            title: Some(builtin.title().into()),
 60            default: false,
 61            saved_at: DateTime::default(),
 62        }
 63    }
 64}
 65
 66/// Built-in prompts that have default content and can be customized by users.
 67#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, EnumIter)]
 68pub enum BuiltInPrompt {
 69    CommitMessage,
 70}
 71
 72impl BuiltInPrompt {
 73    pub fn title(&self) -> &'static str {
 74        match self {
 75            Self::CommitMessage => "Commit message",
 76        }
 77    }
 78
 79    /// Returns the default content for this built-in prompt.
 80    pub fn default_content(&self) -> &'static str {
 81        match self {
 82            Self::CommitMessage => include_str!("../../git_ui/src/commit_message_prompt.txt"),
 83        }
 84    }
 85}
 86
 87impl std::fmt::Display for BuiltInPrompt {
 88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 89        match self {
 90            Self::CommitMessage => write!(f, "Commit message"),
 91        }
 92    }
 93}
 94
 95#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
 96#[serde(tag = "kind")]
 97pub enum PromptId {
 98    User { uuid: UserPromptId },
 99    BuiltIn(BuiltInPrompt),
100}
101
102impl PromptId {
103    pub fn new() -> PromptId {
104        UserPromptId::new().into()
105    }
106
107    pub fn as_user(&self) -> Option<UserPromptId> {
108        match self {
109            Self::User { uuid } => Some(*uuid),
110            Self::BuiltIn { .. } => None,
111        }
112    }
113
114    pub fn as_built_in(&self) -> Option<BuiltInPrompt> {
115        match self {
116            Self::User { .. } => None,
117            Self::BuiltIn(builtin) => Some(*builtin),
118        }
119    }
120
121    pub fn is_built_in(&self) -> bool {
122        matches!(self, Self::BuiltIn { .. })
123    }
124
125    pub fn can_edit(&self) -> bool {
126        match self {
127            Self::User { .. } => true,
128            Self::BuiltIn(builtin) => match builtin {
129                BuiltInPrompt::CommitMessage => true,
130            },
131        }
132    }
133}
134
135impl From<BuiltInPrompt> for PromptId {
136    fn from(builtin: BuiltInPrompt) -> Self {
137        PromptId::BuiltIn(builtin)
138    }
139}
140
141impl From<UserPromptId> for PromptId {
142    fn from(uuid: UserPromptId) -> Self {
143        PromptId::User { uuid }
144    }
145}
146
147#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
148#[serde(transparent)]
149pub struct UserPromptId(pub Uuid);
150
151impl UserPromptId {
152    pub fn new() -> UserPromptId {
153        UserPromptId(Uuid::new_v4())
154    }
155}
156
157impl From<Uuid> for UserPromptId {
158    fn from(uuid: Uuid) -> Self {
159        UserPromptId(uuid)
160    }
161}
162
163impl std::fmt::Display for PromptId {
164    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165        match self {
166            PromptId::User { uuid } => write!(f, "{}", uuid.0),
167            PromptId::BuiltIn(builtin) => write!(f, "{}", builtin),
168        }
169    }
170}
171
172pub struct PromptStore {
173    env: heed::Env,
174    metadata_cache: RwLock<MetadataCache>,
175    metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
176    bodies: Database<SerdeJson<PromptId>, Str>,
177}
178
179pub struct PromptsUpdatedEvent;
180
181impl EventEmitter<PromptsUpdatedEvent> for PromptStore {}
182
183#[derive(Default)]
184struct MetadataCache {
185    metadata: Vec<PromptMetadata>,
186    metadata_by_id: HashMap<PromptId, PromptMetadata>,
187}
188
189impl MetadataCache {
190    fn from_db(
191        db: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
192        txn: &RoTxn,
193    ) -> Result<Self> {
194        let mut cache = MetadataCache::default();
195        for result in db.iter(txn)? {
196            let (prompt_id, metadata) = result?;
197            cache.metadata.push(metadata.clone());
198            cache.metadata_by_id.insert(prompt_id, metadata);
199        }
200
201        // Insert all the built-in prompts that were not customized by the user
202        for builtin in BuiltInPrompt::iter() {
203            let builtin_id = PromptId::BuiltIn(builtin);
204            if !cache.metadata_by_id.contains_key(&builtin_id) {
205                let metadata = PromptMetadata::builtin(builtin);
206                cache.metadata.push(metadata.clone());
207                cache.metadata_by_id.insert(builtin_id, metadata);
208            }
209        }
210        cache.sort();
211        Ok(cache)
212    }
213
214    fn insert(&mut self, metadata: PromptMetadata) {
215        self.metadata_by_id.insert(metadata.id, metadata.clone());
216        if let Some(old_metadata) = self.metadata.iter_mut().find(|m| m.id == metadata.id) {
217            *old_metadata = metadata;
218        } else {
219            self.metadata.push(metadata);
220        }
221        self.sort();
222    }
223
224    fn remove(&mut self, id: PromptId) {
225        self.metadata.retain(|metadata| metadata.id != id);
226        self.metadata_by_id.remove(&id);
227    }
228
229    fn sort(&mut self) {
230        self.metadata.sort_unstable_by(|a, b| {
231            a.title
232                .cmp(&b.title)
233                .then_with(|| b.saved_at.cmp(&a.saved_at))
234        });
235    }
236}
237
238impl PromptStore {
239    pub fn global(cx: &App) -> impl Future<Output = Result<Entity<Self>>> + use<> {
240        let store = GlobalPromptStore::global(cx).0.clone();
241        async move { store.await.map_err(|err| anyhow!(err)) }
242    }
243
244    pub fn new(db_path: PathBuf, cx: &App) -> Task<Result<Self>> {
245        cx.background_spawn(async move {
246            std::fs::create_dir_all(&db_path)?;
247
248            let db_env = unsafe {
249                heed::EnvOpenOptions::new()
250                    .map_size(1024 * 1024 * 1024) // 1GB
251                    .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
252                    .open(db_path)?
253            };
254
255            let mut txn = db_env.write_txn()?;
256            let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
257            let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
258            txn.commit()?;
259
260            Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
261
262            let txn = db_env.read_txn()?;
263            let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
264            txn.commit()?;
265
266            Ok(PromptStore {
267                env: db_env,
268                metadata_cache: RwLock::new(metadata_cache),
269                metadata,
270                bodies,
271            })
272        })
273    }
274
275    fn upgrade_dbs(
276        env: &heed::Env,
277        metadata_db: heed::Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
278        bodies_db: heed::Database<SerdeJson<PromptId>, Str>,
279    ) -> Result<()> {
280        #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
281        pub struct PromptIdV1(Uuid);
282
283        #[derive(Clone, Debug, Serialize, Deserialize)]
284        pub struct PromptMetadataV1 {
285            pub id: PromptIdV1,
286            pub title: Option<SharedString>,
287            pub default: bool,
288            pub saved_at: DateTime<Utc>,
289        }
290
291        let mut txn = env.write_txn()?;
292        let Some(bodies_v1_db) = env
293            .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<String>>(
294                &txn,
295                Some("bodies"),
296            )?
297        else {
298            return Ok(());
299        };
300        let mut bodies_v1 = bodies_v1_db
301            .iter(&txn)?
302            .collect::<heed::Result<HashMap<_, _>>>()?;
303
304        let Some(metadata_v1_db) = env
305            .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<PromptMetadataV1>>(
306                &txn,
307                Some("metadata"),
308            )?
309        else {
310            return Ok(());
311        };
312        let metadata_v1 = metadata_v1_db
313            .iter(&txn)?
314            .collect::<heed::Result<HashMap<_, _>>>()?;
315
316        for (prompt_id_v1, metadata_v1) in metadata_v1 {
317            let prompt_id_v2 = UserPromptId(prompt_id_v1.0).into();
318            let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
319                continue;
320            };
321
322            if metadata_db
323                .get(&txn, &prompt_id_v2)?
324                .is_none_or(|metadata_v2| metadata_v1.saved_at > metadata_v2.saved_at)
325            {
326                metadata_db.put(
327                    &mut txn,
328                    &prompt_id_v2,
329                    &PromptMetadata {
330                        id: prompt_id_v2,
331                        title: metadata_v1.title.clone(),
332                        default: metadata_v1.default,
333                        saved_at: metadata_v1.saved_at,
334                    },
335                )?;
336                bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?;
337            }
338        }
339
340        txn.commit()?;
341
342        Ok(())
343    }
344
345    pub fn load(&self, id: PromptId, cx: &App) -> Task<Result<String>> {
346        let env = self.env.clone();
347        let bodies = self.bodies;
348        cx.background_spawn(async move {
349            let txn = env.read_txn()?;
350            let mut prompt: String = match bodies.get(&txn, &id)? {
351                Some(body) => body.into(),
352                None => {
353                    if let Some(built_in) = id.as_built_in() {
354                        built_in.default_content().into()
355                    } else {
356                        anyhow::bail!("prompt not found")
357                    }
358                }
359            };
360            LineEnding::normalize(&mut prompt);
361            Ok(prompt)
362        })
363    }
364
365    pub fn all_prompt_metadata(&self) -> Vec<PromptMetadata> {
366        self.metadata_cache.read().metadata.clone()
367    }
368
369    pub fn default_prompt_metadata(&self) -> Vec<PromptMetadata> {
370        return self
371            .metadata_cache
372            .read()
373            .metadata
374            .iter()
375            .filter(|metadata| metadata.default)
376            .cloned()
377            .collect::<Vec<_>>();
378    }
379
380    pub fn delete(&self, id: PromptId, cx: &Context<Self>) -> Task<Result<()>> {
381        self.metadata_cache.write().remove(id);
382
383        let db_connection = self.env.clone();
384        let bodies = self.bodies;
385        let metadata = self.metadata;
386
387        let task = cx.background_spawn(async move {
388            let mut txn = db_connection.write_txn()?;
389
390            metadata.delete(&mut txn, &id)?;
391            bodies.delete(&mut txn, &id)?;
392
393            txn.commit()?;
394            anyhow::Ok(())
395        });
396
397        cx.spawn(async move |this, cx| {
398            task.await?;
399            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
400            anyhow::Ok(())
401        })
402    }
403
404    pub fn metadata(&self, id: PromptId) -> Option<PromptMetadata> {
405        self.metadata_cache.read().metadata_by_id.get(&id).cloned()
406    }
407
408    pub fn first(&self) -> Option<PromptMetadata> {
409        self.metadata_cache.read().metadata.first().cloned()
410    }
411
412    pub fn id_for_title(&self, title: &str) -> Option<PromptId> {
413        let metadata_cache = self.metadata_cache.read();
414        let metadata = metadata_cache
415            .metadata
416            .iter()
417            .find(|metadata| metadata.title.as_ref().map(|title| &***title) == Some(title))?;
418        Some(metadata.id)
419    }
420
421    pub fn search(
422        &self,
423        query: String,
424        cancellation_flag: Arc<AtomicBool>,
425        cx: &App,
426    ) -> Task<Vec<PromptMetadata>> {
427        let cached_metadata = self.metadata_cache.read().metadata.clone();
428        let executor = cx.background_executor().clone();
429        cx.background_spawn(async move {
430            let mut matches = if query.is_empty() {
431                cached_metadata
432            } else {
433                let candidates = cached_metadata
434                    .iter()
435                    .enumerate()
436                    .filter_map(|(ix, metadata)| {
437                        Some(StringMatchCandidate::new(ix, metadata.title.as_ref()?))
438                    })
439                    .collect::<Vec<_>>();
440                let matches = fuzzy::match_strings(
441                    &candidates,
442                    &query,
443                    false,
444                    true,
445                    100,
446                    &cancellation_flag,
447                    executor,
448                )
449                .await;
450                matches
451                    .into_iter()
452                    .map(|mat| cached_metadata[mat.candidate_id].clone())
453                    .collect()
454            };
455            matches.sort_by_key(|metadata| Reverse(metadata.default));
456            matches
457        })
458    }
459
460    pub fn save(
461        &self,
462        id: PromptId,
463        title: Option<SharedString>,
464        default: bool,
465        body: Rope,
466        cx: &Context<Self>,
467    ) -> Task<Result<()>> {
468        if !id.can_edit() {
469            return Task::ready(Err(anyhow!("this prompt cannot be edited")));
470        }
471
472        let body = body.to_string();
473        let is_default_content = id
474            .as_built_in()
475            .is_some_and(|builtin| body.trim() == builtin.default_content().trim());
476
477        let metadata = if let Some(builtin) = id.as_built_in() {
478            PromptMetadata::builtin(builtin)
479        } else {
480            PromptMetadata {
481                id,
482                title,
483                default,
484                saved_at: Utc::now(),
485            }
486        };
487
488        self.metadata_cache.write().insert(metadata.clone());
489
490        let db_connection = self.env.clone();
491        let bodies = self.bodies;
492        let metadata_db = self.metadata;
493
494        let task = cx.background_spawn(async move {
495            let mut txn = db_connection.write_txn()?;
496
497            if is_default_content {
498                metadata_db.delete(&mut txn, &id)?;
499                bodies.delete(&mut txn, &id)?;
500            } else {
501                metadata_db.put(&mut txn, &id, &metadata)?;
502                bodies.put(&mut txn, &id, &body)?;
503            }
504
505            txn.commit()?;
506
507            anyhow::Ok(())
508        });
509
510        cx.spawn(async move |this, cx| {
511            task.await?;
512            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
513            anyhow::Ok(())
514        })
515    }
516
517    pub fn save_metadata(
518        &self,
519        id: PromptId,
520        mut title: Option<SharedString>,
521        default: bool,
522        cx: &Context<Self>,
523    ) -> Task<Result<()>> {
524        let mut cache = self.metadata_cache.write();
525
526        if !id.can_edit() {
527            title = cache
528                .metadata_by_id
529                .get(&id)
530                .and_then(|metadata| metadata.title.clone());
531        }
532
533        let prompt_metadata = PromptMetadata {
534            id,
535            title,
536            default,
537            saved_at: Utc::now(),
538        };
539
540        cache.insert(prompt_metadata.clone());
541
542        let db_connection = self.env.clone();
543        let metadata = self.metadata;
544
545        let task = cx.background_spawn(async move {
546            let mut txn = db_connection.write_txn()?;
547            metadata.put(&mut txn, &id, &prompt_metadata)?;
548            txn.commit()?;
549
550            anyhow::Ok(())
551        });
552
553        cx.spawn(async move |this, cx| {
554            task.await?;
555            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
556            anyhow::Ok(())
557        })
558    }
559}
560
561/// Wraps a shared future to a prompt store so it can be assigned as a context global.
562pub struct GlobalPromptStore(Shared<Task<Result<Entity<PromptStore>, Arc<anyhow::Error>>>>);
563
564impl Global for GlobalPromptStore {}
565
566#[cfg(test)]
567mod tests {
568    use super::*;
569    use gpui::TestAppContext;
570
571    #[gpui::test]
572    async fn test_built_in_prompt_load_save(cx: &mut TestAppContext) {
573        cx.executor().allow_parking();
574
575        let temp_dir = tempfile::tempdir().unwrap();
576        let db_path = temp_dir.path().join("prompts-db");
577
578        let store = cx.update(|cx| PromptStore::new(db_path, cx)).await.unwrap();
579        let store = cx.new(|_cx| store);
580
581        let commit_message_id = PromptId::BuiltIn(BuiltInPrompt::CommitMessage);
582
583        let loaded_content = store
584            .update(cx, |store, cx| store.load(commit_message_id, cx))
585            .await
586            .unwrap();
587
588        let mut expected_content = BuiltInPrompt::CommitMessage.default_content().to_string();
589        LineEnding::normalize(&mut expected_content);
590        assert_eq!(
591            loaded_content.trim(),
592            expected_content.trim(),
593            "Loading a built-in prompt not in DB should return default content"
594        );
595
596        let metadata = store.read_with(cx, |store, _| store.metadata(commit_message_id));
597        assert!(
598            metadata.is_some(),
599            "Built-in prompt should always have metadata"
600        );
601        assert!(
602            store.read_with(cx, |store, _| {
603                store
604                    .metadata_cache
605                    .read()
606                    .metadata_by_id
607                    .contains_key(&commit_message_id)
608            }),
609            "Built-in prompt should always be in cache"
610        );
611
612        let custom_content = "Custom commit message prompt";
613        store
614            .update(cx, |store, cx| {
615                store.save(
616                    commit_message_id,
617                    Some("Commit message".into()),
618                    false,
619                    Rope::from(custom_content),
620                    cx,
621                )
622            })
623            .await
624            .unwrap();
625
626        let loaded_custom = store
627            .update(cx, |store, cx| store.load(commit_message_id, cx))
628            .await
629            .unwrap();
630        assert_eq!(
631            loaded_custom.trim(),
632            custom_content.trim(),
633            "Custom content should be loaded after saving"
634        );
635
636        assert!(
637            store
638                .read_with(cx, |store, _| store.metadata(commit_message_id))
639                .is_some(),
640            "Built-in prompt should have metadata after customization"
641        );
642
643        store
644            .update(cx, |store, cx| {
645                store.save(
646                    commit_message_id,
647                    Some("Commit message".into()),
648                    false,
649                    Rope::from(BuiltInPrompt::CommitMessage.default_content()),
650                    cx,
651                )
652            })
653            .await
654            .unwrap();
655
656        let metadata_after_reset =
657            store.read_with(cx, |store, _| store.metadata(commit_message_id));
658        assert!(
659            metadata_after_reset.is_some(),
660            "Built-in prompt should still have metadata after reset"
661        );
662        assert_eq!(
663            metadata_after_reset
664                .as_ref()
665                .and_then(|m| m.title.as_ref().map(|t| t.as_ref())),
666            Some("Commit message"),
667            "Built-in prompt should have default title after reset"
668        );
669
670        let loaded_after_reset = store
671            .update(cx, |store, cx| store.load(commit_message_id, cx))
672            .await
673            .unwrap();
674        let mut expected_content_after_reset =
675            BuiltInPrompt::CommitMessage.default_content().to_string();
676        LineEnding::normalize(&mut expected_content_after_reset);
677        assert_eq!(
678            loaded_after_reset.trim(),
679            expected_content_after_reset.trim(),
680            "After saving default content, load should return default"
681        );
682    }
683}