prompt_store.rs

  1mod prompts;
  2
  3use anyhow::{Context as _, Result, anyhow};
  4use chrono::{DateTime, Utc};
  5use collections::HashMap;
  6use futures::FutureExt as _;
  7use futures::future::Shared;
  8use fuzzy::StringMatchCandidate;
  9use gpui::{
 10    App, AppContext, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Task,
 11};
 12use heed::{
 13    Database, RoTxn,
 14    types::{SerdeBincode, SerdeJson, Str},
 15};
 16use parking_lot::RwLock;
 17pub use prompts::*;
 18use rope::Rope;
 19use serde::{Deserialize, Serialize};
 20use std::{
 21    cmp::Reverse,
 22    future::Future,
 23    path::PathBuf,
 24    sync::{Arc, atomic::AtomicBool},
 25};
 26use text::LineEnding;
 27use util::ResultExt;
 28use uuid::Uuid;
 29
 30/// Init starts loading the PromptStore in the background and assigns
 31/// a shared future to a global.
 32pub fn init(cx: &mut App) {
 33    let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb");
 34    let prompt_store_task = PromptStore::new(db_path, cx);
 35    let prompt_store_entity_task = cx
 36        .spawn(async move |cx| {
 37            prompt_store_task
 38                .await
 39                .and_then(|prompt_store| cx.new(|_cx| prompt_store))
 40                .map_err(Arc::new)
 41        })
 42        .shared();
 43    cx.set_global(GlobalPromptStore(prompt_store_entity_task))
 44}
 45
 46#[derive(Clone, Debug, Serialize, Deserialize)]
 47pub struct PromptMetadata {
 48    pub id: PromptId,
 49    pub title: Option<SharedString>,
 50    pub default: bool,
 51    pub saved_at: DateTime<Utc>,
 52}
 53
 54#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
 55#[serde(tag = "kind")]
 56pub enum PromptId {
 57    User { uuid: UserPromptId },
 58    EditWorkflow,
 59    CommitMessage,
 60}
 61
 62impl PromptId {
 63    pub fn new() -> PromptId {
 64        UserPromptId::new().into()
 65    }
 66
 67    pub fn user_id(&self) -> Option<UserPromptId> {
 68        match self {
 69            Self::User { uuid } => Some(*uuid),
 70            _ => None,
 71        }
 72    }
 73
 74    pub fn is_built_in(&self) -> bool {
 75        match self {
 76            Self::User { .. } => false,
 77            Self::EditWorkflow | Self::CommitMessage => true,
 78        }
 79    }
 80
 81    pub fn can_edit(&self) -> bool {
 82        match self {
 83            Self::User { .. } | Self::CommitMessage => true,
 84            Self::EditWorkflow => false,
 85        }
 86    }
 87
 88    pub fn default_content(&self) -> Option<&'static str> {
 89        match self {
 90            Self::User { .. } | Self::EditWorkflow => None,
 91            Self::CommitMessage => Some(include_str!("../../git_ui/src/commit_message_prompt.txt")),
 92        }
 93    }
 94}
 95
 96impl From<UserPromptId> for PromptId {
 97    fn from(uuid: UserPromptId) -> Self {
 98        PromptId::User { uuid }
 99    }
100}
101
102#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
103#[serde(transparent)]
104pub struct UserPromptId(pub Uuid);
105
106impl UserPromptId {
107    pub fn new() -> UserPromptId {
108        UserPromptId(Uuid::new_v4())
109    }
110}
111
112impl From<Uuid> for UserPromptId {
113    fn from(uuid: Uuid) -> Self {
114        UserPromptId(uuid)
115    }
116}
117
118impl std::fmt::Display for PromptId {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        match self {
121            PromptId::User { uuid } => write!(f, "{}", uuid.0),
122            PromptId::EditWorkflow => write!(f, "Edit workflow"),
123            PromptId::CommitMessage => write!(f, "Commit message"),
124        }
125    }
126}
127
128pub struct PromptStore {
129    env: heed::Env,
130    metadata_cache: RwLock<MetadataCache>,
131    metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
132    bodies: Database<SerdeJson<PromptId>, Str>,
133}
134
135pub struct PromptsUpdatedEvent;
136
137impl EventEmitter<PromptsUpdatedEvent> for PromptStore {}
138
139#[derive(Default)]
140struct MetadataCache {
141    metadata: Vec<PromptMetadata>,
142    metadata_by_id: HashMap<PromptId, PromptMetadata>,
143}
144
145impl MetadataCache {
146    fn from_db(
147        db: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
148        txn: &RoTxn,
149    ) -> Result<Self> {
150        let mut cache = MetadataCache::default();
151        for result in db.iter(txn)? {
152            let (prompt_id, metadata) = result?;
153            cache.metadata.push(metadata.clone());
154            cache.metadata_by_id.insert(prompt_id, metadata);
155        }
156        cache.sort();
157        Ok(cache)
158    }
159
160    fn insert(&mut self, metadata: PromptMetadata) {
161        self.metadata_by_id.insert(metadata.id, metadata.clone());
162        if let Some(old_metadata) = self.metadata.iter_mut().find(|m| m.id == metadata.id) {
163            *old_metadata = metadata;
164        } else {
165            self.metadata.push(metadata);
166        }
167        self.sort();
168    }
169
170    fn remove(&mut self, id: PromptId) {
171        self.metadata.retain(|metadata| metadata.id != id);
172        self.metadata_by_id.remove(&id);
173    }
174
175    fn sort(&mut self) {
176        self.metadata.sort_unstable_by(|a, b| {
177            a.title
178                .cmp(&b.title)
179                .then_with(|| b.saved_at.cmp(&a.saved_at))
180        });
181    }
182}
183
184impl PromptStore {
185    pub fn global(cx: &App) -> impl Future<Output = Result<Entity<Self>>> + use<> {
186        let store = GlobalPromptStore::global(cx).0.clone();
187        async move { store.await.map_err(|err| anyhow!(err)) }
188    }
189
190    pub fn new(db_path: PathBuf, cx: &App) -> Task<Result<Self>> {
191        cx.background_spawn(async move {
192            std::fs::create_dir_all(&db_path)?;
193
194            let db_env = unsafe {
195                heed::EnvOpenOptions::new()
196                    .map_size(1024 * 1024 * 1024) // 1GB
197                    .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
198                    .open(db_path)?
199            };
200
201            let mut txn = db_env.write_txn()?;
202            let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
203            let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
204
205            // Remove edit workflow prompt, as we decided to opt into it using
206            // a slash command instead.
207            metadata.delete(&mut txn, &PromptId::EditWorkflow).ok();
208            bodies.delete(&mut txn, &PromptId::EditWorkflow).ok();
209
210            // Insert default commit message prompt if not present
211            if metadata.get(&txn, &PromptId::CommitMessage)?.is_none() {
212                metadata.put(
213                    &mut txn,
214                    &PromptId::CommitMessage,
215                    &PromptMetadata {
216                        id: PromptId::CommitMessage,
217                        title: Some("Git Commit Message".into()),
218                        default: false,
219                        saved_at: Utc::now(),
220                    },
221                )?;
222            }
223            if bodies.get(&txn, &PromptId::CommitMessage)?.is_none() {
224                let commit_message_prompt =
225                    include_str!("../../git_ui/src/commit_message_prompt.txt");
226                bodies.put(&mut txn, &PromptId::CommitMessage, commit_message_prompt)?;
227            }
228
229            txn.commit()?;
230
231            Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
232
233            let txn = db_env.read_txn()?;
234            let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
235            txn.commit()?;
236
237            Ok(PromptStore {
238                env: db_env,
239                metadata_cache: RwLock::new(metadata_cache),
240                metadata,
241                bodies,
242            })
243        })
244    }
245
246    fn upgrade_dbs(
247        env: &heed::Env,
248        metadata_db: heed::Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
249        bodies_db: heed::Database<SerdeJson<PromptId>, Str>,
250    ) -> Result<()> {
251        #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
252        pub struct PromptIdV1(Uuid);
253
254        #[derive(Clone, Debug, Serialize, Deserialize)]
255        pub struct PromptMetadataV1 {
256            pub id: PromptIdV1,
257            pub title: Option<SharedString>,
258            pub default: bool,
259            pub saved_at: DateTime<Utc>,
260        }
261
262        let mut txn = env.write_txn()?;
263        let Some(bodies_v1_db) = env
264            .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<String>>(
265                &txn,
266                Some("bodies"),
267            )?
268        else {
269            return Ok(());
270        };
271        let mut bodies_v1 = bodies_v1_db
272            .iter(&txn)?
273            .collect::<heed::Result<HashMap<_, _>>>()?;
274
275        let Some(metadata_v1_db) = env
276            .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<PromptMetadataV1>>(
277                &txn,
278                Some("metadata"),
279            )?
280        else {
281            return Ok(());
282        };
283        let metadata_v1 = metadata_v1_db
284            .iter(&txn)?
285            .collect::<heed::Result<HashMap<_, _>>>()?;
286
287        for (prompt_id_v1, metadata_v1) in metadata_v1 {
288            let prompt_id_v2 = UserPromptId(prompt_id_v1.0).into();
289            let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
290                continue;
291            };
292
293            if metadata_db
294                .get(&txn, &prompt_id_v2)?
295                .is_none_or(|metadata_v2| metadata_v1.saved_at > metadata_v2.saved_at)
296            {
297                metadata_db.put(
298                    &mut txn,
299                    &prompt_id_v2,
300                    &PromptMetadata {
301                        id: prompt_id_v2,
302                        title: metadata_v1.title.clone(),
303                        default: metadata_v1.default,
304                        saved_at: metadata_v1.saved_at,
305                    },
306                )?;
307                bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?;
308            }
309        }
310
311        txn.commit()?;
312
313        Ok(())
314    }
315
316    pub fn load(&self, id: PromptId, cx: &App) -> Task<Result<String>> {
317        let env = self.env.clone();
318        let bodies = self.bodies;
319        cx.background_spawn(async move {
320            let txn = env.read_txn()?;
321            let mut prompt = bodies.get(&txn, &id)?.context("prompt not found")?.into();
322            LineEnding::normalize(&mut prompt);
323            Ok(prompt)
324        })
325    }
326
327    pub fn all_prompt_metadata(&self) -> Vec<PromptMetadata> {
328        self.metadata_cache.read().metadata.clone()
329    }
330
331    pub fn default_prompt_metadata(&self) -> Vec<PromptMetadata> {
332        return self
333            .metadata_cache
334            .read()
335            .metadata
336            .iter()
337            .filter(|metadata| metadata.default)
338            .cloned()
339            .collect::<Vec<_>>();
340    }
341
342    pub fn delete(&self, id: PromptId, cx: &Context<Self>) -> Task<Result<()>> {
343        self.metadata_cache.write().remove(id);
344
345        let db_connection = self.env.clone();
346        let bodies = self.bodies;
347        let metadata = self.metadata;
348
349        let task = cx.background_spawn(async move {
350            let mut txn = db_connection.write_txn()?;
351
352            metadata.delete(&mut txn, &id)?;
353            bodies.delete(&mut txn, &id)?;
354
355            txn.commit()?;
356            anyhow::Ok(())
357        });
358
359        cx.spawn(async move |this, cx| {
360            task.await?;
361            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
362            anyhow::Ok(())
363        })
364    }
365
366    /// Returns the number of prompts in the store.
367    pub fn prompt_count(&self) -> usize {
368        self.metadata_cache.read().metadata.len()
369    }
370
371    pub fn metadata(&self, id: PromptId) -> Option<PromptMetadata> {
372        self.metadata_cache.read().metadata_by_id.get(&id).cloned()
373    }
374
375    pub fn first(&self) -> Option<PromptMetadata> {
376        self.metadata_cache.read().metadata.first().cloned()
377    }
378
379    pub fn id_for_title(&self, title: &str) -> Option<PromptId> {
380        let metadata_cache = self.metadata_cache.read();
381        let metadata = metadata_cache
382            .metadata
383            .iter()
384            .find(|metadata| metadata.title.as_ref().map(|title| &***title) == Some(title))?;
385        Some(metadata.id)
386    }
387
388    pub fn search(
389        &self,
390        query: String,
391        cancellation_flag: Arc<AtomicBool>,
392        cx: &App,
393    ) -> Task<Vec<PromptMetadata>> {
394        let cached_metadata = self.metadata_cache.read().metadata.clone();
395        let executor = cx.background_executor().clone();
396        cx.background_spawn(async move {
397            let mut matches = if query.is_empty() {
398                cached_metadata
399            } else {
400                let candidates = cached_metadata
401                    .iter()
402                    .enumerate()
403                    .filter_map(|(ix, metadata)| {
404                        Some(StringMatchCandidate::new(ix, metadata.title.as_ref()?))
405                    })
406                    .collect::<Vec<_>>();
407                let matches = fuzzy::match_strings(
408                    &candidates,
409                    &query,
410                    false,
411                    true,
412                    100,
413                    &cancellation_flag,
414                    executor,
415                )
416                .await;
417                matches
418                    .into_iter()
419                    .map(|mat| cached_metadata[mat.candidate_id].clone())
420                    .collect()
421            };
422            matches.sort_by_key(|metadata| Reverse(metadata.default));
423            matches
424        })
425    }
426
427    pub fn save(
428        &self,
429        id: PromptId,
430        title: Option<SharedString>,
431        default: bool,
432        body: Rope,
433        cx: &Context<Self>,
434    ) -> Task<Result<()>> {
435        if !id.can_edit() {
436            return Task::ready(Err(anyhow!("this prompt cannot be edited")));
437        }
438
439        let prompt_metadata = PromptMetadata {
440            id,
441            title,
442            default,
443            saved_at: Utc::now(),
444        };
445        self.metadata_cache.write().insert(prompt_metadata.clone());
446
447        let db_connection = self.env.clone();
448        let bodies = self.bodies;
449        let metadata = self.metadata;
450
451        let task = cx.background_spawn(async move {
452            let mut txn = db_connection.write_txn()?;
453
454            metadata.put(&mut txn, &id, &prompt_metadata)?;
455            bodies.put(&mut txn, &id, &body.to_string())?;
456
457            txn.commit()?;
458
459            anyhow::Ok(())
460        });
461
462        cx.spawn(async move |this, cx| {
463            task.await?;
464            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
465            anyhow::Ok(())
466        })
467    }
468
469    pub fn save_metadata(
470        &self,
471        id: PromptId,
472        mut title: Option<SharedString>,
473        default: bool,
474        cx: &Context<Self>,
475    ) -> Task<Result<()>> {
476        let mut cache = self.metadata_cache.write();
477
478        if !id.can_edit() {
479            title = cache
480                .metadata_by_id
481                .get(&id)
482                .and_then(|metadata| metadata.title.clone());
483        }
484
485        let prompt_metadata = PromptMetadata {
486            id,
487            title,
488            default,
489            saved_at: Utc::now(),
490        };
491
492        cache.insert(prompt_metadata.clone());
493
494        let db_connection = self.env.clone();
495        let metadata = self.metadata;
496
497        let task = cx.background_spawn(async move {
498            let mut txn = db_connection.write_txn()?;
499            metadata.put(&mut txn, &id, &prompt_metadata)?;
500            txn.commit()?;
501
502            anyhow::Ok(())
503        });
504
505        cx.spawn(async move |this, cx| {
506            task.await?;
507            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
508            anyhow::Ok(())
509        })
510    }
511}
512
513/// Wraps a shared future to a prompt store so it can be assigned as a context global.
514pub struct GlobalPromptStore(Shared<Task<Result<Entity<PromptStore>, Arc<anyhow::Error>>>>);
515
516impl Global for GlobalPromptStore {}