prompt_store.rs

  1use anyhow::{anyhow, Result};
  2use chrono::{DateTime, Utc};
  3use collections::HashMap;
  4use futures::future::{self, BoxFuture, Shared};
  5use futures::FutureExt as _;
  6use fuzzy::StringMatchCandidate;
  7use gpui::{App, BackgroundExecutor, Global, ReadGlobal, SharedString, Task};
  8use heed::{
  9    types::{SerdeBincode, SerdeJson, Str},
 10    Database, RoTxn,
 11};
 12use parking_lot::RwLock;
 13use rope::Rope;
 14use serde::{Deserialize, Serialize};
 15use std::{
 16    cmp::Reverse,
 17    future::Future,
 18    path::PathBuf,
 19    sync::{atomic::AtomicBool, Arc},
 20};
 21use text::LineEnding;
 22use util::ResultExt;
 23use uuid::Uuid;
 24
 25/// Init starts loading the PromptStore in the background and assigns
 26/// a shared future to a global.
 27pub fn init(cx: &mut App) {
 28    let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb");
 29    let prompt_store_future = PromptStore::new(db_path, cx.background_executor().clone())
 30        .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
 31        .boxed()
 32        .shared();
 33    cx.set_global(GlobalPromptStore(prompt_store_future))
 34}
 35
 36#[derive(Clone, Debug, Serialize, Deserialize)]
 37pub struct PromptMetadata {
 38    pub id: PromptId,
 39    pub title: Option<SharedString>,
 40    pub default: bool,
 41    pub saved_at: DateTime<Utc>,
 42}
 43
 44#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
 45#[serde(tag = "kind")]
 46pub enum PromptId {
 47    User { uuid: Uuid },
 48    EditWorkflow,
 49}
 50
 51impl PromptId {
 52    pub fn new() -> PromptId {
 53        PromptId::User {
 54            uuid: Uuid::new_v4(),
 55        }
 56    }
 57
 58    pub fn is_built_in(&self) -> bool {
 59        !matches!(self, PromptId::User { .. })
 60    }
 61}
 62
 63pub struct PromptStore {
 64    executor: BackgroundExecutor,
 65    env: heed::Env,
 66    metadata_cache: RwLock<MetadataCache>,
 67    metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
 68    bodies: Database<SerdeJson<PromptId>, Str>,
 69}
 70
 71#[derive(Default)]
 72struct MetadataCache {
 73    metadata: Vec<PromptMetadata>,
 74    metadata_by_id: HashMap<PromptId, PromptMetadata>,
 75}
 76
 77impl MetadataCache {
 78    fn from_db(
 79        db: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
 80        txn: &RoTxn,
 81    ) -> Result<Self> {
 82        let mut cache = MetadataCache::default();
 83        for result in db.iter(txn)? {
 84            let (prompt_id, metadata) = result?;
 85            cache.metadata.push(metadata.clone());
 86            cache.metadata_by_id.insert(prompt_id, metadata);
 87        }
 88        cache.sort();
 89        Ok(cache)
 90    }
 91
 92    fn insert(&mut self, metadata: PromptMetadata) {
 93        self.metadata_by_id.insert(metadata.id, metadata.clone());
 94        if let Some(old_metadata) = self.metadata.iter_mut().find(|m| m.id == metadata.id) {
 95            *old_metadata = metadata;
 96        } else {
 97            self.metadata.push(metadata);
 98        }
 99        self.sort();
100    }
101
102    fn remove(&mut self, id: PromptId) {
103        self.metadata.retain(|metadata| metadata.id != id);
104        self.metadata_by_id.remove(&id);
105    }
106
107    fn sort(&mut self) {
108        self.metadata.sort_unstable_by(|a, b| {
109            a.title
110                .cmp(&b.title)
111                .then_with(|| b.saved_at.cmp(&a.saved_at))
112        });
113    }
114}
115
116impl PromptStore {
117    pub fn global(cx: &App) -> impl Future<Output = Result<Arc<Self>>> {
118        let store = GlobalPromptStore::global(cx).0.clone();
119        async move { store.await.map_err(|err| anyhow!(err)) }
120    }
121
122    pub fn new(db_path: PathBuf, executor: BackgroundExecutor) -> Task<Result<Self>> {
123        executor.spawn({
124            let executor = executor.clone();
125            async move {
126                std::fs::create_dir_all(&db_path)?;
127
128                let db_env = unsafe {
129                    heed::EnvOpenOptions::new()
130                        .map_size(1024 * 1024 * 1024) // 1GB
131                        .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
132                        .open(db_path)?
133                };
134
135                let mut txn = db_env.write_txn()?;
136                let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
137                let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
138
139                // Remove edit workflow prompt, as we decided to opt into it using
140                // a slash command instead.
141                metadata.delete(&mut txn, &PromptId::EditWorkflow).ok();
142                bodies.delete(&mut txn, &PromptId::EditWorkflow).ok();
143
144                txn.commit()?;
145
146                Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
147
148                let txn = db_env.read_txn()?;
149                let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
150                txn.commit()?;
151
152                Ok(PromptStore {
153                    executor,
154                    env: db_env,
155                    metadata_cache: RwLock::new(metadata_cache),
156                    metadata,
157                    bodies,
158                })
159            }
160        })
161    }
162
163    fn upgrade_dbs(
164        env: &heed::Env,
165        metadata_db: heed::Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
166        bodies_db: heed::Database<SerdeJson<PromptId>, Str>,
167    ) -> Result<()> {
168        #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
169        pub struct PromptIdV1(Uuid);
170
171        #[derive(Clone, Debug, Serialize, Deserialize)]
172        pub struct PromptMetadataV1 {
173            pub id: PromptIdV1,
174            pub title: Option<SharedString>,
175            pub default: bool,
176            pub saved_at: DateTime<Utc>,
177        }
178
179        let mut txn = env.write_txn()?;
180        let Some(bodies_v1_db) = env
181            .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<String>>(
182                &txn,
183                Some("bodies"),
184            )?
185        else {
186            return Ok(());
187        };
188        let mut bodies_v1 = bodies_v1_db
189            .iter(&txn)?
190            .collect::<heed::Result<HashMap<_, _>>>()?;
191
192        let Some(metadata_v1_db) = env
193            .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<PromptMetadataV1>>(
194                &txn,
195                Some("metadata"),
196            )?
197        else {
198            return Ok(());
199        };
200        let metadata_v1 = metadata_v1_db
201            .iter(&txn)?
202            .collect::<heed::Result<HashMap<_, _>>>()?;
203
204        for (prompt_id_v1, metadata_v1) in metadata_v1 {
205            let prompt_id_v2 = PromptId::User {
206                uuid: prompt_id_v1.0,
207            };
208            let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
209                continue;
210            };
211
212            if metadata_db
213                .get(&txn, &prompt_id_v2)?
214                .map_or(true, |metadata_v2| {
215                    metadata_v1.saved_at > metadata_v2.saved_at
216                })
217            {
218                metadata_db.put(
219                    &mut txn,
220                    &prompt_id_v2,
221                    &PromptMetadata {
222                        id: prompt_id_v2,
223                        title: metadata_v1.title.clone(),
224                        default: metadata_v1.default,
225                        saved_at: metadata_v1.saved_at,
226                    },
227                )?;
228                bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?;
229            }
230        }
231
232        txn.commit()?;
233
234        Ok(())
235    }
236
237    pub fn load(&self, id: PromptId) -> Task<Result<String>> {
238        let env = self.env.clone();
239        let bodies = self.bodies;
240        self.executor.spawn(async move {
241            let txn = env.read_txn()?;
242            let mut prompt = bodies
243                .get(&txn, &id)?
244                .ok_or_else(|| anyhow!("prompt not found"))?
245                .into();
246            LineEnding::normalize(&mut prompt);
247            Ok(prompt)
248        })
249    }
250
251    pub fn default_prompt_metadata(&self) -> Vec<PromptMetadata> {
252        return self
253            .metadata_cache
254            .read()
255            .metadata
256            .iter()
257            .filter(|metadata| metadata.default)
258            .cloned()
259            .collect::<Vec<_>>();
260    }
261
262    pub fn delete(&self, id: PromptId) -> Task<Result<()>> {
263        self.metadata_cache.write().remove(id);
264
265        let db_connection = self.env.clone();
266        let bodies = self.bodies;
267        let metadata = self.metadata;
268
269        self.executor.spawn(async move {
270            let mut txn = db_connection.write_txn()?;
271
272            metadata.delete(&mut txn, &id)?;
273            bodies.delete(&mut txn, &id)?;
274
275            txn.commit()?;
276            Ok(())
277        })
278    }
279
280    /// Returns the number of prompts in the store.
281    pub fn prompt_count(&self) -> usize {
282        self.metadata_cache.read().metadata.len()
283    }
284
285    pub fn metadata(&self, id: PromptId) -> Option<PromptMetadata> {
286        self.metadata_cache.read().metadata_by_id.get(&id).cloned()
287    }
288
289    pub fn first(&self) -> Option<PromptMetadata> {
290        self.metadata_cache.read().metadata.first().cloned()
291    }
292
293    pub fn id_for_title(&self, title: &str) -> Option<PromptId> {
294        let metadata_cache = self.metadata_cache.read();
295        let metadata = metadata_cache
296            .metadata
297            .iter()
298            .find(|metadata| metadata.title.as_ref().map(|title| &***title) == Some(title))?;
299        Some(metadata.id)
300    }
301
302    pub fn search(&self, query: String) -> Task<Vec<PromptMetadata>> {
303        let cached_metadata = self.metadata_cache.read().metadata.clone();
304        let executor = self.executor.clone();
305        self.executor.spawn(async move {
306            let mut matches = if query.is_empty() {
307                cached_metadata
308            } else {
309                let candidates = cached_metadata
310                    .iter()
311                    .enumerate()
312                    .filter_map(|(ix, metadata)| {
313                        Some(StringMatchCandidate::new(ix, metadata.title.as_ref()?))
314                    })
315                    .collect::<Vec<_>>();
316                let matches = fuzzy::match_strings(
317                    &candidates,
318                    &query,
319                    false,
320                    100,
321                    &AtomicBool::default(),
322                    executor,
323                )
324                .await;
325                matches
326                    .into_iter()
327                    .map(|mat| cached_metadata[mat.candidate_id].clone())
328                    .collect()
329            };
330            matches.sort_by_key(|metadata| Reverse(metadata.default));
331            matches
332        })
333    }
334
335    pub fn save(
336        &self,
337        id: PromptId,
338        title: Option<SharedString>,
339        default: bool,
340        body: Rope,
341    ) -> Task<Result<()>> {
342        if id.is_built_in() {
343            return Task::ready(Err(anyhow!("built-in prompts cannot be saved")));
344        }
345
346        let prompt_metadata = PromptMetadata {
347            id,
348            title,
349            default,
350            saved_at: Utc::now(),
351        };
352        self.metadata_cache.write().insert(prompt_metadata.clone());
353
354        let db_connection = self.env.clone();
355        let bodies = self.bodies;
356        let metadata = self.metadata;
357
358        self.executor.spawn(async move {
359            let mut txn = db_connection.write_txn()?;
360
361            metadata.put(&mut txn, &id, &prompt_metadata)?;
362            bodies.put(&mut txn, &id, &body.to_string())?;
363
364            txn.commit()?;
365
366            Ok(())
367        })
368    }
369
370    pub fn save_metadata(
371        &self,
372        id: PromptId,
373        mut title: Option<SharedString>,
374        default: bool,
375    ) -> Task<Result<()>> {
376        let mut cache = self.metadata_cache.write();
377
378        if id.is_built_in() {
379            title = cache
380                .metadata_by_id
381                .get(&id)
382                .and_then(|metadata| metadata.title.clone());
383        }
384
385        let prompt_metadata = PromptMetadata {
386            id,
387            title,
388            default,
389            saved_at: Utc::now(),
390        };
391
392        cache.insert(prompt_metadata.clone());
393
394        let db_connection = self.env.clone();
395        let metadata = self.metadata;
396
397        self.executor.spawn(async move {
398            let mut txn = db_connection.write_txn()?;
399            metadata.put(&mut txn, &id, &prompt_metadata)?;
400            txn.commit()?;
401
402            Ok(())
403        })
404    }
405}
406
407/// Wraps a shared future to a prompt store so it can be assigned as a context global.
408pub struct GlobalPromptStore(
409    Shared<BoxFuture<'static, Result<Arc<PromptStore>, Arc<anyhow::Error>>>>,
410);
411
412impl Global for GlobalPromptStore {}