prompt_store.rs

  1mod prompts;
  2
  3use anyhow::{Result, anyhow};
  4use chrono::{DateTime, Utc};
  5use collections::HashMap;
  6use futures::FutureExt as _;
  7use futures::future::Shared;
  8use fuzzy::StringMatchCandidate;
  9use gpui::{
 10    App, AppContext, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Task,
 11};
 12use heed::{
 13    Database, RoTxn,
 14    types::{SerdeBincode, SerdeJson, Str},
 15};
 16use parking_lot::RwLock;
 17pub use prompts::*;
 18use rope::Rope;
 19use serde::{Deserialize, Serialize};
 20use std::{
 21    cmp::Reverse,
 22    future::Future,
 23    path::PathBuf,
 24    sync::{Arc, atomic::AtomicBool},
 25};
 26use text::LineEnding;
 27use util::ResultExt;
 28use uuid::Uuid;
 29
 30/// Init starts loading the PromptStore in the background and assigns
 31/// a shared future to a global.
 32pub fn init(cx: &mut App) {
 33    let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb");
 34    let prompt_store_task = PromptStore::new(db_path, cx);
 35    let prompt_store_entity_task = cx
 36        .spawn(async move |cx| {
 37            prompt_store_task
 38                .await
 39                .and_then(|prompt_store| cx.new(|_cx| prompt_store))
 40                .map_err(Arc::new)
 41        })
 42        .shared();
 43    cx.set_global(GlobalPromptStore(prompt_store_entity_task))
 44}
 45
 46#[derive(Clone, Debug, Serialize, Deserialize)]
 47pub struct PromptMetadata {
 48    pub id: PromptId,
 49    pub title: Option<SharedString>,
 50    pub default: bool,
 51    pub saved_at: DateTime<Utc>,
 52}
 53
 54#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
 55#[serde(tag = "kind")]
 56pub enum PromptId {
 57    User { uuid: UserPromptId },
 58    EditWorkflow,
 59}
 60
 61impl PromptId {
 62    pub fn new() -> PromptId {
 63        UserPromptId::new().into()
 64    }
 65
 66    pub fn is_built_in(&self) -> bool {
 67        !matches!(self, PromptId::User { .. })
 68    }
 69}
 70
 71impl From<UserPromptId> for PromptId {
 72    fn from(uuid: UserPromptId) -> Self {
 73        PromptId::User { uuid }
 74    }
 75}
 76
 77#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
 78#[serde(transparent)]
 79pub struct UserPromptId(pub Uuid);
 80
 81impl UserPromptId {
 82    pub fn new() -> UserPromptId {
 83        UserPromptId(Uuid::new_v4())
 84    }
 85}
 86
 87impl From<Uuid> for UserPromptId {
 88    fn from(uuid: Uuid) -> Self {
 89        UserPromptId(uuid)
 90    }
 91}
 92
 93pub struct PromptStore {
 94    env: heed::Env,
 95    metadata_cache: RwLock<MetadataCache>,
 96    metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
 97    bodies: Database<SerdeJson<PromptId>, Str>,
 98}
 99
