1mod prompts;
  2
  3use anyhow::{Context as _, Result, anyhow};
  4use chrono::{DateTime, Utc};
  5use collections::HashMap;
  6use futures::FutureExt as _;
  7use futures::future::Shared;
  8use fuzzy::StringMatchCandidate;
  9use gpui::{
 10    App, AppContext, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Task,
 11};
 12use heed::{
 13    Database, RoTxn,
 14    types::{SerdeBincode, SerdeJson, Str},
 15};
 16use parking_lot::RwLock;
 17pub use prompts::*;
 18use rope::Rope;
 19use serde::{Deserialize, Serialize};
 20use std::{
 21    cmp::Reverse,
 22    future::Future,
 23    path::PathBuf,
 24    sync::{Arc, atomic::AtomicBool},
 25};
 26use text::LineEnding;
 27use util::ResultExt;
 28use uuid::Uuid;
 29
 30/// Init starts loading the PromptStore in the background and assigns
 31/// a shared future to a global.
 32pub fn init(cx: &mut App) {
 33    let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb");
 34    let prompt_store_task = PromptStore::new(db_path, cx);
 35    let prompt_store_entity_task = cx
 36        .spawn(async move |cx| {
 37            prompt_store_task
 38                .await
 39                .and_then(|prompt_store| cx.new(|_cx| prompt_store))
 40                .map_err(Arc::new)
 41        })
 42        .shared();
 43    cx.set_global(GlobalPromptStore(prompt_store_entity_task))
 44}
 45
 46#[derive(Clone, Debug, Serialize, Deserialize)]
 47pub struct PromptMetadata {
 48    pub id: PromptId,
 49    pub title: Option<SharedString>,
 50    pub default: bool,
 51    pub saved_at: DateTime<Utc>,
 52}
 53
 54#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
 55#[serde(tag = "kind")]
 56pub enum PromptId {
 57    User { uuid: UserPromptId },
 58    EditWorkflow,
 59}
 60
 61impl PromptId {
 62    pub fn new() -> PromptId {
 63        UserPromptId::new().into()
 64    }
 65
 66    pub const fn is_built_in(&self) -> bool {
 67        !matches!(self, PromptId::User { .. })
 68    }
 69}
 70
 71impl From<UserPromptId> for PromptId {
 72    fn from(uuid: UserPromptId) -> Self {
 73        PromptId::User { uuid }
 74    }
 75}
 76
 77#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
 78#[serde(transparent)]
 79pub struct UserPromptId(pub Uuid);
 80
 81impl UserPromptId {
 82    pub fn new() -> UserPromptId {
 83        UserPromptId(Uuid::new_v4())
 84    }
 85}
 86
 87impl From<Uuid> for UserPromptId {
 88    fn from(uuid: Uuid) -> Self {
 89        UserPromptId(uuid)
 90    }
 91}
 92
 93impl std::fmt::Display for PromptId {
 94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 95        match self {
 96            PromptId::User { uuid } => write!(f, "{}", uuid.0),
 97            PromptId::EditWorkflow => write!(f, "Edit workflow"),
 98        }
 99    }
