1mod prompts;
2
3use anyhow::{Context as _, Result, anyhow};
4use chrono::{DateTime, Utc};
5use collections::HashMap;
6use futures::FutureExt as _;
7use futures::future::Shared;
8use fuzzy::StringMatchCandidate;
9use gpui::{
10 App, AppContext, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Task,
11};
12use heed::{
13 Database, RoTxn,
14 types::{SerdeBincode, SerdeJson, Str},
15};
16use parking_lot::RwLock;
17pub use prompts::*;
18use rope::Rope;
19use serde::{Deserialize, Serialize};
20use std::{
21 cmp::Reverse,
22 future::Future,
23 path::PathBuf,
24 sync::{Arc, atomic::AtomicBool},
25};
26use text::LineEnding;
27use util::ResultExt;
28use uuid::Uuid;
29
30/// Init starts loading the PromptStore in the background and assigns
31/// a shared future to a global.
32pub fn init(cx: &mut App) {
33 let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb");
34 let prompt_store_task = PromptStore::new(db_path, cx);
35 let prompt_store_entity_task = cx
36 .spawn(async move |cx| {
37 prompt_store_task
38 .await
39 .and_then(|prompt_store| cx.new(|_cx| prompt_store))
40 .map_err(Arc::new)
41 })
42 .shared();
43 cx.set_global(GlobalPromptStore(prompt_store_entity_task))
44}
45
46#[derive(Clone, Debug, Serialize, Deserialize)]
47pub struct PromptMetadata {
48 pub id: PromptId,
49 pub title: Option<SharedString>,
50 pub default: bool,
51 pub saved_at: DateTime<Utc>,
52}
53
54#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
55#[serde(tag = "kind")]
56pub enum PromptId {
57 User { uuid: UserPromptId },
58 EditWorkflow,
59 CommitMessage,
60}
61
62impl PromptId {
63 pub fn new() -> PromptId {
64 UserPromptId::new().into()
65 }
66
67 pub fn user_id(&self) -> Option<UserPromptId> {
68 match self {
69 Self::User { uuid } => Some(*uuid),
70 _ => None,
71 }
72 }
73
74 pub fn is_built_in(&self) -> bool {
75 match self {
76 Self::User { .. } => false,
77 Self::EditWorkflow | Self::CommitMessage => true,
78 }
79 }
80
81 pub fn can_edit(&self) -> bool {
82 match self {
83 Self::User { .. } | Self::CommitMessage => true,
84 Self::EditWorkflow => false,
85 }
86 }
87
88 pub fn default_content(&self) -> Option<&'static str> {
89 match self {
90 Self::User { .. } | Self::EditWorkflow => None,
91 Self::CommitMessage => Some(include_str!("../../git_ui/src/commit_message_prompt.txt")),
92 }
93 }
94}
95
96impl From<UserPromptId> for PromptId {
97 fn from(uuid: UserPromptId) -> Self {
98 PromptId::User { uuid }
99 }
100}
101
102#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
103#[serde(transparent)]
104pub struct UserPromptId(pub Uuid);
105
106impl UserPromptId {
107 pub fn new() -> UserPromptId {
108 UserPromptId(Uuid::new_v4())
109 }
110}
111
112impl From<Uuid> for UserPromptId {
113 fn from(uuid: Uuid) -> Self {
114 UserPromptId(uuid)
115 }
116}
117
118impl std::fmt::Display for PromptId {
119 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 match self {
121 PromptId::User { uuid } => write!(f, "{}", uuid.0),
122 PromptId::EditWorkflow => write!(f, "Edit workflow"),
123 PromptId::CommitMessage => write!(f, "Commit message"),
124 }
125 }
126}
127
128pub struct PromptStore {
129 env: heed::Env,
130 metadata_cache: RwLock<MetadataCache>,
131 metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
132 bodies: Database<SerdeJson<PromptId>, Str>,
133}
134
135pub struct PromptsUpdatedEvent;
136
137impl EventEmitter<PromptsUpdatedEvent> for PromptStore {}
138
139#[derive(Default)]
140struct MetadataCache {
141 metadata: Vec<PromptMetadata>,
142 metadata_by_id: HashMap<PromptId, PromptMetadata>,
143}
144
145impl MetadataCache {
146 fn from_db(
147 db: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
148 txn: &RoTxn,
149 ) -> Result<Self> {
150 let mut cache = MetadataCache::default();
151 for result in db.iter(txn)? {
152 let (prompt_id, metadata) = result?;
153 cache.metadata.push(metadata.clone());
154 cache.metadata_by_id.insert(prompt_id, metadata);
155 }
156 cache.sort();
157 Ok(cache)
158 }
159
160 fn insert(&mut self, metadata: PromptMetadata) {
161 self.metadata_by_id.insert(metadata.id, metadata.clone());
162 if let Some(old_metadata) = self.metadata.iter_mut().find(|m| m.id == metadata.id) {
163 *old_metadata = metadata;
164 } else {
165 self.metadata.push(metadata);
166 }
167 self.sort();
168 }
169
170 fn remove(&mut self, id: PromptId) {
171 self.metadata.retain(|metadata| metadata.id != id);
172 self.metadata_by_id.remove(&id);
173 }
174
175 fn sort(&mut self) {
176 self.metadata.sort_unstable_by(|a, b| {
177 a.title
178 .cmp(&b.title)
179 .then_with(|| b.saved_at.cmp(&a.saved_at))
180 });
181 }
182}
183
184impl PromptStore {
185 pub fn global(cx: &App) -> impl Future<Output = Result<Entity<Self>>> + use<> {
186 let store = GlobalPromptStore::global(cx).0.clone();
187 async move { store.await.map_err(|err| anyhow!(err)) }
188 }
189
190 pub fn new(db_path: PathBuf, cx: &App) -> Task<Result<Self>> {
191 cx.background_spawn(async move {
192 std::fs::create_dir_all(&db_path)?;
193
194 let db_env = unsafe {
195 heed::EnvOpenOptions::new()
196 .map_size(1024 * 1024 * 1024) // 1GB
197 .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
198 .open(db_path)?
199 };
200
201 let mut txn = db_env.write_txn()?;
202 let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
203 let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
204
205 // Remove edit workflow prompt, as we decided to opt into it using
206 // a slash command instead.
207 metadata.delete(&mut txn, &PromptId::EditWorkflow).ok();
208 bodies.delete(&mut txn, &PromptId::EditWorkflow).ok();
209
210 // Insert default commit message prompt if not present
211 if metadata.get(&txn, &PromptId::CommitMessage)?.is_none() {
212 metadata.put(
213 &mut txn,
214 &PromptId::CommitMessage,
215 &PromptMetadata {
216 id: PromptId::CommitMessage,
217 title: Some("Git Commit Message".into()),
218 default: false,
219 saved_at: Utc::now(),
220 },
221 )?;
222 }
223 if bodies.get(&txn, &PromptId::CommitMessage)?.is_none() {
224 let commit_message_prompt =
225 include_str!("../../git_ui/src/commit_message_prompt.txt");
226 bodies.put(&mut txn, &PromptId::CommitMessage, commit_message_prompt)?;
227 }
228
229 txn.commit()?;
230
231 Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
232
233 let txn = db_env.read_txn()?;
234 let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
235 txn.commit()?;
236
237 Ok(PromptStore {
238 env: db_env,
239 metadata_cache: RwLock::new(metadata_cache),
240 metadata,
241 bodies,
242 })
243 })
244 }
245
246 fn upgrade_dbs(
247 env: &heed::Env,
248 metadata_db: heed::Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
249 bodies_db: heed::Database<SerdeJson<PromptId>, Str>,
250 ) -> Result<()> {
251 #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
252 pub struct PromptIdV1(Uuid);
253
254 #[derive(Clone, Debug, Serialize, Deserialize)]
255 pub struct PromptMetadataV1 {
256 pub id: PromptIdV1,
257 pub title: Option<SharedString>,
258 pub default: bool,
259 pub saved_at: DateTime<Utc>,
260 }
261
262 let mut txn = env.write_txn()?;
263 let Some(bodies_v1_db) = env
264 .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<String>>(
265 &txn,
266 Some("bodies"),
267 )?
268 else {
269 return Ok(());
270 };
271 let mut bodies_v1 = bodies_v1_db
272 .iter(&txn)?
273 .collect::<heed::Result<HashMap<_, _>>>()?;
274
275 let Some(metadata_v1_db) = env
276 .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<PromptMetadataV1>>(
277 &txn,
278 Some("metadata"),
279 )?
280 else {
281 return Ok(());
282 };
283 let metadata_v1 = metadata_v1_db
284 .iter(&txn)?
285 .collect::<heed::Result<HashMap<_, _>>>()?;
286
287 for (prompt_id_v1, metadata_v1) in metadata_v1 {
288 let prompt_id_v2 = UserPromptId(prompt_id_v1.0).into();
289 let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
290 continue;
291 };
292
293 if metadata_db
294 .get(&txn, &prompt_id_v2)?
295 .is_none_or(|metadata_v2| metadata_v1.saved_at > metadata_v2.saved_at)
296 {
297 metadata_db.put(
298 &mut txn,
299 &prompt_id_v2,
300 &PromptMetadata {
301 id: prompt_id_v2,
302 title: metadata_v1.title.clone(),
303 default: metadata_v1.default,
304 saved_at: metadata_v1.saved_at,
305 },
306 )?;
307 bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?;
308 }
309 }
310
311 txn.commit()?;
312
313 Ok(())
314 }
315
316 pub fn load(&self, id: PromptId, cx: &App) -> Task<Result<String>> {
317 let env = self.env.clone();
318 let bodies = self.bodies;
319 cx.background_spawn(async move {
320 let txn = env.read_txn()?;
321 let mut prompt = bodies.get(&txn, &id)?.context("prompt not found")?.into();
322 LineEnding::normalize(&mut prompt);
323 Ok(prompt)
324 })
325 }
326
327 pub fn all_prompt_metadata(&self) -> Vec<PromptMetadata> {
328 self.metadata_cache.read().metadata.clone()
329 }
330
331 pub fn default_prompt_metadata(&self) -> Vec<PromptMetadata> {
332 return self
333 .metadata_cache
334 .read()
335 .metadata
336 .iter()
337 .filter(|metadata| metadata.default)
338 .cloned()
339 .collect::<Vec<_>>();
340 }
341
342 pub fn delete(&self, id: PromptId, cx: &Context<Self>) -> Task<Result<()>> {
343 self.metadata_cache.write().remove(id);
344
345 let db_connection = self.env.clone();
346 let bodies = self.bodies;
347 let metadata = self.metadata;
348
349 let task = cx.background_spawn(async move {
350 let mut txn = db_connection.write_txn()?;
351
352 metadata.delete(&mut txn, &id)?;
353 bodies.delete(&mut txn, &id)?;
354
355 txn.commit()?;
356 anyhow::Ok(())
357 });
358
359 cx.spawn(async move |this, cx| {
360 task.await?;
361 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
362 anyhow::Ok(())
363 })
364 }
365
366 /// Returns the number of prompts in the store.
367 pub fn prompt_count(&self) -> usize {
368 self.metadata_cache.read().metadata.len()
369 }
370
371 pub fn metadata(&self, id: PromptId) -> Option<PromptMetadata> {
372 self.metadata_cache.read().metadata_by_id.get(&id).cloned()
373 }
374
375 pub fn first(&self) -> Option<PromptMetadata> {
376 self.metadata_cache.read().metadata.first().cloned()
377 }
378
379 pub fn id_for_title(&self, title: &str) -> Option<PromptId> {
380 let metadata_cache = self.metadata_cache.read();
381 let metadata = metadata_cache
382 .metadata
383 .iter()
384 .find(|metadata| metadata.title.as_ref().map(|title| &***title) == Some(title))?;
385 Some(metadata.id)
386 }
387
388 pub fn search(
389 &self,
390 query: String,
391 cancellation_flag: Arc<AtomicBool>,
392 cx: &App,
393 ) -> Task<Vec<PromptMetadata>> {
394 let cached_metadata = self.metadata_cache.read().metadata.clone();
395 let executor = cx.background_executor().clone();
396 cx.background_spawn(async move {
397 let mut matches = if query.is_empty() {
398 cached_metadata
399 } else {
400 let candidates = cached_metadata
401 .iter()
402 .enumerate()
403 .filter_map(|(ix, metadata)| {
404 Some(StringMatchCandidate::new(ix, metadata.title.as_ref()?))
405 })
406 .collect::<Vec<_>>();
407 let matches = fuzzy::match_strings(
408 &candidates,
409 &query,
410 false,
411 true,
412 100,
413 &cancellation_flag,
414 executor,
415 )
416 .await;
417 matches
418 .into_iter()
419 .map(|mat| cached_metadata[mat.candidate_id].clone())
420 .collect()
421 };
422 matches.sort_by_key(|metadata| Reverse(metadata.default));
423 matches
424 })
425 }
426
427 pub fn save(
428 &self,
429 id: PromptId,
430 title: Option<SharedString>,
431 default: bool,
432 body: Rope,
433 cx: &Context<Self>,
434 ) -> Task<Result<()>> {
435 if !id.can_edit() {
436 return Task::ready(Err(anyhow!("this prompt cannot be edited")));
437 }
438
439 let prompt_metadata = PromptMetadata {
440 id,
441 title,
442 default,
443 saved_at: Utc::now(),
444 };
445 self.metadata_cache.write().insert(prompt_metadata.clone());
446
447 let db_connection = self.env.clone();
448 let bodies = self.bodies;
449 let metadata = self.metadata;
450
451 let task = cx.background_spawn(async move {
452 let mut txn = db_connection.write_txn()?;
453
454 metadata.put(&mut txn, &id, &prompt_metadata)?;
455 bodies.put(&mut txn, &id, &body.to_string())?;
456
457 txn.commit()?;
458
459 anyhow::Ok(())
460 });
461
462 cx.spawn(async move |this, cx| {
463 task.await?;
464 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
465 anyhow::Ok(())
466 })
467 }
468
469 pub fn save_metadata(
470 &self,
471 id: PromptId,
472 mut title: Option<SharedString>,
473 default: bool,
474 cx: &Context<Self>,
475 ) -> Task<Result<()>> {
476 let mut cache = self.metadata_cache.write();
477
478 if !id.can_edit() {
479 title = cache
480 .metadata_by_id
481 .get(&id)
482 .and_then(|metadata| metadata.title.clone());
483 }
484
485 let prompt_metadata = PromptMetadata {
486 id,
487 title,
488 default,
489 saved_at: Utc::now(),
490 };
491
492 cache.insert(prompt_metadata.clone());
493
494 let db_connection = self.env.clone();
495 let metadata = self.metadata;
496
497 let task = cx.background_spawn(async move {
498 let mut txn = db_connection.write_txn()?;
499 metadata.put(&mut txn, &id, &prompt_metadata)?;
500 txn.commit()?;
501
502 anyhow::Ok(())
503 });
504
505 cx.spawn(async move |this, cx| {
506 task.await?;
507 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
508 anyhow::Ok(())
509 })
510 }
511}
512
513/// Wraps a shared future to a prompt store so it can be assigned as a context global.
514pub struct GlobalPromptStore(Shared<Task<Result<Entity<PromptStore>, Arc<anyhow::Error>>>>);
515
516impl Global for GlobalPromptStore {}