prompt_library.rs

  1use fs::Fs;
  2use futures::StreamExt;
  3use gpui::{AppContext, DismissEvent, EventEmitter, FocusHandle, FocusableView, Render};
  4use parking_lot::RwLock;
  5use serde::{Deserialize, Serialize};
  6use std::collections::HashMap;
  7use std::sync::Arc;
  8use ui::{prelude::*, Checkbox, ModalHeader};
  9use util::{paths::PROMPTS_DIR, ResultExt};
 10use workspace::ModalView;
 11
 12pub struct PromptLibraryState {
 13    /// The default prompt all assistant contexts will start with
 14    _system_prompt: String,
 15    /// All [UserPrompt]s loaded into the library
 16    prompts: HashMap<String, UserPrompt>,
 17    /// Prompts included in the default prompt
 18    default_prompts: Vec<String>,
 19    /// Prompts that have a pending update that hasn't been applied yet
 20    _updateable_prompts: Vec<String>,
 21    /// Prompts that have been changed since they were loaded
 22    /// and can be reverted to their original state
 23    _revertable_prompts: Vec<String>,
 24    version: usize,
 25}
 26
 27pub struct PromptLibrary {
 28    state: RwLock<PromptLibraryState>,
 29}
 30
 31impl Default for PromptLibrary {
 32    fn default() -> Self {
 33        Self::new()
 34    }
 35}
 36
 37impl PromptLibrary {
 38    fn new() -> Self {
 39        Self {
 40            state: RwLock::new(PromptLibraryState {
 41                _system_prompt: String::new(),
 42                prompts: HashMap::new(),
 43                default_prompts: Vec::new(),
 44                _updateable_prompts: Vec::new(),
 45                _revertable_prompts: Vec::new(),
 46                version: 0,
 47            }),
 48        }
 49    }
 50
 51    pub async fn init(fs: Arc<dyn Fs>) -> anyhow::Result<Self> {
 52        let prompt_library = PromptLibrary::new();
 53        prompt_library.load_prompts(fs)?;
 54        Ok(prompt_library)
 55    }
 56
 57    fn load_prompts(&self, fs: Arc<dyn Fs>) -> anyhow::Result<()> {
 58        let prompts = futures::executor::block_on(UserPrompt::list(fs))?;
 59        let prompts_with_ids = prompts
 60            .clone()
 61            .into_iter()
 62            .map(|prompt| {
 63                let id = uuid::Uuid::new_v4().to_string();
 64                (id, prompt)
 65            })
 66            .collect::<Vec<_>>();
 67        let mut state = self.state.write();
 68        state.prompts.extend(prompts_with_ids);
 69        state.version += 1;
 70
 71        Ok(())
 72    }
 73
 74    pub fn default_prompt(&self) -> Option<String> {
 75        let state = self.state.read();
 76
 77        if state.default_prompts.is_empty() {
 78            None
 79        } else {
 80            Some(self.join_default_prompts())
 81        }
 82    }
 83
 84    pub fn add_prompt_to_default(&self, prompt_id: String) -> anyhow::Result<()> {
 85        let mut state = self.state.write();
 86
 87        if !state.default_prompts.contains(&prompt_id) && state.prompts.contains_key(&prompt_id) {
 88            state.default_prompts.push(prompt_id);
 89            state.version += 1;
 90        }
 91
 92        Ok(())
 93    }
 94
 95    pub fn remove_prompt_from_default(&self, prompt_id: String) -> anyhow::Result<()> {
 96        let mut state = self.state.write();
 97
 98        state.default_prompts.retain(|id| id != &prompt_id);
 99        state.version += 1;
100        Ok(())
101    }
102
103    fn join_default_prompts(&self) -> String {
104        let state = self.state.read();
105        let active_prompt_ids = state.default_prompts.to_vec();
106
107        active_prompt_ids
108            .iter()
109            .filter_map(|id| state.prompts.get(id).map(|p| p.prompt.clone()))
110            .collect::<Vec<_>>()
111            .join("\n\n---\n\n")
112    }
113
114    #[allow(unused)]
115    pub fn prompts(&self) -> Vec<UserPrompt> {
116        let state = self.state.read();
117        state.prompts.values().cloned().collect()
118    }
119
120    pub fn prompts_with_ids(&self) -> Vec<(String, UserPrompt)> {
121        let state = self.state.read();
122        state
123            .prompts
124            .iter()
125            .map(|(id, prompt)| (id.clone(), prompt.clone()))
126            .collect()
127    }
128
129    pub fn _default_prompts(&self) -> Vec<UserPrompt> {
130        let state = self.state.read();
131        state
132            .default_prompts
133            .iter()
134            .filter_map(|id| state.prompts.get(id).cloned())
135            .collect()
136    }
137
138    pub fn default_prompt_ids(&self) -> Vec<String> {
139        let state = self.state.read();
140        state.default_prompts.clone()
141    }
142}
143
144/// A custom prompt that can be loaded into the prompt library
145///
146/// Example:
147///
148/// ```json
149/// {
150///   "title": "Foo",
151///   "version": "1.0",
152///   "author": "Jane Kim <jane@kim.com>",
153///   "languages": ["*"], // or ["rust", "python", "javascript"] etc...
154///   "prompt": "bar"
155/// }
156#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
157pub struct UserPrompt {
158    version: String,
159    title: String,
160    author: String,
161    languages: Vec<String>,
162    prompt: String,
163}
164
165impl UserPrompt {
166    async fn list(fs: Arc<dyn Fs>) -> anyhow::Result<Vec<Self>> {
167        fs.create_dir(&PROMPTS_DIR).await?;
168
169        let mut paths = fs.read_dir(&PROMPTS_DIR).await?;
170        let mut prompts = Vec::new();
171
172        while let Some(path_result) = paths.next().await {
173            let path = match path_result {
174                Ok(p) => p,
175                Err(e) => {
176                    eprintln!("Error reading path: {:?}", e);
177                    continue;
178                }
179            };
180
181            if path.extension() == Some(std::ffi::OsStr::new("json")) {
182                match fs.load(&path).await {
183                    Ok(content) => {
184                        let user_prompt: UserPrompt =
185                            serde_json::from_str(&content).map_err(|e| {
186                                anyhow::anyhow!("Failed to deserialize UserPrompt: {}", e)
187                            })?;
188
189                        prompts.push(user_prompt);
190                    }
191                    Err(e) => eprintln!("Failed to load file {}: {}", path.display(), e),
192                }
193            }
194        }
195
196        Ok(prompts)
197    }
198}
199
200pub struct PromptManager {
201    focus_handle: FocusHandle,
202    prompt_library: Arc<PromptLibrary>,
203    active_prompt: Option<String>,
204}
205
206impl PromptManager {
207    pub fn new(prompt_library: Arc<PromptLibrary>, cx: &mut WindowContext) -> Self {
208        let focus_handle = cx.focus_handle();
209        Self {
210            focus_handle,
211            prompt_library,
212            active_prompt: None,
213        }
214    }
215
216    pub fn set_active_prompt(&mut self, prompt_id: Option<String>) {
217        self.active_prompt = prompt_id;
218    }
219
220    fn dismiss(&mut self, _: &menu::Cancel, cx: &mut ViewContext<Self>) {
221        cx.emit(DismissEvent);
222    }
223}
224
225impl Render for PromptManager {
226    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
227        let prompt_library = self.prompt_library.clone();
228        let prompts = prompt_library
229            .clone()
230            .prompts_with_ids()
231            .clone()
232            .into_iter()
233            .collect::<Vec<_>>();
234
235        let active_prompt = self.active_prompt.as_ref().and_then(|id| {
236            prompt_library
237                .prompts_with_ids()
238                .iter()
239                .find(|(prompt_id, _)| prompt_id == id)
240                .map(|(_, prompt)| prompt.clone())
241        });
242
243        v_flex()
244            .key_context("PromptManager")
245            .track_focus(&self.focus_handle)
246            .on_action(cx.listener(Self::dismiss))
247            .elevation_3(cx)
248            .size_full()
249            .flex_none()
250            .w(rems(54.))
251            .h(rems(40.))
252            .overflow_hidden()
253            .child(
254                ModalHeader::new()
255                    .headline("Prompt Library")
256                    .show_dismiss_button(true),
257            )
258            .child(
259                h_flex()
260                    .flex_grow()
261                    .overflow_hidden()
262                    .border_t_1()
263                    .border_color(cx.theme().colors().border)
264                    .child(
265                        div()
266                            .id("prompt-preview")
267                            .overflow_y_scroll()
268                            .h_full()
269                            .min_w_64()
270                            .max_w_1_2()
271                            .child(
272                                v_flex()
273                                    .justify_start()
274                                    .py(Spacing::Medium.rems(cx))
275                                    .px(Spacing::Large.rems(cx))
276                                    .bg(cx.theme().colors().surface_background)
277                                    .when_else(
278                                        !prompts.is_empty(),
279                                        |with_items| {
280                                            with_items.children(prompts.into_iter().map(
281                                                |(id, prompt)| {
282                                                    let prompt_library = prompt_library.clone();
283                                                    let prompt = prompt.clone();
284                                                    let prompt_id = id.clone();
285                                                    let shared_string_id: SharedString =
286                                                        id.clone().into();
287
288                                                    let default_prompt_ids =
289                                                        prompt_library.clone().default_prompt_ids();
290                                                    let is_default =
291                                                        default_prompt_ids.contains(&id);
292                                                    // We'll use this for conditionally enabled prompts
293                                                    // like those loaded only for certain languages
294                                                    let is_conditional = false;
295                                                    let selection =
296                                                        match (is_default, is_conditional) {
297                                                            (_, true) => Selection::Indeterminate,
298                                                            (true, _) => Selection::Selected,
299                                                            (false, _) => Selection::Unselected,
300                                                        };
301
302                                                    v_flex()
303                                                    .id(ElementId::Name(
304                                                        format!("prompt-{}", shared_string_id)
305                                                            .into(),
306                                                    ))
307                                                    .p(Spacing::Small.rems(cx))
308
309                                                    .on_click(cx.listener({
310                                                        let prompt_id = prompt_id.clone();
311                                                        move |this, _event, _cx| {
312                                                            this.set_active_prompt(Some(
313                                                                prompt_id.clone(),
314                                                            ));
315                                                        }
316                                                    }))
317                                                    .child(
318                                                        h_flex()
319                                                            .justify_between()
320                                                            .child(
321                                                                h_flex()
322                                                                    .gap(Spacing::Large.rems(cx))
323                                                                    .child(
324                                                                        Checkbox::new(
325                                                                            shared_string_id,
326                                                                            selection,
327                                                                        )
328                                                                        .on_click(move |_, _cx| {
329                                                                            if is_default {
330                                                                                prompt_library
331                                                                        .clone()
332                                                                        .remove_prompt_from_default(
333                                                                            prompt_id.clone(),
334                                                                        )
335                                                                        .log_err();
336                                                                            } else {
337                                                                                prompt_library
338                                                                            .clone()
339                                                                            .add_prompt_to_default(
340                                                                                prompt_id.clone(),
341                                                                            )
342                                                                            .log_err();
343                                                                            }
344                                                                        }),
345                                                                    )
346                                                                    .child(Label::new(
347                                                                        prompt.title,
348                                                                    )),
349                                                            )
350                                                            .child(div()),
351                                                    )
352                                                },
353                                            ))
354                                        },
355                                        |no_items| {
356                                            no_items.child(
357                                                Label::new("No prompts").color(Color::Placeholder),
358                                            )
359                                        },
360                                    ),
361                            ),
362                    )
363                    .child(
364                        div()
365                            .id("prompt-preview")
366                            .overflow_y_scroll()
367                            .border_l_1()
368                            .border_color(cx.theme().colors().border)
369                            .size_full()
370                            .flex_none()
371                            .child(
372                                v_flex()
373                                    .justify_start()
374                                    .py(Spacing::Medium.rems(cx))
375                                    .px(Spacing::Large.rems(cx))
376                                    .gap(Spacing::Large.rems(cx))
377                                    .when_else(
378                                        active_prompt.is_some(),
379                                        |with_prompt| {
380                                            let active_prompt = active_prompt.as_ref().unwrap();
381                                            with_prompt
382                                                .child(
383                                                    v_flex()
384                                                        .gap_0p5()
385                                                        .child(
386                                                            Headline::new(
387                                                                active_prompt.title.clone(),
388                                                            )
389                                                            .size(HeadlineSize::XSmall),
390                                                        )
391                                                        .child(
392                                                            h_flex()
393                                                                .child(
394                                                                    Label::new(
395                                                                        active_prompt
396                                                                            .author
397                                                                            .clone(),
398                                                                    )
399                                                                    .size(LabelSize::XSmall)
400                                                                    .color(Color::Muted),
401                                                                )
402                                                                .child(
403                                                                    Label::new(
404                                                                        if active_prompt
405                                                                            .languages
406                                                                            .is_empty()
407                                                                            || active_prompt
408                                                                                .languages[0]
409                                                                                == "*"
410                                                                        {
411                                                                            " · Global".to_string()
412                                                                        } else {
413                                                                            format!(
414                                                                                " · {}",
415                                                                                active_prompt
416                                                                                    .languages
417                                                                                    .join(", ")
418                                                                            )
419                                                                        },
420                                                                    )
421                                                                    .size(LabelSize::XSmall)
422                                                                    .color(Color::Muted),
423                                                                ),
424                                                        ),
425                                                )
426                                                .child(
427                                                    div()
428                                                        .w_full()
429                                                        .max_w(rems(30.))
430                                                        .text_ui(cx)
431                                                        .child(active_prompt.prompt.clone()),
432                                                )
433                                        },
434                                        |without_prompt| {
435                                            without_prompt.justify_center().items_center().child(
436                                                Label::new("Select a prompt to view details.")
437                                                    .color(Color::Placeholder),
438                                            )
439                                        },
440                                    ),
441                            ),
442                    ),
443            )
444    }
445}
446
447impl EventEmitter<DismissEvent> for PromptManager {}
448impl ModalView for PromptManager {}
449
450impl FocusableView for PromptManager {
451    fn focus_handle(&self, _cx: &AppContext) -> gpui::FocusHandle {
452        self.focus_handle.clone()
453    }
454}