prompt_store.rs

  1mod prompts;
  2
  3use anyhow::{Context as _, Result, anyhow};
  4use chrono::{DateTime, Utc};
  5use collections::HashMap;
  6use futures::FutureExt as _;
  7use futures::future::Shared;
  8use fuzzy::StringMatchCandidate;
  9use gpui::{
 10    App, AppContext, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Task,
 11};
 12use heed::{
 13    Database, RoTxn,
 14    types::{SerdeBincode, SerdeJson, Str},
 15};
 16use parking_lot::RwLock;
 17pub use prompts::*;
 18use rope::Rope;
 19use serde::{Deserialize, Serialize};
 20use std::{
 21    cmp::Reverse,
 22    future::Future,
 23    path::PathBuf,
 24    sync::{Arc, atomic::AtomicBool},
 25};
 26use text::LineEnding;
 27use util::ResultExt;
 28use uuid::Uuid;
 29
 30/// Init starts loading the PromptStore in the background and assigns
 31/// a shared future to a global.
 32pub fn init(cx: &mut App) {
 33    let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb");
 34    let prompt_store_task = PromptStore::new(db_path, cx);
 35    let prompt_store_entity_task = cx
 36        .spawn(async move |cx| {
 37            prompt_store_task
 38                .await
 39                .and_then(|prompt_store| cx.new(|_cx| prompt_store))
 40                .map_err(Arc::new)
 41        })
 42        .shared();
 43    cx.set_global(GlobalPromptStore(prompt_store_entity_task))
 44}
 45
 46#[derive(Clone, Debug, Serialize, Deserialize)]
 47pub struct PromptMetadata {
 48    pub id: PromptId,
 49    pub title: Option<SharedString>,
 50    pub default: bool,
 51    pub saved_at: DateTime<Utc>,
 52}
 53
 54#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
 55#[serde(tag = "kind")]
 56pub enum PromptId {
 57    User { uuid: UserPromptId },
 58    CommitMessage,
 59}
 60
 61impl PromptId {
 62    pub fn new() -> PromptId {
 63        UserPromptId::new().into()
 64    }
 65
 66    pub fn user_id(&self) -> Option<UserPromptId> {
 67        match self {
 68            Self::User { uuid } => Some(*uuid),
 69            _ => None,
 70        }
 71    }
 72
 73    pub fn is_built_in(&self) -> bool {
 74        match self {
 75            Self::User { .. } => false,
 76            Self::CommitMessage => true,
 77        }
 78    }
 79
 80    pub fn can_edit(&self) -> bool {
 81        match self {
 82            Self::User { .. } | Self::CommitMessage => true,
 83        }
 84    }
 85
 86    pub fn default_content(&self) -> Option<&'static str> {
 87        match self {
 88            Self::User { .. } => None,
 89            Self::CommitMessage => Some(include_str!("../../git_ui/src/commit_message_prompt.txt")),
 90        }
 91    }
 92}
 93
 94impl From<UserPromptId> for PromptId {
 95    fn from(uuid: UserPromptId) -> Self {
 96        PromptId::User { uuid }
 97    }
 98}
 99
