1mod prompts;
2
3use anyhow::{Context as _, Result, anyhow};
4use chrono::{DateTime, Utc};
5use collections::HashMap;
6use futures::FutureExt as _;
7use futures::future::Shared;
8use fuzzy::StringMatchCandidate;
9use gpui::{
10 App, AppContext, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Task,
11};
12use heed::{
13 Database, RoTxn,
14 types::{SerdeBincode, SerdeJson, Str},
15};
16use parking_lot::RwLock;
17pub use prompts::*;
18use rope::Rope;
19use serde::{Deserialize, Serialize};
20use std::{
21 cmp::Reverse,
22 future::Future,
23 path::PathBuf,
24 sync::{Arc, atomic::AtomicBool},
25};
26use text::LineEnding;
27use util::ResultExt;
28use uuid::Uuid;
29
30/// Init starts loading the PromptStore in the background and assigns
31/// a shared future to a global.
32pub fn init(cx: &mut App) {
33 let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb");
34 let prompt_store_task = PromptStore::new(db_path, cx);
35 let prompt_store_entity_task = cx
36 .spawn(async move |cx| {
37 prompt_store_task
38 .await
39 .and_then(|prompt_store| cx.new(|_cx| prompt_store))
40 .map_err(Arc::new)
41 })
42 .shared();
43 cx.set_global(GlobalPromptStore(prompt_store_entity_task))
44}
45
46#[derive(Clone, Debug, Serialize, Deserialize)]
47pub struct PromptMetadata {
48 pub id: PromptId,
49 pub title: Option<SharedString>,
50 pub default: bool,
51 pub saved_at: DateTime<Utc>,
52}
53
54#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
55#[serde(tag = "kind")]
56pub enum PromptId {
57 User { uuid: UserPromptId },
58 EditWorkflow,
59}
60
61impl PromptId {
62 pub fn new() -> PromptId {
63 UserPromptId::new().into()
64 }
65
66 pub fn is_built_in(&self) -> bool {
67 !matches!(self, PromptId::User { .. })
68 }
69}
70
71impl From<UserPromptId> for PromptId {
72 fn from(uuid: UserPromptId) -> Self {
73 PromptId::User { uuid }
74 }
75}
76
77#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
78#[serde(transparent)]
79pub struct UserPromptId(pub Uuid);
80
81impl UserPromptId {
82 pub fn new() -> UserPromptId {
83 UserPromptId(Uuid::new_v4())
84 }
85}
86
87impl From<Uuid> for UserPromptId {
88 fn from(uuid: Uuid) -> Self {
89 UserPromptId(uuid)
90 }
91}
92
93impl std::fmt::Display for PromptId {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 match self {
96 PromptId::User { uuid } => write!(f, "{}", uuid.0),
97 PromptId::EditWorkflow => write!(f, "Edit workflow"),
98 }
99 }
100}
101
102pub struct PromptStore {
103 env: heed::Env,
104 metadata_cache: RwLock<MetadataCache>,
105 metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
106 bodies: Database<SerdeJson<PromptId>, Str>,
107}
108
109pub struct PromptsUpdatedEvent;
110
111impl EventEmitter<PromptsUpdatedEvent> for PromptStore {}
112
113#[derive(Default)]
114struct MetadataCache {
115 metadata: Vec<PromptMetadata>,
116 metadata_by_id: HashMap<PromptId, PromptMetadata>,
117}
118
119impl MetadataCache {
120 fn from_db(
121 db: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
122 txn: &RoTxn,
123 ) -> Result<Self> {
124 let mut cache = MetadataCache::default();
125 for result in db.iter(txn)? {
126 let (prompt_id, metadata) = result?;
127 cache.metadata.push(metadata.clone());
128 cache.metadata_by_id.insert(prompt_id, metadata);
129 }
130 cache.sort();
131 Ok(cache)
132 }
133
134 fn insert(&mut self, metadata: PromptMetadata) {
135 self.metadata_by_id.insert(metadata.id, metadata.clone());
136 if let Some(old_metadata) = self.metadata.iter_mut().find(|m| m.id == metadata.id) {
137 *old_metadata = metadata;
138 } else {
139 self.metadata.push(metadata);
140 }
141 self.sort();
142 }
143
144 fn remove(&mut self, id: PromptId) {
145 self.metadata.retain(|metadata| metadata.id != id);
146 self.metadata_by_id.remove(&id);
147 }
148
149 fn sort(&mut self) {
150 self.metadata.sort_unstable_by(|a, b| {
151 a.title
152 .cmp(&b.title)
153 .then_with(|| b.saved_at.cmp(&a.saved_at))
154 });
155 }
156}
157
158impl PromptStore {
159 pub fn global(cx: &App) -> impl Future<Output = Result<Entity<Self>>> + use<> {
160 let store = GlobalPromptStore::global(cx).0.clone();
161 async move { store.await.map_err(|err| anyhow!(err)) }
162 }
163
164 pub fn new(db_path: PathBuf, cx: &App) -> Task<Result<Self>> {
165 cx.background_spawn(async move {
166 std::fs::create_dir_all(&db_path)?;
167
168 let db_env = unsafe {
169 heed::EnvOpenOptions::new()
170 .map_size(1024 * 1024 * 1024) // 1GB
171 .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
172 .open(db_path)?
173 };
174
175 let mut txn = db_env.write_txn()?;
176 let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
177 let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
178
179 // Remove edit workflow prompt, as we decided to opt into it using
180 // a slash command instead.
181 metadata.delete(&mut txn, &PromptId::EditWorkflow).ok();
182 bodies.delete(&mut txn, &PromptId::EditWorkflow).ok();
183
184 txn.commit()?;
185
186 Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
187
188 let txn = db_env.read_txn()?;
189 let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
190 txn.commit()?;
191
192 Ok(PromptStore {
193 env: db_env,
194 metadata_cache: RwLock::new(metadata_cache),
195 metadata,
196 bodies,
197 })
198 })
199 }
200
201 fn upgrade_dbs(
202 env: &heed::Env,
203 metadata_db: heed::Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
204 bodies_db: heed::Database<SerdeJson<PromptId>, Str>,
205 ) -> Result<()> {
206 #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
207 pub struct PromptIdV1(Uuid);
208
209 #[derive(Clone, Debug, Serialize, Deserialize)]
210 pub struct PromptMetadataV1 {
211 pub id: PromptIdV1,
212 pub title: Option<SharedString>,
213 pub default: bool,
214 pub saved_at: DateTime<Utc>,
215 }
216
217 let mut txn = env.write_txn()?;
218 let Some(bodies_v1_db) = env
219 .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<String>>(
220 &txn,
221 Some("bodies"),
222 )?
223 else {
224 return Ok(());
225 };
226 let mut bodies_v1 = bodies_v1_db
227 .iter(&txn)?
228 .collect::<heed::Result<HashMap<_, _>>>()?;
229
230 let Some(metadata_v1_db) = env
231 .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<PromptMetadataV1>>(
232 &txn,
233 Some("metadata"),
234 )?
235 else {
236 return Ok(());
237 };
238 let metadata_v1 = metadata_v1_db
239 .iter(&txn)?
240 .collect::<heed::Result<HashMap<_, _>>>()?;
241
242 for (prompt_id_v1, metadata_v1) in metadata_v1 {
243 let prompt_id_v2 = UserPromptId(prompt_id_v1.0).into();
244 let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
245 continue;
246 };
247
248 if metadata_db
249 .get(&txn, &prompt_id_v2)?
250 .map_or(true, |metadata_v2| {
251 metadata_v1.saved_at > metadata_v2.saved_at
252 })
253 {
254 metadata_db.put(
255 &mut txn,
256 &prompt_id_v2,
257 &PromptMetadata {
258 id: prompt_id_v2,
259 title: metadata_v1.title.clone(),
260 default: metadata_v1.default,
261 saved_at: metadata_v1.saved_at,
262 },
263 )?;
264 bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?;
265 }
266 }
267
268 txn.commit()?;
269
270 Ok(())
271 }
272
273 pub fn load(&self, id: PromptId, cx: &App) -> Task<Result<String>> {
274 let env = self.env.clone();
275 let bodies = self.bodies;
276 cx.background_spawn(async move {
277 let txn = env.read_txn()?;
278 let mut prompt = bodies.get(&txn, &id)?.context("prompt not found")?.into();
279 LineEnding::normalize(&mut prompt);
280 Ok(prompt)
281 })
282 }
283
284 pub fn all_prompt_metadata(&self) -> Vec<PromptMetadata> {
285 self.metadata_cache.read().metadata.clone()
286 }
287
288 pub fn default_prompt_metadata(&self) -> Vec<PromptMetadata> {
289 return self
290 .metadata_cache
291 .read()
292 .metadata
293 .iter()
294 .filter(|metadata| metadata.default)
295 .cloned()
296 .collect::<Vec<_>>();
297 }
298
299 pub fn delete(&self, id: PromptId, cx: &Context<Self>) -> Task<Result<()>> {
300 self.metadata_cache.write().remove(id);
301
302 let db_connection = self.env.clone();
303 let bodies = self.bodies;
304 let metadata = self.metadata;
305
306 let task = cx.background_spawn(async move {
307 let mut txn = db_connection.write_txn()?;
308
309 metadata.delete(&mut txn, &id)?;
310 bodies.delete(&mut txn, &id)?;
311
312 txn.commit()?;
313 anyhow::Ok(())
314 });
315
316 cx.spawn(async move |this, cx| {
317 task.await?;
318 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
319 anyhow::Ok(())
320 })
321 }
322
323 /// Returns the number of prompts in the store.
324 pub fn prompt_count(&self) -> usize {
325 self.metadata_cache.read().metadata.len()
326 }
327
328 pub fn metadata(&self, id: PromptId) -> Option<PromptMetadata> {
329 self.metadata_cache.read().metadata_by_id.get(&id).cloned()
330 }
331
332 pub fn first(&self) -> Option<PromptMetadata> {
333 self.metadata_cache.read().metadata.first().cloned()
334 }
335
336 pub fn id_for_title(&self, title: &str) -> Option<PromptId> {
337 let metadata_cache = self.metadata_cache.read();
338 let metadata = metadata_cache
339 .metadata
340 .iter()
341 .find(|metadata| metadata.title.as_ref().map(|title| &***title) == Some(title))?;
342 Some(metadata.id)
343 }
344
345 pub fn search(
346 &self,
347 query: String,
348 cancellation_flag: Arc<AtomicBool>,
349 cx: &App,
350 ) -> Task<Vec<PromptMetadata>> {
351 let cached_metadata = self.metadata_cache.read().metadata.clone();
352 let executor = cx.background_executor().clone();
353 cx.background_spawn(async move {
354 let mut matches = if query.is_empty() {
355 cached_metadata
356 } else {
357 let candidates = cached_metadata
358 .iter()
359 .enumerate()
360 .filter_map(|(ix, metadata)| {
361 Some(StringMatchCandidate::new(ix, metadata.title.as_ref()?))
362 })
363 .collect::<Vec<_>>();
364 let matches = fuzzy::match_strings(
365 &candidates,
366 &query,
367 false,
368 true,
369 100,
370 &cancellation_flag,
371 executor,
372 )
373 .await;
374 matches
375 .into_iter()
376 .map(|mat| cached_metadata[mat.candidate_id].clone())
377 .collect()
378 };
379 matches.sort_by_key(|metadata| Reverse(metadata.default));
380 matches
381 })
382 }
383
384 pub fn save(
385 &self,
386 id: PromptId,
387 title: Option<SharedString>,
388 default: bool,
389 body: Rope,
390 cx: &Context<Self>,
391 ) -> Task<Result<()>> {
392 if id.is_built_in() {
393 return Task::ready(Err(anyhow!("built-in prompts cannot be saved")));
394 }
395
396 let prompt_metadata = PromptMetadata {
397 id,
398 title,
399 default,
400 saved_at: Utc::now(),
401 };
402 self.metadata_cache.write().insert(prompt_metadata.clone());
403
404 let db_connection = self.env.clone();
405 let bodies = self.bodies;
406 let metadata = self.metadata;
407
408 let task = cx.background_spawn(async move {
409 let mut txn = db_connection.write_txn()?;
410
411 metadata.put(&mut txn, &id, &prompt_metadata)?;
412 bodies.put(&mut txn, &id, &body.to_string())?;
413
414 txn.commit()?;
415
416 anyhow::Ok(())
417 });
418
419 cx.spawn(async move |this, cx| {
420 task.await?;
421 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
422 anyhow::Ok(())
423 })
424 }
425
426 pub fn save_metadata(
427 &self,
428 id: PromptId,
429 mut title: Option<SharedString>,
430 default: bool,
431 cx: &Context<Self>,
432 ) -> Task<Result<()>> {
433 let mut cache = self.metadata_cache.write();
434
435 if id.is_built_in() {
436 title = cache
437 .metadata_by_id
438 .get(&id)
439 .and_then(|metadata| metadata.title.clone());
440 }
441
442 let prompt_metadata = PromptMetadata {
443 id,
444 title,
445 default,
446 saved_at: Utc::now(),
447 };
448
449 cache.insert(prompt_metadata.clone());
450
451 let db_connection = self.env.clone();
452 let metadata = self.metadata;
453
454 let task = cx.background_spawn(async move {
455 let mut txn = db_connection.write_txn()?;
456 metadata.put(&mut txn, &id, &prompt_metadata)?;
457 txn.commit()?;
458
459 anyhow::Ok(())
460 });
461
462 cx.spawn(async move |this, cx| {
463 task.await?;
464 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
465 anyhow::Ok(())
466 })
467 }
468}
469
470/// Wraps a shared future to a prompt store so it can be assigned as a context global.
471pub struct GlobalPromptStore(Shared<Task<Result<Entity<PromptStore>, Arc<anyhow::Error>>>>);
472
473impl Global for GlobalPromptStore {}