1mod prompts;
2
3use anyhow::{Context as _, Result, anyhow};
4use chrono::{DateTime, Utc};
5use collections::HashMap;
6use futures::FutureExt as _;
7use futures::future::Shared;
8use fuzzy::StringMatchCandidate;
9use gpui::{
10 App, AppContext, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Task,
11};
12use heed::{
13 Database, RoTxn,
14 types::{SerdeBincode, SerdeJson, Str},
15};
16use parking_lot::RwLock;
17pub use prompts::*;
18use rope::Rope;
19use serde::{Deserialize, Serialize};
20use std::{
21 cmp::Reverse,
22 future::Future,
23 path::PathBuf,
24 sync::{Arc, atomic::AtomicBool},
25};
26use text::LineEnding;
27use util::ResultExt;
28use uuid::Uuid;
29
30/// Init starts loading the PromptStore in the background and assigns
31/// a shared future to a global.
32pub fn init(cx: &mut App) {
33 let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb");
34 let prompt_store_task = PromptStore::new(db_path, cx);
35 let prompt_store_entity_task = cx
36 .spawn(async move |cx| {
37 prompt_store_task
38 .await
39 .and_then(|prompt_store| cx.new(|_cx| prompt_store))
40 .map_err(Arc::new)
41 })
42 .shared();
43 cx.set_global(GlobalPromptStore(prompt_store_entity_task))
44}
45
46#[derive(Clone, Debug, Serialize, Deserialize)]
47pub struct PromptMetadata {
48 pub id: PromptId,
49 pub title: Option<SharedString>,
50 pub default: bool,
51 pub saved_at: DateTime<Utc>,
52}
53
54#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
55#[serde(tag = "kind")]
56pub enum PromptId {
57 User { uuid: UserPromptId },
58 CommitMessage,
59}
60
61impl PromptId {
62 pub fn new() -> PromptId {
63 UserPromptId::new().into()
64 }
65
66 pub fn user_id(&self) -> Option<UserPromptId> {
67 match self {
68 Self::User { uuid } => Some(*uuid),
69 _ => None,
70 }
71 }
72
73 pub fn is_built_in(&self) -> bool {
74 match self {
75 Self::User { .. } => false,
76 Self::CommitMessage => true,
77 }
78 }
79
80 pub fn can_edit(&self) -> bool {
81 match self {
82 Self::User { .. } | Self::CommitMessage => true,
83 }
84 }
85
86 pub fn default_content(&self) -> Option<&'static str> {
87 match self {
88 Self::User { .. } => None,
89 Self::CommitMessage => Some(include_str!("../../git_ui/src/commit_message_prompt.txt")),
90 }
91 }
92}
93
94impl From<UserPromptId> for PromptId {
95 fn from(uuid: UserPromptId) -> Self {
96 PromptId::User { uuid }
97 }
98}
99
100#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
101#[serde(transparent)]
102pub struct UserPromptId(pub Uuid);
103
104impl UserPromptId {
105 pub fn new() -> UserPromptId {
106 UserPromptId(Uuid::new_v4())
107 }
108}
109
110impl From<Uuid> for UserPromptId {
111 fn from(uuid: Uuid) -> Self {
112 UserPromptId(uuid)
113 }
114}
115
116impl std::fmt::Display for PromptId {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 match self {
119 PromptId::User { uuid } => write!(f, "{}", uuid.0),
120 PromptId::CommitMessage => write!(f, "Commit message"),
121 }
122 }
123}
124
125pub struct PromptStore {
126 env: heed::Env,
127 metadata_cache: RwLock<MetadataCache>,
128 metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
129 bodies: Database<SerdeJson<PromptId>, Str>,
130}
131
132pub struct PromptsUpdatedEvent;
133
134impl EventEmitter<PromptsUpdatedEvent> for PromptStore {}
135
136#[derive(Default)]
137struct MetadataCache {
138 metadata: Vec<PromptMetadata>,
139 metadata_by_id: HashMap<PromptId, PromptMetadata>,
140}
141
142impl MetadataCache {
143 fn from_db(
144 db: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
145 txn: &RoTxn,
146 ) -> Result<Self> {
147 let mut cache = MetadataCache::default();
148 for result in db.iter(txn)? {
149 let (prompt_id, metadata) = result?;
150 cache.metadata.push(metadata.clone());
151 cache.metadata_by_id.insert(prompt_id, metadata);
152 }
153 cache.sort();
154 Ok(cache)
155 }
156
157 fn insert(&mut self, metadata: PromptMetadata) {
158 self.metadata_by_id.insert(metadata.id, metadata.clone());
159 if let Some(old_metadata) = self.metadata.iter_mut().find(|m| m.id == metadata.id) {
160 *old_metadata = metadata;
161 } else {
162 self.metadata.push(metadata);
163 }
164 self.sort();
165 }
166
167 fn remove(&mut self, id: PromptId) {
168 self.metadata.retain(|metadata| metadata.id != id);
169 self.metadata_by_id.remove(&id);
170 }
171
172 fn sort(&mut self) {
173 self.metadata.sort_unstable_by(|a, b| {
174 a.title
175 .cmp(&b.title)
176 .then_with(|| b.saved_at.cmp(&a.saved_at))
177 });
178 }
179}
180
181impl PromptStore {
182 pub fn global(cx: &App) -> impl Future<Output = Result<Entity<Self>>> + use<> {
183 let store = GlobalPromptStore::global(cx).0.clone();
184 async move { store.await.map_err(|err| anyhow!(err)) }
185 }
186
187 pub fn new(db_path: PathBuf, cx: &App) -> Task<Result<Self>> {
188 cx.background_spawn(async move {
189 std::fs::create_dir_all(&db_path)?;
190
191 let db_env = unsafe {
192 heed::EnvOpenOptions::new()
193 .map_size(1024 * 1024 * 1024) // 1GB
194 .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
195 .open(db_path)?
196 };
197
198 let mut txn = db_env.write_txn()?;
199 let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
200 let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
201
202 metadata.delete(&mut txn, &PromptId::CommitMessage)?;
203 bodies.delete(&mut txn, &PromptId::CommitMessage)?;
204
205 txn.commit()?;
206
207 Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
208
209 let txn = db_env.read_txn()?;
210 let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
211 txn.commit()?;
212
213 Ok(PromptStore {
214 env: db_env,
215 metadata_cache: RwLock::new(metadata_cache),
216 metadata,
217 bodies,
218 })
219 })
220 }
221
222 fn upgrade_dbs(
223 env: &heed::Env,
224 metadata_db: heed::Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
225 bodies_db: heed::Database<SerdeJson<PromptId>, Str>,
226 ) -> Result<()> {
227 #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
228 pub struct PromptIdV1(Uuid);
229
230 #[derive(Clone, Debug, Serialize, Deserialize)]
231 pub struct PromptMetadataV1 {
232 pub id: PromptIdV1,
233 pub title: Option<SharedString>,
234 pub default: bool,
235 pub saved_at: DateTime<Utc>,
236 }
237
238 let mut txn = env.write_txn()?;
239 let Some(bodies_v1_db) = env
240 .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<String>>(
241 &txn,
242 Some("bodies"),
243 )?
244 else {
245 return Ok(());
246 };
247 let mut bodies_v1 = bodies_v1_db
248 .iter(&txn)?
249 .collect::<heed::Result<HashMap<_, _>>>()?;
250
251 let Some(metadata_v1_db) = env
252 .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<PromptMetadataV1>>(
253 &txn,
254 Some("metadata"),
255 )?
256 else {
257 return Ok(());
258 };
259 let metadata_v1 = metadata_v1_db
260 .iter(&txn)?
261 .collect::<heed::Result<HashMap<_, _>>>()?;
262
263 for (prompt_id_v1, metadata_v1) in metadata_v1 {
264 let prompt_id_v2 = UserPromptId(prompt_id_v1.0).into();
265 let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
266 continue;
267 };
268
269 if metadata_db
270 .get(&txn, &prompt_id_v2)?
271 .is_none_or(|metadata_v2| metadata_v1.saved_at > metadata_v2.saved_at)
272 {
273 metadata_db.put(
274 &mut txn,
275 &prompt_id_v2,
276 &PromptMetadata {
277 id: prompt_id_v2,
278 title: metadata_v1.title.clone(),
279 default: metadata_v1.default,
280 saved_at: metadata_v1.saved_at,
281 },
282 )?;
283 bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?;
284 }
285 }
286
287 txn.commit()?;
288
289 Ok(())
290 }
291
292 pub fn load(&self, id: PromptId, cx: &App) -> Task<Result<String>> {
293 let env = self.env.clone();
294 let bodies = self.bodies;
295 cx.background_spawn(async move {
296 let txn = env.read_txn()?;
297 let mut prompt = bodies.get(&txn, &id)?.context("prompt not found")?.into();
298 LineEnding::normalize(&mut prompt);
299 Ok(prompt)
300 })
301 }
302
303 pub fn all_prompt_metadata(&self) -> Vec<PromptMetadata> {
304 self.metadata_cache.read().metadata.clone()
305 }
306
307 pub fn default_prompt_metadata(&self) -> Vec<PromptMetadata> {
308 return self
309 .metadata_cache
310 .read()
311 .metadata
312 .iter()
313 .filter(|metadata| metadata.default)
314 .cloned()
315 .collect::<Vec<_>>();
316 }
317
318 pub fn delete(&self, id: PromptId, cx: &Context<Self>) -> Task<Result<()>> {
319 self.metadata_cache.write().remove(id);
320
321 let db_connection = self.env.clone();
322 let bodies = self.bodies;
323 let metadata = self.metadata;
324
325 let task = cx.background_spawn(async move {
326 let mut txn = db_connection.write_txn()?;
327
328 metadata.delete(&mut txn, &id)?;
329 bodies.delete(&mut txn, &id)?;
330
331 txn.commit()?;
332 anyhow::Ok(())
333 });
334
335 cx.spawn(async move |this, cx| {
336 task.await?;
337 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
338 anyhow::Ok(())
339 })
340 }
341
342 /// Returns the number of prompts in the store.
343 pub fn prompt_count(&self) -> usize {
344 self.metadata_cache.read().metadata.len()
345 }
346
347 pub fn metadata(&self, id: PromptId) -> Option<PromptMetadata> {
348 self.metadata_cache.read().metadata_by_id.get(&id).cloned()
349 }
350
351 pub fn first(&self) -> Option<PromptMetadata> {
352 self.metadata_cache.read().metadata.first().cloned()
353 }
354
355 pub fn id_for_title(&self, title: &str) -> Option<PromptId> {
356 let metadata_cache = self.metadata_cache.read();
357 let metadata = metadata_cache
358 .metadata
359 .iter()
360 .find(|metadata| metadata.title.as_ref().map(|title| &***title) == Some(title))?;
361 Some(metadata.id)
362 }
363
364 pub fn search(
365 &self,
366 query: String,
367 cancellation_flag: Arc<AtomicBool>,
368 cx: &App,
369 ) -> Task<Vec<PromptMetadata>> {
370 let cached_metadata = self.metadata_cache.read().metadata.clone();
371 let executor = cx.background_executor().clone();
372 cx.background_spawn(async move {
373 let mut matches = if query.is_empty() {
374 cached_metadata
375 } else {
376 let candidates = cached_metadata
377 .iter()
378 .enumerate()
379 .filter_map(|(ix, metadata)| {
380 Some(StringMatchCandidate::new(ix, metadata.title.as_ref()?))
381 })
382 .collect::<Vec<_>>();
383 let matches = fuzzy::match_strings(
384 &candidates,
385 &query,
386 false,
387 true,
388 100,
389 &cancellation_flag,
390 executor,
391 )
392 .await;
393 matches
394 .into_iter()
395 .map(|mat| cached_metadata[mat.candidate_id].clone())
396 .collect()
397 };
398 matches.sort_by_key(|metadata| Reverse(metadata.default));
399 matches
400 })
401 }
402
403 pub fn save(
404 &self,
405 id: PromptId,
406 title: Option<SharedString>,
407 default: bool,
408 body: Rope,
409 cx: &Context<Self>,
410 ) -> Task<Result<()>> {
411 if !id.can_edit() {
412 return Task::ready(Err(anyhow!("this prompt cannot be edited")));
413 }
414
415 let prompt_metadata = PromptMetadata {
416 id,
417 title,
418 default,
419 saved_at: Utc::now(),
420 };
421 self.metadata_cache.write().insert(prompt_metadata.clone());
422
423 let db_connection = self.env.clone();
424 let bodies = self.bodies;
425 let metadata = self.metadata;
426
427 let task = cx.background_spawn(async move {
428 let mut txn = db_connection.write_txn()?;
429
430 metadata.put(&mut txn, &id, &prompt_metadata)?;
431 bodies.put(&mut txn, &id, &body.to_string())?;
432
433 txn.commit()?;
434
435 anyhow::Ok(())
436 });
437
438 cx.spawn(async move |this, cx| {
439 task.await?;
440 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
441 anyhow::Ok(())
442 })
443 }
444
445 pub fn save_metadata(
446 &self,
447 id: PromptId,
448 mut title: Option<SharedString>,
449 default: bool,
450 cx: &Context<Self>,
451 ) -> Task<Result<()>> {
452 let mut cache = self.metadata_cache.write();
453
454 if !id.can_edit() {
455 title = cache
456 .metadata_by_id
457 .get(&id)
458 .and_then(|metadata| metadata.title.clone());
459 }
460
461 let prompt_metadata = PromptMetadata {
462 id,
463 title,
464 default,
465 saved_at: Utc::now(),
466 };
467
468 cache.insert(prompt_metadata.clone());
469
470 let db_connection = self.env.clone();
471 let metadata = self.metadata;
472
473 let task = cx.background_spawn(async move {
474 let mut txn = db_connection.write_txn()?;
475 metadata.put(&mut txn, &id, &prompt_metadata)?;
476 txn.commit()?;
477
478 anyhow::Ok(())
479 });
480
481 cx.spawn(async move |this, cx| {
482 task.await?;
483 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
484 anyhow::Ok(())
485 })
486 }
487}
488
489/// Wraps a shared future to a prompt store so it can be assigned as a context global.
490pub struct GlobalPromptStore(Shared<Task<Result<Entity<PromptStore>, Arc<anyhow::Error>>>>);
491
492impl Global for GlobalPromptStore {}