1mod prompts;
2
3use anyhow::{Result, anyhow};
4use chrono::{DateTime, Utc};
5use collections::HashMap;
6use futures::FutureExt as _;
7use futures::future::Shared;
8use fuzzy::StringMatchCandidate;
9use gpui::{
10 App, AppContext, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Task,
11};
12use heed::{
13 Database, RoTxn,
14 types::{SerdeBincode, SerdeJson, Str},
15};
16use parking_lot::RwLock;
17pub use prompts::*;
18use rope::Rope;
19use serde::{Deserialize, Serialize};
20use std::{
21 cmp::Reverse,
22 future::Future,
23 path::PathBuf,
24 sync::{Arc, atomic::AtomicBool},
25};
26use strum::{EnumIter, IntoEnumIterator as _};
27use text::LineEnding;
28use util::ResultExt;
29use uuid::Uuid;
30
31/// Init starts loading the PromptStore in the background and assigns
32/// a shared future to a global.
33pub fn init(cx: &mut App) {
34 let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb");
35 let prompt_store_task = PromptStore::new(db_path, cx);
36 let prompt_store_entity_task = cx
37 .spawn(async move |cx| {
38 prompt_store_task
39 .await
40 .and_then(|prompt_store| cx.new(|_cx| prompt_store))
41 .map_err(Arc::new)
42 })
43 .shared();
44 cx.set_global(GlobalPromptStore(prompt_store_entity_task))
45}
46
47#[derive(Clone, Debug, Serialize, Deserialize)]
48pub struct PromptMetadata {
49 pub id: PromptId,
50 pub title: Option<SharedString>,
51 pub default: bool,
52 pub saved_at: DateTime<Utc>,
53}
54
55impl PromptMetadata {
56 fn builtin(builtin: BuiltInPrompt) -> Self {
57 Self {
58 id: PromptId::BuiltIn(builtin),
59 title: Some(builtin.title().into()),
60 default: false,
61 saved_at: DateTime::default(),
62 }
63 }
64}
65
66/// Built-in prompts that have default content and can be customized by users.
67#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, EnumIter)]
68pub enum BuiltInPrompt {
69 CommitMessage,
70}
71
72impl BuiltInPrompt {
73 pub fn title(&self) -> &'static str {
74 match self {
75 Self::CommitMessage => "Commit message",
76 }
77 }
78
79 /// Returns the default content for this built-in prompt.
80 pub fn default_content(&self) -> &'static str {
81 match self {
82 Self::CommitMessage => include_str!("../../git_ui/src/commit_message_prompt.txt"),
83 }
84 }
85}
86
87impl std::fmt::Display for BuiltInPrompt {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 match self {
90 Self::CommitMessage => write!(f, "Commit message"),
91 }
92 }
93}
94
95#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
96#[serde(tag = "kind")]
97pub enum PromptId {
98 User { uuid: UserPromptId },
99 BuiltIn(BuiltInPrompt),
100}
101
102impl PromptId {
103 pub fn new() -> PromptId {
104 UserPromptId::new().into()
105 }
106
107 pub fn as_user(&self) -> Option<UserPromptId> {
108 match self {
109 Self::User { uuid } => Some(*uuid),
110 Self::BuiltIn { .. } => None,
111 }
112 }
113
114 pub fn as_built_in(&self) -> Option<BuiltInPrompt> {
115 match self {
116 Self::User { .. } => None,
117 Self::BuiltIn(builtin) => Some(*builtin),
118 }
119 }
120
121 pub fn is_built_in(&self) -> bool {
122 matches!(self, Self::BuiltIn { .. })
123 }
124
125 pub fn can_edit(&self) -> bool {
126 match self {
127 Self::User { .. } => true,
128 Self::BuiltIn(builtin) => match builtin {
129 BuiltInPrompt::CommitMessage => true,
130 },
131 }
132 }
133}
134
135impl From<BuiltInPrompt> for PromptId {
136 fn from(builtin: BuiltInPrompt) -> Self {
137 PromptId::BuiltIn(builtin)
138 }
139}
140
141impl From<UserPromptId> for PromptId {
142 fn from(uuid: UserPromptId) -> Self {
143 PromptId::User { uuid }
144 }
145}
146
147#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
148#[serde(transparent)]
149pub struct UserPromptId(pub Uuid);
150
151impl UserPromptId {
152 pub fn new() -> UserPromptId {
153 UserPromptId(Uuid::new_v4())
154 }
155}
156
157impl From<Uuid> for UserPromptId {
158 fn from(uuid: Uuid) -> Self {
159 UserPromptId(uuid)
160 }
161}
162
163impl std::fmt::Display for PromptId {
164 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165 match self {
166 PromptId::User { uuid } => write!(f, "{}", uuid.0),
167 PromptId::BuiltIn(builtin) => write!(f, "{}", builtin),
168 }
169 }
170}
171
172pub struct PromptStore {
173 env: heed::Env,
174 metadata_cache: RwLock<MetadataCache>,
175 metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
176 bodies: Database<SerdeJson<PromptId>, Str>,
177}
178
179pub struct PromptsUpdatedEvent;
180
181impl EventEmitter<PromptsUpdatedEvent> for PromptStore {}
182
183#[derive(Default)]
184struct MetadataCache {
185 metadata: Vec<PromptMetadata>,
186 metadata_by_id: HashMap<PromptId, PromptMetadata>,
187}
188
189impl MetadataCache {
190 fn from_db(
191 db: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
192 txn: &RoTxn,
193 ) -> Result<Self> {
194 let mut cache = MetadataCache::default();
195 for result in db.iter(txn)? {
196 let (prompt_id, metadata) = result?;
197 cache.metadata.push(metadata.clone());
198 cache.metadata_by_id.insert(prompt_id, metadata);
199 }
200
201 // Insert all the built-in prompts that were not customized by the user
202 for builtin in BuiltInPrompt::iter() {
203 let builtin_id = PromptId::BuiltIn(builtin);
204 if !cache.metadata_by_id.contains_key(&builtin_id) {
205 let metadata = PromptMetadata::builtin(builtin);
206 cache.metadata.push(metadata.clone());
207 cache.metadata_by_id.insert(builtin_id, metadata);
208 }
209 }
210 cache.sort();
211 Ok(cache)
212 }
213
214 fn insert(&mut self, metadata: PromptMetadata) {
215 self.metadata_by_id.insert(metadata.id, metadata.clone());
216 if let Some(old_metadata) = self.metadata.iter_mut().find(|m| m.id == metadata.id) {
217 *old_metadata = metadata;
218 } else {
219 self.metadata.push(metadata);
220 }
221 self.sort();
222 }
223
224 fn remove(&mut self, id: PromptId) {
225 self.metadata.retain(|metadata| metadata.id != id);
226 self.metadata_by_id.remove(&id);
227 }
228
229 fn sort(&mut self) {
230 self.metadata.sort_unstable_by(|a, b| {
231 a.title
232 .cmp(&b.title)
233 .then_with(|| b.saved_at.cmp(&a.saved_at))
234 });
235 }
236}
237
238impl PromptStore {
239 pub fn global(cx: &App) -> impl Future<Output = Result<Entity<Self>>> + use<> {
240 let store = GlobalPromptStore::global(cx).0.clone();
241 async move { store.await.map_err(|err| anyhow!(err)) }
242 }
243
244 pub fn new(db_path: PathBuf, cx: &App) -> Task<Result<Self>> {
245 cx.background_spawn(async move {
246 std::fs::create_dir_all(&db_path)?;
247
248 let db_env = unsafe {
249 heed::EnvOpenOptions::new()
250 .map_size(1024 * 1024 * 1024) // 1GB
251 .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
252 .open(db_path)?
253 };
254
255 let mut txn = db_env.write_txn()?;
256 let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
257 let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
258 txn.commit()?;
259
260 Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
261
262 let txn = db_env.read_txn()?;
263 let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
264 txn.commit()?;
265
266 Ok(PromptStore {
267 env: db_env,
268 metadata_cache: RwLock::new(metadata_cache),
269 metadata,
270 bodies,
271 })
272 })
273 }
274
275 fn upgrade_dbs(
276 env: &heed::Env,
277 metadata_db: heed::Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
278 bodies_db: heed::Database<SerdeJson<PromptId>, Str>,
279 ) -> Result<()> {
280 #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
281 pub struct PromptIdV1(Uuid);
282
283 #[derive(Clone, Debug, Serialize, Deserialize)]
284 pub struct PromptMetadataV1 {
285 pub id: PromptIdV1,
286 pub title: Option<SharedString>,
287 pub default: bool,
288 pub saved_at: DateTime<Utc>,
289 }
290
291 let mut txn = env.write_txn()?;
292 let Some(bodies_v1_db) = env
293 .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<String>>(
294 &txn,
295 Some("bodies"),
296 )?
297 else {
298 return Ok(());
299 };
300 let mut bodies_v1 = bodies_v1_db
301 .iter(&txn)?
302 .collect::<heed::Result<HashMap<_, _>>>()?;
303
304 let Some(metadata_v1_db) = env
305 .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<PromptMetadataV1>>(
306 &txn,
307 Some("metadata"),
308 )?
309 else {
310 return Ok(());
311 };
312 let metadata_v1 = metadata_v1_db
313 .iter(&txn)?
314 .collect::<heed::Result<HashMap<_, _>>>()?;
315
316 for (prompt_id_v1, metadata_v1) in metadata_v1 {
317 let prompt_id_v2 = UserPromptId(prompt_id_v1.0).into();
318 let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
319 continue;
320 };
321
322 if metadata_db
323 .get(&txn, &prompt_id_v2)?
324 .is_none_or(|metadata_v2| metadata_v1.saved_at > metadata_v2.saved_at)
325 {
326 metadata_db.put(
327 &mut txn,
328 &prompt_id_v2,
329 &PromptMetadata {
330 id: prompt_id_v2,
331 title: metadata_v1.title.clone(),
332 default: metadata_v1.default,
333 saved_at: metadata_v1.saved_at,
334 },
335 )?;
336 bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?;
337 }
338 }
339
340 txn.commit()?;
341
342 Ok(())
343 }
344
345 pub fn load(&self, id: PromptId, cx: &App) -> Task<Result<String>> {
346 let env = self.env.clone();
347 let bodies = self.bodies;
348 cx.background_spawn(async move {
349 let txn = env.read_txn()?;
350 let mut prompt: String = match bodies.get(&txn, &id)? {
351 Some(body) => body.into(),
352 None => {
353 if let Some(built_in) = id.as_built_in() {
354 built_in.default_content().into()
355 } else {
356 anyhow::bail!("prompt not found")
357 }
358 }
359 };
360 LineEnding::normalize(&mut prompt);
361 Ok(prompt)
362 })
363 }
364
365 pub fn all_prompt_metadata(&self) -> Vec<PromptMetadata> {
366 self.metadata_cache.read().metadata.clone()
367 }
368
369 pub fn default_prompt_metadata(&self) -> Vec<PromptMetadata> {
370 return self
371 .metadata_cache
372 .read()
373 .metadata
374 .iter()
375 .filter(|metadata| metadata.default)
376 .cloned()
377 .collect::<Vec<_>>();
378 }
379
380 pub fn delete(&self, id: PromptId, cx: &Context<Self>) -> Task<Result<()>> {
381 self.metadata_cache.write().remove(id);
382
383 let db_connection = self.env.clone();
384 let bodies = self.bodies;
385 let metadata = self.metadata;
386
387 let task = cx.background_spawn(async move {
388 let mut txn = db_connection.write_txn()?;
389
390 metadata.delete(&mut txn, &id)?;
391 bodies.delete(&mut txn, &id)?;
392
393 txn.commit()?;
394 anyhow::Ok(())
395 });
396
397 cx.spawn(async move |this, cx| {
398 task.await?;
399 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
400 anyhow::Ok(())
401 })
402 }
403
404 pub fn metadata(&self, id: PromptId) -> Option<PromptMetadata> {
405 self.metadata_cache.read().metadata_by_id.get(&id).cloned()
406 }
407
408 pub fn first(&self) -> Option<PromptMetadata> {
409 self.metadata_cache.read().metadata.first().cloned()
410 }
411
412 pub fn id_for_title(&self, title: &str) -> Option<PromptId> {
413 let metadata_cache = self.metadata_cache.read();
414 let metadata = metadata_cache
415 .metadata
416 .iter()
417 .find(|metadata| metadata.title.as_ref().map(|title| &***title) == Some(title))?;
418 Some(metadata.id)
419 }
420
421 pub fn search(
422 &self,
423 query: String,
424 cancellation_flag: Arc<AtomicBool>,
425 cx: &App,
426 ) -> Task<Vec<PromptMetadata>> {
427 let cached_metadata = self.metadata_cache.read().metadata.clone();
428 let executor = cx.background_executor().clone();
429 cx.background_spawn(async move {
430 let mut matches = if query.is_empty() {
431 cached_metadata
432 } else {
433 let candidates = cached_metadata
434 .iter()
435 .enumerate()
436 .filter_map(|(ix, metadata)| {
437 Some(StringMatchCandidate::new(ix, metadata.title.as_ref()?))
438 })
439 .collect::<Vec<_>>();
440 let matches = fuzzy::match_strings(
441 &candidates,
442 &query,
443 false,
444 true,
445 100,
446 &cancellation_flag,
447 executor,
448 )
449 .await;
450 matches
451 .into_iter()
452 .map(|mat| cached_metadata[mat.candidate_id].clone())
453 .collect()
454 };
455 matches.sort_by_key(|metadata| Reverse(metadata.default));
456 matches
457 })
458 }
459
460 pub fn save(
461 &self,
462 id: PromptId,
463 title: Option<SharedString>,
464 default: bool,
465 body: Rope,
466 cx: &Context<Self>,
467 ) -> Task<Result<()>> {
468 if !id.can_edit() {
469 return Task::ready(Err(anyhow!("this prompt cannot be edited")));
470 }
471
472 let body = body.to_string();
473 let is_default_content = id
474 .as_built_in()
475 .is_some_and(|builtin| body.trim() == builtin.default_content().trim());
476
477 let metadata = if let Some(builtin) = id.as_built_in() {
478 PromptMetadata::builtin(builtin)
479 } else {
480 PromptMetadata {
481 id,
482 title,
483 default,
484 saved_at: Utc::now(),
485 }
486 };
487
488 self.metadata_cache.write().insert(metadata.clone());
489
490 let db_connection = self.env.clone();
491 let bodies = self.bodies;
492 let metadata_db = self.metadata;
493
494 let task = cx.background_spawn(async move {
495 let mut txn = db_connection.write_txn()?;
496
497 if is_default_content {
498 metadata_db.delete(&mut txn, &id)?;
499 bodies.delete(&mut txn, &id)?;
500 } else {
501 metadata_db.put(&mut txn, &id, &metadata)?;
502 bodies.put(&mut txn, &id, &body)?;
503 }
504
505 txn.commit()?;
506
507 anyhow::Ok(())
508 });
509
510 cx.spawn(async move |this, cx| {
511 task.await?;
512 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
513 anyhow::Ok(())
514 })
515 }
516
517 pub fn save_metadata(
518 &self,
519 id: PromptId,
520 mut title: Option<SharedString>,
521 default: bool,
522 cx: &Context<Self>,
523 ) -> Task<Result<()>> {
524 let mut cache = self.metadata_cache.write();
525
526 if !id.can_edit() {
527 title = cache
528 .metadata_by_id
529 .get(&id)
530 .and_then(|metadata| metadata.title.clone());
531 }
532
533 let prompt_metadata = PromptMetadata {
534 id,
535 title,
536 default,
537 saved_at: Utc::now(),
538 };
539
540 cache.insert(prompt_metadata.clone());
541
542 let db_connection = self.env.clone();
543 let metadata = self.metadata;
544
545 let task = cx.background_spawn(async move {
546 let mut txn = db_connection.write_txn()?;
547 metadata.put(&mut txn, &id, &prompt_metadata)?;
548 txn.commit()?;
549
550 anyhow::Ok(())
551 });
552
553 cx.spawn(async move |this, cx| {
554 task.await?;
555 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
556 anyhow::Ok(())
557 })
558 }
559}
560
561/// Wraps a shared future to a prompt store so it can be assigned as a context global.
562pub struct GlobalPromptStore(Shared<Task<Result<Entity<PromptStore>, Arc<anyhow::Error>>>>);
563
564impl Global for GlobalPromptStore {}
565
566#[cfg(test)]
567mod tests {
568 use super::*;
569 use gpui::TestAppContext;
570
571 #[gpui::test]
572 async fn test_built_in_prompt_load_save(cx: &mut TestAppContext) {
573 cx.executor().allow_parking();
574
575 let temp_dir = tempfile::tempdir().unwrap();
576 let db_path = temp_dir.path().join("prompts-db");
577
578 let store = cx.update(|cx| PromptStore::new(db_path, cx)).await.unwrap();
579 let store = cx.new(|_cx| store);
580
581 let commit_message_id = PromptId::BuiltIn(BuiltInPrompt::CommitMessage);
582
583 let loaded_content = store
584 .update(cx, |store, cx| store.load(commit_message_id, cx))
585 .await
586 .unwrap();
587
588 let mut expected_content = BuiltInPrompt::CommitMessage.default_content().to_string();
589 LineEnding::normalize(&mut expected_content);
590 assert_eq!(
591 loaded_content.trim(),
592 expected_content.trim(),
593 "Loading a built-in prompt not in DB should return default content"
594 );
595
596 let metadata = store.read_with(cx, |store, _| store.metadata(commit_message_id));
597 assert!(
598 metadata.is_some(),
599 "Built-in prompt should always have metadata"
600 );
601 assert!(
602 store.read_with(cx, |store, _| {
603 store
604 .metadata_cache
605 .read()
606 .metadata_by_id
607 .contains_key(&commit_message_id)
608 }),
609 "Built-in prompt should always be in cache"
610 );
611
612 let custom_content = "Custom commit message prompt";
613 store
614 .update(cx, |store, cx| {
615 store.save(
616 commit_message_id,
617 Some("Commit message".into()),
618 false,
619 Rope::from(custom_content),
620 cx,
621 )
622 })
623 .await
624 .unwrap();
625
626 let loaded_custom = store
627 .update(cx, |store, cx| store.load(commit_message_id, cx))
628 .await
629 .unwrap();
630 assert_eq!(
631 loaded_custom.trim(),
632 custom_content.trim(),
633 "Custom content should be loaded after saving"
634 );
635
636 assert!(
637 store
638 .read_with(cx, |store, _| store.metadata(commit_message_id))
639 .is_some(),
640 "Built-in prompt should have metadata after customization"
641 );
642
643 store
644 .update(cx, |store, cx| {
645 store.save(
646 commit_message_id,
647 Some("Commit message".into()),
648 false,
649 Rope::from(BuiltInPrompt::CommitMessage.default_content()),
650 cx,
651 )
652 })
653 .await
654 .unwrap();
655
656 let metadata_after_reset =
657 store.read_with(cx, |store, _| store.metadata(commit_message_id));
658 assert!(
659 metadata_after_reset.is_some(),
660 "Built-in prompt should still have metadata after reset"
661 );
662 assert_eq!(
663 metadata_after_reset
664 .as_ref()
665 .and_then(|m| m.title.as_ref().map(|t| t.as_ref())),
666 Some("Commit message"),
667 "Built-in prompt should have default title after reset"
668 );
669
670 let loaded_after_reset = store
671 .update(cx, |store, cx| store.load(commit_message_id, cx))
672 .await
673 .unwrap();
674 let mut expected_content_after_reset =
675 BuiltInPrompt::CommitMessage.default_content().to_string();
676 LineEnding::normalize(&mut expected_content_after_reset);
677 assert_eq!(
678 loaded_after_reset.trim(),
679 expected_content_after_reset.trim(),
680 "After saving default content, load should return default"
681 );
682 }
683}