prompt_store.rs

  1mod prompts;
  2
  3use anyhow::{Context as _, Result, anyhow};
  4use chrono::{DateTime, Utc};
  5use collections::HashMap;
  6use futures::FutureExt as _;
  7use futures::future::Shared;
  8use fuzzy::StringMatchCandidate;
  9use gpui::{
 10    App, AppContext, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Task,
 11};
 12use heed::{
 13    Database, RoTxn,
 14    types::{SerdeBincode, SerdeJson, Str},
 15};
 16use parking_lot::RwLock;
 17pub use prompts::*;
 18use rope::Rope;
 19use serde::{Deserialize, Serialize};
 20use std::{
 21    cmp::Reverse,
 22    future::Future,
 23    path::PathBuf,
 24    sync::{Arc, atomic::AtomicBool},
 25};
 26use text::LineEnding;
 27use util::ResultExt;
 28use uuid::Uuid;
 29
 30/// Init starts loading the PromptStore in the background and assigns
 31/// a shared future to a global.
 32pub fn init(cx: &mut App) {
 33    let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb");
 34    let prompt_store_task = PromptStore::new(db_path, cx);
 35    let prompt_store_entity_task = cx
 36        .spawn(async move |cx| {
 37            prompt_store_task
 38                .await
 39                .and_then(|prompt_store| cx.new(|_cx| prompt_store))
 40                .map_err(Arc::new)
 41        })
 42        .shared();
 43    cx.set_global(GlobalPromptStore(prompt_store_entity_task))
 44}
 45
 46#[derive(Clone, Debug, Serialize, Deserialize)]
 47pub struct PromptMetadata {
 48    pub id: PromptId,
 49    pub title: Option<SharedString>,
 50    pub default: bool,
 51    pub saved_at: DateTime<Utc>,
 52}
 53
 54#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
 55#[serde(tag = "kind")]
 56pub enum PromptId {
 57    User { uuid: UserPromptId },
 58    CommitMessage,
 59}
 60
 61impl PromptId {
 62    pub fn new() -> PromptId {
 63        UserPromptId::new().into()
 64    }
 65
 66    pub fn user_id(&self) -> Option<UserPromptId> {
 67        match self {
 68            Self::User { uuid } => Some(*uuid),
 69            _ => None,
 70        }
 71    }
 72
 73    pub fn is_built_in(&self) -> bool {
 74        match self {
 75            Self::User { .. } => false,
 76            Self::CommitMessage => true,
 77        }
 78    }
 79
 80    pub fn can_edit(&self) -> bool {
 81        match self {
 82            Self::User { .. } | Self::CommitMessage => true,
 83        }
 84    }
 85
 86    pub fn default_content(&self) -> Option<&'static str> {
 87        match self {
 88            Self::User { .. } => None,
 89            Self::CommitMessage => Some(include_str!("../../git_ui/src/commit_message_prompt.txt")),
 90        }
 91    }
 92}
 93
 94impl From<UserPromptId> for PromptId {
 95    fn from(uuid: UserPromptId) -> Self {
 96        PromptId::User { uuid }
 97    }
 98}
 99