100#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
101#[serde(transparent)]
102pub struct UserPromptId(pub Uuid);
103
104impl UserPromptId {
105    pub fn new() -> UserPromptId {
106        UserPromptId(Uuid::new_v4())
107    }
108}
109
110impl From<Uuid> for UserPromptId {
111    fn from(uuid: Uuid) -> Self {
112        UserPromptId(uuid)
113    }
114}
115
116impl std::fmt::Display for PromptId {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        match self {
119            PromptId::User { uuid } => write!(f, "{}", uuid.0),
120            PromptId::CommitMessage => write!(f, "Commit message"),
121        }
122    }
123}
124
125pub struct PromptStore {
126    env: heed::Env,
127    metadata_cache: RwLock<MetadataCache>,
128    metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
129    bodies: Database<SerdeJson<PromptId>, Str>,
130}
131
132pub struct PromptsUpdatedEvent;
133
134impl EventEmitter<PromptsUpdatedEvent> for PromptStore {}
135
136#[derive(Default)]
137struct MetadataCache {
138    metadata: Vec<PromptMetadata>,
139    metadata_by_id: HashMap<PromptId, PromptMetadata>,
140}
141
142impl MetadataCache {
143    fn from_db(
144        db: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
145        txn: &RoTxn,
146    ) -> Result<Self> {
147        let mut cache = MetadataCache::default();
148        for result in db.iter(txn)? {
149            let (prompt_id, metadata) = result?;
150            cache.metadata.push(metadata.clone());
151            cache.metadata_by_id.insert(prompt_id, metadata);
152        }
153        cache.sort();
154        Ok(cache)
155    }
156
157    fn insert(&mut self, metadata: PromptMetadata) {
158        self.metadata_by_id.insert(metadata.id, metadata.clone());
159        if let Some(old_metadata) = self.metadata.iter_mut().find(|m| m.id == metadata.id) {
160            *old_metadata = metadata;
161        } else {
162            self.metadata.push(metadata);
163        }
164        self.sort();
165    }
166
167    fn remove(&mut self, id: PromptId) {
168        self.metadata.retain(|metadata| metadata.id != id);
169        self.metadata_by_id.remove(&id);
170    }
171
172    fn sort(&mut self) {
173        self.metadata.sort_unstable_by(|a, b| {
174            a.title
175                .cmp(&b.title)
176                .then_with(|| b.saved_at.cmp(&a.saved_at))
177        });
178    }
179}
180
181impl PromptStore {
182    pub fn global(cx: &App) -> impl Future<Output = Result<Entity<Self>>> + use<> {
183        let store = GlobalPromptStore::global(cx).0.clone();
184        async move { store.await.map_err(|err| anyhow!(err)) }
185    }
186
187    pub fn new(db_path: PathBuf, cx: &App) -> Task<Result<Self>> {
188        cx.background_spawn(async move {
189            std::fs::create_dir_all(&db_path)?;
190
191            let db_env = unsafe {
192                heed::EnvOpenOptions::new()
193                    .map_size(1024 * 1024 * 1024) // 1GB
194                    .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
195                    .open(db_path)?
196            };
197
198            let mut txn = db_env.write_txn()?;
199            let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
200            let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
201
202            // Insert default commit message prompt if not present
203            if metadata.get(&txn, &PromptId::CommitMessage)?.is_none() {
204                metadata.put(
205                    &mut txn,
206                    &PromptId::CommitMessage,
207                    &PromptMetadata {
208                        id: PromptId::CommitMessage,
209                        title: Some("Git Commit Message".into()),
210                        default: false,
211                        saved_at: Utc::now(),
212                    },
213                )?;
214            }
215            if bodies.get(&txn, &PromptId::CommitMessage)?.is_none() {
216                let commit_message_prompt =
217                    include_str!("../../git_ui/src/commit_message_prompt.txt");
218                bodies.put(&mut txn, &PromptId::CommitMessage, commit_message_prompt)?;
219            }
220
221            txn.commit()?;
222
223            Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
224
225            let txn = db_env.read_txn()?;
226            let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
227            txn.commit()?;
228
229            Ok(PromptStore {
230                env: db_env,
231                metadata_cache: RwLock::new(metadata_cache),
232                metadata,
233                bodies,
234            })
235        })
236    }
237
238    fn upgrade_dbs(
239        env: &heed::Env,
240        metadata_db: heed::Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
241        bodies_db: heed::Database<SerdeJson<PromptId>, Str>,
242    ) -> Result<()> {
243        #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
244        pub struct PromptIdV1(Uuid);
245
246        #[derive(Clone, Debug, Serialize, Deserialize)]
247        pub struct PromptMetadataV1 {
248            pub id: PromptIdV1,
249            pub title: Option<SharedString>,
250            pub default: bool,
251            pub saved_at: DateTime<Utc>,
252        }
253
254        let mut txn = env.write_txn()?;
255        let Some(bodies_v1_db) = env
256            .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<String>>(
257                &txn,
258                Some("bodies"),
259            )?
260        else {
261            return Ok(());
262        };
263        let mut bodies_v1 = bodies_v1_db
264            .iter(&txn)?
265            .collect::<heed::Result<HashMap<_, _>>>()?;
266
267        let Some(metadata_v1_db) = env
268            .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<PromptMetadataV1>>(
269                &txn,
270                Some("metadata"),
271            )?
272        else {
273            return Ok(());
274        };
275        let metadata_v1 = metadata_v1_db
276            .iter(&txn)?
277            .collect::<heed::Result<HashMap<_, _>>>()?;
278
279        for (prompt_id_v1, metadata_v1) in metadata_v1 {
280            let prompt_id_v2 = UserPromptId(prompt_id_v1.0).into();
281            let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
282                continue;
283            };
284
285            if metadata_db
286                .get(&txn, &prompt_id_v2)?
287                .is_none_or(|metadata_v2| metadata_v1.saved_at > metadata_v2.saved_at)
288            {
289                metadata_db.put(
290                    &mut txn,
291                    &prompt_id_v2,
292                    &PromptMetadata {
293                        id: prompt_id_v2,
294                        title: metadata_v1.title.clone(),
295                        default: metadata_v1.default,
296                        saved_at: metadata_v1.saved_at,
297                    },
298                )?;
299                bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?;
300            }
301        }
302
303        txn.commit()?;
304
305        Ok(())
306    }
307
308    pub fn load(&self, id: PromptId, cx: &App) -> Task<Result<String>> {
309        let env = self.env.clone();
310        let bodies = self.bodies;
311        cx.background_spawn(async move {
312            let txn = env.read_txn()?;
313            let mut prompt = bodies.get(&txn, &id)?.context("prompt not found")?.into();
314            LineEnding::normalize(&mut prompt);
315            Ok(prompt)
316        })
317    }
318
319    pub fn all_prompt_metadata(&self) -> Vec<PromptMetadata> {
320        self.metadata_cache.read().metadata.clone()
321    }
322
323    pub fn default_prompt_metadata(&self) -> Vec<PromptMetadata> {
324        return self
325            .metadata_cache
326            .read()
327            .metadata
328            .iter()
329            .filter(|metadata| metadata.default)
330            .cloned()
331            .collect::<Vec<_>>();
332    }
333
334    pub fn delete(&self, id: PromptId, cx: &Context<Self>) -> Task<Result<()>> {
335        self.metadata_cache.write().remove(id);
336
337        let db_connection = self.env.clone();
338        let bodies = self.bodies;
339        let metadata = self.metadata;
340
341        let task = cx.background_spawn(async move {
342            let mut txn = db_connection.write_txn()?;
343
344            metadata.delete(&mut txn, &id)?;
345            bodies.delete(&mut txn, &id)?;
346
347            txn.commit()?;
348            anyhow::Ok(())
349        });
350
351        cx.spawn(async move |this, cx| {
352            task.await?;
353            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
354            anyhow::Ok(())
355        })
356    }
357
358    /// Returns the number of prompts in the store.
359    pub fn prompt_count(&self) -> usize {
360        self.metadata_cache.read().metadata.len()
361    }
362
363    pub fn metadata(&self, id: PromptId) -> Option<PromptMetadata> {
364        self.metadata_cache.read().metadata_by_id.get(&id).cloned()
365    }
366
367    pub fn first(&self) -> Option<PromptMetadata> {
368        self.metadata_cache.read().metadata.first().cloned()
369    }
370
371    pub fn id_for_title(&self, title: &str) -> Option<PromptId> {
372        let metadata_cache = self.metadata_cache.read();
373        let metadata = metadata_cache
374            .metadata
375            .iter()
376            .find(|metadata| metadata.title.as_ref().map(|title| &***title) == Some(title))?;
377        Some(metadata.id)
378    }
379
380    pub fn search(
381        &self,
382        query: String,
383        cancellation_flag: Arc<AtomicBool>,
384        cx: &App,
385    ) -> Task<Vec<PromptMetadata>> {
386        let cached_metadata = self.metadata_cache.read().metadata.clone();
387        let executor = cx.background_executor().clone();
388        cx.background_spawn(async move {
389            let mut matches = if query.is_empty() {
390                cached_metadata
391            } else {
392                let candidates = cached_metadata
393                    .iter()
394                    .enumerate()
395                    .filter_map(|(ix, metadata)| {
396                        Some(StringMatchCandidate::new(ix, metadata.title.as_ref()?))
397                    })
398                    .collect::<Vec<_>>();
399                let matches = fuzzy::match_strings(
400                    &candidates,
401                    &query,
402                    false,
403                    true,
404                    100,
405                    &cancellation_flag,
406                    executor,
407                )
408                .await;
409                matches
410                    .into_iter()
411                    .map(|mat| cached_metadata[mat.candidate_id].clone())
412                    .collect()
413            };
414            matches.sort_by_key(|metadata| Reverse(metadata.default));
415            matches
416        })
417    }
418
419    pub fn save(
420        &self,
421        id: PromptId,
422        title: Option<SharedString>,
423        default: bool,
424        body: Rope,
425        cx: &Context<Self>,
426    ) -> Task<Result<()>> {
427        if !id.can_edit() {
428            return Task::ready(Err(anyhow!("this prompt cannot be edited")));
429        }
430
431        let prompt_metadata = PromptMetadata {
432            id,
433            title,
434            default,
435            saved_at: Utc::now(),
436        };
437        self.metadata_cache.write().insert(prompt_metadata.clone());
438
439        let db_connection = self.env.clone();
440        let bodies = self.bodies;
441        let metadata = self.metadata;
442
443        let task = cx.background_spawn(async move {
444            let mut txn = db_connection.write_txn()?;
445
446            metadata.put(&mut txn, &id, &prompt_metadata)?;
447            bodies.put(&mut txn, &id, &body.to_string())?;
448
449            txn.commit()?;
450
451            anyhow::Ok(())
452        });
453
454        cx.spawn(async move |this, cx| {
455            task.await?;
456            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
457            anyhow::Ok(())
458        })
459    }
460
461    pub fn save_metadata(
462        &self,
463        id: PromptId,
464        mut title: Option<SharedString>,
465        default: bool,
466        cx: &Context<Self>,
467    ) -> Task<Result<()>> {
468        let mut cache = self.metadata_cache.write();
469
470        if !id.can_edit() {
471            title = cache
472                .metadata_by_id
473                .get(&id)
474                .and_then(|metadata| metadata.title.clone());
475        }
476
477        let prompt_metadata = PromptMetadata {
478            id,
479            title,
480            default,
481            saved_at: Utc::now(),
482        };
483
484        cache.insert(prompt_metadata.clone());
485
486        let db_connection = self.env.clone();
487        let metadata = self.metadata;
488
489        let task = cx.background_spawn(async move {
490            let mut txn = db_connection.write_txn()?;
491            metadata.put(&mut txn, &id, &prompt_metadata)?;
492            txn.commit()?;
493
494            anyhow::Ok(())
495        });
496
497        cx.spawn(async move |this, cx| {
498            task.await?;
499            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
500            anyhow::Ok(())
501        })
502    }
503}
504
505/// Wraps a shared future to a prompt store so it can be assigned as a context global.
506pub struct GlobalPromptStore(Shared<Task<Result<Entity<PromptStore>, Arc<anyhow::Error>>>>);
507
508impl Global for GlobalPromptStore {}