prompt_store.rs

  1mod prompts;
  2
  3use anyhow::{Context as _, Result, anyhow};
  4use chrono::{DateTime, Utc};
  5use collections::HashMap;
  6use futures::FutureExt as _;
  7use futures::future::Shared;
  8use fuzzy::StringMatchCandidate;
  9use gpui::{
 10    App, AppContext, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Task,
 11};
 12use heed::{
 13    Database, RoTxn,
 14    types::{SerdeBincode, SerdeJson, Str},
 15};
 16use parking_lot::RwLock;
 17pub use prompts::*;
 18use rope::Rope;
 19use serde::{Deserialize, Serialize};
 20use std::{
 21    cmp::Reverse,
 22    future::Future,
 23    path::PathBuf,
 24    sync::{Arc, atomic::AtomicBool},
 25};
 26use text::LineEnding;
 27use util::ResultExt;
 28use uuid::Uuid;
 29
 30/// Init starts loading the PromptStore in the background and assigns
 31/// a shared future to a global.
 32pub fn init(cx: &mut App) {
 33    let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb");
 34    let prompt_store_task = PromptStore::new(db_path, cx);
 35    let prompt_store_entity_task = cx
 36        .spawn(async move |cx| {
 37            prompt_store_task
 38                .await
 39                .and_then(|prompt_store| cx.new(|_cx| prompt_store))
 40                .map_err(Arc::new)
 41        })
 42        .shared();
 43    cx.set_global(GlobalPromptStore(prompt_store_entity_task))
 44}
 45
 46#[derive(Clone, Debug, Serialize, Deserialize)]
 47pub struct PromptMetadata {
 48    pub id: PromptId,
 49    pub title: Option<SharedString>,
 50    pub default: bool,
 51    pub saved_at: DateTime<Utc>,
 52}
 53
 54#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
 55#[serde(tag = "kind")]
 56pub enum PromptId {
 57    User { uuid: UserPromptId },
 58    EditWorkflow,
 59}
 60
 61impl PromptId {
 62    pub fn new() -> PromptId {
 63        UserPromptId::new().into()
 64    }
 65
 66    pub fn is_built_in(&self) -> bool {
 67        !matches!(self, PromptId::User { .. })
 68    }
 69}
 70
 71impl From<UserPromptId> for PromptId {
 72    fn from(uuid: UserPromptId) -> Self {
 73        PromptId::User { uuid }
 74    }
 75}
 76
 77#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
 78#[serde(transparent)]
 79pub struct UserPromptId(pub Uuid);
 80
 81impl UserPromptId {
 82    pub fn new() -> UserPromptId {
 83        UserPromptId(Uuid::new_v4())
 84    }
 85}
 86
 87impl From<Uuid> for UserPromptId {
 88    fn from(uuid: Uuid) -> Self {
 89        UserPromptId(uuid)
 90    }
 91}
 92
 93pub struct PromptStore {
 94    env: heed::Env,
 95    metadata_cache: RwLock<MetadataCache>,
 96    metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
 97    bodies: Database<SerdeJson<PromptId>, Str>,
 98}
 99
