prompt_store.rs

  1mod prompts;
  2
  3use anyhow::{Result, anyhow};
  4use chrono::{DateTime, Utc};
  5use collections::HashMap;
  6use futures::FutureExt as _;
  7use futures::future::{self, BoxFuture, Shared};
  8use fuzzy::StringMatchCandidate;
  9use gpui::{App, BackgroundExecutor, Global, ReadGlobal, SharedString, Task};
 10use heed::{
 11    Database, RoTxn,
 12    types::{SerdeBincode, SerdeJson, Str},
 13};
 14use parking_lot::RwLock;
 15pub use prompts::*;
 16use rope::Rope;
 17use serde::{Deserialize, Serialize};
 18use std::{
 19    cmp::Reverse,
 20    future::Future,
 21    path::PathBuf,
 22    sync::{Arc, atomic::AtomicBool},
 23};
 24use text::LineEnding;
 25use util::ResultExt;
 26use uuid::Uuid;
 27
 28/// Init starts loading the PromptStore in the background and assigns
 29/// a shared future to a global.
 30pub fn init(cx: &mut App) {
 31    let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb");
 32    let prompt_store_future = PromptStore::new(db_path, cx.background_executor().clone())
 33        .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
 34        .boxed()
 35        .shared();
 36    cx.set_global(GlobalPromptStore(prompt_store_future))
 37}
 38
 39#[derive(Clone, Debug, Serialize, Deserialize)]
 40pub struct PromptMetadata {
 41    pub id: PromptId,
 42    pub title: Option<SharedString>,
 43    pub default: bool,
 44    pub saved_at: DateTime<Utc>,
 45}
 46
 47#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
 48#[serde(tag = "kind")]
 49pub enum PromptId {
 50    User { uuid: Uuid },
 51    EditWorkflow,
 52}
 53
 54impl PromptId {
 55    pub fn new() -> PromptId {
 56        PromptId::User {
 57            uuid: Uuid::new_v4(),
 58        }
 59    }
 60
 61    pub fn is_built_in(&self) -> bool {
 62        !matches!(self, PromptId::User { .. })
 63    }
 64}
 65
 66pub struct PromptStore {
 67    executor: BackgroundExecutor,
 68    env: heed::Env,
 69    metadata_cache: RwLock<MetadataCache>,
 70    metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
 71    bodies: Database<SerdeJson<PromptId>, Str>,
 72}
 73
 74#[derive(Default)]
 75struct MetadataCache {
 76    metadata: Vec<PromptMetadata>,
 77    metadata_by_id: HashMap<PromptId, PromptMetadata>,
 78}
 79
 80impl MetadataCache {
 81    fn from_db(
 82        db: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
 83        txn: &RoTxn,
 84    ) -> Result<Self> {
 85        let mut cache = MetadataCache::default();
 86        for result in db.iter(txn)? {
 87            let (prompt_id, metadata) = result?;
 88            cache.metadata.push(metadata.clone());
 89            cache.metadata_by_id.insert(prompt_id, metadata);
 90        }
 91        cache.sort();
 92        Ok(cache)
 93    }
 94
 95    fn insert(&mut self, metadata: PromptMetadata) {
 96        self.metadata_by_id.insert(metadata.id, metadata.clone());
 97        if let Some(old_metadata) = self.metadata.iter_mut().find(|m| m.id == metadata.id) {
 98            *old_metadata = metadata;
 99        } else {
100            self.metadata.push(metadata);
101        }
102        self.sort();
103    }
104
105    fn remove(&mut self, id: PromptId) {
106        self.metadata.retain(|metadata| metadata.id != id);
107        self.metadata_by_id.remove(&id);
108    }
109
110    fn sort(&mut self) {
111        self.metadata.sort_unstable_by(|a, b| {
112            a.title
113                .cmp(&b.title)
114                .then_with(|| b.saved_at.cmp(&a.saved_at))
115        });
116    }
117}
118
119impl PromptStore {
120    pub fn global(cx: &App) -> impl Future<Output = Result<Arc<Self>>> + use<> {
121        let store = GlobalPromptStore::global(cx).0.clone();
122        async move { store.await.map_err(|err| anyhow!(err)) }
123    }
124
125    pub fn new(db_path: PathBuf, executor: BackgroundExecutor) -> Task<Result<Self>> {
126        executor.spawn({
127            let executor = executor.clone();
128            async move {
129                std::fs::create_dir_all(&db_path)?;
130
131                let db_env = unsafe {
132                    heed::EnvOpenOptions::new()
133                        .map_size(1024 * 1024 * 1024) // 1GB
134                        .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
135                        .open(db_path)?
136                };
137
138                let mut txn = db_env.write_txn()?;
139                let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
140                let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
141
142                // Remove edit workflow prompt, as we decided to opt into it using
143                // a slash command instead.
144                metadata.delete(&mut txn, &PromptId::EditWorkflow).ok();
145                bodies.delete(&mut txn, &PromptId::EditWorkflow).ok();
146
147                txn.commit()?;
148
149                Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
150
151                let txn = db_env.read_txn()?;
152                let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
153                txn.commit()?;
154
155                Ok(PromptStore {
156                    executor,
157                    env: db_env,
158                    metadata_cache: RwLock::new(metadata_cache),
159                    metadata,
160                    bodies,
161                })
162            }
163        })
164    }
165
166    fn upgrade_dbs(
167        env: &heed::Env,
168        metadata_db: heed::Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
169        bodies_db: heed::Database<SerdeJson<PromptId>, Str>,
170    ) -> Result<()> {
171        #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
172        pub struct PromptIdV1(Uuid);
173
174        #[derive(Clone, Debug, Serialize, Deserialize)]
175        pub struct PromptMetadataV1 {
176            pub id: PromptIdV1,
177            pub title: Option<SharedString>,
178            pub default: bool,
179            pub saved_at: DateTime<Utc>,
180        }
181
182        let mut txn = env.write_txn()?;
183        let Some(bodies_v1_db) = env
184            .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<String>>(
185                &txn,
186                Some("bodies"),
187            )?
188        else {
189            return Ok(());
190        };
191        let mut bodies_v1 = bodies_v1_db
192            .iter(&txn)?
193            .collect::<heed::Result<HashMap<_, _>>>()?;
194
195        let Some(metadata_v1_db) = env
196            .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<PromptMetadataV1>>(
197                &txn,
198                Some("metadata"),
199            )?
200        else {
201            return Ok(());
202        };
203        let metadata_v1 = metadata_v1_db
204            .iter(&txn)?
205            .collect::<heed::Result<HashMap<_, _>>>()?;
206
207        for (prompt_id_v1, metadata_v1) in metadata_v1 {
208            let prompt_id_v2 = PromptId::User {
209                uuid: prompt_id_v1.0,
210            };
211            let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
212                continue;
213            };
214
215            if metadata_db
216                .get(&txn, &prompt_id_v2)?
217                .map_or(true, |metadata_v2| {
218                    metadata_v1.saved_at > metadata_v2.saved_at
219                })
220            {
221                metadata_db.put(
222                    &mut txn,
223                    &prompt_id_v2,
224                    &PromptMetadata {
225                        id: prompt_id_v2,
226                        title: metadata_v1.title.clone(),
227                        default: metadata_v1.default,
228                        saved_at: metadata_v1.saved_at,
229                    },
230                )?;
231                bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?;
232            }
233        }
234
235        txn.commit()?;
236
237        Ok(())
238    }
239
240    pub fn load(&self, id: PromptId) -> Task<Result<String>> {
241        let env = self.env.clone();
242        let bodies = self.bodies;
243        self.executor.spawn(async move {
244            let txn = env.read_txn()?;
245            let mut prompt = bodies
246                .get(&txn, &id)?
247                .ok_or_else(|| anyhow!("prompt not found"))?
248                .into();
249            LineEnding::normalize(&mut prompt);
250            Ok(prompt)
251        })
252    }
253
254    pub fn default_prompt_metadata(&self) -> Vec<PromptMetadata> {
255        return self
256            .metadata_cache
257            .read()
258            .metadata
259            .iter()
260            .filter(|metadata| metadata.default)
261            .cloned()
262            .collect::<Vec<_>>();
263    }
264
265    pub fn delete(&self, id: PromptId) -> Task<Result<()>> {
266        self.metadata_cache.write().remove(id);
267
268        let db_connection = self.env.clone();
269        let bodies = self.bodies;
270        let metadata = self.metadata;
271
272        self.executor.spawn(async move {
273            let mut txn = db_connection.write_txn()?;
274
275            metadata.delete(&mut txn, &id)?;
276            bodies.delete(&mut txn, &id)?;
277
278            txn.commit()?;
279            Ok(())
280        })
281    }
282
283    /// Returns the number of prompts in the store.
284    pub fn prompt_count(&self) -> usize {
285        self.metadata_cache.read().metadata.len()
286    }
287
288    pub fn metadata(&self, id: PromptId) -> Option<PromptMetadata> {
289        self.metadata_cache.read().metadata_by_id.get(&id).cloned()
290    }
291
292    pub fn first(&self) -> Option<PromptMetadata> {
293        self.metadata_cache.read().metadata.first().cloned()
294    }
295
296    pub fn id_for_title(&self, title: &str) -> Option<PromptId> {
297        let metadata_cache = self.metadata_cache.read();
298        let metadata = metadata_cache
299            .metadata
300            .iter()
301            .find(|metadata| metadata.title.as_ref().map(|title| &***title) == Some(title))?;
302        Some(metadata.id)
303    }
304
305    pub fn search(&self, query: String) -> Task<Vec<PromptMetadata>> {
306        let cached_metadata = self.metadata_cache.read().metadata.clone();
307        let executor = self.executor.clone();
308        self.executor.spawn(async move {
309            let mut matches = if query.is_empty() {
310                cached_metadata
311            } else {
312                let candidates = cached_metadata
313                    .iter()
314                    .enumerate()
315                    .filter_map(|(ix, metadata)| {
316                        Some(StringMatchCandidate::new(ix, metadata.title.as_ref()?))
317                    })
318                    .collect::<Vec<_>>();
319                let matches = fuzzy::match_strings(
320                    &candidates,
321                    &query,
322                    false,
323                    100,
324                    &AtomicBool::default(),
325                    executor,
326                )
327                .await;
328                matches
329                    .into_iter()
330                    .map(|mat| cached_metadata[mat.candidate_id].clone())
331                    .collect()
332            };
333            matches.sort_by_key(|metadata| Reverse(metadata.default));
334            matches
335        })
336    }
337
338    pub fn save(
339        &self,
340        id: PromptId,
341        title: Option<SharedString>,
342        default: bool,
343        body: Rope,
344    ) -> Task<Result<()>> {
345        if id.is_built_in() {
346            return Task::ready(Err(anyhow!("built-in prompts cannot be saved")));
347        }
348
349        let prompt_metadata = PromptMetadata {
350            id,
351            title,
352            default,
353            saved_at: Utc::now(),
354        };
355        self.metadata_cache.write().insert(prompt_metadata.clone());
356
357        let db_connection = self.env.clone();
358        let bodies = self.bodies;
359        let metadata = self.metadata;
360
361        self.executor.spawn(async move {
362            let mut txn = db_connection.write_txn()?;
363
364            metadata.put(&mut txn, &id, &prompt_metadata)?;
365            bodies.put(&mut txn, &id, &body.to_string())?;
366
367            txn.commit()?;
368
369            Ok(())
370        })
371    }
372
373    pub fn save_metadata(
374        &self,
375        id: PromptId,
376        mut title: Option<SharedString>,
377        default: bool,
378    ) -> Task<Result<()>> {
379        let mut cache = self.metadata_cache.write();
380
381        if id.is_built_in() {
382            title = cache
383                .metadata_by_id
384                .get(&id)
385                .and_then(|metadata| metadata.title.clone());
386        }
387
388        let prompt_metadata = PromptMetadata {
389            id,
390            title,
391            default,
392            saved_at: Utc::now(),
393        };
394
395        cache.insert(prompt_metadata.clone());
396
397        let db_connection = self.env.clone();
398        let metadata = self.metadata;
399
400        self.executor.spawn(async move {
401            let mut txn = db_connection.write_txn()?;
402            metadata.put(&mut txn, &id, &prompt_metadata)?;
403            txn.commit()?;
404
405            Ok(())
406        })
407    }
408}
409
410/// Wraps a shared future to a prompt store so it can be assigned as a context global.
411pub struct GlobalPromptStore(
412    Shared<BoxFuture<'static, Result<Arc<PromptStore>, Arc<anyhow::Error>>>>,
413);
414
415impl Global for GlobalPromptStore {}