prompt_store.rs

  1mod prompts;
  2
  3use anyhow::{Result, anyhow};
  4use chrono::{DateTime, Utc};
  5use collections::HashMap;
  6use futures::FutureExt as _;
  7use futures::future::Shared;
  8use fuzzy::StringMatchCandidate;
  9use gpui::{
 10    App, AppContext, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Task,
 11};
 12use heed::{
 13    Database, RoTxn,
 14    types::{SerdeBincode, SerdeJson, Str},
 15};
 16use parking_lot::RwLock;
 17pub use prompts::*;
 18use rope::Rope;
 19use serde::{Deserialize, Serialize};
 20use std::{
 21    cmp::Reverse,
 22    future::Future,
 23    path::PathBuf,
 24    sync::{Arc, atomic::AtomicBool},
 25};
 26use text::LineEnding;
 27use util::ResultExt;
 28use uuid::Uuid;
 29
 30/// Init starts loading the PromptStore in the background and assigns
 31/// a shared future to a global.
 32pub fn init(cx: &mut App) {
 33    let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb");
 34    let prompt_store_task = PromptStore::new(db_path, cx);
 35    let prompt_store_entity_task = cx
 36        .spawn(async move |cx| {
 37            prompt_store_task
 38                .await
 39                .and_then(|prompt_store| cx.new(|_cx| prompt_store))
 40                .map_err(Arc::new)
 41        })
 42        .shared();
 43    cx.set_global(GlobalPromptStore(prompt_store_entity_task))
 44}
 45
 46#[derive(Clone, Debug, Serialize, Deserialize)]
 47pub struct PromptMetadata {
 48    pub id: PromptId,
 49    pub title: Option<SharedString>,
 50    pub default: bool,
 51    pub saved_at: DateTime<Utc>,
 52}
 53
 54#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
 55#[serde(tag = "kind")]
 56pub enum PromptId {
 57    User { uuid: Uuid },
 58    EditWorkflow,
 59}
 60
 61impl PromptId {
 62    pub fn new() -> PromptId {
 63        PromptId::User {
 64            uuid: Uuid::new_v4(),
 65        }
 66    }
 67
 68    pub fn is_built_in(&self) -> bool {
 69        !matches!(self, PromptId::User { .. })
 70    }
 71}
 72
 73pub struct PromptStore {
 74    env: heed::Env,
 75    metadata_cache: RwLock<MetadataCache>,
 76    metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
 77    bodies: Database<SerdeJson<PromptId>, Str>,
 78}
 79
 80pub struct PromptsUpdatedEvent;
 81
 82impl EventEmitter<PromptsUpdatedEvent> for PromptStore {}
 83
 84#[derive(Default)]
 85struct MetadataCache {
 86    metadata: Vec<PromptMetadata>,
 87    metadata_by_id: HashMap<PromptId, PromptMetadata>,
 88}
 89
 90impl MetadataCache {
 91    fn from_db(
 92        db: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
 93        txn: &RoTxn,
 94    ) -> Result<Self> {
 95        let mut cache = MetadataCache::default();
 96        for result in db.iter(txn)? {
 97            let (prompt_id, metadata) = result?;
 98            cache.metadata.push(metadata.clone());
 99            cache.metadata_by_id.insert(prompt_id, metadata);
100        }
101        cache.sort();
102        Ok(cache)
103    }
104
105    fn insert(&mut self, metadata: PromptMetadata) {
106        self.metadata_by_id.insert(metadata.id, metadata.clone());
107        if let Some(old_metadata) = self.metadata.iter_mut().find(|m| m.id == metadata.id) {
108            *old_metadata = metadata;
109        } else {
110            self.metadata.push(metadata);
111        }
112        self.sort();
113    }
114
115    fn remove(&mut self, id: PromptId) {
116        self.metadata.retain(|metadata| metadata.id != id);
117        self.metadata_by_id.remove(&id);
118    }
119
120    fn sort(&mut self) {
121        self.metadata.sort_unstable_by(|a, b| {
122            a.title
123                .cmp(&b.title)
124                .then_with(|| b.saved_at.cmp(&a.saved_at))
125        });
126    }
127}
128
129impl PromptStore {
130    pub fn global(cx: &App) -> impl Future<Output = Result<Entity<Self>>> + use<> {
131        let store = GlobalPromptStore::global(cx).0.clone();
132        async move { store.await.map_err(|err| anyhow!(err)) }
133    }
134
135    pub fn new(db_path: PathBuf, cx: &App) -> Task<Result<Self>> {
136        cx.background_spawn(async move {
137            std::fs::create_dir_all(&db_path)?;
138
139            let db_env = unsafe {
140                heed::EnvOpenOptions::new()
141                    .map_size(1024 * 1024 * 1024) // 1GB
142                    .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
143                    .open(db_path)?
144            };
145
146            let mut txn = db_env.write_txn()?;
147            let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
148            let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
149
150            // Remove edit workflow prompt, as we decided to opt into it using
151            // a slash command instead.
152            metadata.delete(&mut txn, &PromptId::EditWorkflow).ok();
153            bodies.delete(&mut txn, &PromptId::EditWorkflow).ok();
154
155            txn.commit()?;
156
157            Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
158
159            let txn = db_env.read_txn()?;
160            let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
161            txn.commit()?;
162
163            Ok(PromptStore {
164                env: db_env,
165                metadata_cache: RwLock::new(metadata_cache),
166                metadata,
167                bodies,
168            })
169        })
170    }
171
172    fn upgrade_dbs(
173        env: &heed::Env,
174        metadata_db: heed::Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
175        bodies_db: heed::Database<SerdeJson<PromptId>, Str>,
176    ) -> Result<()> {
177        #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
178        pub struct PromptIdV1(Uuid);
179
180        #[derive(Clone, Debug, Serialize, Deserialize)]
181        pub struct PromptMetadataV1 {
182            pub id: PromptIdV1,
183            pub title: Option<SharedString>,
184            pub default: bool,
185            pub saved_at: DateTime<Utc>,
186        }
187
188        let mut txn = env.write_txn()?;
189        let Some(bodies_v1_db) = env
190            .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<String>>(
191                &txn,
192                Some("bodies"),
193            )?
194        else {
195            return Ok(());
196        };
197        let mut bodies_v1 = bodies_v1_db
198            .iter(&txn)?
199            .collect::<heed::Result<HashMap<_, _>>>()?;
200
201        let Some(metadata_v1_db) = env
202            .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<PromptMetadataV1>>(
203                &txn,
204                Some("metadata"),
205            )?
206        else {
207            return Ok(());
208        };
209        let metadata_v1 = metadata_v1_db
210            .iter(&txn)?
211            .collect::<heed::Result<HashMap<_, _>>>()?;
212
213        for (prompt_id_v1, metadata_v1) in metadata_v1 {
214            let prompt_id_v2 = PromptId::User {
215                uuid: prompt_id_v1.0,
216            };
217            let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
218                continue;
219            };
220
221            if metadata_db
222                .get(&txn, &prompt_id_v2)?
223                .map_or(true, |metadata_v2| {
224                    metadata_v1.saved_at > metadata_v2.saved_at
225                })
226            {
227                metadata_db.put(
228                    &mut txn,
229                    &prompt_id_v2,
230                    &PromptMetadata {
231                        id: prompt_id_v2,
232                        title: metadata_v1.title.clone(),
233                        default: metadata_v1.default,
234                        saved_at: metadata_v1.saved_at,
235                    },
236                )?;
237                bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?;
238            }
239        }
240
241        txn.commit()?;
242
243        Ok(())
244    }
245
246    pub fn load(&self, id: PromptId, cx: &App) -> Task<Result<String>> {
247        let env = self.env.clone();
248        let bodies = self.bodies;
249        cx.background_spawn(async move {
250            let txn = env.read_txn()?;
251            let mut prompt = bodies
252                .get(&txn, &id)?
253                .ok_or_else(|| anyhow!("prompt not found"))?
254                .into();
255            LineEnding::normalize(&mut prompt);
256            Ok(prompt)
257        })
258    }
259
260    pub fn default_prompt_metadata(&self) -> Vec<PromptMetadata> {
261        return self
262            .metadata_cache
263            .read()
264            .metadata
265            .iter()
266            .filter(|metadata| metadata.default)
267            .cloned()
268            .collect::<Vec<_>>();
269    }
270
271    pub fn delete(&self, id: PromptId, cx: &Context<Self>) -> Task<Result<()>> {
272        self.metadata_cache.write().remove(id);
273
274        let db_connection = self.env.clone();
275        let bodies = self.bodies;
276        let metadata = self.metadata;
277
278        let task = cx.background_spawn(async move {
279            let mut txn = db_connection.write_txn()?;
280
281            metadata.delete(&mut txn, &id)?;
282            bodies.delete(&mut txn, &id)?;
283
284            txn.commit()?;
285            anyhow::Ok(())
286        });
287
288        cx.spawn(async move |this, cx| {
289            task.await?;
290            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
291            anyhow::Ok(())
292        })
293    }
294
295    /// Returns the number of prompts in the store.
296    pub fn prompt_count(&self) -> usize {
297        self.metadata_cache.read().metadata.len()
298    }
299
300    pub fn metadata(&self, id: PromptId) -> Option<PromptMetadata> {
301        self.metadata_cache.read().metadata_by_id.get(&id).cloned()
302    }
303
304    pub fn first(&self) -> Option<PromptMetadata> {
305        self.metadata_cache.read().metadata.first().cloned()
306    }
307
308    pub fn id_for_title(&self, title: &str) -> Option<PromptId> {
309        let metadata_cache = self.metadata_cache.read();
310        let metadata = metadata_cache
311            .metadata
312            .iter()
313            .find(|metadata| metadata.title.as_ref().map(|title| &***title) == Some(title))?;
314        Some(metadata.id)
315    }
316
317    pub fn search(&self, query: String, cx: &App) -> Task<Vec<PromptMetadata>> {
318        let cached_metadata = self.metadata_cache.read().metadata.clone();
319        let executor = cx.background_executor().clone();
320        cx.background_spawn(async move {
321            let mut matches = if query.is_empty() {
322                cached_metadata
323            } else {
324                let candidates = cached_metadata
325                    .iter()
326                    .enumerate()
327                    .filter_map(|(ix, metadata)| {
328                        Some(StringMatchCandidate::new(ix, metadata.title.as_ref()?))
329                    })
330                    .collect::<Vec<_>>();
331                let matches = fuzzy::match_strings(
332                    &candidates,
333                    &query,
334                    false,
335                    100,
336                    &AtomicBool::default(),
337                    executor,
338                )
339                .await;
340                matches
341                    .into_iter()
342                    .map(|mat| cached_metadata[mat.candidate_id].clone())
343                    .collect()
344            };
345            matches.sort_by_key(|metadata| Reverse(metadata.default));
346            matches
347        })
348    }
349
350    pub fn save(
351        &self,
352        id: PromptId,
353        title: Option<SharedString>,
354        default: bool,
355        body: Rope,
356        cx: &Context<Self>,
357    ) -> Task<Result<()>> {
358        if id.is_built_in() {
359            return Task::ready(Err(anyhow!("built-in prompts cannot be saved")));
360        }
361
362        let prompt_metadata = PromptMetadata {
363            id,
364            title,
365            default,
366            saved_at: Utc::now(),
367        };
368        self.metadata_cache.write().insert(prompt_metadata.clone());
369
370        let db_connection = self.env.clone();
371        let bodies = self.bodies;
372        let metadata = self.metadata;
373
374        let task = cx.background_spawn(async move {
375            let mut txn = db_connection.write_txn()?;
376
377            metadata.put(&mut txn, &id, &prompt_metadata)?;
378            bodies.put(&mut txn, &id, &body.to_string())?;
379
380            txn.commit()?;
381
382            anyhow::Ok(())
383        });
384
385        cx.spawn(async move |this, cx| {
386            task.await?;
387            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
388            anyhow::Ok(())
389        })
390    }
391
392    pub fn save_metadata(
393        &self,
394        id: PromptId,
395        mut title: Option<SharedString>,
396        default: bool,
397        cx: &Context<Self>,
398    ) -> Task<Result<()>> {
399        let mut cache = self.metadata_cache.write();
400
401        if id.is_built_in() {
402            title = cache
403                .metadata_by_id
404                .get(&id)
405                .and_then(|metadata| metadata.title.clone());
406        }
407
408        let prompt_metadata = PromptMetadata {
409            id,
410            title,
411            default,
412            saved_at: Utc::now(),
413        };
414
415        cache.insert(prompt_metadata.clone());
416
417        let db_connection = self.env.clone();
418        let metadata = self.metadata;
419
420        let task = cx.background_spawn(async move {
421            let mut txn = db_connection.write_txn()?;
422            metadata.put(&mut txn, &id, &prompt_metadata)?;
423            txn.commit()?;
424
425            anyhow::Ok(())
426        });
427
428        cx.spawn(async move |this, cx| {
429            task.await?;
430            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
431            anyhow::Ok(())
432        })
433    }
434}
435
436/// Wraps a shared future to a prompt store so it can be assigned as a context global.
437pub struct GlobalPromptStore(Shared<Task<Result<Entity<PromptStore>, Arc<anyhow::Error>>>>);
438
439impl Global for GlobalPromptStore {}