1mod prompts;
2
3use anyhow::{Context as _, Result, anyhow};
4use chrono::{DateTime, Utc};
5use collections::HashMap;
6use futures::FutureExt as _;
7use futures::future::Shared;
8use fuzzy::StringMatchCandidate;
9use gpui::{
10 App, AppContext, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Task,
11};
12use heed::{
13 Database, RoTxn,
14 types::{SerdeBincode, SerdeJson, Str},
15};
16use parking_lot::RwLock;
17pub use prompts::*;
18use rope::Rope;
19use serde::{Deserialize, Serialize};
20use std::{
21 cmp::Reverse,
22 future::Future,
23 path::PathBuf,
24 sync::{Arc, atomic::AtomicBool},
25};
26use text::LineEnding;
27use util::ResultExt;
28use uuid::Uuid;
29
30/// Init starts loading the PromptStore in the background and assigns
31/// a shared future to a global.
32pub fn init(cx: &mut App) {
33 let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb");
34 let prompt_store_task = PromptStore::new(db_path, cx);
35 let prompt_store_entity_task = cx
36 .spawn(async move |cx| {
37 prompt_store_task
38 .await
39 .and_then(|prompt_store| cx.new(|_cx| prompt_store))
40 .map_err(Arc::new)
41 })
42 .shared();
43 cx.set_global(GlobalPromptStore(prompt_store_entity_task))
44}
45
46#[derive(Clone, Debug, Serialize, Deserialize)]
47pub struct PromptMetadata {
48 pub id: PromptId,
49 pub title: Option<SharedString>,
50 pub default: bool,
51 pub saved_at: DateTime<Utc>,
52}
53
54#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
55#[serde(tag = "kind")]
56pub enum PromptId {
57 User { uuid: UserPromptId },
58 EditWorkflow,
59}
60
61impl PromptId {
62 pub fn new() -> PromptId {
63 UserPromptId::new().into()
64 }
65
66 pub fn is_built_in(&self) -> bool {
67 !matches!(self, PromptId::User { .. })
68 }
69}
70
71impl From<UserPromptId> for PromptId {
72 fn from(uuid: UserPromptId) -> Self {
73 PromptId::User { uuid }
74 }
75}
76
77#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
78#[serde(transparent)]
79pub struct UserPromptId(pub Uuid);
80
81impl UserPromptId {
82 pub fn new() -> UserPromptId {
83 UserPromptId(Uuid::new_v4())
84 }
85}
86
87impl From<Uuid> for UserPromptId {
88 fn from(uuid: Uuid) -> Self {
89 UserPromptId(uuid)
90 }
91}
92
93impl std::fmt::Display for PromptId {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 match self {
96 PromptId::User { uuid } => write!(f, "{}", uuid.0),
97 PromptId::EditWorkflow => write!(f, "Edit workflow"),
98 }
99 }
100}
101
102pub struct PromptStore {
103 env: heed::Env,
104 metadata_cache: RwLock<MetadataCache>,
105 metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
106 bodies: Database<SerdeJson<PromptId>, Str>,
107}
108
109pub struct PromptsUpdatedEvent;
110
111impl EventEmitter<PromptsUpdatedEvent> for PromptStore {}
112
113#[derive(Default)]
114struct MetadataCache {
115 metadata: Vec<PromptMetadata>,
116 metadata_by_id: HashMap<PromptId, PromptMetadata>,
117}
118
119impl MetadataCache {
120 fn from_db(
121 db: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
122 txn: &RoTxn,
123 ) -> Result<Self> {
124 let mut cache = MetadataCache::default();
125 for result in db.iter(txn)? {
126 let (prompt_id, metadata) = result?;
127 cache.metadata.push(metadata.clone());
128 cache.metadata_by_id.insert(prompt_id, metadata);
129 }
130 cache.sort();
131 Ok(cache)
132 }
133
134 fn insert(&mut self, metadata: PromptMetadata) {
135 self.metadata_by_id.insert(metadata.id, metadata.clone());
136 if let Some(old_metadata) = self.metadata.iter_mut().find(|m| m.id == metadata.id) {
137 *old_metadata = metadata;
138 } else {
139 self.metadata.push(metadata);
140 }
141 self.sort();
142 }
143
144 fn remove(&mut self, id: PromptId) {
145 self.metadata.retain(|metadata| metadata.id != id);
146 self.metadata_by_id.remove(&id);
147 }
148
149 fn sort(&mut self) {
150 self.metadata.sort_unstable_by(|a, b| {
151 a.title
152 .cmp(&b.title)
153 .then_with(|| b.saved_at.cmp(&a.saved_at))
154 });
155 }
156}
157
158impl PromptStore {
159 pub fn global(cx: &App) -> impl Future<Output = Result<Entity<Self>>> + use<> {
160 let store = GlobalPromptStore::global(cx).0.clone();
161 async move { store.await.map_err(|err| anyhow!(err)) }
162 }
163
164 pub fn new(db_path: PathBuf, cx: &App) -> Task<Result<Self>> {
165 cx.background_spawn(async move {
166 std::fs::create_dir_all(&db_path)?;
167
168 let db_env = unsafe {
169 heed::EnvOpenOptions::new()
170 .map_size(1024 * 1024 * 1024) // 1GB
171 .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
172 .open(db_path)?
173 };
174
175 let mut txn = db_env.write_txn()?;
176 let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
177 let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
178
179 // Remove edit workflow prompt, as we decided to opt into it using
180 // a slash command instead.
181 metadata.delete(&mut txn, &PromptId::EditWorkflow).ok();
182 bodies.delete(&mut txn, &PromptId::EditWorkflow).ok();
183
184 txn.commit()?;
185
186 Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
187
188 let txn = db_env.read_txn()?;
189 let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
190 txn.commit()?;
191
192 Ok(PromptStore {
193 env: db_env,
194 metadata_cache: RwLock::new(metadata_cache),
195 metadata,
196 bodies,
197 })
198 })
199 }
200
201 fn upgrade_dbs(
202 env: &heed::Env,
203 metadata_db: heed::Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
204 bodies_db: heed::Database<SerdeJson<PromptId>, Str>,
205 ) -> Result<()> {
206 #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
207 pub struct PromptIdV1(Uuid);
208
209 #[derive(Clone, Debug, Serialize, Deserialize)]
210 pub struct PromptMetadataV1 {
211 pub id: PromptIdV1,
212 pub title: Option<SharedString>,
213 pub default: bool,
214 pub saved_at: DateTime<Utc>,
215 }
216
217 let mut txn = env.write_txn()?;
218 let Some(bodies_v1_db) = env
219 .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<String>>(
220 &txn,
221 Some("bodies"),
222 )?
223 else {
224 return Ok(());
225 };
226 let mut bodies_v1 = bodies_v1_db
227 .iter(&txn)?
228 .collect::<heed::Result<HashMap<_, _>>>()?;
229
230 let Some(metadata_v1_db) = env
231 .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<PromptMetadataV1>>(
232 &txn,
233 Some("metadata"),
234 )?
235 else {
236 return Ok(());
237 };
238 let metadata_v1 = metadata_v1_db
239 .iter(&txn)?
240 .collect::<heed::Result<HashMap<_, _>>>()?;
241
242 for (prompt_id_v1, metadata_v1) in metadata_v1 {
243 let prompt_id_v2 = UserPromptId(prompt_id_v1.0).into();
244 let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
245 continue;
246 };
247
248 if metadata_db
249 .get(&txn, &prompt_id_v2)?
250 .is_none_or(|metadata_v2| metadata_v1.saved_at > metadata_v2.saved_at)
251 {
252 metadata_db.put(
253 &mut txn,
254 &prompt_id_v2,
255 &PromptMetadata {
256 id: prompt_id_v2,
257 title: metadata_v1.title.clone(),
258 default: metadata_v1.default,
259 saved_at: metadata_v1.saved_at,
260 },
261 )?;
262 bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?;
263 }
264 }
265
266 txn.commit()?;
267
268 Ok(())
269 }
270
271 pub fn load(&self, id: PromptId, cx: &App) -> Task<Result<String>> {
272 let env = self.env.clone();
273 let bodies = self.bodies;
274 cx.background_spawn(async move {
275 let txn = env.read_txn()?;
276 let mut prompt = bodies.get(&txn, &id)?.context("prompt not found")?.into();
277 LineEnding::normalize(&mut prompt);
278 Ok(prompt)
279 })
280 }
281
282 pub fn all_prompt_metadata(&self) -> Vec<PromptMetadata> {
283 self.metadata_cache.read().metadata.clone()
284 }
285
286 pub fn default_prompt_metadata(&self) -> Vec<PromptMetadata> {
287 return self
288 .metadata_cache
289 .read()
290 .metadata
291 .iter()
292 .filter(|metadata| metadata.default)
293 .cloned()
294 .collect::<Vec<_>>();
295 }
296
297 pub fn delete(&self, id: PromptId, cx: &Context<Self>) -> Task<Result<()>> {
298 self.metadata_cache.write().remove(id);
299
300 let db_connection = self.env.clone();
301 let bodies = self.bodies;
302 let metadata = self.metadata;
303
304 let task = cx.background_spawn(async move {
305 let mut txn = db_connection.write_txn()?;
306
307 metadata.delete(&mut txn, &id)?;
308 bodies.delete(&mut txn, &id)?;
309
310 txn.commit()?;
311 anyhow::Ok(())
312 });
313
314 cx.spawn(async move |this, cx| {
315 task.await?;
316 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
317 anyhow::Ok(())
318 })
319 }
320
321 /// Returns the number of prompts in the store.
322 pub fn prompt_count(&self) -> usize {
323 self.metadata_cache.read().metadata.len()
324 }
325
326 pub fn metadata(&self, id: PromptId) -> Option<PromptMetadata> {
327 self.metadata_cache.read().metadata_by_id.get(&id).cloned()
328 }
329
330 pub fn first(&self) -> Option<PromptMetadata> {
331 self.metadata_cache.read().metadata.first().cloned()
332 }
333
334 pub fn id_for_title(&self, title: &str) -> Option<PromptId> {
335 let metadata_cache = self.metadata_cache.read();
336 let metadata = metadata_cache
337 .metadata
338 .iter()
339 .find(|metadata| metadata.title.as_ref().map(|title| &***title) == Some(title))?;
340 Some(metadata.id)
341 }
342
343 pub fn search(
344 &self,
345 query: String,
346 cancellation_flag: Arc<AtomicBool>,
347 cx: &App,
348 ) -> Task<Vec<PromptMetadata>> {
349 let cached_metadata = self.metadata_cache.read().metadata.clone();
350 let executor = cx.background_executor().clone();
351 cx.background_spawn(async move {
352 let mut matches = if query.is_empty() {
353 cached_metadata
354 } else {
355 let candidates = cached_metadata
356 .iter()
357 .enumerate()
358 .filter_map(|(ix, metadata)| {
359 Some(StringMatchCandidate::new(ix, metadata.title.as_ref()?))
360 })
361 .collect::<Vec<_>>();
362 let matches = fuzzy::match_strings(
363 &candidates,
364 &query,
365 false,
366 true,
367 100,
368 &cancellation_flag,
369 executor,
370 )
371 .await;
372 matches
373 .into_iter()
374 .map(|mat| cached_metadata[mat.candidate_id].clone())
375 .collect()
376 };
377 matches.sort_by_key(|metadata| Reverse(metadata.default));
378 matches
379 })
380 }
381
382 pub fn save(
383 &self,
384 id: PromptId,
385 title: Option<SharedString>,
386 default: bool,
387 body: Rope,
388 cx: &Context<Self>,
389 ) -> Task<Result<()>> {
390 if id.is_built_in() {
391 return Task::ready(Err(anyhow!("built-in prompts cannot be saved")));
392 }
393
394 let prompt_metadata = PromptMetadata {
395 id,
396 title,
397 default,
398 saved_at: Utc::now(),
399 };
400 self.metadata_cache.write().insert(prompt_metadata.clone());
401
402 let db_connection = self.env.clone();
403 let bodies = self.bodies;
404 let metadata = self.metadata;
405
406 let task = cx.background_spawn(async move {
407 let mut txn = db_connection.write_txn()?;
408
409 metadata.put(&mut txn, &id, &prompt_metadata)?;
410 bodies.put(&mut txn, &id, &body.to_string())?;
411
412 txn.commit()?;
413
414 anyhow::Ok(())
415 });
416
417 cx.spawn(async move |this, cx| {
418 task.await?;
419 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
420 anyhow::Ok(())
421 })
422 }
423
424 pub fn save_metadata(
425 &self,
426 id: PromptId,
427 mut title: Option<SharedString>,
428 default: bool,
429 cx: &Context<Self>,
430 ) -> Task<Result<()>> {
431 let mut cache = self.metadata_cache.write();
432
433 if id.is_built_in() {
434 title = cache
435 .metadata_by_id
436 .get(&id)
437 .and_then(|metadata| metadata.title.clone());
438 }
439
440 let prompt_metadata = PromptMetadata {
441 id,
442 title,
443 default,
444 saved_at: Utc::now(),
445 };
446
447 cache.insert(prompt_metadata.clone());
448
449 let db_connection = self.env.clone();
450 let metadata = self.metadata;
451
452 let task = cx.background_spawn(async move {
453 let mut txn = db_connection.write_txn()?;
454 metadata.put(&mut txn, &id, &prompt_metadata)?;
455 txn.commit()?;
456
457 anyhow::Ok(())
458 });
459
460 cx.spawn(async move |this, cx| {
461 task.await?;
462 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
463 anyhow::Ok(())
464 })
465 }
466}
467
468/// Wraps a shared future to a prompt store so it can be assigned as a context global.
469pub struct GlobalPromptStore(Shared<Task<Result<Entity<PromptStore>, Arc<anyhow::Error>>>>);
470
471impl Global for GlobalPromptStore {}