1mod prompts;
2
3use anyhow::{Result, anyhow};
4use chrono::{DateTime, Utc};
5use collections::HashMap;
6use futures::FutureExt as _;
7use futures::future::Shared;
8use fuzzy::StringMatchCandidate;
9use gpui::{
10 App, AppContext, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Task,
11};
12use heed::{
13 Database, RoTxn,
14 types::{SerdeBincode, SerdeJson, Str},
15};
16use parking_lot::RwLock;
17pub use prompts::*;
18use rope::Rope;
19use serde::{Deserialize, Serialize};
20use std::{
21 cmp::Reverse,
22 future::Future,
23 path::PathBuf,
24 sync::{Arc, atomic::AtomicBool},
25};
26use text::LineEnding;
27use util::ResultExt;
28use uuid::Uuid;
29
30/// Init starts loading the PromptStore in the background and assigns
31/// a shared future to a global.
32pub fn init(cx: &mut App) {
33 let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb");
34 let prompt_store_task = PromptStore::new(db_path, cx);
35 let prompt_store_entity_task = cx
36 .spawn(async move |cx| {
37 prompt_store_task
38 .await
39 .and_then(|prompt_store| cx.new(|_cx| prompt_store))
40 .map_err(Arc::new)
41 })
42 .shared();
43 cx.set_global(GlobalPromptStore(prompt_store_entity_task))
44}
45
46#[derive(Clone, Debug, Serialize, Deserialize)]
47pub struct PromptMetadata {
48 pub id: PromptId,
49 pub title: Option<SharedString>,
50 pub default: bool,
51 pub saved_at: DateTime<Utc>,
52}
53
54#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
55#[serde(tag = "kind")]
56pub enum PromptId {
57 User { uuid: UserPromptId },
58 EditWorkflow,
59}
60
61impl PromptId {
62 pub fn new() -> PromptId {
63 PromptId::User {
64 uuid: UserPromptId::new(),
65 }
66 }
67
68 pub fn is_built_in(&self) -> bool {
69 !matches!(self, PromptId::User { .. })
70 }
71}
72
73#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
74#[serde(transparent)]
75pub struct UserPromptId(pub Uuid);
76
77impl UserPromptId {
78 pub fn new() -> UserPromptId {
79 UserPromptId(Uuid::new_v4())
80 }
81}
82
83impl From<Uuid> for UserPromptId {
84 fn from(uuid: Uuid) -> Self {
85 UserPromptId(uuid)
86 }
87}
88
89pub struct PromptStore {
90 env: heed::Env,
91 metadata_cache: RwLock<MetadataCache>,
92 metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
93 bodies: Database<SerdeJson<PromptId>, Str>,
94}
95
96pub struct PromptsUpdatedEvent;
97
98impl EventEmitter<PromptsUpdatedEvent> for PromptStore {}
99
100#[derive(Default)]
101struct MetadataCache {
102 metadata: Vec<PromptMetadata>,
103 metadata_by_id: HashMap<PromptId, PromptMetadata>,
104}
105
106impl MetadataCache {
107 fn from_db(
108 db: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
109 txn: &RoTxn,
110 ) -> Result<Self> {
111 let mut cache = MetadataCache::default();
112 for result in db.iter(txn)? {
113 let (prompt_id, metadata) = result?;
114 cache.metadata.push(metadata.clone());
115 cache.metadata_by_id.insert(prompt_id, metadata);
116 }
117 cache.sort();
118 Ok(cache)
119 }
120
121 fn insert(&mut self, metadata: PromptMetadata) {
122 self.metadata_by_id.insert(metadata.id, metadata.clone());
123 if let Some(old_metadata) = self.metadata.iter_mut().find(|m| m.id == metadata.id) {
124 *old_metadata = metadata;
125 } else {
126 self.metadata.push(metadata);
127 }
128 self.sort();
129 }
130
131 fn remove(&mut self, id: PromptId) {
132 self.metadata.retain(|metadata| metadata.id != id);
133 self.metadata_by_id.remove(&id);
134 }
135
136 fn sort(&mut self) {
137 self.metadata.sort_unstable_by(|a, b| {
138 a.title
139 .cmp(&b.title)
140 .then_with(|| b.saved_at.cmp(&a.saved_at))
141 });
142 }
143}
144
145impl PromptStore {
146 pub fn global(cx: &App) -> impl Future<Output = Result<Entity<Self>>> + use<> {
147 let store = GlobalPromptStore::global(cx).0.clone();
148 async move { store.await.map_err(|err| anyhow!(err)) }
149 }
150
151 pub fn new(db_path: PathBuf, cx: &App) -> Task<Result<Self>> {
152 cx.background_spawn(async move {
153 std::fs::create_dir_all(&db_path)?;
154
155 let db_env = unsafe {
156 heed::EnvOpenOptions::new()
157 .map_size(1024 * 1024 * 1024) // 1GB
158 .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
159 .open(db_path)?
160 };
161
162 let mut txn = db_env.write_txn()?;
163 let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
164 let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
165
166 // Remove edit workflow prompt, as we decided to opt into it using
167 // a slash command instead.
168 metadata.delete(&mut txn, &PromptId::EditWorkflow).ok();
169 bodies.delete(&mut txn, &PromptId::EditWorkflow).ok();
170
171 txn.commit()?;
172
173 Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
174
175 let txn = db_env.read_txn()?;
176 let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
177 txn.commit()?;
178
179 Ok(PromptStore {
180 env: db_env,
181 metadata_cache: RwLock::new(metadata_cache),
182 metadata,
183 bodies,
184 })
185 })
186 }
187
188 fn upgrade_dbs(
189 env: &heed::Env,
190 metadata_db: heed::Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
191 bodies_db: heed::Database<SerdeJson<PromptId>, Str>,
192 ) -> Result<()> {
193 #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
194 pub struct PromptIdV1(Uuid);
195
196 #[derive(Clone, Debug, Serialize, Deserialize)]
197 pub struct PromptMetadataV1 {
198 pub id: PromptIdV1,
199 pub title: Option<SharedString>,
200 pub default: bool,
201 pub saved_at: DateTime<Utc>,
202 }
203
204 let mut txn = env.write_txn()?;
205 let Some(bodies_v1_db) = env
206 .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<String>>(
207 &txn,
208 Some("bodies"),
209 )?
210 else {
211 return Ok(());
212 };
213 let mut bodies_v1 = bodies_v1_db
214 .iter(&txn)?
215 .collect::<heed::Result<HashMap<_, _>>>()?;
216
217 let Some(metadata_v1_db) = env
218 .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<PromptMetadataV1>>(
219 &txn,
220 Some("metadata"),
221 )?
222 else {
223 return Ok(());
224 };
225 let metadata_v1 = metadata_v1_db
226 .iter(&txn)?
227 .collect::<heed::Result<HashMap<_, _>>>()?;
228
229 for (prompt_id_v1, metadata_v1) in metadata_v1 {
230 let prompt_id_v2 = PromptId::User {
231 uuid: UserPromptId(prompt_id_v1.0),
232 };
233 let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
234 continue;
235 };
236
237 if metadata_db
238 .get(&txn, &prompt_id_v2)?
239 .map_or(true, |metadata_v2| {
240 metadata_v1.saved_at > metadata_v2.saved_at
241 })
242 {
243 metadata_db.put(
244 &mut txn,
245 &prompt_id_v2,
246 &PromptMetadata {
247 id: prompt_id_v2,
248 title: metadata_v1.title.clone(),
249 default: metadata_v1.default,
250 saved_at: metadata_v1.saved_at,
251 },
252 )?;
253 bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?;
254 }
255 }
256
257 txn.commit()?;
258
259 Ok(())
260 }
261
262 pub fn load(&self, id: PromptId, cx: &App) -> Task<Result<String>> {
263 let env = self.env.clone();
264 let bodies = self.bodies;
265 cx.background_spawn(async move {
266 let txn = env.read_txn()?;
267 let mut prompt = bodies
268 .get(&txn, &id)?
269 .ok_or_else(|| anyhow!("prompt not found"))?
270 .into();
271 LineEnding::normalize(&mut prompt);
272 Ok(prompt)
273 })
274 }
275
276 pub fn all_prompt_metadata(&self) -> Vec<PromptMetadata> {
277 self.metadata_cache.read().metadata.clone()
278 }
279
280 pub fn default_prompt_metadata(&self) -> Vec<PromptMetadata> {
281 return self
282 .metadata_cache
283 .read()
284 .metadata
285 .iter()
286 .filter(|metadata| metadata.default)
287 .cloned()
288 .collect::<Vec<_>>();
289 }
290
291 pub fn delete(&self, id: PromptId, cx: &Context<Self>) -> Task<Result<()>> {
292 self.metadata_cache.write().remove(id);
293
294 let db_connection = self.env.clone();
295 let bodies = self.bodies;
296 let metadata = self.metadata;
297
298 let task = cx.background_spawn(async move {
299 let mut txn = db_connection.write_txn()?;
300
301 metadata.delete(&mut txn, &id)?;
302 bodies.delete(&mut txn, &id)?;
303
304 txn.commit()?;
305 anyhow::Ok(())
306 });
307
308 cx.spawn(async move |this, cx| {
309 task.await?;
310 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
311 anyhow::Ok(())
312 })
313 }
314
315 /// Returns the number of prompts in the store.
316 pub fn prompt_count(&self) -> usize {
317 self.metadata_cache.read().metadata.len()
318 }
319
320 pub fn metadata(&self, id: PromptId) -> Option<PromptMetadata> {
321 self.metadata_cache.read().metadata_by_id.get(&id).cloned()
322 }
323
324 pub fn first(&self) -> Option<PromptMetadata> {
325 self.metadata_cache.read().metadata.first().cloned()
326 }
327
328 pub fn id_for_title(&self, title: &str) -> Option<PromptId> {
329 let metadata_cache = self.metadata_cache.read();
330 let metadata = metadata_cache
331 .metadata
332 .iter()
333 .find(|metadata| metadata.title.as_ref().map(|title| &***title) == Some(title))?;
334 Some(metadata.id)
335 }
336
337 pub fn search(
338 &self,
339 query: String,
340 cancellation_flag: Arc<AtomicBool>,
341 cx: &App,
342 ) -> Task<Vec<PromptMetadata>> {
343 let cached_metadata = self.metadata_cache.read().metadata.clone();
344 let executor = cx.background_executor().clone();
345 cx.background_spawn(async move {
346 let mut matches = if query.is_empty() {
347 cached_metadata
348 } else {
349 let candidates = cached_metadata
350 .iter()
351 .enumerate()
352 .filter_map(|(ix, metadata)| {
353 Some(StringMatchCandidate::new(ix, metadata.title.as_ref()?))
354 })
355 .collect::<Vec<_>>();
356 let matches = fuzzy::match_strings(
357 &candidates,
358 &query,
359 false,
360 100,
361 &cancellation_flag,
362 executor,
363 )
364 .await;
365 matches
366 .into_iter()
367 .map(|mat| cached_metadata[mat.candidate_id].clone())
368 .collect()
369 };
370 matches.sort_by_key(|metadata| Reverse(metadata.default));
371 matches
372 })
373 }
374
375 pub fn save(
376 &self,
377 id: PromptId,
378 title: Option<SharedString>,
379 default: bool,
380 body: Rope,
381 cx: &Context<Self>,
382 ) -> Task<Result<()>> {
383 if id.is_built_in() {
384 return Task::ready(Err(anyhow!("built-in prompts cannot be saved")));
385 }
386
387 let prompt_metadata = PromptMetadata {
388 id,
389 title,
390 default,
391 saved_at: Utc::now(),
392 };
393 self.metadata_cache.write().insert(prompt_metadata.clone());
394
395 let db_connection = self.env.clone();
396 let bodies = self.bodies;
397 let metadata = self.metadata;
398
399 let task = cx.background_spawn(async move {
400 let mut txn = db_connection.write_txn()?;
401
402 metadata.put(&mut txn, &id, &prompt_metadata)?;
403 bodies.put(&mut txn, &id, &body.to_string())?;
404
405 txn.commit()?;
406
407 anyhow::Ok(())
408 });
409
410 cx.spawn(async move |this, cx| {
411 task.await?;
412 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
413 anyhow::Ok(())
414 })
415 }
416
417 pub fn save_metadata(
418 &self,
419 id: PromptId,
420 mut title: Option<SharedString>,
421 default: bool,
422 cx: &Context<Self>,
423 ) -> Task<Result<()>> {
424 let mut cache = self.metadata_cache.write();
425
426 if id.is_built_in() {
427 title = cache
428 .metadata_by_id
429 .get(&id)
430 .and_then(|metadata| metadata.title.clone());
431 }
432
433 let prompt_metadata = PromptMetadata {
434 id,
435 title,
436 default,
437 saved_at: Utc::now(),
438 };
439
440 cache.insert(prompt_metadata.clone());
441
442 let db_connection = self.env.clone();
443 let metadata = self.metadata;
444
445 let task = cx.background_spawn(async move {
446 let mut txn = db_connection.write_txn()?;
447 metadata.put(&mut txn, &id, &prompt_metadata)?;
448 txn.commit()?;
449
450 anyhow::Ok(())
451 });
452
453 cx.spawn(async move |this, cx| {
454 task.await?;
455 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
456 anyhow::Ok(())
457 })
458 }
459}
460
461/// Wraps a shared future to a prompt store so it can be assigned as a context global.
462pub struct GlobalPromptStore(Shared<Task<Result<Entity<PromptStore>, Arc<anyhow::Error>>>>);
463
464impl Global for GlobalPromptStore {}