100}
101
102pub struct PromptStore {
103    env: heed::Env,
104    metadata_cache: RwLock<MetadataCache>,
105    metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
106    bodies: Database<SerdeJson<PromptId>, Str>,
107}
108
109pub struct PromptsUpdatedEvent;
110
111impl EventEmitter<PromptsUpdatedEvent> for PromptStore {}
112
113#[derive(Default)]
114struct MetadataCache {
115    metadata: Vec<PromptMetadata>,
116    metadata_by_id: HashMap<PromptId, PromptMetadata>,
117}
118
119impl MetadataCache {
120    fn from_db(
121        db: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
122        txn: &RoTxn,
123    ) -> Result<Self> {
124        let mut cache = MetadataCache::default();
125        for result in db.iter(txn)? {
126            let (prompt_id, metadata) = result?;
127            cache.metadata.push(metadata.clone());
128            cache.metadata_by_id.insert(prompt_id, metadata);
129        }
130        cache.sort();
131        Ok(cache)
132    }
133
134    fn insert(&mut self, metadata: PromptMetadata) {
135        self.metadata_by_id.insert(metadata.id, metadata.clone());
136        if let Some(old_metadata) = self.metadata.iter_mut().find(|m| m.id == metadata.id) {
137            *old_metadata = metadata;
138        } else {
139            self.metadata.push(metadata);
140        }
141        self.sort();
142    }
143
144    fn remove(&mut self, id: PromptId) {
145        self.metadata.retain(|metadata| metadata.id != id);
146        self.metadata_by_id.remove(&id);
147    }
148
149    fn sort(&mut self) {
150        self.metadata.sort_unstable_by(|a, b| {
151            a.title
152                .cmp(&b.title)
153                .then_with(|| b.saved_at.cmp(&a.saved_at))
154        });
155    }
156}
157
158impl PromptStore {
159    pub fn global(cx: &App) -> impl Future<Output = Result<Entity<Self>>> + use<> {
160        let store = GlobalPromptStore::global(cx).0.clone();
161        async move { store.await.map_err(|err| anyhow!(err)) }
162    }
163
164    pub fn new(db_path: PathBuf, cx: &App) -> Task<Result<Self>> {
165        cx.background_spawn(async move {
166            std::fs::create_dir_all(&db_path)?;
167
168            let db_env = unsafe {
169                heed::EnvOpenOptions::new()
170                    .map_size(1024 * 1024 * 1024) // 1GB
171                    .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
172                    .open(db_path)?
173            };
174
175            let mut txn = db_env.write_txn()?;
176            let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
177            let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
178
179            // Remove edit workflow prompt, as we decided to opt into it using
180            // a slash command instead.
181            metadata.delete(&mut txn, &PromptId::EditWorkflow).ok();
182            bodies.delete(&mut txn, &PromptId::EditWorkflow).ok();
183
184            txn.commit()?;
185
186            Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
187
188            let txn = db_env.read_txn()?;
189            let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
190            txn.commit()?;
191
192            Ok(PromptStore {
193                env: db_env,
194                metadata_cache: RwLock::new(metadata_cache),
195                metadata,
196                bodies,
197            })
198        })
199    }
200
201    fn upgrade_dbs(
202        env: &heed::Env,
203        metadata_db: heed::Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
204        bodies_db: heed::Database<SerdeJson<PromptId>, Str>,
205    ) -> Result<()> {
206        #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
207        pub struct PromptIdV1(Uuid);
208
209        #[derive(Clone, Debug, Serialize, Deserialize)]
210        pub struct PromptMetadataV1 {
211            pub id: PromptIdV1,
212            pub title: Option<SharedString>,
213            pub default: bool,
214            pub saved_at: DateTime<Utc>,
215        }
216
217        let mut txn = env.write_txn()?;
218        let Some(bodies_v1_db) = env
219            .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<String>>(
220                &txn,
221                Some("bodies"),
222            )?
223        else {
224            return Ok(());
225        };
226        let mut bodies_v1 = bodies_v1_db
227            .iter(&txn)?
228            .collect::<heed::Result<HashMap<_, _>>>()?;
229
230        let Some(metadata_v1_db) = env
231            .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<PromptMetadataV1>>(
232                &txn,
233                Some("metadata"),
234            )?
235        else {
236            return Ok(());
237        };
238        let metadata_v1 = metadata_v1_db
239            .iter(&txn)?
240            .collect::<heed::Result<HashMap<_, _>>>()?;
241
242        for (prompt_id_v1, metadata_v1) in metadata_v1 {
243            let prompt_id_v2 = UserPromptId(prompt_id_v1.0).into();
244            let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
245                continue;
246            };
247
248            if metadata_db
249                .get(&txn, &prompt_id_v2)?
250                .is_none_or(|metadata_v2| metadata_v1.saved_at > metadata_v2.saved_at)
251            {
252                metadata_db.put(
253                    &mut txn,
254                    &prompt_id_v2,
255                    &PromptMetadata {
256                        id: prompt_id_v2,
257                        title: metadata_v1.title.clone(),
258                        default: metadata_v1.default,
259                        saved_at: metadata_v1.saved_at,
260                    },
261                )?;
262                bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?;
263            }
264        }
265
266        txn.commit()?;
267
268        Ok(())
269    }
270
271    pub fn load(&self, id: PromptId, cx: &App) -> Task<Result<String>> {
272        let env = self.env.clone();
273        let bodies = self.bodies;
274        cx.background_spawn(async move {
275            let txn = env.read_txn()?;
276            let mut prompt = bodies.get(&txn, &id)?.context("prompt not found")?.into();
277            LineEnding::normalize(&mut prompt);
278            Ok(prompt)
279        })
280    }
281
282    pub fn all_prompt_metadata(&self) -> Vec<PromptMetadata> {
283        self.metadata_cache.read().metadata.clone()
284    }
285
286    pub fn default_prompt_metadata(&self) -> Vec<PromptMetadata> {
287        return self
288            .metadata_cache
289            .read()
290            .metadata
291            .iter()
292            .filter(|metadata| metadata.default)
293            .cloned()
294            .collect::<Vec<_>>();
295    }
296
297    pub fn delete(&self, id: PromptId, cx: &Context<Self>) -> Task<Result<()>> {
298        self.metadata_cache.write().remove(id);
299
300        let db_connection = self.env.clone();
301        let bodies = self.bodies;
302        let metadata = self.metadata;
303
304        let task = cx.background_spawn(async move {
305            let mut txn = db_connection.write_txn()?;
306
307            metadata.delete(&mut txn, &id)?;
308            bodies.delete(&mut txn, &id)?;
309
310            txn.commit()?;
311            anyhow::Ok(())
312        });
313
314        cx.spawn(async move |this, cx| {
315            task.await?;
316            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
317            anyhow::Ok(())
318        })
319    }
320
321    /// Returns the number of prompts in the store.
322    pub fn prompt_count(&self) -> usize {
323        self.metadata_cache.read().metadata.len()
324    }
325
326    pub fn metadata(&self, id: PromptId) -> Option<PromptMetadata> {
327        self.metadata_cache.read().metadata_by_id.get(&id).cloned()
328    }
329
330    pub fn first(&self) -> Option<PromptMetadata> {
331        self.metadata_cache.read().metadata.first().cloned()
332    }
333
334    pub fn id_for_title(&self, title: &str) -> Option<PromptId> {
335        let metadata_cache = self.metadata_cache.read();
336        let metadata = metadata_cache
337            .metadata
338            .iter()
339            .find(|metadata| metadata.title.as_ref().map(|title| &***title) == Some(title))?;
340        Some(metadata.id)
341    }
342
343    pub fn search(
344        &self,
345        query: String,
346        cancellation_flag: Arc<AtomicBool>,
347        cx: &App,
348    ) -> Task<Vec<PromptMetadata>> {
349        let cached_metadata = self.metadata_cache.read().metadata.clone();
350        let executor = cx.background_executor().clone();
351        cx.background_spawn(async move {
352            let mut matches = if query.is_empty() {
353                cached_metadata
354            } else {
355                let candidates = cached_metadata
356                    .iter()
357                    .enumerate()
358                    .filter_map(|(ix, metadata)| {
359                        Some(StringMatchCandidate::new(ix, metadata.title.as_ref()?))
360                    })
361                    .collect::<Vec<_>>();
362                let matches = fuzzy::match_strings(
363                    &candidates,
364                    &query,
365                    false,
366                    true,
367                    100,
368                    &cancellation_flag,
369                    executor,
370                )
371                .await;
372                matches
373                    .into_iter()
374                    .map(|mat| cached_metadata[mat.candidate_id].clone())
375                    .collect()
376            };
377            matches.sort_by_key(|metadata| Reverse(metadata.default));
378            matches
379        })
380    }
381
382    pub fn save(
383        &self,
384        id: PromptId,
385        title: Option<SharedString>,
386        default: bool,
387        body: Rope,
388        cx: &Context<Self>,
389    ) -> Task<Result<()>> {
390        if id.is_built_in() {
391            return Task::ready(Err(anyhow!("built-in prompts cannot be saved")));
392        }
393
394        let prompt_metadata = PromptMetadata {
395            id,
396            title,
397            default,
398            saved_at: Utc::now(),
399        };
400        self.metadata_cache.write().insert(prompt_metadata.clone());
401
402        let db_connection = self.env.clone();
403        let bodies = self.bodies;
404        let metadata = self.metadata;
405
406        let task = cx.background_spawn(async move {
407            let mut txn = db_connection.write_txn()?;
408
409            metadata.put(&mut txn, &id, &prompt_metadata)?;
410            bodies.put(&mut txn, &id, &body.to_string())?;
411
412            txn.commit()?;
413
414            anyhow::Ok(())
415        });
416
417        cx.spawn(async move |this, cx| {
418            task.await?;
419            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
420            anyhow::Ok(())
421        })
422    }
423
424    pub fn save_metadata(
425        &self,
426        id: PromptId,
427        mut title: Option<SharedString>,
428        default: bool,
429        cx: &Context<Self>,
430    ) -> Task<Result<()>> {
431        let mut cache = self.metadata_cache.write();
432
433        if id.is_built_in() {
434            title = cache
435                .metadata_by_id
436                .get(&id)
437                .and_then(|metadata| metadata.title.clone());
438        }
439
440        let prompt_metadata = PromptMetadata {
441            id,
442            title,
443            default,
444            saved_at: Utc::now(),
445        };
446
447        cache.insert(prompt_metadata.clone());
448
449        let db_connection = self.env.clone();
450        let metadata = self.metadata;
451
452        let task = cx.background_spawn(async move {
453            let mut txn = db_connection.write_txn()?;
454            metadata.put(&mut txn, &id, &prompt_metadata)?;
455            txn.commit()?;
456
457            anyhow::Ok(())
458        });
459
460        cx.spawn(async move |this, cx| {
461            task.await?;
462            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
463            anyhow::Ok(())
464        })
465    }
466}
467
468/// Wraps a shared future to a prompt store so it can be assigned as a context global.
469pub struct GlobalPromptStore(Shared<Task<Result<Entity<PromptStore>, Arc<anyhow::Error>>>>);
470
471impl Global for GlobalPromptStore {}