100pub struct PromptsUpdatedEvent;
101
102impl EventEmitter<PromptsUpdatedEvent> for PromptStore {}
103
104#[derive(Default)]
105struct MetadataCache {
106    metadata: Vec<PromptMetadata>,
107    metadata_by_id: HashMap<PromptId, PromptMetadata>,
108}
109
110impl MetadataCache {
111    fn from_db(
112        db: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
113        txn: &RoTxn,
114    ) -> Result<Self> {
115        let mut cache = MetadataCache::default();
116        for result in db.iter(txn)? {
117            let (prompt_id, metadata) = result?;
118            cache.metadata.push(metadata.clone());
119            cache.metadata_by_id.insert(prompt_id, metadata);
120        }
121        cache.sort();
122        Ok(cache)
123    }
124
125    fn insert(&mut self, metadata: PromptMetadata) {
126        self.metadata_by_id.insert(metadata.id, metadata.clone());
127        if let Some(old_metadata) = self.metadata.iter_mut().find(|m| m.id == metadata.id) {
128            *old_metadata = metadata;
129        } else {
130            self.metadata.push(metadata);
131        }
132        self.sort();
133    }
134
135    fn remove(&mut self, id: PromptId) {
136        self.metadata.retain(|metadata| metadata.id != id);
137        self.metadata_by_id.remove(&id);
138    }
139
140    fn sort(&mut self) {
141        self.metadata.sort_unstable_by(|a, b| {
142            a.title
143                .cmp(&b.title)
144                .then_with(|| b.saved_at.cmp(&a.saved_at))
145        });
146    }
147}
148
149impl PromptStore {
150    pub fn global(cx: &App) -> impl Future<Output = Result<Entity<Self>>> + use<> {
151        let store = GlobalPromptStore::global(cx).0.clone();
152        async move { store.await.map_err(|err| anyhow!(err)) }
153    }
154
155    pub fn new(db_path: PathBuf, cx: &App) -> Task<Result<Self>> {
156        cx.background_spawn(async move {
157            std::fs::create_dir_all(&db_path)?;
158
159            let db_env = unsafe {
160                heed::EnvOpenOptions::new()
161                    .map_size(1024 * 1024 * 1024) // 1GB
162                    .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
163                    .open(db_path)?
164            };
165
166            let mut txn = db_env.write_txn()?;
167            let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
168            let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
169
170            // Remove edit workflow prompt, as we decided to opt into it using
171            // a slash command instead.
172            metadata.delete(&mut txn, &PromptId::EditWorkflow).ok();
173            bodies.delete(&mut txn, &PromptId::EditWorkflow).ok();
174
175            txn.commit()?;
176
177            Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
178
179            let txn = db_env.read_txn()?;
180            let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
181            txn.commit()?;
182
183            Ok(PromptStore {
184                env: db_env,
185                metadata_cache: RwLock::new(metadata_cache),
186                metadata,
187                bodies,
188            })
189        })
190    }
191
192    fn upgrade_dbs(
193        env: &heed::Env,
194        metadata_db: heed::Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
195        bodies_db: heed::Database<SerdeJson<PromptId>, Str>,
196    ) -> Result<()> {
197        #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
198        pub struct PromptIdV1(Uuid);
199
200        #[derive(Clone, Debug, Serialize, Deserialize)]
201        pub struct PromptMetadataV1 {
202            pub id: PromptIdV1,
203            pub title: Option<SharedString>,
204            pub default: bool,
205            pub saved_at: DateTime<Utc>,
206        }
207
208        let mut txn = env.write_txn()?;
209        let Some(bodies_v1_db) = env
210            .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<String>>(
211                &txn,
212                Some("bodies"),
213            )?
214        else {
215            return Ok(());
216        };
217        let mut bodies_v1 = bodies_v1_db
218            .iter(&txn)?
219            .collect::<heed::Result<HashMap<_, _>>>()?;
220
221        let Some(metadata_v1_db) = env
222            .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<PromptMetadataV1>>(
223                &txn,
224                Some("metadata"),
225            )?
226        else {
227            return Ok(());
228        };
229        let metadata_v1 = metadata_v1_db
230            .iter(&txn)?
231            .collect::<heed::Result<HashMap<_, _>>>()?;
232
233        for (prompt_id_v1, metadata_v1) in metadata_v1 {
234            let prompt_id_v2 = UserPromptId(prompt_id_v1.0).into();
235            let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
236                continue;
237            };
238
239            if metadata_db
240                .get(&txn, &prompt_id_v2)?
241                .map_or(true, |metadata_v2| {
242                    metadata_v1.saved_at > metadata_v2.saved_at
243                })
244            {
245                metadata_db.put(
246                    &mut txn,
247                    &prompt_id_v2,
248                    &PromptMetadata {
249                        id: prompt_id_v2,
250                        title: metadata_v1.title.clone(),
251                        default: metadata_v1.default,
252                        saved_at: metadata_v1.saved_at,
253                    },
254                )?;
255                bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?;
256            }
257        }
258
259        txn.commit()?;
260
261        Ok(())
262    }
263
264    pub fn load(&self, id: PromptId, cx: &App) -> Task<Result<String>> {
265        let env = self.env.clone();
266        let bodies = self.bodies;
267        cx.background_spawn(async move {
268            let txn = env.read_txn()?;
269            let mut prompt = bodies.get(&txn, &id)?.context("prompt not found")?.into();
270            LineEnding::normalize(&mut prompt);
271            Ok(prompt)
272        })
273    }
274
275    pub fn all_prompt_metadata(&self) -> Vec<PromptMetadata> {
276        self.metadata_cache.read().metadata.clone()
277    }
278
279    pub fn default_prompt_metadata(&self) -> Vec<PromptMetadata> {
280        return self
281            .metadata_cache
282            .read()
283            .metadata
284            .iter()
285            .filter(|metadata| metadata.default)
286            .cloned()
287            .collect::<Vec<_>>();
288    }
289
290    pub fn delete(&self, id: PromptId, cx: &Context<Self>) -> Task<Result<()>> {
291        self.metadata_cache.write().remove(id);
292
293        let db_connection = self.env.clone();
294        let bodies = self.bodies;
295        let metadata = self.metadata;
296
297        let task = cx.background_spawn(async move {
298            let mut txn = db_connection.write_txn()?;
299
300            metadata.delete(&mut txn, &id)?;
301            bodies.delete(&mut txn, &id)?;
302
303            txn.commit()?;
304            anyhow::Ok(())
305        });
306
307        cx.spawn(async move |this, cx| {
308            task.await?;
309            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
310            anyhow::Ok(())
311        })
312    }
313
314    /// Returns the number of prompts in the store.
315    pub fn prompt_count(&self) -> usize {
316        self.metadata_cache.read().metadata.len()
317    }
318
319    pub fn metadata(&self, id: PromptId) -> Option<PromptMetadata> {
320        self.metadata_cache.read().metadata_by_id.get(&id).cloned()
321    }
322
323    pub fn first(&self) -> Option<PromptMetadata> {
324        self.metadata_cache.read().metadata.first().cloned()
325    }
326
327    pub fn id_for_title(&self, title: &str) -> Option<PromptId> {
328        let metadata_cache = self.metadata_cache.read();
329        let metadata = metadata_cache
330            .metadata
331            .iter()
332            .find(|metadata| metadata.title.as_ref().map(|title| &***title) == Some(title))?;
333        Some(metadata.id)
334    }
335
336    pub fn search(
337        &self,
338        query: String,
339        cancellation_flag: Arc<AtomicBool>,
340        cx: &App,
341    ) -> Task<Vec<PromptMetadata>> {
342        let cached_metadata = self.metadata_cache.read().metadata.clone();
343        let executor = cx.background_executor().clone();
344        cx.background_spawn(async move {
345            let mut matches = if query.is_empty() {
346                cached_metadata
347            } else {
348                let candidates = cached_metadata
349                    .iter()
350                    .enumerate()
351                    .filter_map(|(ix, metadata)| {
352                        Some(StringMatchCandidate::new(ix, metadata.title.as_ref()?))
353                    })
354                    .collect::<Vec<_>>();
355                let matches = fuzzy::match_strings(
356                    &candidates,
357                    &query,
358                    false,
359                    true,
360                    100,
361                    &cancellation_flag,
362                    executor,
363                )
364                .await;
365                matches
366                    .into_iter()
367                    .map(|mat| cached_metadata[mat.candidate_id].clone())
368                    .collect()
369            };
370            matches.sort_by_key(|metadata| Reverse(metadata.default));
371            matches
372        })
373    }
374
375    pub fn save(
376        &self,
377        id: PromptId,
378        title: Option<SharedString>,
379        default: bool,
380        body: Rope,
381        cx: &Context<Self>,
382    ) -> Task<Result<()>> {
383        if id.is_built_in() {
384            return Task::ready(Err(anyhow!("built-in prompts cannot be saved")));
385        }
386
387        let prompt_metadata = PromptMetadata {
388            id,
389            title,
390            default,
391            saved_at: Utc::now(),
392        };
393        self.metadata_cache.write().insert(prompt_metadata.clone());
394
395        let db_connection = self.env.clone();
396        let bodies = self.bodies;
397        let metadata = self.metadata;
398
399        let task = cx.background_spawn(async move {
400            let mut txn = db_connection.write_txn()?;
401
402            metadata.put(&mut txn, &id, &prompt_metadata)?;
403            bodies.put(&mut txn, &id, &body.to_string())?;
404
405            txn.commit()?;
406
407            anyhow::Ok(())
408        });
409
410        cx.spawn(async move |this, cx| {
411            task.await?;
412            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
413            anyhow::Ok(())
414        })
415    }
416
417    pub fn save_metadata(
418        &self,
419        id: PromptId,
420        mut title: Option<SharedString>,
421        default: bool,
422        cx: &Context<Self>,
423    ) -> Task<Result<()>> {
424        let mut cache = self.metadata_cache.write();
425
426        if id.is_built_in() {
427            title = cache
428                .metadata_by_id
429                .get(&id)
430                .and_then(|metadata| metadata.title.clone());
431        }
432
433        let prompt_metadata = PromptMetadata {
434            id,
435            title,
436            default,
437            saved_at: Utc::now(),
438        };
439
440        cache.insert(prompt_metadata.clone());
441
442        let db_connection = self.env.clone();
443        let metadata = self.metadata;
444
445        let task = cx.background_spawn(async move {
446            let mut txn = db_connection.write_txn()?;
447            metadata.put(&mut txn, &id, &prompt_metadata)?;
448            txn.commit()?;
449
450            anyhow::Ok(())
451        });
452
453        cx.spawn(async move |this, cx| {
454            task.await?;
455            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
456            anyhow::Ok(())
457        })
458    }
459}
460
461/// Wraps a shared future to a prompt store so it can be assigned as a context global.
462pub struct GlobalPromptStore(Shared<Task<Result<Entity<PromptStore>, Arc<anyhow::Error>>>>);
463
464impl Global for GlobalPromptStore {}