1mod prompts;
2
3use anyhow::{Context as _, Result, anyhow};
4use chrono::{DateTime, Utc};
5use collections::HashMap;
6use futures::FutureExt as _;
7use futures::future::Shared;
8use fuzzy::StringMatchCandidate;
9use gpui::{
10 App, AppContext, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Task,
11};
12use heed::{
13 Database, RoTxn,
14 types::{SerdeBincode, SerdeJson, Str},
15};
16use parking_lot::RwLock;
17pub use prompts::*;
18use rope::Rope;
19use serde::{Deserialize, Serialize};
20use std::{
21 cmp::Reverse,
22 future::Future,
23 path::PathBuf,
24 sync::{Arc, atomic::AtomicBool},
25};
26use text::LineEnding;
27use util::ResultExt;
28use uuid::Uuid;
29
30/// Init starts loading the PromptStore in the background and assigns
31/// a shared future to a global.
32pub fn init(cx: &mut App) {
33 let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb");
34 let prompt_store_task = PromptStore::new(db_path, cx);
35 let prompt_store_entity_task = cx
36 .spawn(async move |cx| {
37 prompt_store_task
38 .await
39 .and_then(|prompt_store| cx.new(|_cx| prompt_store))
40 .map_err(Arc::new)
41 })
42 .shared();
43 cx.set_global(GlobalPromptStore(prompt_store_entity_task))
44}
45
46#[derive(Clone, Debug, Serialize, Deserialize)]
47pub struct PromptMetadata {
48 pub id: PromptId,
49 pub title: Option<SharedString>,
50 pub default: bool,
51 pub saved_at: DateTime<Utc>,
52}
53
54#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
55#[serde(tag = "kind")]
56pub enum PromptId {
57 User { uuid: UserPromptId },
58 EditWorkflow,
59}
60
61impl PromptId {
62 pub fn new() -> PromptId {
63 UserPromptId::new().into()
64 }
65
66 pub fn is_built_in(&self) -> bool {
67 !matches!(self, PromptId::User { .. })
68 }
69}
70
71impl From<UserPromptId> for PromptId {
72 fn from(uuid: UserPromptId) -> Self {
73 PromptId::User { uuid }
74 }
75}
76
77#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
78#[serde(transparent)]
79pub struct UserPromptId(pub Uuid);
80
81impl UserPromptId {
82 pub fn new() -> UserPromptId {
83 UserPromptId(Uuid::new_v4())
84 }
85}
86
87impl From<Uuid> for UserPromptId {
88 fn from(uuid: Uuid) -> Self {
89 UserPromptId(uuid)
90 }
91}
92
93pub struct PromptStore {
94 env: heed::Env,
95 metadata_cache: RwLock<MetadataCache>,
96 metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
97 bodies: Database<SerdeJson<PromptId>, Str>,
98}
99
100pub struct PromptsUpdatedEvent;
101
102impl EventEmitter<PromptsUpdatedEvent> for PromptStore {}
103
104#[derive(Default)]
105struct MetadataCache {
106 metadata: Vec<PromptMetadata>,
107 metadata_by_id: HashMap<PromptId, PromptMetadata>,
108}
109
110impl MetadataCache {
111 fn from_db(
112 db: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
113 txn: &RoTxn,
114 ) -> Result<Self> {
115 let mut cache = MetadataCache::default();
116 for result in db.iter(txn)? {
117 let (prompt_id, metadata) = result?;
118 cache.metadata.push(metadata.clone());
119 cache.metadata_by_id.insert(prompt_id, metadata);
120 }
121 cache.sort();
122 Ok(cache)
123 }
124
125 fn insert(&mut self, metadata: PromptMetadata) {
126 self.metadata_by_id.insert(metadata.id, metadata.clone());
127 if let Some(old_metadata) = self.metadata.iter_mut().find(|m| m.id == metadata.id) {
128 *old_metadata = metadata;
129 } else {
130 self.metadata.push(metadata);
131 }
132 self.sort();
133 }
134
135 fn remove(&mut self, id: PromptId) {
136 self.metadata.retain(|metadata| metadata.id != id);
137 self.metadata_by_id.remove(&id);
138 }
139
140 fn sort(&mut self) {
141 self.metadata.sort_unstable_by(|a, b| {
142 a.title
143 .cmp(&b.title)
144 .then_with(|| b.saved_at.cmp(&a.saved_at))
145 });
146 }
147}
148
149impl PromptStore {
150 pub fn global(cx: &App) -> impl Future<Output = Result<Entity<Self>>> + use<> {
151 let store = GlobalPromptStore::global(cx).0.clone();
152 async move { store.await.map_err(|err| anyhow!(err)) }
153 }
154
155 pub fn new(db_path: PathBuf, cx: &App) -> Task<Result<Self>> {
156 cx.background_spawn(async move {
157 std::fs::create_dir_all(&db_path)?;
158
159 let db_env = unsafe {
160 heed::EnvOpenOptions::new()
161 .map_size(1024 * 1024 * 1024) // 1GB
162 .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
163 .open(db_path)?
164 };
165
166 let mut txn = db_env.write_txn()?;
167 let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
168 let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
169
170 // Remove edit workflow prompt, as we decided to opt into it using
171 // a slash command instead.
172 metadata.delete(&mut txn, &PromptId::EditWorkflow).ok();
173 bodies.delete(&mut txn, &PromptId::EditWorkflow).ok();
174
175 txn.commit()?;
176
177 Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
178
179 let txn = db_env.read_txn()?;
180 let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
181 txn.commit()?;
182
183 Ok(PromptStore {
184 env: db_env,
185 metadata_cache: RwLock::new(metadata_cache),
186 metadata,
187 bodies,
188 })
189 })
190 }
191
192 fn upgrade_dbs(
193 env: &heed::Env,
194 metadata_db: heed::Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
195 bodies_db: heed::Database<SerdeJson<PromptId>, Str>,
196 ) -> Result<()> {
197 #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
198 pub struct PromptIdV1(Uuid);
199
200 #[derive(Clone, Debug, Serialize, Deserialize)]
201 pub struct PromptMetadataV1 {
202 pub id: PromptIdV1,
203 pub title: Option<SharedString>,
204 pub default: bool,
205 pub saved_at: DateTime<Utc>,
206 }
207
208 let mut txn = env.write_txn()?;
209 let Some(bodies_v1_db) = env
210 .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<String>>(
211 &txn,
212 Some("bodies"),
213 )?
214 else {
215 return Ok(());
216 };
217 let mut bodies_v1 = bodies_v1_db
218 .iter(&txn)?
219 .collect::<heed::Result<HashMap<_, _>>>()?;
220
221 let Some(metadata_v1_db) = env
222 .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<PromptMetadataV1>>(
223 &txn,
224 Some("metadata"),
225 )?
226 else {
227 return Ok(());
228 };
229 let metadata_v1 = metadata_v1_db
230 .iter(&txn)?
231 .collect::<heed::Result<HashMap<_, _>>>()?;
232
233 for (prompt_id_v1, metadata_v1) in metadata_v1 {
234 let prompt_id_v2 = UserPromptId(prompt_id_v1.0).into();
235 let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
236 continue;
237 };
238
239 if metadata_db
240 .get(&txn, &prompt_id_v2)?
241 .map_or(true, |metadata_v2| {
242 metadata_v1.saved_at > metadata_v2.saved_at
243 })
244 {
245 metadata_db.put(
246 &mut txn,
247 &prompt_id_v2,
248 &PromptMetadata {
249 id: prompt_id_v2,
250 title: metadata_v1.title.clone(),
251 default: metadata_v1.default,
252 saved_at: metadata_v1.saved_at,
253 },
254 )?;
255 bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?;
256 }
257 }
258
259 txn.commit()?;
260
261 Ok(())
262 }
263
264 pub fn load(&self, id: PromptId, cx: &App) -> Task<Result<String>> {
265 let env = self.env.clone();
266 let bodies = self.bodies;
267 cx.background_spawn(async move {
268 let txn = env.read_txn()?;
269 let mut prompt = bodies.get(&txn, &id)?.context("prompt not found")?.into();
270 LineEnding::normalize(&mut prompt);
271 Ok(prompt)
272 })
273 }
274
275 pub fn all_prompt_metadata(&self) -> Vec<PromptMetadata> {
276 self.metadata_cache.read().metadata.clone()
277 }
278
279 pub fn default_prompt_metadata(&self) -> Vec<PromptMetadata> {
280 return self
281 .metadata_cache
282 .read()
283 .metadata
284 .iter()
285 .filter(|metadata| metadata.default)
286 .cloned()
287 .collect::<Vec<_>>();
288 }
289
290 pub fn delete(&self, id: PromptId, cx: &Context<Self>) -> Task<Result<()>> {
291 self.metadata_cache.write().remove(id);
292
293 let db_connection = self.env.clone();
294 let bodies = self.bodies;
295 let metadata = self.metadata;
296
297 let task = cx.background_spawn(async move {
298 let mut txn = db_connection.write_txn()?;
299
300 metadata.delete(&mut txn, &id)?;
301 bodies.delete(&mut txn, &id)?;
302
303 txn.commit()?;
304 anyhow::Ok(())
305 });
306
307 cx.spawn(async move |this, cx| {
308 task.await?;
309 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
310 anyhow::Ok(())
311 })
312 }
313
314 /// Returns the number of prompts in the store.
315 pub fn prompt_count(&self) -> usize {
316 self.metadata_cache.read().metadata.len()
317 }
318
319 pub fn metadata(&self, id: PromptId) -> Option<PromptMetadata> {
320 self.metadata_cache.read().metadata_by_id.get(&id).cloned()
321 }
322
323 pub fn first(&self) -> Option<PromptMetadata> {
324 self.metadata_cache.read().metadata.first().cloned()
325 }
326
327 pub fn id_for_title(&self, title: &str) -> Option<PromptId> {
328 let metadata_cache = self.metadata_cache.read();
329 let metadata = metadata_cache
330 .metadata
331 .iter()
332 .find(|metadata| metadata.title.as_ref().map(|title| &***title) == Some(title))?;
333 Some(metadata.id)
334 }
335
336 pub fn search(
337 &self,
338 query: String,
339 cancellation_flag: Arc<AtomicBool>,
340 cx: &App,
341 ) -> Task<Vec<PromptMetadata>> {
342 let cached_metadata = self.metadata_cache.read().metadata.clone();
343 let executor = cx.background_executor().clone();
344 cx.background_spawn(async move {
345 let mut matches = if query.is_empty() {
346 cached_metadata
347 } else {
348 let candidates = cached_metadata
349 .iter()
350 .enumerate()
351 .filter_map(|(ix, metadata)| {
352 Some(StringMatchCandidate::new(ix, metadata.title.as_ref()?))
353 })
354 .collect::<Vec<_>>();
355 let matches = fuzzy::match_strings(
356 &candidates,
357 &query,
358 false,
359 true,
360 100,
361 &cancellation_flag,
362 executor,
363 )
364 .await;
365 matches
366 .into_iter()
367 .map(|mat| cached_metadata[mat.candidate_id].clone())
368 .collect()
369 };
370 matches.sort_by_key(|metadata| Reverse(metadata.default));
371 matches
372 })
373 }
374
375 pub fn save(
376 &self,
377 id: PromptId,
378 title: Option<SharedString>,
379 default: bool,
380 body: Rope,
381 cx: &Context<Self>,
382 ) -> Task<Result<()>> {
383 if id.is_built_in() {
384 return Task::ready(Err(anyhow!("built-in prompts cannot be saved")));
385 }
386
387 let prompt_metadata = PromptMetadata {
388 id,
389 title,
390 default,
391 saved_at: Utc::now(),
392 };
393 self.metadata_cache.write().insert(prompt_metadata.clone());
394
395 let db_connection = self.env.clone();
396 let bodies = self.bodies;
397 let metadata = self.metadata;
398
399 let task = cx.background_spawn(async move {
400 let mut txn = db_connection.write_txn()?;
401
402 metadata.put(&mut txn, &id, &prompt_metadata)?;
403 bodies.put(&mut txn, &id, &body.to_string())?;
404
405 txn.commit()?;
406
407 anyhow::Ok(())
408 });
409
410 cx.spawn(async move |this, cx| {
411 task.await?;
412 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
413 anyhow::Ok(())
414 })
415 }
416
417 pub fn save_metadata(
418 &self,
419 id: PromptId,
420 mut title: Option<SharedString>,
421 default: bool,
422 cx: &Context<Self>,
423 ) -> Task<Result<()>> {
424 let mut cache = self.metadata_cache.write();
425
426 if id.is_built_in() {
427 title = cache
428 .metadata_by_id
429 .get(&id)
430 .and_then(|metadata| metadata.title.clone());
431 }
432
433 let prompt_metadata = PromptMetadata {
434 id,
435 title,
436 default,
437 saved_at: Utc::now(),
438 };
439
440 cache.insert(prompt_metadata.clone());
441
442 let db_connection = self.env.clone();
443 let metadata = self.metadata;
444
445 let task = cx.background_spawn(async move {
446 let mut txn = db_connection.write_txn()?;
447 metadata.put(&mut txn, &id, &prompt_metadata)?;
448 txn.commit()?;
449
450 anyhow::Ok(())
451 });
452
453 cx.spawn(async move |this, cx| {
454 task.await?;
455 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
456 anyhow::Ok(())
457 })
458 }
459}
460
461/// Wraps a shared future to a prompt store so it can be assigned as a context global.
462pub struct GlobalPromptStore(Shared<Task<Result<Entity<PromptStore>, Arc<anyhow::Error>>>>);
463
464impl Global for GlobalPromptStore {}