1mod prompts;
2
3use anyhow::{Context as _, Result, anyhow};
4use chrono::{DateTime, Utc};
5use collections::HashMap;
6use futures::FutureExt as _;
7use futures::future::Shared;
8use fuzzy::StringMatchCandidate;
9use gpui::{
10 App, AppContext, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Task,
11};
12use heed::{
13 Database, RoTxn,
14 types::{SerdeBincode, SerdeJson, Str},
15};
16use parking_lot::RwLock;
17pub use prompts::*;
18use rope::Rope;
19use serde::{Deserialize, Serialize};
20use std::{
21 cmp::Reverse,
22 future::Future,
23 path::PathBuf,
24 sync::{Arc, atomic::AtomicBool},
25};
26use text::LineEnding;
27use util::ResultExt;
28use uuid::Uuid;
29
30/// Init starts loading the PromptStore in the background and assigns
31/// a shared future to a global.
32pub fn init(cx: &mut App) {
33 let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb");
34 let prompt_store_task = PromptStore::new(db_path, cx);
35 let prompt_store_entity_task = cx
36 .spawn(async move |cx| {
37 prompt_store_task
38 .await
39 .and_then(|prompt_store| cx.new(|_cx| prompt_store))
40 .map_err(Arc::new)
41 })
42 .shared();
43 cx.set_global(GlobalPromptStore(prompt_store_entity_task))
44}
45
46#[derive(Clone, Debug, Serialize, Deserialize)]
47pub struct PromptMetadata {
48 pub id: PromptId,
49 pub title: Option<SharedString>,
50 pub default: bool,
51 pub saved_at: DateTime<Utc>,
52}
53
54#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
55#[serde(tag = "kind")]
56pub enum PromptId {
57 User { uuid: UserPromptId },
58 CommitMessage,
59}
60
61impl PromptId {
62 pub fn new() -> PromptId {
63 UserPromptId::new().into()
64 }
65
66 pub fn user_id(&self) -> Option<UserPromptId> {
67 match self {
68 Self::User { uuid } => Some(*uuid),
69 _ => None,
70 }
71 }
72
73 pub fn is_built_in(&self) -> bool {
74 match self {
75 Self::User { .. } => false,
76 Self::CommitMessage => true,
77 }
78 }
79
80 pub fn can_edit(&self) -> bool {
81 match self {
82 Self::User { .. } | Self::CommitMessage => true,
83 }
84 }
85
86 pub fn default_content(&self) -> Option<&'static str> {
87 match self {
88 Self::User { .. } => None,
89 Self::CommitMessage => Some(include_str!("../../git_ui/src/commit_message_prompt.txt")),
90 }
91 }
92}
93
94impl From<UserPromptId> for PromptId {
95 fn from(uuid: UserPromptId) -> Self {
96 PromptId::User { uuid }
97 }
98}
99
100#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
101#[serde(transparent)]
102pub struct UserPromptId(pub Uuid);
103
104impl UserPromptId {
105 pub fn new() -> UserPromptId {
106 UserPromptId(Uuid::new_v4())
107 }
108}
109
110impl From<Uuid> for UserPromptId {
111 fn from(uuid: Uuid) -> Self {
112 UserPromptId(uuid)
113 }
114}
115
116impl std::fmt::Display for PromptId {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 match self {
119 PromptId::User { uuid } => write!(f, "{}", uuid.0),
120 PromptId::CommitMessage => write!(f, "Commit message"),
121 }
122 }
123}
124
125pub struct PromptStore {
126 env: heed::Env,
127 metadata_cache: RwLock<MetadataCache>,
128 metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
129 bodies: Database<SerdeJson<PromptId>, Str>,
130}
131
132pub struct PromptsUpdatedEvent;
133
134impl EventEmitter<PromptsUpdatedEvent> for PromptStore {}
135
136#[derive(Default)]
137struct MetadataCache {
138 metadata: Vec<PromptMetadata>,
139 metadata_by_id: HashMap<PromptId, PromptMetadata>,
140}
141
142impl MetadataCache {
143 fn from_db(
144 db: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
145 txn: &RoTxn,
146 ) -> Result<Self> {
147 let mut cache = MetadataCache::default();
148 for result in db.iter(txn)? {
149 let (prompt_id, metadata) = result?;
150 cache.metadata.push(metadata.clone());
151 cache.metadata_by_id.insert(prompt_id, metadata);
152 }
153 cache.sort();
154 Ok(cache)
155 }
156
157 fn insert(&mut self, metadata: PromptMetadata) {
158 self.metadata_by_id.insert(metadata.id, metadata.clone());
159 if let Some(old_metadata) = self.metadata.iter_mut().find(|m| m.id == metadata.id) {
160 *old_metadata = metadata;
161 } else {
162 self.metadata.push(metadata);
163 }
164 self.sort();
165 }
166
167 fn remove(&mut self, id: PromptId) {
168 self.metadata.retain(|metadata| metadata.id != id);
169 self.metadata_by_id.remove(&id);
170 }
171
172 fn sort(&mut self) {
173 self.metadata.sort_unstable_by(|a, b| {
174 a.title
175 .cmp(&b.title)
176 .then_with(|| b.saved_at.cmp(&a.saved_at))
177 });
178 }
179}
180
181impl PromptStore {
182 pub fn global(cx: &App) -> impl Future<Output = Result<Entity<Self>>> + use<> {
183 let store = GlobalPromptStore::global(cx).0.clone();
184 async move { store.await.map_err(|err| anyhow!(err)) }
185 }
186
187 pub fn new(db_path: PathBuf, cx: &App) -> Task<Result<Self>> {
188 cx.background_spawn(async move {
189 std::fs::create_dir_all(&db_path)?;
190
191 let db_env = unsafe {
192 heed::EnvOpenOptions::new()
193 .map_size(1024 * 1024 * 1024) // 1GB
194 .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
195 .open(db_path)?
196 };
197
198 let mut txn = db_env.write_txn()?;
199 let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
200 let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
201
202 // Insert default commit message prompt if not present
203 if metadata.get(&txn, &PromptId::CommitMessage)?.is_none() {
204 metadata.put(
205 &mut txn,
206 &PromptId::CommitMessage,
207 &PromptMetadata {
208 id: PromptId::CommitMessage,
209 title: Some("Git Commit Message".into()),
210 default: false,
211 saved_at: Utc::now(),
212 },
213 )?;
214 }
215 if bodies.get(&txn, &PromptId::CommitMessage)?.is_none() {
216 let commit_message_prompt =
217 include_str!("../../git_ui/src/commit_message_prompt.txt");
218 bodies.put(&mut txn, &PromptId::CommitMessage, commit_message_prompt)?;
219 }
220
221 txn.commit()?;
222
223 Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
224
225 let txn = db_env.read_txn()?;
226 let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
227 txn.commit()?;
228
229 Ok(PromptStore {
230 env: db_env,
231 metadata_cache: RwLock::new(metadata_cache),
232 metadata,
233 bodies,
234 })
235 })
236 }
237
238 fn upgrade_dbs(
239 env: &heed::Env,
240 metadata_db: heed::Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
241 bodies_db: heed::Database<SerdeJson<PromptId>, Str>,
242 ) -> Result<()> {
243 #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
244 pub struct PromptIdV1(Uuid);
245
246 #[derive(Clone, Debug, Serialize, Deserialize)]
247 pub struct PromptMetadataV1 {
248 pub id: PromptIdV1,
249 pub title: Option<SharedString>,
250 pub default: bool,
251 pub saved_at: DateTime<Utc>,
252 }
253
254 let mut txn = env.write_txn()?;
255 let Some(bodies_v1_db) = env
256 .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<String>>(
257 &txn,
258 Some("bodies"),
259 )?
260 else {
261 return Ok(());
262 };
263 let mut bodies_v1 = bodies_v1_db
264 .iter(&txn)?
265 .collect::<heed::Result<HashMap<_, _>>>()?;
266
267 let Some(metadata_v1_db) = env
268 .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<PromptMetadataV1>>(
269 &txn,
270 Some("metadata"),
271 )?
272 else {
273 return Ok(());
274 };
275 let metadata_v1 = metadata_v1_db
276 .iter(&txn)?
277 .collect::<heed::Result<HashMap<_, _>>>()?;
278
279 for (prompt_id_v1, metadata_v1) in metadata_v1 {
280 let prompt_id_v2 = UserPromptId(prompt_id_v1.0).into();
281 let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
282 continue;
283 };
284
285 if metadata_db
286 .get(&txn, &prompt_id_v2)?
287 .is_none_or(|metadata_v2| metadata_v1.saved_at > metadata_v2.saved_at)
288 {
289 metadata_db.put(
290 &mut txn,
291 &prompt_id_v2,
292 &PromptMetadata {
293 id: prompt_id_v2,
294 title: metadata_v1.title.clone(),
295 default: metadata_v1.default,
296 saved_at: metadata_v1.saved_at,
297 },
298 )?;
299 bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?;
300 }
301 }
302
303 txn.commit()?;
304
305 Ok(())
306 }
307
308 pub fn load(&self, id: PromptId, cx: &App) -> Task<Result<String>> {
309 let env = self.env.clone();
310 let bodies = self.bodies;
311 cx.background_spawn(async move {
312 let txn = env.read_txn()?;
313 let mut prompt = bodies.get(&txn, &id)?.context("prompt not found")?.into();
314 LineEnding::normalize(&mut prompt);
315 Ok(prompt)
316 })
317 }
318
319 pub fn all_prompt_metadata(&self) -> Vec<PromptMetadata> {
320 self.metadata_cache.read().metadata.clone()
321 }
322
323 pub fn default_prompt_metadata(&self) -> Vec<PromptMetadata> {
324 return self
325 .metadata_cache
326 .read()
327 .metadata
328 .iter()
329 .filter(|metadata| metadata.default)
330 .cloned()
331 .collect::<Vec<_>>();
332 }
333
334 pub fn delete(&self, id: PromptId, cx: &Context<Self>) -> Task<Result<()>> {
335 self.metadata_cache.write().remove(id);
336
337 let db_connection = self.env.clone();
338 let bodies = self.bodies;
339 let metadata = self.metadata;
340
341 let task = cx.background_spawn(async move {
342 let mut txn = db_connection.write_txn()?;
343
344 metadata.delete(&mut txn, &id)?;
345 bodies.delete(&mut txn, &id)?;
346
347 txn.commit()?;
348 anyhow::Ok(())
349 });
350
351 cx.spawn(async move |this, cx| {
352 task.await?;
353 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
354 anyhow::Ok(())
355 })
356 }
357
358 /// Returns the number of prompts in the store.
359 pub fn prompt_count(&self) -> usize {
360 self.metadata_cache.read().metadata.len()
361 }
362
363 pub fn metadata(&self, id: PromptId) -> Option<PromptMetadata> {
364 self.metadata_cache.read().metadata_by_id.get(&id).cloned()
365 }
366
367 pub fn first(&self) -> Option<PromptMetadata> {
368 self.metadata_cache.read().metadata.first().cloned()
369 }
370
371 pub fn id_for_title(&self, title: &str) -> Option<PromptId> {
372 let metadata_cache = self.metadata_cache.read();
373 let metadata = metadata_cache
374 .metadata
375 .iter()
376 .find(|metadata| metadata.title.as_ref().map(|title| &***title) == Some(title))?;
377 Some(metadata.id)
378 }
379
380 pub fn search(
381 &self,
382 query: String,
383 cancellation_flag: Arc<AtomicBool>,
384 cx: &App,
385 ) -> Task<Vec<PromptMetadata>> {
386 let cached_metadata = self.metadata_cache.read().metadata.clone();
387 let executor = cx.background_executor().clone();
388 cx.background_spawn(async move {
389 let mut matches = if query.is_empty() {
390 cached_metadata
391 } else {
392 let candidates = cached_metadata
393 .iter()
394 .enumerate()
395 .filter_map(|(ix, metadata)| {
396 Some(StringMatchCandidate::new(ix, metadata.title.as_ref()?))
397 })
398 .collect::<Vec<_>>();
399 let matches = fuzzy::match_strings(
400 &candidates,
401 &query,
402 false,
403 true,
404 100,
405 &cancellation_flag,
406 executor,
407 )
408 .await;
409 matches
410 .into_iter()
411 .map(|mat| cached_metadata[mat.candidate_id].clone())
412 .collect()
413 };
414 matches.sort_by_key(|metadata| Reverse(metadata.default));
415 matches
416 })
417 }
418
419 pub fn save(
420 &self,
421 id: PromptId,
422 title: Option<SharedString>,
423 default: bool,
424 body: Rope,
425 cx: &Context<Self>,
426 ) -> Task<Result<()>> {
427 if !id.can_edit() {
428 return Task::ready(Err(anyhow!("this prompt cannot be edited")));
429 }
430
431 let prompt_metadata = PromptMetadata {
432 id,
433 title,
434 default,
435 saved_at: Utc::now(),
436 };
437 self.metadata_cache.write().insert(prompt_metadata.clone());
438
439 let db_connection = self.env.clone();
440 let bodies = self.bodies;
441 let metadata = self.metadata;
442
443 let task = cx.background_spawn(async move {
444 let mut txn = db_connection.write_txn()?;
445
446 metadata.put(&mut txn, &id, &prompt_metadata)?;
447 bodies.put(&mut txn, &id, &body.to_string())?;
448
449 txn.commit()?;
450
451 anyhow::Ok(())
452 });
453
454 cx.spawn(async move |this, cx| {
455 task.await?;
456 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
457 anyhow::Ok(())
458 })
459 }
460
461 pub fn save_metadata(
462 &self,
463 id: PromptId,
464 mut title: Option<SharedString>,
465 default: bool,
466 cx: &Context<Self>,
467 ) -> Task<Result<()>> {
468 let mut cache = self.metadata_cache.write();
469
470 if !id.can_edit() {
471 title = cache
472 .metadata_by_id
473 .get(&id)
474 .and_then(|metadata| metadata.title.clone());
475 }
476
477 let prompt_metadata = PromptMetadata {
478 id,
479 title,
480 default,
481 saved_at: Utc::now(),
482 };
483
484 cache.insert(prompt_metadata.clone());
485
486 let db_connection = self.env.clone();
487 let metadata = self.metadata;
488
489 let task = cx.background_spawn(async move {
490 let mut txn = db_connection.write_txn()?;
491 metadata.put(&mut txn, &id, &prompt_metadata)?;
492 txn.commit()?;
493
494 anyhow::Ok(())
495 });
496
497 cx.spawn(async move |this, cx| {
498 task.await?;
499 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
500 anyhow::Ok(())
501 })
502 }
503}
504
505/// Wraps a shared future to a prompt store so it can be assigned as a context global.
506pub struct GlobalPromptStore(Shared<Task<Result<Entity<PromptStore>, Arc<anyhow::Error>>>>);
507
508impl Global for GlobalPromptStore {}