agent: Use default prompts from prompt library in system prompt (#28915)

Michael Sloan and Danilo Leal created

Related to #28490.

- Default prompts from the prompt library are now included as "user
rules" in the system prompt.
- Presence of these user rules is shown at the beginning of the thread
in the UI.
_ Now uses an `Entity<PromptStore>` instead of an `Arc<PromptStore>`.
Motivation for this is emitting a `PromptsUpdatedEvent`.
- Now disallows concurrent reloading of the system prompt. Before this
change it was possible for reloads to race.

Release Notes:

- agent: Added support for including default prompts from the Prompt
Library as "user rules" in the system prompt.

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>

Change summary

assets/prompts/assistant_system_prompt.hbs             |  15 +
crates/agent/src/active_thread.rs                      | 104 ++++++--
crates/agent/src/agent_diff.rs                         |   4 
crates/agent/src/assistant_panel.rs                    |   2 
crates/agent/src/thread.rs                             |  51 ++-
crates/agent/src/thread_store.rs                       | 141 ++++++++++-
crates/assistant_slash_commands/src/default_command.rs |   4 
crates/assistant_slash_commands/src/prompt_command.rs  |  25 +
crates/eval/src/example.rs                             |   2 
crates/prompt_library/src/prompt_library.rs            |  74 +++--
crates/prompt_store/src/prompt_store.rs                | 132 ++++++----
crates/prompt_store/src/prompts.rs                     |  44 +++
12 files changed, 433 insertions(+), 165 deletions(-)

Detailed changes

assets/prompts/assistant_system_prompt.hbs 🔗

@@ -144,6 +144,19 @@ In Markdown, hash marks signify headings. For example:
 This example is unacceptable because the path is in the wrong place. The path must be directly after the opening backticks.
 </style>
 
+{{#if has_default_user_rules}}
+The user has specified the following rules that should be applied:
+{{#each default_user_rules}}
+
+{{#if title}}
+Rules title: {{title}}
+{{/if}}
+``````
+{{contents}}
+``````
+{{/each}}
+
+{{/if}}
 The user has opened a project that contains the following root directories/files. Whenever you specify a path in the project, it must be a relative path which begins with one of these root directories/files:
 
 {{#each worktrees}}
@@ -151,7 +164,7 @@ The user has opened a project that contains the following root directories/files
 {{/each}}
 {{#if has_rules}}
 
-There are rules that apply to these root directories:
+There are project rules that apply to these root directories:
 {{#each worktrees}}
 {{#if rules_file}}
 

crates/agent/src/active_thread.rs 🔗

@@ -42,6 +42,7 @@ use ui::{
 };
 use util::ResultExt as _;
 use workspace::{OpenOptions, Workspace};
+use zed_actions::assistant::OpenPromptLibrary;
 
 use crate::context_store::ContextStore;
 
@@ -2948,53 +2949,106 @@ impl ActiveThread {
             return div().into_any();
         };
 
+        let default_user_rules_text = if project_context.default_user_rules.is_empty() {
+            None
+        } else if project_context.default_user_rules.len() == 1 {
+            let user_rules = &project_context.default_user_rules[0];
+
+            match user_rules.title.as_ref() {
+                Some(title) => Some(format!("Using \"{title}\" user rule")),
+                None => Some("Using user rule".into()),
+            }
+        } else {
+            Some(format!(
+                "Using {} user rules",
+                project_context.default_user_rules.len()
+            ))
+        };
+
         let rules_files = project_context
             .worktrees
             .iter()
             .filter_map(|worktree| worktree.rules_file.as_ref())
             .collect::<Vec<_>>();
 
-        let label_text = match rules_files.as_slice() {
-            &[] => return div().into_any(),
-            &[rules_file] => {
-                format!("Using {:?} file", rules_file.path_in_worktree)
-            }
-            rules_files => {
-                format!("Using {} rules files", rules_files.len())
-            }
+        let rules_file_text = match rules_files.as_slice() {
+            &[] => None,
+            &[rules_file] => Some(format!(
+                "Using project {:?} file",
+                rules_file.path_in_worktree
+            )),
+            rules_files => Some(format!("Using {} project rules files", rules_files.len())),
         };
 
-        div()
+        if default_user_rules_text.is_none() && rules_file_text.is_none() {
+            return div().into_any();
+        }
+
+        v_flex()
             .pt_2()
             .px_2p5()
-            .child(
-                h_flex()
-                    .w_full()
-                    .gap_0p5()
-                    .child(
+            .gap_1()
+            .when_some(
+                default_user_rules_text,
+                |parent, default_user_rules_text| {
+                    parent.child(
                         h_flex()
-                            .gap_1p5()
+                            .w_full()
                             .child(
                                 Icon::new(IconName::File)
                                     .size(IconSize::XSmall)
                                     .color(Color::Disabled),
                             )
                             .child(
-                                Label::new(label_text)
+                                Label::new(default_user_rules_text)
                                     .size(LabelSize::XSmall)
                                     .color(Color::Muted)
-                                    .buffer_font(cx),
+                                    .truncate()
+                                    .buffer_font(cx)
+                                    .ml_1p5()
+                                    .mr_0p5(),
+                            )
+                            .child(
+                                IconButton::new("open-prompt-library", IconName::ArrowUpRightAlt)
+                                    .shape(ui::IconButtonShape::Square)
+                                    .icon_size(IconSize::XSmall)
+                                    .icon_color(Color::Ignored)
+                                    // TODO: Figure out a way to pass focus handle here so we can display the `OpenPromptLibrary`  keybinding
+                                    .tooltip(Tooltip::text("View User Rules"))
+                                    .on_click(|_event, window, cx| {
+                                        window.dispatch_action(Box::new(OpenPromptLibrary), cx)
+                                    }),
                             ),
                     )
-                    .child(
-                        IconButton::new("open-rule", IconName::ArrowUpRightAlt)
-                            .shape(ui::IconButtonShape::Square)
-                            .icon_size(IconSize::XSmall)
-                            .icon_color(Color::Ignored)
-                            .on_click(cx.listener(Self::handle_open_rules))
-                            .tooltip(Tooltip::text("View Rules")),
-                    ),
+                },
             )
+            .when_some(rules_file_text, |parent, rules_file_text| {
+                parent.child(
+                    h_flex()
+                        .w_full()
+                        .child(
+                            Icon::new(IconName::File)
+                                .size(IconSize::XSmall)
+                                .color(Color::Disabled),
+                        )
+                        .child(
+                            Label::new(rules_file_text)
+                                .size(LabelSize::XSmall)
+                                .color(Color::Muted)
+                                .buffer_font(cx)
+                                .ml_1p5()
+                                .mr_0p5(),
+                        )
+                        .child(
+                            IconButton::new("open-rule", IconName::ArrowUpRightAlt)
+                                .shape(ui::IconButtonShape::Square)
+                                .icon_size(IconSize::XSmall)
+                                .icon_color(Color::Ignored)
+                                .on_click(cx.listener(Self::handle_open_rules))
+                                .tooltip(Tooltip::text("View Rules")),
+                        ),
+                )
+            })
             .into_any()
     }
 

crates/agent/src/agent_diff.rs 🔗

@@ -922,6 +922,7 @@ mod tests {
             language::init(cx);
             Project::init_settings(cx);
             AssistantSettings::register(cx);
+            prompt_store::init(cx);
             thread_store::init(cx);
             workspace::init_settings(cx);
             ThemeSettings::register(cx);
@@ -951,7 +952,8 @@ mod tests {
                     cx,
                 )
             })
-            .await;
+            .await
+            .unwrap();
         let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
         let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
 

crates/agent/src/assistant_panel.rs 🔗

@@ -213,7 +213,7 @@ impl AssistantPanel {
                     let project = workspace.project().clone();
                     ThreadStore::load(project, tools.clone(), prompt_builder.clone(), cx)
                 })?
-                .await;
+                .await?;
 
             let slash_commands = Arc::new(SlashCommandWorkingSet::default());
             let context_store = workspace

crates/agent/src/thread.rs 🔗

@@ -4,7 +4,7 @@ use std::ops::Range;
 use std::sync::Arc;
 use std::time::Instant;
 
-use anyhow::{Context as _, Result, anyhow};
+use anyhow::{Result, anyhow};
 use assistant_settings::AssistantSettings;
 use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
 use chrono::{DateTime, Utc};
@@ -939,7 +939,7 @@ impl Thread {
     pub fn to_completion_request(
         &self,
         request_kind: RequestKind,
-        cx: &App,
+        cx: &mut Context<Self>,
     ) -> LanguageModelRequest {
         let mut request = LanguageModelRequest {
             messages: vec![],
@@ -949,20 +949,33 @@ impl Thread {
         };
 
         if let Some(project_context) = self.project_context.borrow().as_ref() {
-            if let Some(system_prompt) = self
+            match self
                 .prompt_builder
                 .generate_assistant_system_prompt(project_context)
-                .context("failed to generate assistant system prompt")
-                .log_err()
             {
-                request.messages.push(LanguageModelRequestMessage {
-                    role: Role::System,
-                    content: vec![MessageContent::Text(system_prompt)],
-                    cache: true,
-                });
+                Err(err) => {
+                    let message = format!("{err:?}").into();
+                    log::error!("{message}");
+                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
+                        header: "Error generating system prompt".into(),
+                        message,
+                    }));
+                }
+                Ok(system_prompt) => {
+                    request.messages.push(LanguageModelRequestMessage {
+                        role: Role::System,
+                        content: vec![MessageContent::Text(system_prompt)],
+                        cache: true,
+                    });
+                }
             }
         } else {
-            log::error!("project_context not set.")
+            let message = "Context for system prompt unexpectedly not ready.".into();
+            log::error!("{message}");
+            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
+                header: "Error generating system prompt".into(),
+                message,
+            }));
         }
 
         for message in &self.messages {
@@ -2163,7 +2176,7 @@ fn main() {{
         assert_eq!(message.context, expected_context);
 
         // Check message in request
-        let request = thread.read_with(cx, |thread, cx| {
+        let request = thread.update(cx, |thread, cx| {
             thread.to_completion_request(RequestKind::Chat, cx)
         });
 
@@ -2255,7 +2268,7 @@ fn main() {{
         assert!(message3.context.contains("file3.rs"));
 
         // Check entire request to make sure all contexts are properly included
-        let request = thread.read_with(cx, |thread, cx| {
+        let request = thread.update(cx, |thread, cx| {
             thread.to_completion_request(RequestKind::Chat, cx)
         });
 
@@ -2307,7 +2320,7 @@ fn main() {{
         assert_eq!(message.context, "");
 
         // Check message in request
-        let request = thread.read_with(cx, |thread, cx| {
+        let request = thread.update(cx, |thread, cx| {
             thread.to_completion_request(RequestKind::Chat, cx)
         });
 
@@ -2327,7 +2340,7 @@ fn main() {{
         assert_eq!(message2.context, "");
 
         // Check that both messages appear in the request
-        let request = thread.read_with(cx, |thread, cx| {
+        let request = thread.update(cx, |thread, cx| {
             thread.to_completion_request(RequestKind::Chat, cx)
         });
 
@@ -2369,7 +2382,7 @@ fn main() {{
         });
 
         // Create a request and check that it doesn't have a stale buffer warning yet
-        let initial_request = thread.read_with(cx, |thread, cx| {
+        let initial_request = thread.update(cx, |thread, cx| {
             thread.to_completion_request(RequestKind::Chat, cx)
         });
 
@@ -2399,7 +2412,7 @@ fn main() {{
         });
 
         // Create a new request and check for the stale buffer warning
-        let new_request = thread.read_with(cx, |thread, cx| {
+        let new_request = thread.update(cx, |thread, cx| {
             thread.to_completion_request(RequestKind::Chat, cx)
         });
 
@@ -2428,6 +2441,7 @@ fn main() {{
             language::init(cx);
             Project::init_settings(cx);
             AssistantSettings::register(cx);
+            prompt_store::init(cx);
             thread_store::init(cx);
             workspace::init_settings(cx);
             ThemeSettings::register(cx);
@@ -2467,7 +2481,8 @@ fn main() {{
                     cx,
                 )
             })
-            .await;
+            .await
+            .unwrap();
 
         let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
         let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));

crates/agent/src/thread_store.rs 🔗

@@ -12,8 +12,9 @@ use collections::HashMap;
 use context_server::manager::ContextServerManager;
 use context_server::{ContextServerFactoryRegistry, ContextServerTool};
 use fs::Fs;
-use futures::FutureExt as _;
+use futures::channel::{mpsc, oneshot};
 use futures::future::{self, BoxFuture, Shared};
+use futures::{FutureExt as _, StreamExt as _};
 use gpui::{
     App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
     Subscription, Task, prelude::*,
@@ -22,7 +23,10 @@ use heed::Database;
 use heed::types::SerdeBincode;
 use language_model::{LanguageModelToolUseId, Role, TokenUsage};
 use project::{Project, Worktree};
-use prompt_store::{ProjectContext, PromptBuilder, RulesFileContext, WorktreeContext};
+use prompt_store::{
+    DefaultUserRulesContext, ProjectContext, PromptBuilder, PromptStore, PromptsUpdatedEvent,
+    RulesFileContext, WorktreeContext,
+};
 use serde::{Deserialize, Serialize};
 use settings::{Settings as _, SettingsStore};
 use util::ResultExt as _;
@@ -62,6 +66,8 @@ pub struct ThreadStore {
     context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
     threads: Vec<SerializedThreadMetadata>,
     project_context: SharedProjectContext,
+    reload_system_prompt_tx: mpsc::Sender<()>,
+    _reload_system_prompt_task: Task<()>,
     _subscriptions: Vec<Subscription>,
 }
 
@@ -77,12 +83,22 @@ impl ThreadStore {
         tools: Entity<ToolWorkingSet>,
         prompt_builder: Arc<PromptBuilder>,
         cx: &mut App,
-    ) -> Task<Entity<Self>> {
-        let thread_store = cx.new(|cx| Self::new(project, tools, prompt_builder, cx));
-        let reload = thread_store.update(cx, |store, cx| store.reload_system_prompt(cx));
-        cx.foreground_executor().spawn(async move {
-            reload.await;
-            thread_store
+    ) -> Task<Result<Entity<Self>>> {
+        let prompt_store = PromptStore::global(cx);
+        cx.spawn(async move |cx| {
+            let prompt_store = prompt_store.await.ok();
+            let (thread_store, ready_rx) = cx.update(|cx| {
+                let mut option_ready_rx = None;
+                let thread_store = cx.new(|cx| {
+                    let (thread_store, ready_rx) =
+                        Self::new(project, tools, prompt_builder, prompt_store, cx);
+                    option_ready_rx = Some(ready_rx);
+                    thread_store
+                });
+                (thread_store, option_ready_rx.take().unwrap())
+            })?;
+            ready_rx.await?;
+            Ok(thread_store)
         })
     }
 
@@ -90,17 +106,53 @@ impl ThreadStore {
         project: Entity<Project>,
         tools: Entity<ToolWorkingSet>,
         prompt_builder: Arc<PromptBuilder>,
+        prompt_store: Option<Entity<PromptStore>>,
         cx: &mut Context<Self>,
-    ) -> Self {
+    ) -> (Self, oneshot::Receiver<()>) {
         let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
         let context_server_manager = cx.new(|cx| {
             ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
         });
-        let settings_subscription =
+
+        let mut subscriptions = vec![
             cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
                 this.load_default_profile(cx);
-            });
-        let project_subscription = cx.subscribe(&project, Self::handle_project_event);
+            }),
+            cx.subscribe(&project, Self::handle_project_event),
+        ];
+
+        if let Some(prompt_store) = prompt_store.as_ref() {
+            subscriptions.push(cx.subscribe(
+                prompt_store,
+                |this, _prompt_store, PromptsUpdatedEvent, _cx| {
+                    this.enqueue_system_prompt_reload();
+                },
+            ))
+        }
+
+        // This channel and task prevent concurrent and redundant loading of the system prompt.
+        let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
+        let (ready_tx, ready_rx) = oneshot::channel();
+        let mut ready_tx = Some(ready_tx);
+        let reload_system_prompt_task = cx.spawn({
+            async move |thread_store, cx| {
+                loop {
+                    let Some(reload_task) = thread_store
+                        .update(cx, |thread_store, cx| {
+                            thread_store.reload_system_prompt(prompt_store.clone(), cx)
+                        })
+                        .ok()
+                    else {
+                        return;
+                    };
+                    reload_task.await;
+                    if let Some(ready_tx) = ready_tx.take() {
+                        ready_tx.send(()).ok();
+                    }
+                    reload_system_prompt_rx.next().await;
+                }
+            }
+        });
 
         let this = Self {
             project,
@@ -110,23 +162,25 @@ impl ThreadStore {
             context_server_tool_ids: HashMap::default(),
             threads: Vec::new(),
             project_context: SharedProjectContext::default(),
-            _subscriptions: vec![settings_subscription, project_subscription],
+            reload_system_prompt_tx,
+            _reload_system_prompt_task: reload_system_prompt_task,
+            _subscriptions: subscriptions,
         };
         this.load_default_profile(cx);
         this.register_context_server_handlers(cx);
         this.reload(cx).detach_and_log_err(cx);
-        this
+        (this, ready_rx)
     }
 
     fn handle_project_event(
         &mut self,
         _project: Entity<Project>,
         event: &project::Event,
-        cx: &mut Context<Self>,
+        _cx: &mut Context<Self>,
     ) {
         match event {
             project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
-                self.reload_system_prompt(cx).detach();
+                self.enqueue_system_prompt_reload();
             }
             project::Event::WorktreeUpdatedEntries(_, items) => {
                 if items.iter().any(|(path, _, _)| {
@@ -134,16 +188,25 @@ impl ThreadStore {
                         .iter()
                         .any(|name| path.as_ref() == Path::new(name))
                 }) {
-                    self.reload_system_prompt(cx).detach();
+                    self.enqueue_system_prompt_reload();
                 }
             }
             _ => {}
         }
     }
 
-    pub fn reload_system_prompt(&self, cx: &mut Context<Self>) -> Task<()> {
+    fn enqueue_system_prompt_reload(&mut self) {
+        self.reload_system_prompt_tx.try_send(()).ok();
+    }
+
+    // Note that this should only be called from `reload_system_prompt_task`.
+    fn reload_system_prompt(
+        &self,
+        prompt_store: Option<Entity<PromptStore>>,
+        cx: &mut Context<Self>,
+    ) -> Task<()> {
         let project = self.project.read(cx);
-        let tasks = project
+        let worktree_tasks = project
             .visible_worktrees(cx)
             .map(|worktree| {
                 Self::load_worktree_info_for_system_prompt(
@@ -153,10 +216,23 @@ impl ThreadStore {
                 )
             })
             .collect::<Vec<_>>();
+        let default_user_rules_task = match prompt_store {
+            None => Task::ready(vec![]),
+            Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
+                let prompts = prompt_store.default_prompt_metadata();
+                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
+                    let contents = prompt_store.load(prompt_metadata.id, cx);
+                    async move { (contents.await, prompt_metadata) }
+                });
+                cx.background_spawn(future::join_all(load_tasks))
+            }),
+        };
 
         cx.spawn(async move |this, cx| {
-            let results = futures::future::join_all(tasks).await;
-            let worktrees = results
+            let (worktrees, default_user_rules) =
+                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
+
+            let worktrees = worktrees
                 .into_iter()
                 .map(|(worktree, rules_error)| {
                     if let Some(rules_error) = rules_error {
@@ -165,8 +241,29 @@ impl ThreadStore {
                     worktree
                 })
                 .collect::<Vec<_>>();
+
+            let default_user_rules = default_user_rules
+                .into_iter()
+                .flat_map(|(contents, prompt_metadata)| match contents {
+                    Ok(contents) => Some(DefaultUserRulesContext {
+                        title: prompt_metadata.title.map(|title| title.to_string()),
+                        contents,
+                    }),
+                    Err(err) => {
+                        this.update(cx, |_, cx| {
+                            cx.emit(RulesLoadingError {
+                                message: format!("{err:?}").into(),
+                            });
+                        })
+                        .ok();
+                        None
+                    }
+                })
+                .collect::<Vec<_>>();
+
             this.update(cx, |this, _cx| {
-                *this.project_context.0.borrow_mut() = Some(ProjectContext::new(worktrees));
+                *this.project_context.0.borrow_mut() =
+                    Some(ProjectContext::new(worktrees, default_user_rules));
             })
             .ok();
         })

crates/assistant_slash_commands/src/default_command.rs 🔗

@@ -54,9 +54,9 @@ impl SlashCommand for DefaultSlashCommand {
         cx: &mut App,
     ) -> Task<SlashCommandResult> {
         let store = PromptStore::global(cx);
-        cx.background_spawn(async move {
+        cx.spawn(async move |cx| {
             let store = store.await?;
-            let prompts = store.default_prompt_metadata();
+            let prompts = store.read_with(cx, |store, _cx| store.default_prompt_metadata())?;
 
             let mut text = String::new();
             text.push('\n');

crates/assistant_slash_commands/src/prompt_command.rs 🔗

@@ -5,7 +5,7 @@ use assistant_slash_command::{
 };
 use gpui::{Task, WeakEntity};
 use language::{BufferSnapshot, LspAdapterDelegate};
-use prompt_store::PromptStore;
+use prompt_store::{PromptMetadata, PromptStore};
 use std::sync::{Arc, atomic::AtomicBool};
 use ui::prelude::*;
 use workspace::Workspace;
@@ -43,8 +43,11 @@ impl SlashCommand for PromptSlashCommand {
     ) -> Task<Result<Vec<ArgumentCompletion>>> {
         let store = PromptStore::global(cx);
         let query = arguments.to_owned().join(" ");
-        cx.background_spawn(async move {
-            let prompts = store.await?.search(query).await;
+        cx.spawn(async move |cx| {
+            let prompts: Vec<PromptMetadata> = store
+                .await?
+                .read_with(cx, |store, cx| store.search(query, cx))?
+                .await;
             Ok(prompts
                 .into_iter()
                 .filter_map(|prompt| {
@@ -77,14 +80,18 @@ impl SlashCommand for PromptSlashCommand {
 
         let store = PromptStore::global(cx);
         let title = SharedString::from(title.clone());
-        let prompt = cx.background_spawn({
+        let prompt = cx.spawn({
             let title = title.clone();
-            async move {
+            async move |cx| {
                 let store = store.await?;
-                let prompt_id = store
-                    .id_for_title(&title)
-                    .with_context(|| format!("no prompt found with title {:?}", title))?;
-                let body = store.load(prompt_id).await?;
+                let body = store
+                    .read_with(cx, |store, cx| {
+                        let prompt_id = store
+                            .id_for_title(&title)
+                            .with_context(|| format!("no prompt found with title {:?}", title))?;
+                        anyhow::Ok(store.load(prompt_id, cx))
+                    })??
+                    .await?;
                 anyhow::Ok(body)
             }
         });

crates/eval/src/example.rs 🔗

@@ -309,7 +309,7 @@ impl Example {
                 return Err(anyhow!("Setup only mode"));
             }
 
-            let thread_store = thread_store.await;
+            let thread_store = thread_store.await?;
             let thread =
                 thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
 

crates/prompt_library/src/prompt_library.rs 🔗

@@ -136,7 +136,7 @@ pub fn open_prompt_library(
 }
 
 pub struct PromptLibrary {
-    store: Arc<PromptStore>,
+    store: Entity<PromptStore>,
     language_registry: Arc<LanguageRegistry>,
     prompt_editors: HashMap<PromptId, PromptEditor>,
     active_prompt_id: Option<PromptId>,
@@ -158,7 +158,7 @@ struct PromptEditor {
 }
 
 struct PromptPickerDelegate {
-    store: Arc<PromptStore>,
+    store: Entity<PromptStore>,
     selected_index: usize,
     matches: Vec<PromptMetadata>,
 }
@@ -179,8 +179,8 @@ impl PickerDelegate for PromptPickerDelegate {
         self.matches.len()
     }
 
-    fn no_matches_text(&self, _window: &mut Window, _cx: &mut App) -> Option<SharedString> {
-        let text = if self.store.prompt_count() == 0 {
+    fn no_matches_text(&self, _window: &mut Window, cx: &mut App) -> Option<SharedString> {
+        let text = if self.store.read(cx).prompt_count() == 0 {
             "No prompts.".into()
         } else {
             "No prompts found matching your search.".into()
@@ -211,7 +211,7 @@ impl PickerDelegate for PromptPickerDelegate {
         window: &mut Window,
         cx: &mut Context<Picker<Self>>,
     ) -> Task<()> {
-        let search = self.store.search(query);
+        let search = self.store.read(cx).search(query, cx);
         let prev_prompt_id = self.matches.get(self.selected_index).map(|mat| mat.id);
         cx.spawn_in(window, async move |this, cx| {
             let (matches, selected_index) = cx
@@ -339,7 +339,7 @@ impl PickerDelegate for PromptPickerDelegate {
 
 impl PromptLibrary {
     fn new(
-        store: Arc<PromptStore>,
+        store: Entity<PromptStore>,
         language_registry: Arc<LanguageRegistry>,
         inline_assist_delegate: Box<dyn InlineAssistDelegate>,
         make_completion_provider: Arc<dyn Fn() -> Box<dyn CompletionProvider>>,
@@ -398,7 +398,7 @@ impl PromptLibrary {
     pub fn new_prompt(&mut self, window: &mut Window, cx: &mut Context<Self>) {
         // If we already have an untitled prompt, use that instead
         // of creating a new one.
-        if let Some(metadata) = self.store.first() {
+        if let Some(metadata) = self.store.read(cx).first() {
             if metadata.title.is_none() {
                 self.load_prompt(metadata.id, true, window, cx);
                 return;
@@ -406,7 +406,9 @@ impl PromptLibrary {
         }
 
         let prompt_id = PromptId::new();
-        let save = self.store.save(prompt_id, None, false, "".into());
+        let save = self.store.update(cx, |store, cx| {
+            store.save(prompt_id, None, false, "".into(), cx)
+        });
         self.picker
             .update(cx, |picker, cx| picker.refresh(window, cx));
         cx.spawn_in(window, async move |this, cx| {
@@ -430,7 +432,7 @@ impl PromptLibrary {
             return;
         }
 
-        let prompt_metadata = self.store.metadata(prompt_id).unwrap();
+        let prompt_metadata = self.store.read(cx).metadata(prompt_id).unwrap();
         let prompt_editor = self.prompt_editors.get_mut(&prompt_id).unwrap();
         let title = prompt_editor.title_editor.read(cx).text(cx);
         let body = prompt_editor.body_editor.update(cx, |editor, cx| {
@@ -465,10 +467,13 @@ impl PromptLibrary {
                             } else {
                                 Some(SharedString::from(title))
                             };
-                            store
-                                .save(prompt_id, title, prompt_metadata.default, body)
-                                .await
-                                .log_err();
+                            cx.update(|_window, cx| {
+                                store.update(cx, |store, cx| {
+                                    store.save(prompt_id, title, prompt_metadata.default, body, cx)
+                                })
+                            })?
+                            .await
+                            .log_err();
                             this.update_in(cx, |this, window, cx| {
                                 this.picker
                                     .update(cx, |picker, cx| picker.refresh(window, cx));
@@ -521,14 +526,21 @@ impl PromptLibrary {
         window: &mut Window,
         cx: &mut Context<Self>,
     ) {
-        if let Some(prompt_metadata) = self.store.metadata(prompt_id) {
-            self.store
-                .save_metadata(prompt_id, prompt_metadata.title, !prompt_metadata.default)
-                .detach_and_log_err(cx);
-            self.picker
-                .update(cx, |picker, cx| picker.refresh(window, cx));
-            cx.notify();
-        }
+        self.store.update(cx, move |store, cx| {
+            if let Some(prompt_metadata) = store.metadata(prompt_id) {
+                store
+                    .save_metadata(
+                        prompt_id,
+                        prompt_metadata.title,
+                        !prompt_metadata.default,
+                        cx,
+                    )
+                    .detach_and_log_err(cx);
+            }
+        });
+        self.picker
+            .update(cx, |picker, cx| picker.refresh(window, cx));
+        cx.notify();
     }
 
     pub fn load_prompt(
@@ -545,9 +557,9 @@ impl PromptLibrary {
                     .update(cx, |editor, cx| window.focus(&editor.focus_handle(cx)));
             }
             self.set_active_prompt(Some(prompt_id), window, cx);
-        } else if let Some(prompt_metadata) = self.store.metadata(prompt_id) {
+        } else if let Some(prompt_metadata) = self.store.read(cx).metadata(prompt_id) {
             let language_registry = self.language_registry.clone();
-            let prompt = self.store.load(prompt_id);
+            let prompt = self.store.read(cx).load(prompt_id, cx);
             let make_completion_provider = self.make_completion_provider.clone();
             self.pending_load = cx.spawn_in(window, async move |this, cx| {
                 let prompt = prompt.await;
@@ -673,7 +685,7 @@ impl PromptLibrary {
         window: &mut Window,
         cx: &mut Context<Self>,
     ) {
-        if let Some(metadata) = self.store.metadata(prompt_id) {
+        if let Some(metadata) = self.store.read(cx).metadata(prompt_id) {
             let confirmation = window.prompt(
                 PromptLevel::Warning,
                 &format!(
@@ -692,7 +704,9 @@ impl PromptLibrary {
                             this.set_active_prompt(None, window, cx);
                         }
                         this.prompt_editors.remove(&prompt_id);
-                        this.store.delete(prompt_id).detach_and_log_err(cx);
+                        this.store
+                            .update(cx, |store, cx| store.delete(prompt_id, cx))
+                            .detach_and_log_err(cx);
                         this.picker
                             .update(cx, |picker, cx| picker.refresh(window, cx));
                         cx.notify();
@@ -736,9 +750,9 @@ impl PromptLibrary {
 
             let new_id = PromptId::new();
             let body = prompt.body_editor.read(cx).text(cx);
-            let save = self
-                .store
-                .save(new_id, Some(title.into()), false, body.into());
+            let save = self.store.update(cx, |store, cx| {
+                store.save(new_id, Some(title.into()), false, body.into(), cx)
+            });
             self.picker
                 .update(cx, |picker, cx| picker.refresh(window, cx));
             cx.spawn_in(window, async move |this, cx| {
@@ -968,7 +982,7 @@ impl PromptLibrary {
             .flex_none()
             .min_w_64()
             .children(self.active_prompt_id.and_then(|prompt_id| {
-                let prompt_metadata = self.store.metadata(prompt_id)?;
+                let prompt_metadata = self.store.read(cx).metadata(prompt_id)?;
                 let prompt_editor = &self.prompt_editors[&prompt_id];
                 let focus_handle = prompt_editor.body_editor.focus_handle(cx);
                 let model = LanguageModelRegistry::read_global(cx)
@@ -1238,7 +1252,7 @@ impl Render for PromptLibrary {
             .text_color(theme.colors().text)
             .child(self.render_prompt_list(cx))
             .map(|el| {
-                if self.store.prompt_count() == 0 {
+                if self.store.read(cx).prompt_count() == 0 {
                     el.child(
                         v_flex()
                             .w_2_3()

crates/prompt_store/src/prompt_store.rs 🔗

@@ -4,9 +4,11 @@ use anyhow::{Result, anyhow};
 use chrono::{DateTime, Utc};
 use collections::HashMap;
 use futures::FutureExt as _;
-use futures::future::{self, BoxFuture, Shared};
+use futures::future::Shared;
 use fuzzy::StringMatchCandidate;
-use gpui::{App, BackgroundExecutor, Global, ReadGlobal, SharedString, Task};
+use gpui::{
+    App, AppContext, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Task,
+};
 use heed::{
     Database, RoTxn,
     types::{SerdeBincode, SerdeJson, Str},
@@ -29,11 +31,16 @@ use uuid::Uuid;
 /// a shared future to a global.
 pub fn init(cx: &mut App) {
     let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb");
-    let prompt_store_future = PromptStore::new(db_path, cx.background_executor().clone())
-        .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
-        .boxed()
+    let prompt_store_task = PromptStore::new(db_path, cx);
+    let prompt_store_entity_task = cx
+        .spawn(async move |cx| {
+            prompt_store_task
+                .await
+                .and_then(|prompt_store| cx.new(|_cx| prompt_store))
+                .map_err(Arc::new)
+        })
         .shared();
-    cx.set_global(GlobalPromptStore(prompt_store_future))
+    cx.set_global(GlobalPromptStore(prompt_store_entity_task))
 }
 
 #[derive(Clone, Debug, Serialize, Deserialize)]
@@ -64,13 +71,16 @@ impl PromptId {
 }
 
 pub struct PromptStore {
-    executor: BackgroundExecutor,
     env: heed::Env,
     metadata_cache: RwLock<MetadataCache>,
     metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
     bodies: Database<SerdeJson<PromptId>, Str>,
 }
 
+pub struct PromptsUpdatedEvent;
+
+impl EventEmitter<PromptsUpdatedEvent> for PromptStore {}
+
 #[derive(Default)]
 struct MetadataCache {
     metadata: Vec<PromptMetadata>,
@@ -117,49 +127,45 @@ impl MetadataCache {
 }
 
 impl PromptStore {
-    pub fn global(cx: &App) -> impl Future<Output = Result<Arc<Self>>> + use<> {
+    pub fn global(cx: &App) -> impl Future<Output = Result<Entity<Self>>> + use<> {
         let store = GlobalPromptStore::global(cx).0.clone();
         async move { store.await.map_err(|err| anyhow!(err)) }
     }
 
-    pub fn new(db_path: PathBuf, executor: BackgroundExecutor) -> Task<Result<Self>> {
-        executor.spawn({
-            let executor = executor.clone();
-            async move {
-                std::fs::create_dir_all(&db_path)?;
+    pub fn new(db_path: PathBuf, cx: &App) -> Task<Result<Self>> {
+        cx.background_spawn(async move {
+            std::fs::create_dir_all(&db_path)?;
 
-                let db_env = unsafe {
-                    heed::EnvOpenOptions::new()
-                        .map_size(1024 * 1024 * 1024) // 1GB
-                        .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
-                        .open(db_path)?
-                };
+            let db_env = unsafe {
+                heed::EnvOpenOptions::new()
+                    .map_size(1024 * 1024 * 1024) // 1GB
+                    .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
+                    .open(db_path)?
+            };
 
-                let mut txn = db_env.write_txn()?;
-                let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
-                let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
+            let mut txn = db_env.write_txn()?;
+            let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
+            let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
 
-                // Remove edit workflow prompt, as we decided to opt into it using
-                // a slash command instead.
-                metadata.delete(&mut txn, &PromptId::EditWorkflow).ok();
-                bodies.delete(&mut txn, &PromptId::EditWorkflow).ok();
+            // Remove edit workflow prompt, as we decided to opt into it using
+            // a slash command instead.
+            metadata.delete(&mut txn, &PromptId::EditWorkflow).ok();
+            bodies.delete(&mut txn, &PromptId::EditWorkflow).ok();
 
-                txn.commit()?;
+            txn.commit()?;
 
-                Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
+            Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
 
-                let txn = db_env.read_txn()?;
-                let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
-                txn.commit()?;
+            let txn = db_env.read_txn()?;
+            let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
+            txn.commit()?;
 
-                Ok(PromptStore {
-                    executor,
-                    env: db_env,
-                    metadata_cache: RwLock::new(metadata_cache),
-                    metadata,
-                    bodies,
-                })
-            }
+            Ok(PromptStore {
+                env: db_env,
+                metadata_cache: RwLock::new(metadata_cache),
+                metadata,
+                bodies,
+            })
         })
     }
 
@@ -237,10 +243,10 @@ impl PromptStore {
         Ok(())
     }
 
-    pub fn load(&self, id: PromptId) -> Task<Result<String>> {
+    pub fn load(&self, id: PromptId, cx: &App) -> Task<Result<String>> {
         let env = self.env.clone();
         let bodies = self.bodies;
-        self.executor.spawn(async move {
+        cx.background_spawn(async move {
             let txn = env.read_txn()?;
             let mut prompt = bodies
                 .get(&txn, &id)?
@@ -262,21 +268,27 @@ impl PromptStore {
             .collect::<Vec<_>>();
     }
 
-    pub fn delete(&self, id: PromptId) -> Task<Result<()>> {
+    pub fn delete(&self, id: PromptId, cx: &Context<Self>) -> Task<Result<()>> {
         self.metadata_cache.write().remove(id);
 
         let db_connection = self.env.clone();
         let bodies = self.bodies;
         let metadata = self.metadata;
 
-        self.executor.spawn(async move {
+        let task = cx.background_spawn(async move {
             let mut txn = db_connection.write_txn()?;
 
             metadata.delete(&mut txn, &id)?;
             bodies.delete(&mut txn, &id)?;
 
             txn.commit()?;
-            Ok(())
+            anyhow::Ok(())
+        });
+
+        cx.spawn(async move |this, cx| {
+            task.await?;
+            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
+            anyhow::Ok(())
         })
     }
 
@@ -302,10 +314,10 @@ impl PromptStore {
         Some(metadata.id)
     }
 
-    pub fn search(&self, query: String) -> Task<Vec<PromptMetadata>> {
+    pub fn search(&self, query: String, cx: &App) -> Task<Vec<PromptMetadata>> {
         let cached_metadata = self.metadata_cache.read().metadata.clone();
-        let executor = self.executor.clone();
-        self.executor.spawn(async move {
+        let executor = cx.background_executor().clone();
+        cx.background_spawn(async move {
             let mut matches = if query.is_empty() {
                 cached_metadata
             } else {
@@ -341,6 +353,7 @@ impl PromptStore {
         title: Option<SharedString>,
         default: bool,
         body: Rope,
+        cx: &Context<Self>,
     ) -> Task<Result<()>> {
         if id.is_built_in() {
             return Task::ready(Err(anyhow!("built-in prompts cannot be saved")));
@@ -358,7 +371,7 @@ impl PromptStore {
         let bodies = self.bodies;
         let metadata = self.metadata;
 
-        self.executor.spawn(async move {
+        let task = cx.background_spawn(async move {
             let mut txn = db_connection.write_txn()?;
 
             metadata.put(&mut txn, &id, &prompt_metadata)?;
@@ -366,7 +379,13 @@ impl PromptStore {
 
             txn.commit()?;
 
-            Ok(())
+            anyhow::Ok(())
+        });
+
+        cx.spawn(async move |this, cx| {
+            task.await?;
+            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
+            anyhow::Ok(())
         })
     }
 
@@ -375,6 +394,7 @@ impl PromptStore {
         id: PromptId,
         mut title: Option<SharedString>,
         default: bool,
+        cx: &Context<Self>,
     ) -> Task<Result<()>> {
         let mut cache = self.metadata_cache.write();
 
@@ -397,19 +417,23 @@ impl PromptStore {
         let db_connection = self.env.clone();
         let metadata = self.metadata;
 
-        self.executor.spawn(async move {
+        let task = cx.background_spawn(async move {
             let mut txn = db_connection.write_txn()?;
             metadata.put(&mut txn, &id, &prompt_metadata)?;
             txn.commit()?;
 
-            Ok(())
+            anyhow::Ok(())
+        });
+
+        cx.spawn(async move |this, cx| {
+            task.await?;
+            this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
+            anyhow::Ok(())
         })
     }
 }
 
 /// Wraps a shared future to a prompt store so it can be assigned as a context global.
-pub struct GlobalPromptStore(
-    Shared<BoxFuture<'static, Result<Arc<PromptStore>, Arc<anyhow::Error>>>>,
-);
+pub struct GlobalPromptStore(Shared<Task<Result<Entity<PromptStore>, Arc<anyhow::Error>>>>);
 
 impl Global for GlobalPromptStore {}

crates/prompt_store/src/prompts.rs 🔗

@@ -19,20 +19,29 @@ use util::{ResultExt, get_system_shell};
 #[derive(Debug, Clone, Serialize)]
 pub struct ProjectContext {
     pub worktrees: Vec<WorktreeContext>,
+    /// Whether any worktree has a rules_file. Provided as a field because handlebars can't do this.
     pub has_rules: bool,
+    pub default_user_rules: Vec<DefaultUserRulesContext>,
+    /// `!default_user_rules.is_empty()` - provided as a field because handlebars can't do this.
+    pub has_default_user_rules: bool,
     pub os: String,
     pub arch: String,
     pub shell: String,
 }
 
 impl ProjectContext {
-    pub fn new(worktrees: Vec<WorktreeContext>) -> Self {
+    pub fn new(
+        worktrees: Vec<WorktreeContext>,
+        default_user_rules: Vec<DefaultUserRulesContext>,
+    ) -> Self {
         let has_rules = worktrees
             .iter()
             .any(|worktree| worktree.rules_file.is_some());
         Self {
             worktrees,
             has_rules,
+            has_default_user_rules: !default_user_rules.is_empty(),
+            default_user_rules,
             os: std::env::consts::OS.to_string(),
             arch: std::env::consts::ARCH.to_string(),
             shell: get_system_shell(),
@@ -40,6 +49,12 @@ impl ProjectContext {
     }
 }
 
+#[derive(Debug, Clone, Serialize)]
+pub struct DefaultUserRulesContext {
+    pub title: Option<String>,
+    pub contents: String,
+}
+
 #[derive(Debug, Clone, Serialize)]
 pub struct WorktreeContext {
     pub root_name: String,
@@ -377,3 +392,30 @@ impl PromptBuilder {
         self.handlebars.lock().render("suggest_edits", &())
     }
 }
+
+#[cfg(test)]
+mod test {
+    use super::*;
+
+    #[test]
+    fn test_assistant_system_prompt_renders() {
+        let worktrees = vec![WorktreeContext {
+            root_name: "path".into(),
+            abs_path: Path::new("/some/path").into(),
+            rules_file: Some(RulesFileContext {
+                path_in_worktree: Path::new(".rules").into(),
+                abs_path: Path::new("/some/path/.rules").into(),
+                text: "".into(),
+            }),
+        }];
+        let default_user_rules = vec![DefaultUserRulesContext {
+            title: Some("Rules title".into()),
+            contents: "Rules contents".into(),
+        }];
+        let project_context = ProjectContext::new(worktrees, default_user_rules);
+        PromptBuilder::new(None)
+            .unwrap()
+            .generate_assistant_system_prompt(&project_context)
+            .unwrap();
+    }
+}