1mod prompts;
2
3use anyhow::{Result, anyhow};
4use chrono::{DateTime, Utc};
5use collections::HashMap;
6use futures::FutureExt as _;
7use futures::future::Shared;
8use fuzzy::StringMatchCandidate;
9use gpui::{
10 App, AppContext, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Task,
11};
12use heed::{
13 Database, RoTxn,
14 types::{SerdeBincode, SerdeJson, Str},
15};
16use parking_lot::RwLock;
17pub use prompts::*;
18use rope::Rope;
19use serde::{Deserialize, Serialize};
20use std::{
21 cmp::Reverse,
22 future::Future,
23 path::PathBuf,
24 sync::{Arc, atomic::AtomicBool},
25};
26use text::LineEnding;
27use util::ResultExt;
28use uuid::Uuid;
29
30/// Init starts loading the PromptStore in the background and assigns
31/// a shared future to a global.
32pub fn init(cx: &mut App) {
33 let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb");
34 let prompt_store_task = PromptStore::new(db_path, cx);
35 let prompt_store_entity_task = cx
36 .spawn(async move |cx| {
37 prompt_store_task
38 .await
39 .and_then(|prompt_store| cx.new(|_cx| prompt_store))
40 .map_err(Arc::new)
41 })
42 .shared();
43 cx.set_global(GlobalPromptStore(prompt_store_entity_task))
44}
45
46#[derive(Clone, Debug, Serialize, Deserialize)]
47pub struct PromptMetadata {
48 pub id: PromptId,
49 pub title: Option<SharedString>,
50 pub default: bool,
51 pub saved_at: DateTime<Utc>,
52}
53
54#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
55#[serde(tag = "kind")]
56pub enum PromptId {
57 User { uuid: UserPromptId },
58 EditWorkflow,
59}
60
61impl PromptId {
62 pub fn new() -> PromptId {
63 UserPromptId::new().into()
64 }
65
66 pub fn is_built_in(&self) -> bool {
67 !matches!(self, PromptId::User { .. })
68 }
69}
70
71impl From<UserPromptId> for PromptId {
72 fn from(uuid: UserPromptId) -> Self {
73 PromptId::User { uuid }
74 }
75}
76
77#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
78#[serde(transparent)]
79pub struct UserPromptId(pub Uuid);
80
81impl UserPromptId {
82 pub fn new() -> UserPromptId {
83 UserPromptId(Uuid::new_v4())
84 }
85}
86
87impl From<Uuid> for UserPromptId {
88 fn from(uuid: Uuid) -> Self {
89 UserPromptId(uuid)
90 }
91}
92
93pub struct PromptStore {
94 env: heed::Env,
95 metadata_cache: RwLock<MetadataCache>,
96 metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
97 bodies: Database<SerdeJson<PromptId>, Str>,
98}
99
100pub struct PromptsUpdatedEvent;
101
102impl EventEmitter<PromptsUpdatedEvent> for PromptStore {}
103
104#[derive(Default)]
105struct MetadataCache {
106 metadata: Vec<PromptMetadata>,
107 metadata_by_id: HashMap<PromptId, PromptMetadata>,
108}
109
110impl MetadataCache {
111 fn from_db(
112 db: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
113 txn: &RoTxn,
114 ) -> Result<Self> {
115 let mut cache = MetadataCache::default();
116 for result in db.iter(txn)? {
117 let (prompt_id, metadata) = result?;
118 cache.metadata.push(metadata.clone());
119 cache.metadata_by_id.insert(prompt_id, metadata);
120 }
121 cache.sort();
122 Ok(cache)
123 }
124
125 fn insert(&mut self, metadata: PromptMetadata) {
126 self.metadata_by_id.insert(metadata.id, metadata.clone());
127 if let Some(old_metadata) = self.metadata.iter_mut().find(|m| m.id == metadata.id) {
128 *old_metadata = metadata;
129 } else {
130 self.metadata.push(metadata);
131 }
132 self.sort();
133 }
134
135 fn remove(&mut self, id: PromptId) {
136 self.metadata.retain(|metadata| metadata.id != id);
137 self.metadata_by_id.remove(&id);
138 }
139
140 fn sort(&mut self) {
141 self.metadata.sort_unstable_by(|a, b| {
142 a.title
143 .cmp(&b.title)
144 .then_with(|| b.saved_at.cmp(&a.saved_at))
145 });
146 }
147}
148
149impl PromptStore {
150 pub fn global(cx: &App) -> impl Future<Output = Result<Entity<Self>>> + use<> {
151 let store = GlobalPromptStore::global(cx).0.clone();
152 async move { store.await.map_err(|err| anyhow!(err)) }
153 }
154
155 pub fn new(db_path: PathBuf, cx: &App) -> Task<Result<Self>> {
156 cx.background_spawn(async move {
157 std::fs::create_dir_all(&db_path)?;
158
159 let db_env = unsafe {
160 heed::EnvOpenOptions::new()
161 .map_size(1024 * 1024 * 1024) // 1GB
162 .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
163 .open(db_path)?
164 };
165
166 let mut txn = db_env.write_txn()?;
167 let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
168 let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
169
170 // Remove edit workflow prompt, as we decided to opt into it using
171 // a slash command instead.
172 metadata.delete(&mut txn, &PromptId::EditWorkflow).ok();
173 bodies.delete(&mut txn, &PromptId::EditWorkflow).ok();
174
175 txn.commit()?;
176
177 Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
178
179 let txn = db_env.read_txn()?;
180 let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
181 txn.commit()?;
182
183 Ok(PromptStore {
184 env: db_env,
185 metadata_cache: RwLock::new(metadata_cache),
186 metadata,
187 bodies,
188 })
189 })
190 }
191
192 fn upgrade_dbs(
193 env: &heed::Env,
194 metadata_db: heed::Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
195 bodies_db: heed::Database<SerdeJson<PromptId>, Str>,
196 ) -> Result<()> {
197 #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
198 pub struct PromptIdV1(Uuid);
199
200 #[derive(Clone, Debug, Serialize, Deserialize)]
201 pub struct PromptMetadataV1 {
202 pub id: PromptIdV1,
203 pub title: Option<SharedString>,
204 pub default: bool,
205 pub saved_at: DateTime<Utc>,
206 }
207
208 let mut txn = env.write_txn()?;
209 let Some(bodies_v1_db) = env
210 .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<String>>(
211 &txn,
212 Some("bodies"),
213 )?
214 else {
215 return Ok(());
216 };
217 let mut bodies_v1 = bodies_v1_db
218 .iter(&txn)?
219 .collect::<heed::Result<HashMap<_, _>>>()?;
220
221 let Some(metadata_v1_db) = env
222 .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<PromptMetadataV1>>(
223 &txn,
224 Some("metadata"),
225 )?
226 else {
227 return Ok(());
228 };
229 let metadata_v1 = metadata_v1_db
230 .iter(&txn)?
231 .collect::<heed::Result<HashMap<_, _>>>()?;
232
233 for (prompt_id_v1, metadata_v1) in metadata_v1 {
234 let prompt_id_v2 = UserPromptId(prompt_id_v1.0).into();
235 let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
236 continue;
237 };
238
239 if metadata_db
240 .get(&txn, &prompt_id_v2)?
241 .map_or(true, |metadata_v2| {
242 metadata_v1.saved_at > metadata_v2.saved_at
243 })
244 {
245 metadata_db.put(
246 &mut txn,
247 &prompt_id_v2,
248 &PromptMetadata {
249 id: prompt_id_v2,
250 title: metadata_v1.title.clone(),
251 default: metadata_v1.default,
252 saved_at: metadata_v1.saved_at,
253 },
254 )?;
255 bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?;
256 }
257 }
258
259 txn.commit()?;
260
261 Ok(())
262 }
263
264 pub fn load(&self, id: PromptId, cx: &App) -> Task<Result<String>> {
265 let env = self.env.clone();
266 let bodies = self.bodies;
267 cx.background_spawn(async move {
268 let txn = env.read_txn()?;
269 let mut prompt = bodies
270 .get(&txn, &id)?
271 .ok_or_else(|| anyhow!("prompt not found"))?
272 .into();
273 LineEnding::normalize(&mut prompt);
274 Ok(prompt)
275 })
276 }
277
278 pub fn all_prompt_metadata(&self) -> Vec<PromptMetadata> {
279 self.metadata_cache.read().metadata.clone()
280 }
281
282 pub fn default_prompt_metadata(&self) -> Vec<PromptMetadata> {
283 return self
284 .metadata_cache
285 .read()
286 .metadata
287 .iter()
288 .filter(|metadata| metadata.default)
289 .cloned()
290 .collect::<Vec<_>>();
291 }
292
293 pub fn delete(&self, id: PromptId, cx: &Context<Self>) -> Task<Result<()>> {
294 self.metadata_cache.write().remove(id);
295
296 let db_connection = self.env.clone();
297 let bodies = self.bodies;
298 let metadata = self.metadata;
299
300 let task = cx.background_spawn(async move {
301 let mut txn = db_connection.write_txn()?;
302
303 metadata.delete(&mut txn, &id)?;
304 bodies.delete(&mut txn, &id)?;
305
306 txn.commit()?;
307 anyhow::Ok(())
308 });
309
310 cx.spawn(async move |this, cx| {
311 task.await?;
312 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
313 anyhow::Ok(())
314 })
315 }
316
317 /// Returns the number of prompts in the store.
318 pub fn prompt_count(&self) -> usize {
319 self.metadata_cache.read().metadata.len()
320 }
321
322 pub fn metadata(&self, id: PromptId) -> Option<PromptMetadata> {
323 self.metadata_cache.read().metadata_by_id.get(&id).cloned()
324 }
325
326 pub fn first(&self) -> Option<PromptMetadata> {
327 self.metadata_cache.read().metadata.first().cloned()
328 }
329
330 pub fn id_for_title(&self, title: &str) -> Option<PromptId> {
331 let metadata_cache = self.metadata_cache.read();
332 let metadata = metadata_cache
333 .metadata
334 .iter()
335 .find(|metadata| metadata.title.as_ref().map(|title| &***title) == Some(title))?;
336 Some(metadata.id)
337 }
338
339 pub fn search(
340 &self,
341 query: String,
342 cancellation_flag: Arc<AtomicBool>,
343 cx: &App,
344 ) -> Task<Vec<PromptMetadata>> {
345 let cached_metadata = self.metadata_cache.read().metadata.clone();
346 let executor = cx.background_executor().clone();
347 cx.background_spawn(async move {
348 let mut matches = if query.is_empty() {
349 cached_metadata
350 } else {
351 let candidates = cached_metadata
352 .iter()
353 .enumerate()
354 .filter_map(|(ix, metadata)| {
355 Some(StringMatchCandidate::new(ix, metadata.title.as_ref()?))
356 })
357 .collect::<Vec<_>>();
358 let matches = fuzzy::match_strings(
359 &candidates,
360 &query,
361 false,
362 100,
363 &cancellation_flag,
364 executor,
365 )
366 .await;
367 matches
368 .into_iter()
369 .map(|mat| cached_metadata[mat.candidate_id].clone())
370 .collect()
371 };
372 matches.sort_by_key(|metadata| Reverse(metadata.default));
373 matches
374 })
375 }
376
377 pub fn save(
378 &self,
379 id: PromptId,
380 title: Option<SharedString>,
381 default: bool,
382 body: Rope,
383 cx: &Context<Self>,
384 ) -> Task<Result<()>> {
385 if id.is_built_in() {
386 return Task::ready(Err(anyhow!("built-in prompts cannot be saved")));
387 }
388
389 let prompt_metadata = PromptMetadata {
390 id,
391 title,
392 default,
393 saved_at: Utc::now(),
394 };
395 self.metadata_cache.write().insert(prompt_metadata.clone());
396
397 let db_connection = self.env.clone();
398 let bodies = self.bodies;
399 let metadata = self.metadata;
400
401 let task = cx.background_spawn(async move {
402 let mut txn = db_connection.write_txn()?;
403
404 metadata.put(&mut txn, &id, &prompt_metadata)?;
405 bodies.put(&mut txn, &id, &body.to_string())?;
406
407 txn.commit()?;
408
409 anyhow::Ok(())
410 });
411
412 cx.spawn(async move |this, cx| {
413 task.await?;
414 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
415 anyhow::Ok(())
416 })
417 }
418
419 pub fn save_metadata(
420 &self,
421 id: PromptId,
422 mut title: Option<SharedString>,
423 default: bool,
424 cx: &Context<Self>,
425 ) -> Task<Result<()>> {
426 let mut cache = self.metadata_cache.write();
427
428 if id.is_built_in() {
429 title = cache
430 .metadata_by_id
431 .get(&id)
432 .and_then(|metadata| metadata.title.clone());
433 }
434
435 let prompt_metadata = PromptMetadata {
436 id,
437 title,
438 default,
439 saved_at: Utc::now(),
440 };
441
442 cache.insert(prompt_metadata.clone());
443
444 let db_connection = self.env.clone();
445 let metadata = self.metadata;
446
447 let task = cx.background_spawn(async move {
448 let mut txn = db_connection.write_txn()?;
449 metadata.put(&mut txn, &id, &prompt_metadata)?;
450 txn.commit()?;
451
452 anyhow::Ok(())
453 });
454
455 cx.spawn(async move |this, cx| {
456 task.await?;
457 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
458 anyhow::Ok(())
459 })
460 }
461}
462
463/// Wraps a shared future to a prompt store so it can be assigned as a context global.
464pub struct GlobalPromptStore(Shared<Task<Result<Entity<PromptStore>, Arc<anyhow::Error>>>>);
465
466impl Global for GlobalPromptStore {}