100#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
101#[serde(transparent)]
102pub struct UserPromptId(pub Uuid);
103
104impl UserPromptId {
105    pub fn new() -> UserPromptId {
106        UserPromptId(Uuid::new_v4())
107    }
108}
109
110impl From<Uuid> for UserPromptId {
111    fn from(uuid: Uuid) -> Self {
112        UserPromptId(uuid)
113    }
114}
115
116impl std::fmt::Display for PromptId {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        match self {
119            PromptId::User { uuid } => write!(f, "{}", uuid.0),
120            PromptId::CommitMessage => write!(f, "Commit message"),
121        }
122    }
123}
124
125pub struct PromptStore {
126    env: heed::Env,
127    metadata_cache: RwLock<MetadataCache>,
128    metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
129    bodies: Database<SerdeJson<PromptId>, Str>,
130}
131
132pub struct PromptsUpdatedEvent;
133
134impl EventEmitter<PromptsUpdatedEvent> for PromptStore {}
135
136#[derive(Default)]
137struct MetadataCache {
138    metadata: Vec<PromptMetadata>,
139    metadata_by_id: HashMap<PromptId, PromptMetadata>,
140}
141
142impl MetadataCache {
143    fn from_db(
144        db: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
145        txn: &RoTxn,
146    ) -> Result<Self> {
147        let mut cache = MetadataCache::default();
148        for result in db.iter(txn)? {
149            let (prompt_id, metadata) = result?;
150            cache.metadata.push(metadata.clone());
151            cache.metadata_by_id.insert(prompt_id, metadata);
152        }
153        cache.sort();
154        Ok(cache)
155    }
156
157    fn insert(&mut self, metadata: PromptMetadata) {
158        self.metadata_by_id.insert(metadata.id, metadata.clone());
159        if let Some(old_metadata) = self.metadata.iter_mut().find(|m| m.id == metadata.id) {
160            *old_metadata = metadata;
161        } else {
162            self.metadata.push(metadata);
163        }
164        self.sort();
165    }
166
167    fn remove(&mut self, id: PromptId) {
168        self.metadata.retain(|metadata| metadata.id != id);
169        self.metadata_by_id.remove(&id);
170    }
171
172    fn sort(&mut self) {
173        self.metadata.sort_unstable_by(|a, b| {
174            a.title
175                .cmp(&b.title)
176                .then_with(|| b.saved_at.cmp(&a.saved_at))
177        });
178    }
179}
180
181impl PromptStore {
182    pub fn global(cx: &App) -> impl Future<Output = Result<Entity<Self>>> + use<> {
183        let store = GlobalPromptStore::global(cx).0.clone();
184        async move { store.await.map_err(|err| anyhow!(err)) }
185    }
186
187    pub fn new(db_path: PathBuf, cx: &App) -> Task<Result<Self>> {
188        cx.background_spawn(async move {
189            std::fs::create_dir_all(&db_path)?;
190
191            let db_env = unsafe {
192                heed::EnvOpenOptions::new()
193                    .map_size(1024 * 1024 * 1024) // 1GB
194                    .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
195                    .open(db_path)?
196            };
197
198            let mut txn = db_env.write_txn()?;
199            let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
200            let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
201
202            metadata.delete(&mut txn, &PromptId::CommitMessage)?;
203            bodies.delete(&mut txn, &PromptId::CommitMessage)?;
204
205            txn.commit()?;
206
207            Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
208
209            let txn = db_env.read_txn()?;
210            let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
211            txn.commit()?;
212
213            Ok(PromptStore {
214                env: db_env,
215                metadata_cache: RwLock::new(metadata_cache),
216                metadata,
217                bodies,
218            })
219        })
220    }
221
222    fn upgrade_dbs(
223        env: &heed::Env,
224        metadata_db: heed::Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
225        bodies_db: heed::Database<SerdeJson<PromptId>, Str>,
226    ) -> Result<()> {
227        #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
228        pub struct PromptIdV1(Uuid);
229
230        #[derive(Clone, Debug, Serialize, Deserialize)]
231        pub struct PromptMetadataV1 {
232            pub id: PromptIdV1,
233            pub title: Option<SharedString>,
234            pub default: bool,
235            pub saved_at: DateTime<Utc>,
236        }
237
238        let mut txn = env.write_txn()?;
239        let Some(bodies_v1_db) = env
240            .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<String>>(
241                &txn,
242                Some("bodies"),
243            )?
244        else {
245            return Ok(());
246        };
247        let mut bodies_v1 = bodies_v1_db
248            .iter(&txn)?
249            .collect::<heed::Result<HashMap<_, _>>>()?;
250
251        let Some(metadata_v1_db) = env
252            .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<PromptMetadataV1>>(
253                &txn,
254                Some("metadata"),
255            )?
256        else {
257            return Ok(());
258        };
259        let metadata_v1 = metadata_v1_db
260            .iter(&txn)?
261            .collect::<heed::Result<HashMap<_, _>>>()?;
262
263        for (prompt_id_v1, metadata_v1) in metadata_v1 {
264            let prompt_id_v2 = UserPromptId(prompt_id_v1.0).into();
265            let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
266                continue;
267            };
268
269            if metadata_db
270                .get(&txn, &prompt_id_v2)?
271                .is_none_or(|metadata_v2| metadata_v1.saved_at > metadata_v2.saved_at)
272            {
273                metadata_db.put(
274                    &mut txn,
275                    &prompt_id_v2,
276                    &PromptMetadata {
277                        id: prompt_id_v2,
278                        title: metadata_v1.title.clone(),
279                        default: metadata_v1.default,
280                        saved_at: metadata_v1.saved_at,
281                    },
282                )?;
283                bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?;
284            }
285        }
286
287        txn.commit()?;
288
289        Ok(())
290    }
291
292    pub fn load(&self, id: PromptId, cx: &App) -> Task<Result<String>> {
293        let env = self.env.clone();
294        let bodies = self.bodies;
295        cx.background_spawn(async move {
296            let txn = env.read_txn()?;
297            let mut prompt = bodies.get(&txn, &id)?.context("prompt not found")?.into();
298            LineEnding::normalize(&mut prompt);
299            Ok(prompt)
300        })
301    }
302
303    pub fn all_prompt_metadata(&self) -> Vec<PromptMetadata> {
304        self.metadata_cache.read().metadata.clone()
305    }
306
307    pub fn default_prompt_metadata(&self) -> Vec<PromptMetadata> {
308        return self
309            .metadata_cache
310            .read()
311            .metadata
312            .iter()
313            .filter(|metadata| metadata.default)
314            .cloned()
315            .collect::<Vec<_>>();
316    }
317
318    pub fn delete(&self, id: PromptId, cx: &Context<Self>) -> Task<Result<()>> {
319        self.metadata_cache.write().remove(id);
320
321        let db_connection = self.env.clone();
322        let bodies = self.bodies;
323        let metadata = self.metadata;
324
325        let task = cx.background_spawn(async move {
326            let mut txn = db_connection.write_txn()?;
327
328            metadata.delete(&mut txn, &id)?;
329            bodies.delete(&mut txn, &id)?;
330
331            txn.commit()?;
332            anyhow::Ok(())
333        });
334
335        cx.spawn(async move |this, cx| {
336            task.await?;
337            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
338            anyhow::Ok(())
339        })
340    }
341
342    /// Returns the number of prompts in the store.
343    pub fn prompt_count(&self) -> usize {
344        self.metadata_cache.read().metadata.len()
345    }
346
347    pub fn metadata(&self, id: PromptId) -> Option<PromptMetadata> {
348        self.metadata_cache.read().metadata_by_id.get(&id).cloned()
349    }
350
351    pub fn first(&self) -> Option<PromptMetadata> {
352        self.metadata_cache.read().metadata.first().cloned()
353    }
354
355    pub fn id_for_title(&self, title: &str) -> Option<PromptId> {
356        let metadata_cache = self.metadata_cache.read();
357        let metadata = metadata_cache
358            .metadata
359            .iter()
360            .find(|metadata| metadata.title.as_ref().map(|title| &***title) == Some(title))?;
361        Some(metadata.id)
362    }
363
364    pub fn search(
365        &self,
366        query: String,
367        cancellation_flag: Arc<AtomicBool>,
368        cx: &App,
369    ) -> Task<Vec<PromptMetadata>> {
370        let cached_metadata = self.metadata_cache.read().metadata.clone();
371        let executor = cx.background_executor().clone();
372        cx.background_spawn(async move {
373            let mut matches = if query.is_empty() {
374                cached_metadata
375            } else {
376                let candidates = cached_metadata
377                    .iter()
378                    .enumerate()
379                    .filter_map(|(ix, metadata)| {
380                        Some(StringMatchCandidate::new(ix, metadata.title.as_ref()?))
381                    })
382                    .collect::<Vec<_>>();
383                let matches = fuzzy::match_strings(
384                    &candidates,
385                    &query,
386                    false,
387                    true,
388                    100,
389                    &cancellation_flag,
390                    executor,
391                )
392                .await;
393                matches
394                    .into_iter()
395                    .map(|mat| cached_metadata[mat.candidate_id].clone())
396                    .collect()
397            };
398            matches.sort_by_key(|metadata| Reverse(metadata.default));
399            matches
400        })
401    }
402
403    pub fn save(
404        &self,
405        id: PromptId,
406        title: Option<SharedString>,
407        default: bool,
408        body: Rope,
409        cx: &Context<Self>,
410    ) -> Task<Result<()>> {
411        if !id.can_edit() {
412            return Task::ready(Err(anyhow!("this prompt cannot be edited")));
413        }
414
415        let prompt_metadata = PromptMetadata {
416            id,
417            title,
418            default,
419            saved_at: Utc::now(),
420        };
421        self.metadata_cache.write().insert(prompt_metadata.clone());
422
423        let db_connection = self.env.clone();
424        let bodies = self.bodies;
425        let metadata = self.metadata;
426
427        let task = cx.background_spawn(async move {
428            let mut txn = db_connection.write_txn()?;
429
430            metadata.put(&mut txn, &id, &prompt_metadata)?;
431            bodies.put(&mut txn, &id, &body.to_string())?;
432
433            txn.commit()?;
434
435            anyhow::Ok(())
436        });
437
438        cx.spawn(async move |this, cx| {
439            task.await?;
440            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
441            anyhow::Ok(())
442        })
443    }
444
445    pub fn save_metadata(
446        &self,
447        id: PromptId,
448        mut title: Option<SharedString>,
449        default: bool,
450        cx: &Context<Self>,
451    ) -> Task<Result<()>> {
452        let mut cache = self.metadata_cache.write();
453
454        if !id.can_edit() {
455            title = cache
456                .metadata_by_id
457                .get(&id)
458                .and_then(|metadata| metadata.title.clone());
459        }
460
461        let prompt_metadata = PromptMetadata {
462            id,
463            title,
464            default,
465            saved_at: Utc::now(),
466        };
467
468        cache.insert(prompt_metadata.clone());
469
470        let db_connection = self.env.clone();
471        let metadata = self.metadata;
472
473        let task = cx.background_spawn(async move {
474            let mut txn = db_connection.write_txn()?;
475            metadata.put(&mut txn, &id, &prompt_metadata)?;
476            txn.commit()?;
477
478            anyhow::Ok(())
479        });
480
481        cx.spawn(async move |this, cx| {
482            task.await?;
483            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
484            anyhow::Ok(())
485        })
486    }
487}
488
489/// Wraps a shared future to a prompt store so it can be assigned as a context global.
490pub struct GlobalPromptStore(Shared<Task<Result<Entity<PromptStore>, Arc<anyhow::Error>>>>);
491
492impl Global for GlobalPromptStore {}