1mod prompts;
2
3use anyhow::{Result, anyhow};
4use chrono::{DateTime, Utc};
5use collections::HashMap;
6use futures::FutureExt as _;
7use futures::future::Shared;
8use fuzzy::StringMatchCandidate;
9use gpui::{
10 App, AppContext, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Task,
11};
12use heed::{
13 Database, RoTxn,
14 types::{SerdeBincode, SerdeJson, Str},
15};
16use parking_lot::RwLock;
17pub use prompts::*;
18use rope::Rope;
19use serde::{Deserialize, Serialize};
20use std::{
21 cmp::Reverse,
22 future::Future,
23 path::PathBuf,
24 sync::{Arc, atomic::AtomicBool},
25};
26use text::LineEnding;
27use util::ResultExt;
28use uuid::Uuid;
29
30/// Init starts loading the PromptStore in the background and assigns
31/// a shared future to a global.
32pub fn init(cx: &mut App) {
33 let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb");
34 let prompt_store_task = PromptStore::new(db_path, cx);
35 let prompt_store_entity_task = cx
36 .spawn(async move |cx| {
37 prompt_store_task
38 .await
39 .and_then(|prompt_store| cx.new(|_cx| prompt_store))
40 .map_err(Arc::new)
41 })
42 .shared();
43 cx.set_global(GlobalPromptStore(prompt_store_entity_task))
44}
45
46#[derive(Clone, Debug, Serialize, Deserialize)]
47pub struct PromptMetadata {
48 pub id: PromptId,
49 pub title: Option<SharedString>,
50 pub default: bool,
51 pub saved_at: DateTime<Utc>,
52}
53
54#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
55#[serde(tag = "kind")]
56pub enum PromptId {
57 User { uuid: Uuid },
58 EditWorkflow,
59}
60
61impl PromptId {
62 pub fn new() -> PromptId {
63 PromptId::User {
64 uuid: Uuid::new_v4(),
65 }
66 }
67
68 pub fn is_built_in(&self) -> bool {
69 !matches!(self, PromptId::User { .. })
70 }
71}
72
73pub struct PromptStore {
74 env: heed::Env,
75 metadata_cache: RwLock<MetadataCache>,
76 metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
77 bodies: Database<SerdeJson<PromptId>, Str>,
78}
79
80pub struct PromptsUpdatedEvent;
81
82impl EventEmitter<PromptsUpdatedEvent> for PromptStore {}
83
84#[derive(Default)]
85struct MetadataCache {
86 metadata: Vec<PromptMetadata>,
87 metadata_by_id: HashMap<PromptId, PromptMetadata>,
88}
89
90impl MetadataCache {
91 fn from_db(
92 db: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
93 txn: &RoTxn,
94 ) -> Result<Self> {
95 let mut cache = MetadataCache::default();
96 for result in db.iter(txn)? {
97 let (prompt_id, metadata) = result?;
98 cache.metadata.push(metadata.clone());
99 cache.metadata_by_id.insert(prompt_id, metadata);
100 }
101 cache.sort();
102 Ok(cache)
103 }
104
105 fn insert(&mut self, metadata: PromptMetadata) {
106 self.metadata_by_id.insert(metadata.id, metadata.clone());
107 if let Some(old_metadata) = self.metadata.iter_mut().find(|m| m.id == metadata.id) {
108 *old_metadata = metadata;
109 } else {
110 self.metadata.push(metadata);
111 }
112 self.sort();
113 }
114
115 fn remove(&mut self, id: PromptId) {
116 self.metadata.retain(|metadata| metadata.id != id);
117 self.metadata_by_id.remove(&id);
118 }
119
120 fn sort(&mut self) {
121 self.metadata.sort_unstable_by(|a, b| {
122 a.title
123 .cmp(&b.title)
124 .then_with(|| b.saved_at.cmp(&a.saved_at))
125 });
126 }
127}
128
129impl PromptStore {
130 pub fn global(cx: &App) -> impl Future<Output = Result<Entity<Self>>> + use<> {
131 let store = GlobalPromptStore::global(cx).0.clone();
132 async move { store.await.map_err(|err| anyhow!(err)) }
133 }
134
135 pub fn new(db_path: PathBuf, cx: &App) -> Task<Result<Self>> {
136 cx.background_spawn(async move {
137 std::fs::create_dir_all(&db_path)?;
138
139 let db_env = unsafe {
140 heed::EnvOpenOptions::new()
141 .map_size(1024 * 1024 * 1024) // 1GB
142 .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
143 .open(db_path)?
144 };
145
146 let mut txn = db_env.write_txn()?;
147 let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
148 let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
149
150 // Remove edit workflow prompt, as we decided to opt into it using
151 // a slash command instead.
152 metadata.delete(&mut txn, &PromptId::EditWorkflow).ok();
153 bodies.delete(&mut txn, &PromptId::EditWorkflow).ok();
154
155 txn.commit()?;
156
157 Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
158
159 let txn = db_env.read_txn()?;
160 let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
161 txn.commit()?;
162
163 Ok(PromptStore {
164 env: db_env,
165 metadata_cache: RwLock::new(metadata_cache),
166 metadata,
167 bodies,
168 })
169 })
170 }
171
172 fn upgrade_dbs(
173 env: &heed::Env,
174 metadata_db: heed::Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
175 bodies_db: heed::Database<SerdeJson<PromptId>, Str>,
176 ) -> Result<()> {
177 #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
178 pub struct PromptIdV1(Uuid);
179
180 #[derive(Clone, Debug, Serialize, Deserialize)]
181 pub struct PromptMetadataV1 {
182 pub id: PromptIdV1,
183 pub title: Option<SharedString>,
184 pub default: bool,
185 pub saved_at: DateTime<Utc>,
186 }
187
188 let mut txn = env.write_txn()?;
189 let Some(bodies_v1_db) = env
190 .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<String>>(
191 &txn,
192 Some("bodies"),
193 )?
194 else {
195 return Ok(());
196 };
197 let mut bodies_v1 = bodies_v1_db
198 .iter(&txn)?
199 .collect::<heed::Result<HashMap<_, _>>>()?;
200
201 let Some(metadata_v1_db) = env
202 .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<PromptMetadataV1>>(
203 &txn,
204 Some("metadata"),
205 )?
206 else {
207 return Ok(());
208 };
209 let metadata_v1 = metadata_v1_db
210 .iter(&txn)?
211 .collect::<heed::Result<HashMap<_, _>>>()?;
212
213 for (prompt_id_v1, metadata_v1) in metadata_v1 {
214 let prompt_id_v2 = PromptId::User {
215 uuid: prompt_id_v1.0,
216 };
217 let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
218 continue;
219 };
220
221 if metadata_db
222 .get(&txn, &prompt_id_v2)?
223 .map_or(true, |metadata_v2| {
224 metadata_v1.saved_at > metadata_v2.saved_at
225 })
226 {
227 metadata_db.put(
228 &mut txn,
229 &prompt_id_v2,
230 &PromptMetadata {
231 id: prompt_id_v2,
232 title: metadata_v1.title.clone(),
233 default: metadata_v1.default,
234 saved_at: metadata_v1.saved_at,
235 },
236 )?;
237 bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?;
238 }
239 }
240
241 txn.commit()?;
242
243 Ok(())
244 }
245
246 pub fn load(&self, id: PromptId, cx: &App) -> Task<Result<String>> {
247 let env = self.env.clone();
248 let bodies = self.bodies;
249 cx.background_spawn(async move {
250 let txn = env.read_txn()?;
251 let mut prompt = bodies
252 .get(&txn, &id)?
253 .ok_or_else(|| anyhow!("prompt not found"))?
254 .into();
255 LineEnding::normalize(&mut prompt);
256 Ok(prompt)
257 })
258 }
259
260 pub fn default_prompt_metadata(&self) -> Vec<PromptMetadata> {
261 return self
262 .metadata_cache
263 .read()
264 .metadata
265 .iter()
266 .filter(|metadata| metadata.default)
267 .cloned()
268 .collect::<Vec<_>>();
269 }
270
271 pub fn delete(&self, id: PromptId, cx: &Context<Self>) -> Task<Result<()>> {
272 self.metadata_cache.write().remove(id);
273
274 let db_connection = self.env.clone();
275 let bodies = self.bodies;
276 let metadata = self.metadata;
277
278 let task = cx.background_spawn(async move {
279 let mut txn = db_connection.write_txn()?;
280
281 metadata.delete(&mut txn, &id)?;
282 bodies.delete(&mut txn, &id)?;
283
284 txn.commit()?;
285 anyhow::Ok(())
286 });
287
288 cx.spawn(async move |this, cx| {
289 task.await?;
290 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
291 anyhow::Ok(())
292 })
293 }
294
295 /// Returns the number of prompts in the store.
296 pub fn prompt_count(&self) -> usize {
297 self.metadata_cache.read().metadata.len()
298 }
299
300 pub fn metadata(&self, id: PromptId) -> Option<PromptMetadata> {
301 self.metadata_cache.read().metadata_by_id.get(&id).cloned()
302 }
303
304 pub fn first(&self) -> Option<PromptMetadata> {
305 self.metadata_cache.read().metadata.first().cloned()
306 }
307
308 pub fn id_for_title(&self, title: &str) -> Option<PromptId> {
309 let metadata_cache = self.metadata_cache.read();
310 let metadata = metadata_cache
311 .metadata
312 .iter()
313 .find(|metadata| metadata.title.as_ref().map(|title| &***title) == Some(title))?;
314 Some(metadata.id)
315 }
316
317 pub fn search(&self, query: String, cx: &App) -> Task<Vec<PromptMetadata>> {
318 let cached_metadata = self.metadata_cache.read().metadata.clone();
319 let executor = cx.background_executor().clone();
320 cx.background_spawn(async move {
321 let mut matches = if query.is_empty() {
322 cached_metadata
323 } else {
324 let candidates = cached_metadata
325 .iter()
326 .enumerate()
327 .filter_map(|(ix, metadata)| {
328 Some(StringMatchCandidate::new(ix, metadata.title.as_ref()?))
329 })
330 .collect::<Vec<_>>();
331 let matches = fuzzy::match_strings(
332 &candidates,
333 &query,
334 false,
335 100,
336 &AtomicBool::default(),
337 executor,
338 )
339 .await;
340 matches
341 .into_iter()
342 .map(|mat| cached_metadata[mat.candidate_id].clone())
343 .collect()
344 };
345 matches.sort_by_key(|metadata| Reverse(metadata.default));
346 matches
347 })
348 }
349
350 pub fn save(
351 &self,
352 id: PromptId,
353 title: Option<SharedString>,
354 default: bool,
355 body: Rope,
356 cx: &Context<Self>,
357 ) -> Task<Result<()>> {
358 if id.is_built_in() {
359 return Task::ready(Err(anyhow!("built-in prompts cannot be saved")));
360 }
361
362 let prompt_metadata = PromptMetadata {
363 id,
364 title,
365 default,
366 saved_at: Utc::now(),
367 };
368 self.metadata_cache.write().insert(prompt_metadata.clone());
369
370 let db_connection = self.env.clone();
371 let bodies = self.bodies;
372 let metadata = self.metadata;
373
374 let task = cx.background_spawn(async move {
375 let mut txn = db_connection.write_txn()?;
376
377 metadata.put(&mut txn, &id, &prompt_metadata)?;
378 bodies.put(&mut txn, &id, &body.to_string())?;
379
380 txn.commit()?;
381
382 anyhow::Ok(())
383 });
384
385 cx.spawn(async move |this, cx| {
386 task.await?;
387 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
388 anyhow::Ok(())
389 })
390 }
391
392 pub fn save_metadata(
393 &self,
394 id: PromptId,
395 mut title: Option<SharedString>,
396 default: bool,
397 cx: &Context<Self>,
398 ) -> Task<Result<()>> {
399 let mut cache = self.metadata_cache.write();
400
401 if id.is_built_in() {
402 title = cache
403 .metadata_by_id
404 .get(&id)
405 .and_then(|metadata| metadata.title.clone());
406 }
407
408 let prompt_metadata = PromptMetadata {
409 id,
410 title,
411 default,
412 saved_at: Utc::now(),
413 };
414
415 cache.insert(prompt_metadata.clone());
416
417 let db_connection = self.env.clone();
418 let metadata = self.metadata;
419
420 let task = cx.background_spawn(async move {
421 let mut txn = db_connection.write_txn()?;
422 metadata.put(&mut txn, &id, &prompt_metadata)?;
423 txn.commit()?;
424
425 anyhow::Ok(())
426 });
427
428 cx.spawn(async move |this, cx| {
429 task.await?;
430 this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
431 anyhow::Ok(())
432 })
433 }
434}
435
436/// Wraps a shared future to a prompt store so it can be assigned as a context global.
437pub struct GlobalPromptStore(Shared<Task<Result<Entity<PromptStore>, Arc<anyhow::Error>>>>);
438
439impl Global for GlobalPromptStore {}