100pub struct PromptsUpdatedEvent;
101
102impl EventEmitter<PromptsUpdatedEvent> for PromptStore {}
103
104#[derive(Default)]
105struct MetadataCache {
106    metadata: Vec<PromptMetadata>,
107    metadata_by_id: HashMap<PromptId, PromptMetadata>,
108}
109
110impl MetadataCache {
111    fn from_db(
112        db: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
113        txn: &RoTxn,
114    ) -> Result<Self> {
115        let mut cache = MetadataCache::default();
116        for result in db.iter(txn)? {
117            let (prompt_id, metadata) = result?;
118            cache.metadata.push(metadata.clone());
119            cache.metadata_by_id.insert(prompt_id, metadata);
120        }
121        cache.sort();
122        Ok(cache)
123    }
124
125    fn insert(&mut self, metadata: PromptMetadata) {
126        self.metadata_by_id.insert(metadata.id, metadata.clone());
127        if let Some(old_metadata) = self.metadata.iter_mut().find(|m| m.id == metadata.id) {
128            *old_metadata = metadata;
129        } else {
130            self.metadata.push(metadata);
131        }
132        self.sort();
133    }
134
135    fn remove(&mut self, id: PromptId) {
136        self.metadata.retain(|metadata| metadata.id != id);
137        self.metadata_by_id.remove(&id);
138    }
139
140    fn sort(&mut self) {
141        self.metadata.sort_unstable_by(|a, b| {
142            a.title
143                .cmp(&b.title)
144                .then_with(|| b.saved_at.cmp(&a.saved_at))
145        });
146    }
147}
148
149impl PromptStore {
150    pub fn global(cx: &App) -> impl Future<Output = Result<Entity<Self>>> + use<> {
151        let store = GlobalPromptStore::global(cx).0.clone();
152        async move { store.await.map_err(|err| anyhow!(err)) }
153    }
154
155    pub fn new(db_path: PathBuf, cx: &App) -> Task<Result<Self>> {
156        cx.background_spawn(async move {
157            std::fs::create_dir_all(&db_path)?;
158
159            let db_env = unsafe {
160                heed::EnvOpenOptions::new()
161                    .map_size(1024 * 1024 * 1024) // 1GB
162                    .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
163                    .open(db_path)?
164            };
165
166            let mut txn = db_env.write_txn()?;
167            let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
168            let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
169
170            // Remove edit workflow prompt, as we decided to opt into it using
171            // a slash command instead.
172            metadata.delete(&mut txn, &PromptId::EditWorkflow).ok();
173            bodies.delete(&mut txn, &PromptId::EditWorkflow).ok();
174
175            txn.commit()?;
176
177            Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
178
179            let txn = db_env.read_txn()?;
180            let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
181            txn.commit()?;
182
183            Ok(PromptStore {
184                env: db_env,
185                metadata_cache: RwLock::new(metadata_cache),
186                metadata,
187                bodies,
188            })
189        })
190    }
191
192    fn upgrade_dbs(
193        env: &heed::Env,
194        metadata_db: heed::Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
195        bodies_db: heed::Database<SerdeJson<PromptId>, Str>,
196    ) -> Result<()> {
197        #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
198        pub struct PromptIdV1(Uuid);
199
200        #[derive(Clone, Debug, Serialize, Deserialize)]
201        pub struct PromptMetadataV1 {
202            pub id: PromptIdV1,
203            pub title: Option<SharedString>,
204            pub default: bool,
205            pub saved_at: DateTime<Utc>,
206        }
207
208        let mut txn = env.write_txn()?;
209        let Some(bodies_v1_db) = env
210            .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<String>>(
211                &txn,
212                Some("bodies"),
213            )?
214        else {
215            return Ok(());
216        };
217        let mut bodies_v1 = bodies_v1_db
218            .iter(&txn)?
219            .collect::<heed::Result<HashMap<_, _>>>()?;
220
221        let Some(metadata_v1_db) = env
222            .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<PromptMetadataV1>>(
223                &txn,
224                Some("metadata"),
225            )?
226        else {
227            return Ok(());
228        };
229        let metadata_v1 = metadata_v1_db
230            .iter(&txn)?
231            .collect::<heed::Result<HashMap<_, _>>>()?;
232
233        for (prompt_id_v1, metadata_v1) in metadata_v1 {
234            let prompt_id_v2 = UserPromptId(prompt_id_v1.0).into();
235            let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
236                continue;
237            };
238
239            if metadata_db
240                .get(&txn, &prompt_id_v2)?
241                .map_or(true, |metadata_v2| {
242                    metadata_v1.saved_at > metadata_v2.saved_at
243                })
244            {
245                metadata_db.put(
246                    &mut txn,
247                    &prompt_id_v2,
248                    &PromptMetadata {
249                        id: prompt_id_v2,
250                        title: metadata_v1.title.clone(),
251                        default: metadata_v1.default,
252                        saved_at: metadata_v1.saved_at,
253                    },
254                )?;
255                bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?;
256            }
257        }
258
259        txn.commit()?;
260
261        Ok(())
262    }
263
264    pub fn load(&self, id: PromptId, cx: &App) -> Task<Result<String>> {
265        let env = self.env.clone();
266        let bodies = self.bodies;
267        cx.background_spawn(async move {
268            let txn = env.read_txn()?;
269            let mut prompt = bodies
270                .get(&txn, &id)?
271                .ok_or_else(|| anyhow!("prompt not found"))?
272                .into();
273            LineEnding::normalize(&mut prompt);
274            Ok(prompt)
275        })
276    }
277
278    pub fn all_prompt_metadata(&self) -> Vec<PromptMetadata> {
279        self.metadata_cache.read().metadata.clone()
280    }
281
282    pub fn default_prompt_metadata(&self) -> Vec<PromptMetadata> {
283        return self
284            .metadata_cache
285            .read()
286            .metadata
287            .iter()
288            .filter(|metadata| metadata.default)
289            .cloned()
290            .collect::<Vec<_>>();
291    }
292
293    pub fn delete(&self, id: PromptId, cx: &Context<Self>) -> Task<Result<()>> {
294        self.metadata_cache.write().remove(id);
295
296        let db_connection = self.env.clone();
297        let bodies = self.bodies;
298        let metadata = self.metadata;
299
300        let task = cx.background_spawn(async move {
301            let mut txn = db_connection.write_txn()?;
302
303            metadata.delete(&mut txn, &id)?;
304            bodies.delete(&mut txn, &id)?;
305
306            txn.commit()?;
307            anyhow::Ok(())
308        });
309
310        cx.spawn(async move |this, cx| {
311            task.await?;
312            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
313            anyhow::Ok(())
314        })
315    }
316
317    /// Returns the number of prompts in the store.
318    pub fn prompt_count(&self) -> usize {
319        self.metadata_cache.read().metadata.len()
320    }
321
322    pub fn metadata(&self, id: PromptId) -> Option<PromptMetadata> {
323        self.metadata_cache.read().metadata_by_id.get(&id).cloned()
324    }
325
326    pub fn first(&self) -> Option<PromptMetadata> {
327        self.metadata_cache.read().metadata.first().cloned()
328    }
329
330    pub fn id_for_title(&self, title: &str) -> Option<PromptId> {
331        let metadata_cache = self.metadata_cache.read();
332        let metadata = metadata_cache
333            .metadata
334            .iter()
335            .find(|metadata| metadata.title.as_ref().map(|title| &***title) == Some(title))?;
336        Some(metadata.id)
337    }
338
339    pub fn search(
340        &self,
341        query: String,
342        cancellation_flag: Arc<AtomicBool>,
343        cx: &App,
344    ) -> Task<Vec<PromptMetadata>> {
345        let cached_metadata = self.metadata_cache.read().metadata.clone();
346        let executor = cx.background_executor().clone();
347        cx.background_spawn(async move {
348            let mut matches = if query.is_empty() {
349                cached_metadata
350            } else {
351                let candidates = cached_metadata
352                    .iter()
353                    .enumerate()
354                    .filter_map(|(ix, metadata)| {
355                        Some(StringMatchCandidate::new(ix, metadata.title.as_ref()?))
356                    })
357                    .collect::<Vec<_>>();
358                let matches = fuzzy::match_strings(
359                    &candidates,
360                    &query,
361                    false,
362                    100,
363                    &cancellation_flag,
364                    executor,
365                )
366                .await;
367                matches
368                    .into_iter()
369                    .map(|mat| cached_metadata[mat.candidate_id].clone())
370                    .collect()
371            };
372            matches.sort_by_key(|metadata| Reverse(metadata.default));
373            matches
374        })
375    }
376
377    pub fn save(
378        &self,
379        id: PromptId,
380        title: Option<SharedString>,
381        default: bool,
382        body: Rope,
383        cx: &Context<Self>,
384    ) -> Task<Result<()>> {
385        if id.is_built_in() {
386            return Task::ready(Err(anyhow!("built-in prompts cannot be saved")));
387        }
388
389        let prompt_metadata = PromptMetadata {
390            id,
391            title,
392            default,
393            saved_at: Utc::now(),
394        };
395        self.metadata_cache.write().insert(prompt_metadata.clone());
396
397        let db_connection = self.env.clone();
398        let bodies = self.bodies;
399        let metadata = self.metadata;
400
401        let task = cx.background_spawn(async move {
402            let mut txn = db_connection.write_txn()?;
403
404            metadata.put(&mut txn, &id, &prompt_metadata)?;
405            bodies.put(&mut txn, &id, &body.to_string())?;
406
407            txn.commit()?;
408
409            anyhow::Ok(())
410        });
411
412        cx.spawn(async move |this, cx| {
413            task.await?;
414            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
415            anyhow::Ok(())
416        })
417    }
418
419    pub fn save_metadata(
420        &self,
421        id: PromptId,
422        mut title: Option<SharedString>,
423        default: bool,
424        cx: &Context<Self>,
425    ) -> Task<Result<()>> {
426        let mut cache = self.metadata_cache.write();
427
428        if id.is_built_in() {
429            title = cache
430                .metadata_by_id
431                .get(&id)
432                .and_then(|metadata| metadata.title.clone());
433        }
434
435        let prompt_metadata = PromptMetadata {
436            id,
437            title,
438            default,
439            saved_at: Utc::now(),
440        };
441
442        cache.insert(prompt_metadata.clone());
443
444        let db_connection = self.env.clone();
445        let metadata = self.metadata;
446
447        let task = cx.background_spawn(async move {
448            let mut txn = db_connection.write_txn()?;
449            metadata.put(&mut txn, &id, &prompt_metadata)?;
450            txn.commit()?;
451
452            anyhow::Ok(())
453        });
454
455        cx.spawn(async move |this, cx| {
456            task.await?;
457            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
458            anyhow::Ok(())
459        })
460    }
461}
462
463/// Wraps a shared future to a prompt store so it can be assigned as a context global.
464pub struct GlobalPromptStore(Shared<Task<Result<Entity<PromptStore>, Arc<anyhow::Error>>>>);
465
466impl Global for GlobalPromptStore {}