Detailed changes
@@ -144,6 +144,19 @@ In Markdown, hash marks signify headings. For example:
This example is unacceptable because the path is in the wrong place. The path must be directly after the opening backticks.
</style>
+{{#if has_default_user_rules}}
+The user has specified the following rules that should be applied:
+{{#each default_user_rules}}
+
+{{#if title}}
+Rules title: {{title}}
+{{/if}}
+``````
+{{contents}}
+``````
+{{/each}}
+
+{{/if}}
The user has opened a project that contains the following root directories/files. Whenever you specify a path in the project, it must be a relative path which begins with one of these root directories/files:
{{#each worktrees}}
@@ -151,7 +164,7 @@ The user has opened a project that contains the following root directories/files
{{/each}}
{{#if has_rules}}
-There are rules that apply to these root directories:
+There are project rules that apply to these root directories:
{{#each worktrees}}
{{#if rules_file}}
@@ -42,6 +42,7 @@ use ui::{
};
use util::ResultExt as _;
use workspace::{OpenOptions, Workspace};
+use zed_actions::assistant::OpenPromptLibrary;
use crate::context_store::ContextStore;
@@ -2948,53 +2949,106 @@ impl ActiveThread {
return div().into_any();
};
+ let default_user_rules_text = if project_context.default_user_rules.is_empty() {
+ None
+ } else if project_context.default_user_rules.len() == 1 {
+ let user_rules = &project_context.default_user_rules[0];
+
+ match user_rules.title.as_ref() {
+ Some(title) => Some(format!("Using \"{title}\" user rule")),
+ None => Some("Using user rule".into()),
+ }
+ } else {
+ Some(format!(
+ "Using {} user rules",
+ project_context.default_user_rules.len()
+ ))
+ };
+
let rules_files = project_context
.worktrees
.iter()
.filter_map(|worktree| worktree.rules_file.as_ref())
.collect::<Vec<_>>();
- let label_text = match rules_files.as_slice() {
- &[] => return div().into_any(),
- &[rules_file] => {
- format!("Using {:?} file", rules_file.path_in_worktree)
- }
- rules_files => {
- format!("Using {} rules files", rules_files.len())
- }
+ let rules_file_text = match rules_files.as_slice() {
+ &[] => None,
+ &[rules_file] => Some(format!(
+ "Using project {:?} file",
+ rules_file.path_in_worktree
+ )),
+ rules_files => Some(format!("Using {} project rules files", rules_files.len())),
};
- div()
+ if default_user_rules_text.is_none() && rules_file_text.is_none() {
+ return div().into_any();
+ }
+
+ v_flex()
.pt_2()
.px_2p5()
- .child(
- h_flex()
- .w_full()
- .gap_0p5()
- .child(
+ .gap_1()
+ .when_some(
+ default_user_rules_text,
+ |parent, default_user_rules_text| {
+ parent.child(
h_flex()
- .gap_1p5()
+ .w_full()
.child(
Icon::new(IconName::File)
.size(IconSize::XSmall)
.color(Color::Disabled),
)
.child(
- Label::new(label_text)
+ Label::new(default_user_rules_text)
.size(LabelSize::XSmall)
.color(Color::Muted)
- .buffer_font(cx),
+ .truncate()
+ .buffer_font(cx)
+ .ml_1p5()
+ .mr_0p5(),
+ )
+ .child(
+ IconButton::new("open-prompt-library", IconName::ArrowUpRightAlt)
+ .shape(ui::IconButtonShape::Square)
+ .icon_size(IconSize::XSmall)
+ .icon_color(Color::Ignored)
+ // TODO: Figure out a way to pass focus handle here so we can display the `OpenPromptLibrary` keybinding
+ .tooltip(Tooltip::text("View User Rules"))
+ .on_click(|_event, window, cx| {
+ window.dispatch_action(Box::new(OpenPromptLibrary), cx)
+ }),
),
)
- .child(
- IconButton::new("open-rule", IconName::ArrowUpRightAlt)
- .shape(ui::IconButtonShape::Square)
- .icon_size(IconSize::XSmall)
- .icon_color(Color::Ignored)
- .on_click(cx.listener(Self::handle_open_rules))
- .tooltip(Tooltip::text("View Rules")),
- ),
+ },
)
+ .when_some(rules_file_text, |parent, rules_file_text| {
+ parent.child(
+ h_flex()
+ .w_full()
+ .child(
+ Icon::new(IconName::File)
+ .size(IconSize::XSmall)
+ .color(Color::Disabled),
+ )
+ .child(
+ Label::new(rules_file_text)
+ .size(LabelSize::XSmall)
+ .color(Color::Muted)
+ .buffer_font(cx)
+ .ml_1p5()
+ .mr_0p5(),
+ )
+ .child(
+ IconButton::new("open-rule", IconName::ArrowUpRightAlt)
+ .shape(ui::IconButtonShape::Square)
+ .icon_size(IconSize::XSmall)
+ .icon_color(Color::Ignored)
+ .on_click(cx.listener(Self::handle_open_rules))
+ .tooltip(Tooltip::text("View Rules")),
+ ),
+ )
+ })
.into_any()
}
@@ -922,6 +922,7 @@ mod tests {
language::init(cx);
Project::init_settings(cx);
AssistantSettings::register(cx);
+ prompt_store::init(cx);
thread_store::init(cx);
workspace::init_settings(cx);
ThemeSettings::register(cx);
@@ -951,7 +952,8 @@ mod tests {
cx,
)
})
- .await;
+ .await
+ .unwrap();
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
@@ -213,7 +213,7 @@ impl AssistantPanel {
let project = workspace.project().clone();
ThreadStore::load(project, tools.clone(), prompt_builder.clone(), cx)
})?
- .await;
+ .await?;
let slash_commands = Arc::new(SlashCommandWorkingSet::default());
let context_store = workspace
@@ -4,7 +4,7 @@ use std::ops::Range;
use std::sync::Arc;
use std::time::Instant;
-use anyhow::{Context as _, Result, anyhow};
+use anyhow::{Result, anyhow};
use assistant_settings::AssistantSettings;
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
@@ -939,7 +939,7 @@ impl Thread {
pub fn to_completion_request(
&self,
request_kind: RequestKind,
- cx: &App,
+ cx: &mut Context<Self>,
) -> LanguageModelRequest {
let mut request = LanguageModelRequest {
messages: vec![],
@@ -949,20 +949,33 @@ impl Thread {
};
if let Some(project_context) = self.project_context.borrow().as_ref() {
- if let Some(system_prompt) = self
+ match self
.prompt_builder
.generate_assistant_system_prompt(project_context)
- .context("failed to generate assistant system prompt")
- .log_err()
{
- request.messages.push(LanguageModelRequestMessage {
- role: Role::System,
- content: vec![MessageContent::Text(system_prompt)],
- cache: true,
- });
+ Err(err) => {
+ let message = format!("{err:?}").into();
+ log::error!("{message}");
+ cx.emit(ThreadEvent::ShowError(ThreadError::Message {
+ header: "Error generating system prompt".into(),
+ message,
+ }));
+ }
+ Ok(system_prompt) => {
+ request.messages.push(LanguageModelRequestMessage {
+ role: Role::System,
+ content: vec![MessageContent::Text(system_prompt)],
+ cache: true,
+ });
+ }
}
} else {
- log::error!("project_context not set.")
+ let message = "Context for system prompt unexpectedly not ready.".into();
+ log::error!("{message}");
+ cx.emit(ThreadEvent::ShowError(ThreadError::Message {
+ header: "Error generating system prompt".into(),
+ message,
+ }));
}
for message in &self.messages {
@@ -2163,7 +2176,7 @@ fn main() {{
assert_eq!(message.context, expected_context);
// Check message in request
- let request = thread.read_with(cx, |thread, cx| {
+ let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx)
});
@@ -2255,7 +2268,7 @@ fn main() {{
assert!(message3.context.contains("file3.rs"));
// Check entire request to make sure all contexts are properly included
- let request = thread.read_with(cx, |thread, cx| {
+ let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx)
});
@@ -2307,7 +2320,7 @@ fn main() {{
assert_eq!(message.context, "");
// Check message in request
- let request = thread.read_with(cx, |thread, cx| {
+ let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx)
});
@@ -2327,7 +2340,7 @@ fn main() {{
assert_eq!(message2.context, "");
// Check that both messages appear in the request
- let request = thread.read_with(cx, |thread, cx| {
+ let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx)
});
@@ -2369,7 +2382,7 @@ fn main() {{
});
// Create a request and check that it doesn't have a stale buffer warning yet
- let initial_request = thread.read_with(cx, |thread, cx| {
+ let initial_request = thread.update(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx)
});
@@ -2399,7 +2412,7 @@ fn main() {{
});
// Create a new request and check for the stale buffer warning
- let new_request = thread.read_with(cx, |thread, cx| {
+ let new_request = thread.update(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx)
});
@@ -2428,6 +2441,7 @@ fn main() {{
language::init(cx);
Project::init_settings(cx);
AssistantSettings::register(cx);
+ prompt_store::init(cx);
thread_store::init(cx);
workspace::init_settings(cx);
ThemeSettings::register(cx);
@@ -2467,7 +2481,8 @@ fn main() {{
cx,
)
})
- .await;
+ .await
+ .unwrap();
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
@@ -12,8 +12,9 @@ use collections::HashMap;
use context_server::manager::ContextServerManager;
use context_server::{ContextServerFactoryRegistry, ContextServerTool};
use fs::Fs;
-use futures::FutureExt as _;
+use futures::channel::{mpsc, oneshot};
use futures::future::{self, BoxFuture, Shared};
+use futures::{FutureExt as _, StreamExt as _};
use gpui::{
App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
Subscription, Task, prelude::*,
@@ -22,7 +23,10 @@ use heed::Database;
use heed::types::SerdeBincode;
use language_model::{LanguageModelToolUseId, Role, TokenUsage};
use project::{Project, Worktree};
-use prompt_store::{ProjectContext, PromptBuilder, RulesFileContext, WorktreeContext};
+use prompt_store::{
+ DefaultUserRulesContext, ProjectContext, PromptBuilder, PromptStore, PromptsUpdatedEvent,
+ RulesFileContext, WorktreeContext,
+};
use serde::{Deserialize, Serialize};
use settings::{Settings as _, SettingsStore};
use util::ResultExt as _;
@@ -62,6 +66,8 @@ pub struct ThreadStore {
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
threads: Vec<SerializedThreadMetadata>,
project_context: SharedProjectContext,
+ reload_system_prompt_tx: mpsc::Sender<()>,
+ _reload_system_prompt_task: Task<()>,
_subscriptions: Vec<Subscription>,
}
@@ -77,12 +83,22 @@ impl ThreadStore {
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
cx: &mut App,
- ) -> Task<Entity<Self>> {
- let thread_store = cx.new(|cx| Self::new(project, tools, prompt_builder, cx));
- let reload = thread_store.update(cx, |store, cx| store.reload_system_prompt(cx));
- cx.foreground_executor().spawn(async move {
- reload.await;
- thread_store
+ ) -> Task<Result<Entity<Self>>> {
+ let prompt_store = PromptStore::global(cx);
+ cx.spawn(async move |cx| {
+ let prompt_store = prompt_store.await.ok();
+ let (thread_store, ready_rx) = cx.update(|cx| {
+ let mut option_ready_rx = None;
+ let thread_store = cx.new(|cx| {
+ let (thread_store, ready_rx) =
+ Self::new(project, tools, prompt_builder, prompt_store, cx);
+ option_ready_rx = Some(ready_rx);
+ thread_store
+ });
+ (thread_store, option_ready_rx.take().unwrap())
+ })?;
+ ready_rx.await?;
+ Ok(thread_store)
})
}
@@ -90,17 +106,53 @@ impl ThreadStore {
project: Entity<Project>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
+ prompt_store: Option<Entity<PromptStore>>,
cx: &mut Context<Self>,
- ) -> Self {
+ ) -> (Self, oneshot::Receiver<()>) {
let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
let context_server_manager = cx.new(|cx| {
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
});
- let settings_subscription =
+
+ let mut subscriptions = vec![
cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
this.load_default_profile(cx);
- });
- let project_subscription = cx.subscribe(&project, Self::handle_project_event);
+ }),
+ cx.subscribe(&project, Self::handle_project_event),
+ ];
+
+ if let Some(prompt_store) = prompt_store.as_ref() {
+ subscriptions.push(cx.subscribe(
+ prompt_store,
+ |this, _prompt_store, PromptsUpdatedEvent, _cx| {
+ this.enqueue_system_prompt_reload();
+ },
+ ))
+ }
+
+ // This channel and task prevent concurrent and redundant loading of the system prompt.
+ let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
+ let (ready_tx, ready_rx) = oneshot::channel();
+ let mut ready_tx = Some(ready_tx);
+ let reload_system_prompt_task = cx.spawn({
+ async move |thread_store, cx| {
+ loop {
+ let Some(reload_task) = thread_store
+ .update(cx, |thread_store, cx| {
+ thread_store.reload_system_prompt(prompt_store.clone(), cx)
+ })
+ .ok()
+ else {
+ return;
+ };
+ reload_task.await;
+ if let Some(ready_tx) = ready_tx.take() {
+ ready_tx.send(()).ok();
+ }
+ reload_system_prompt_rx.next().await;
+ }
+ }
+ });
let this = Self {
project,
@@ -110,23 +162,25 @@ impl ThreadStore {
context_server_tool_ids: HashMap::default(),
threads: Vec::new(),
project_context: SharedProjectContext::default(),
- _subscriptions: vec![settings_subscription, project_subscription],
+ reload_system_prompt_tx,
+ _reload_system_prompt_task: reload_system_prompt_task,
+ _subscriptions: subscriptions,
};
this.load_default_profile(cx);
this.register_context_server_handlers(cx);
this.reload(cx).detach_and_log_err(cx);
- this
+ (this, ready_rx)
}
fn handle_project_event(
&mut self,
_project: Entity<Project>,
event: &project::Event,
- cx: &mut Context<Self>,
+ _cx: &mut Context<Self>,
) {
match event {
project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
- self.reload_system_prompt(cx).detach();
+ self.enqueue_system_prompt_reload();
}
project::Event::WorktreeUpdatedEntries(_, items) => {
if items.iter().any(|(path, _, _)| {
@@ -134,16 +188,25 @@ impl ThreadStore {
.iter()
.any(|name| path.as_ref() == Path::new(name))
}) {
- self.reload_system_prompt(cx).detach();
+ self.enqueue_system_prompt_reload();
}
}
_ => {}
}
}
- pub fn reload_system_prompt(&self, cx: &mut Context<Self>) -> Task<()> {
+ fn enqueue_system_prompt_reload(&mut self) {
+ self.reload_system_prompt_tx.try_send(()).ok();
+ }
+
+ // Note that this should only be called from `reload_system_prompt_task`.
+ fn reload_system_prompt(
+ &self,
+ prompt_store: Option<Entity<PromptStore>>,
+ cx: &mut Context<Self>,
+ ) -> Task<()> {
let project = self.project.read(cx);
- let tasks = project
+ let worktree_tasks = project
.visible_worktrees(cx)
.map(|worktree| {
Self::load_worktree_info_for_system_prompt(
@@ -153,10 +216,23 @@ impl ThreadStore {
)
})
.collect::<Vec<_>>();
+ let default_user_rules_task = match prompt_store {
+ None => Task::ready(vec![]),
+ Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
+ let prompts = prompt_store.default_prompt_metadata();
+ let load_tasks = prompts.into_iter().map(|prompt_metadata| {
+ let contents = prompt_store.load(prompt_metadata.id, cx);
+ async move { (contents.await, prompt_metadata) }
+ });
+ cx.background_spawn(future::join_all(load_tasks))
+ }),
+ };
cx.spawn(async move |this, cx| {
- let results = futures::future::join_all(tasks).await;
- let worktrees = results
+ let (worktrees, default_user_rules) =
+ future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
+
+ let worktrees = worktrees
.into_iter()
.map(|(worktree, rules_error)| {
if let Some(rules_error) = rules_error {
@@ -165,8 +241,29 @@ impl ThreadStore {
worktree
})
.collect::<Vec<_>>();
+
+ let default_user_rules = default_user_rules
+ .into_iter()
+ .flat_map(|(contents, prompt_metadata)| match contents {
+ Ok(contents) => Some(DefaultUserRulesContext {
+ title: prompt_metadata.title.map(|title| title.to_string()),
+ contents,
+ }),
+ Err(err) => {
+ this.update(cx, |_, cx| {
+ cx.emit(RulesLoadingError {
+ message: format!("{err:?}").into(),
+ });
+ })
+ .ok();
+ None
+ }
+ })
+ .collect::<Vec<_>>();
+
this.update(cx, |this, _cx| {
- *this.project_context.0.borrow_mut() = Some(ProjectContext::new(worktrees));
+ *this.project_context.0.borrow_mut() =
+ Some(ProjectContext::new(worktrees, default_user_rules));
})
.ok();
})
@@ -54,9 +54,9 @@ impl SlashCommand for DefaultSlashCommand {
cx: &mut App,
) -> Task<SlashCommandResult> {
let store = PromptStore::global(cx);
- cx.background_spawn(async move {
+ cx.spawn(async move |cx| {
let store = store.await?;
- let prompts = store.default_prompt_metadata();
+ let prompts = store.read_with(cx, |store, _cx| store.default_prompt_metadata())?;
let mut text = String::new();
text.push('\n');
@@ -5,7 +5,7 @@ use assistant_slash_command::{
};
use gpui::{Task, WeakEntity};
use language::{BufferSnapshot, LspAdapterDelegate};
-use prompt_store::PromptStore;
+use prompt_store::{PromptMetadata, PromptStore};
use std::sync::{Arc, atomic::AtomicBool};
use ui::prelude::*;
use workspace::Workspace;
@@ -43,8 +43,11 @@ impl SlashCommand for PromptSlashCommand {
) -> Task<Result<Vec<ArgumentCompletion>>> {
let store = PromptStore::global(cx);
let query = arguments.to_owned().join(" ");
- cx.background_spawn(async move {
- let prompts = store.await?.search(query).await;
+ cx.spawn(async move |cx| {
+ let prompts: Vec<PromptMetadata> = store
+ .await?
+ .read_with(cx, |store, cx| store.search(query, cx))?
+ .await;
Ok(prompts
.into_iter()
.filter_map(|prompt| {
@@ -77,14 +80,18 @@ impl SlashCommand for PromptSlashCommand {
let store = PromptStore::global(cx);
let title = SharedString::from(title.clone());
- let prompt = cx.background_spawn({
+ let prompt = cx.spawn({
let title = title.clone();
- async move {
+ async move |cx| {
let store = store.await?;
- let prompt_id = store
- .id_for_title(&title)
- .with_context(|| format!("no prompt found with title {:?}", title))?;
- let body = store.load(prompt_id).await?;
+ let body = store
+ .read_with(cx, |store, cx| {
+ let prompt_id = store
+ .id_for_title(&title)
+ .with_context(|| format!("no prompt found with title {:?}", title))?;
+ anyhow::Ok(store.load(prompt_id, cx))
+ })??
+ .await?;
anyhow::Ok(body)
}
});
@@ -309,7 +309,7 @@ impl Example {
return Err(anyhow!("Setup only mode"));
}
- let thread_store = thread_store.await;
+ let thread_store = thread_store.await?;
let thread =
thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
@@ -136,7 +136,7 @@ pub fn open_prompt_library(
}
pub struct PromptLibrary {
- store: Arc<PromptStore>,
+ store: Entity<PromptStore>,
language_registry: Arc<LanguageRegistry>,
prompt_editors: HashMap<PromptId, PromptEditor>,
active_prompt_id: Option<PromptId>,
@@ -158,7 +158,7 @@ struct PromptEditor {
}
struct PromptPickerDelegate {
- store: Arc<PromptStore>,
+ store: Entity<PromptStore>,
selected_index: usize,
matches: Vec<PromptMetadata>,
}
@@ -179,8 +179,8 @@ impl PickerDelegate for PromptPickerDelegate {
self.matches.len()
}
- fn no_matches_text(&self, _window: &mut Window, _cx: &mut App) -> Option<SharedString> {
- let text = if self.store.prompt_count() == 0 {
+ fn no_matches_text(&self, _window: &mut Window, cx: &mut App) -> Option<SharedString> {
+ let text = if self.store.read(cx).prompt_count() == 0 {
"No prompts.".into()
} else {
"No prompts found matching your search.".into()
@@ -211,7 +211,7 @@ impl PickerDelegate for PromptPickerDelegate {
window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Task<()> {
- let search = self.store.search(query);
+ let search = self.store.read(cx).search(query, cx);
let prev_prompt_id = self.matches.get(self.selected_index).map(|mat| mat.id);
cx.spawn_in(window, async move |this, cx| {
let (matches, selected_index) = cx
@@ -339,7 +339,7 @@ impl PickerDelegate for PromptPickerDelegate {
impl PromptLibrary {
fn new(
- store: Arc<PromptStore>,
+ store: Entity<PromptStore>,
language_registry: Arc<LanguageRegistry>,
inline_assist_delegate: Box<dyn InlineAssistDelegate>,
make_completion_provider: Arc<dyn Fn() -> Box<dyn CompletionProvider>>,
@@ -398,7 +398,7 @@ impl PromptLibrary {
pub fn new_prompt(&mut self, window: &mut Window, cx: &mut Context<Self>) {
// If we already have an untitled prompt, use that instead
// of creating a new one.
- if let Some(metadata) = self.store.first() {
+ if let Some(metadata) = self.store.read(cx).first() {
if metadata.title.is_none() {
self.load_prompt(metadata.id, true, window, cx);
return;
@@ -406,7 +406,9 @@ impl PromptLibrary {
}
let prompt_id = PromptId::new();
- let save = self.store.save(prompt_id, None, false, "".into());
+ let save = self.store.update(cx, |store, cx| {
+ store.save(prompt_id, None, false, "".into(), cx)
+ });
self.picker
.update(cx, |picker, cx| picker.refresh(window, cx));
cx.spawn_in(window, async move |this, cx| {
@@ -430,7 +432,7 @@ impl PromptLibrary {
return;
}
- let prompt_metadata = self.store.metadata(prompt_id).unwrap();
+ let prompt_metadata = self.store.read(cx).metadata(prompt_id).unwrap();
let prompt_editor = self.prompt_editors.get_mut(&prompt_id).unwrap();
let title = prompt_editor.title_editor.read(cx).text(cx);
let body = prompt_editor.body_editor.update(cx, |editor, cx| {
@@ -465,10 +467,13 @@ impl PromptLibrary {
} else {
Some(SharedString::from(title))
};
- store
- .save(prompt_id, title, prompt_metadata.default, body)
- .await
- .log_err();
+ cx.update(|_window, cx| {
+ store.update(cx, |store, cx| {
+ store.save(prompt_id, title, prompt_metadata.default, body, cx)
+ })
+ })?
+ .await
+ .log_err();
this.update_in(cx, |this, window, cx| {
this.picker
.update(cx, |picker, cx| picker.refresh(window, cx));
@@ -521,14 +526,21 @@ impl PromptLibrary {
window: &mut Window,
cx: &mut Context<Self>,
) {
- if let Some(prompt_metadata) = self.store.metadata(prompt_id) {
- self.store
- .save_metadata(prompt_id, prompt_metadata.title, !prompt_metadata.default)
- .detach_and_log_err(cx);
- self.picker
- .update(cx, |picker, cx| picker.refresh(window, cx));
- cx.notify();
- }
+ self.store.update(cx, move |store, cx| {
+ if let Some(prompt_metadata) = store.metadata(prompt_id) {
+ store
+ .save_metadata(
+ prompt_id,
+ prompt_metadata.title,
+ !prompt_metadata.default,
+ cx,
+ )
+ .detach_and_log_err(cx);
+ }
+ });
+ self.picker
+ .update(cx, |picker, cx| picker.refresh(window, cx));
+ cx.notify();
}
pub fn load_prompt(
@@ -545,9 +557,9 @@ impl PromptLibrary {
.update(cx, |editor, cx| window.focus(&editor.focus_handle(cx)));
}
self.set_active_prompt(Some(prompt_id), window, cx);
- } else if let Some(prompt_metadata) = self.store.metadata(prompt_id) {
+ } else if let Some(prompt_metadata) = self.store.read(cx).metadata(prompt_id) {
let language_registry = self.language_registry.clone();
- let prompt = self.store.load(prompt_id);
+ let prompt = self.store.read(cx).load(prompt_id, cx);
let make_completion_provider = self.make_completion_provider.clone();
self.pending_load = cx.spawn_in(window, async move |this, cx| {
let prompt = prompt.await;
@@ -673,7 +685,7 @@ impl PromptLibrary {
window: &mut Window,
cx: &mut Context<Self>,
) {
- if let Some(metadata) = self.store.metadata(prompt_id) {
+ if let Some(metadata) = self.store.read(cx).metadata(prompt_id) {
let confirmation = window.prompt(
PromptLevel::Warning,
&format!(
@@ -692,7 +704,9 @@ impl PromptLibrary {
this.set_active_prompt(None, window, cx);
}
this.prompt_editors.remove(&prompt_id);
- this.store.delete(prompt_id).detach_and_log_err(cx);
+ this.store
+ .update(cx, |store, cx| store.delete(prompt_id, cx))
+ .detach_and_log_err(cx);
this.picker
.update(cx, |picker, cx| picker.refresh(window, cx));
cx.notify();
@@ -736,9 +750,9 @@ impl PromptLibrary {
let new_id = PromptId::new();
let body = prompt.body_editor.read(cx).text(cx);
- let save = self
- .store
- .save(new_id, Some(title.into()), false, body.into());
+ let save = self.store.update(cx, |store, cx| {
+ store.save(new_id, Some(title.into()), false, body.into(), cx)
+ });
self.picker
.update(cx, |picker, cx| picker.refresh(window, cx));
cx.spawn_in(window, async move |this, cx| {
@@ -968,7 +982,7 @@ impl PromptLibrary {
.flex_none()
.min_w_64()
.children(self.active_prompt_id.and_then(|prompt_id| {
- let prompt_metadata = self.store.metadata(prompt_id)?;
+ let prompt_metadata = self.store.read(cx).metadata(prompt_id)?;
let prompt_editor = &self.prompt_editors[&prompt_id];
let focus_handle = prompt_editor.body_editor.focus_handle(cx);
let model = LanguageModelRegistry::read_global(cx)
@@ -1238,7 +1252,7 @@ impl Render for PromptLibrary {
.text_color(theme.colors().text)
.child(self.render_prompt_list(cx))
.map(|el| {
- if self.store.prompt_count() == 0 {
+ if self.store.read(cx).prompt_count() == 0 {
el.child(
v_flex()
.w_2_3()
@@ -4,9 +4,11 @@ use anyhow::{Result, anyhow};
use chrono::{DateTime, Utc};
use collections::HashMap;
use futures::FutureExt as _;
-use futures::future::{self, BoxFuture, Shared};
+use futures::future::Shared;
use fuzzy::StringMatchCandidate;
-use gpui::{App, BackgroundExecutor, Global, ReadGlobal, SharedString, Task};
+use gpui::{
+ App, AppContext, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Task,
+};
use heed::{
Database, RoTxn,
types::{SerdeBincode, SerdeJson, Str},
@@ -29,11 +31,16 @@ use uuid::Uuid;
/// a shared future to a global.
pub fn init(cx: &mut App) {
let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb");
- let prompt_store_future = PromptStore::new(db_path, cx.background_executor().clone())
- .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
- .boxed()
+ let prompt_store_task = PromptStore::new(db_path, cx);
+ let prompt_store_entity_task = cx
+ .spawn(async move |cx| {
+ prompt_store_task
+ .await
+ .and_then(|prompt_store| cx.new(|_cx| prompt_store))
+ .map_err(Arc::new)
+ })
.shared();
- cx.set_global(GlobalPromptStore(prompt_store_future))
+ cx.set_global(GlobalPromptStore(prompt_store_entity_task))
}
#[derive(Clone, Debug, Serialize, Deserialize)]
@@ -64,13 +71,16 @@ impl PromptId {
}
pub struct PromptStore {
- executor: BackgroundExecutor,
env: heed::Env,
metadata_cache: RwLock<MetadataCache>,
metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
bodies: Database<SerdeJson<PromptId>, Str>,
}
+pub struct PromptsUpdatedEvent;
+
+impl EventEmitter<PromptsUpdatedEvent> for PromptStore {}
+
#[derive(Default)]
struct MetadataCache {
metadata: Vec<PromptMetadata>,
@@ -117,49 +127,45 @@ impl MetadataCache {
}
impl PromptStore {
- pub fn global(cx: &App) -> impl Future<Output = Result<Arc<Self>>> + use<> {
+ pub fn global(cx: &App) -> impl Future<Output = Result<Entity<Self>>> + use<> {
let store = GlobalPromptStore::global(cx).0.clone();
async move { store.await.map_err(|err| anyhow!(err)) }
}
- pub fn new(db_path: PathBuf, executor: BackgroundExecutor) -> Task<Result<Self>> {
- executor.spawn({
- let executor = executor.clone();
- async move {
- std::fs::create_dir_all(&db_path)?;
+ pub fn new(db_path: PathBuf, cx: &App) -> Task<Result<Self>> {
+ cx.background_spawn(async move {
+ std::fs::create_dir_all(&db_path)?;
- let db_env = unsafe {
- heed::EnvOpenOptions::new()
- .map_size(1024 * 1024 * 1024) // 1GB
- .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
- .open(db_path)?
- };
+ let db_env = unsafe {
+ heed::EnvOpenOptions::new()
+ .map_size(1024 * 1024 * 1024) // 1GB
+ .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
+ .open(db_path)?
+ };
- let mut txn = db_env.write_txn()?;
- let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
- let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
+ let mut txn = db_env.write_txn()?;
+ let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
+ let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
- // Remove edit workflow prompt, as we decided to opt into it using
- // a slash command instead.
- metadata.delete(&mut txn, &PromptId::EditWorkflow).ok();
- bodies.delete(&mut txn, &PromptId::EditWorkflow).ok();
+ // Remove edit workflow prompt, as we decided to opt into it using
+ // a slash command instead.
+ metadata.delete(&mut txn, &PromptId::EditWorkflow).ok();
+ bodies.delete(&mut txn, &PromptId::EditWorkflow).ok();
- txn.commit()?;
+ txn.commit()?;
- Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
+ Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
- let txn = db_env.read_txn()?;
- let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
- txn.commit()?;
+ let txn = db_env.read_txn()?;
+ let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
+ txn.commit()?;
- Ok(PromptStore {
- executor,
- env: db_env,
- metadata_cache: RwLock::new(metadata_cache),
- metadata,
- bodies,
- })
- }
+ Ok(PromptStore {
+ env: db_env,
+ metadata_cache: RwLock::new(metadata_cache),
+ metadata,
+ bodies,
+ })
})
}
@@ -237,10 +243,10 @@ impl PromptStore {
Ok(())
}
- pub fn load(&self, id: PromptId) -> Task<Result<String>> {
+ pub fn load(&self, id: PromptId, cx: &App) -> Task<Result<String>> {
let env = self.env.clone();
let bodies = self.bodies;
- self.executor.spawn(async move {
+ cx.background_spawn(async move {
let txn = env.read_txn()?;
let mut prompt = bodies
.get(&txn, &id)?
@@ -262,21 +268,27 @@ impl PromptStore {
.collect::<Vec<_>>();
}
- pub fn delete(&self, id: PromptId) -> Task<Result<()>> {
+ pub fn delete(&self, id: PromptId, cx: &Context<Self>) -> Task<Result<()>> {
self.metadata_cache.write().remove(id);
let db_connection = self.env.clone();
let bodies = self.bodies;
let metadata = self.metadata;
- self.executor.spawn(async move {
+ let task = cx.background_spawn(async move {
let mut txn = db_connection.write_txn()?;
metadata.delete(&mut txn, &id)?;
bodies.delete(&mut txn, &id)?;
txn.commit()?;
- Ok(())
+ anyhow::Ok(())
+ });
+
+ cx.spawn(async move |this, cx| {
+ task.await?;
+ this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
+ anyhow::Ok(())
})
}
@@ -302,10 +314,10 @@ impl PromptStore {
Some(metadata.id)
}
- pub fn search(&self, query: String) -> Task<Vec<PromptMetadata>> {
+ pub fn search(&self, query: String, cx: &App) -> Task<Vec<PromptMetadata>> {
let cached_metadata = self.metadata_cache.read().metadata.clone();
- let executor = self.executor.clone();
- self.executor.spawn(async move {
+ let executor = cx.background_executor().clone();
+ cx.background_spawn(async move {
let mut matches = if query.is_empty() {
cached_metadata
} else {
@@ -341,6 +353,7 @@ impl PromptStore {
title: Option<SharedString>,
default: bool,
body: Rope,
+ cx: &Context<Self>,
) -> Task<Result<()>> {
if id.is_built_in() {
return Task::ready(Err(anyhow!("built-in prompts cannot be saved")));
@@ -358,7 +371,7 @@ impl PromptStore {
let bodies = self.bodies;
let metadata = self.metadata;
- self.executor.spawn(async move {
+ let task = cx.background_spawn(async move {
let mut txn = db_connection.write_txn()?;
metadata.put(&mut txn, &id, &prompt_metadata)?;
@@ -366,7 +379,13 @@ impl PromptStore {
txn.commit()?;
- Ok(())
+ anyhow::Ok(())
+ });
+
+ cx.spawn(async move |this, cx| {
+ task.await?;
+ this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
+ anyhow::Ok(())
})
}
@@ -375,6 +394,7 @@ impl PromptStore {
id: PromptId,
mut title: Option<SharedString>,
default: bool,
+ cx: &Context<Self>,
) -> Task<Result<()>> {
let mut cache = self.metadata_cache.write();
@@ -397,19 +417,23 @@ impl PromptStore {
let db_connection = self.env.clone();
let metadata = self.metadata;
- self.executor.spawn(async move {
+ let task = cx.background_spawn(async move {
let mut txn = db_connection.write_txn()?;
metadata.put(&mut txn, &id, &prompt_metadata)?;
txn.commit()?;
- Ok(())
+ anyhow::Ok(())
+ });
+
+ cx.spawn(async move |this, cx| {
+ task.await?;
+ this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
+ anyhow::Ok(())
})
}
}
/// Wraps a shared future to a prompt store so it can be assigned as a context global.
-pub struct GlobalPromptStore(
- Shared<BoxFuture<'static, Result<Arc<PromptStore>, Arc<anyhow::Error>>>>,
-);
+pub struct GlobalPromptStore(Shared<Task<Result<Entity<PromptStore>, Arc<anyhow::Error>>>>);
impl Global for GlobalPromptStore {}
@@ -19,20 +19,29 @@ use util::{ResultExt, get_system_shell};
#[derive(Debug, Clone, Serialize)]
pub struct ProjectContext {
pub worktrees: Vec<WorktreeContext>,
+ /// Whether any worktree has a rules_file. Provided as a field because handlebars can't do this.
pub has_rules: bool,
+ pub default_user_rules: Vec<DefaultUserRulesContext>,
+ /// `!default_user_rules.is_empty()` - provided as a field because handlebars can't do this.
+ pub has_default_user_rules: bool,
pub os: String,
pub arch: String,
pub shell: String,
}
impl ProjectContext {
- pub fn new(worktrees: Vec<WorktreeContext>) -> Self {
+ pub fn new(
+ worktrees: Vec<WorktreeContext>,
+ default_user_rules: Vec<DefaultUserRulesContext>,
+ ) -> Self {
let has_rules = worktrees
.iter()
.any(|worktree| worktree.rules_file.is_some());
Self {
worktrees,
has_rules,
+ has_default_user_rules: !default_user_rules.is_empty(),
+ default_user_rules,
os: std::env::consts::OS.to_string(),
arch: std::env::consts::ARCH.to_string(),
shell: get_system_shell(),
@@ -40,6 +49,12 @@ impl ProjectContext {
}
}
+#[derive(Debug, Clone, Serialize)]
+pub struct DefaultUserRulesContext {
+ pub title: Option<String>,
+ pub contents: String,
+}
+
#[derive(Debug, Clone, Serialize)]
pub struct WorktreeContext {
pub root_name: String,
@@ -377,3 +392,30 @@ impl PromptBuilder {
self.handlebars.lock().render("suggest_edits", &())
}
}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ #[test]
+ fn test_assistant_system_prompt_renders() {
+ let worktrees = vec![WorktreeContext {
+ root_name: "path".into(),
+ abs_path: Path::new("/some/path").into(),
+ rules_file: Some(RulesFileContext {
+ path_in_worktree: Path::new(".rules").into(),
+ abs_path: Path::new("/some/path/.rules").into(),
+ text: "".into(),
+ }),
+ }];
+ let default_user_rules = vec![DefaultUserRulesContext {
+ title: Some("Rules title".into()),
+ contents: "Rules contents".into(),
+ }];
+ let project_context = ProjectContext::new(worktrees, default_user_rules);
+ PromptBuilder::new(None)
+ .unwrap()
+ .generate_assistant_system_prompt(&project_context)
+ .unwrap();
+ }
+}