From 990774247ecb61ddf69f0325d28eb3b203fc3699 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Tue, 6 Aug 2024 21:47:42 -0600 Subject: [PATCH] Allow /workflow and step resolution prompts to be overridden (#15892) This will help us as we hit issues with the /workflow and step resolution. We can override the baked-in prompts and make tweaks, then import our refinements back into the source tree when we're ready. Release Notes: - N/A --- .../{edit_workflow.md => edit_workflow.hbs} | 0 ...step_resolution.md => step_resolution.hbs} | 0 crates/assistant/src/assistant.rs | 42 +++++++++----- crates/assistant/src/assistant_panel.rs | 6 +- crates/assistant/src/context.rs | 53 ++++++++++++++---- crates/assistant/src/context_store.rs | 19 ++++++- crates/assistant/src/prompt_library.rs | 16 +----- crates/assistant/src/prompts.rs | 38 +++++++------ .../src/slash_command/workflow_command.rs | 55 ++++++++++--------- crates/collab/src/tests/integration_tests.rs | 7 ++- crates/zed/src/main.rs | 34 ++++++++---- crates/zed/src/zed.rs | 15 +++-- crates/zed/src/zed/open_listener.rs | 11 +++- docs/src/language-model-integration.md | 6 +- 14 files changed, 197 insertions(+), 105 deletions(-) rename assets/prompts/{edit_workflow.md => edit_workflow.hbs} (100%) rename assets/prompts/{step_resolution.md => step_resolution.hbs} (100%) diff --git a/assets/prompts/edit_workflow.md b/assets/prompts/edit_workflow.hbs similarity index 100% rename from assets/prompts/edit_workflow.md rename to assets/prompts/edit_workflow.hbs diff --git a/assets/prompts/step_resolution.md b/assets/prompts/step_resolution.hbs similarity index 100% rename from assets/prompts/step_resolution.md rename to assets/prompts/step_resolution.hbs diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 4896aadeec5cfe7d702a6cd37ab3dccdabfbcf6a..04e4f69505c8c8b46e8218dbff3e2e9d9cff4ddb 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -25,6 +25,7 @@ use language_model::{ LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage, }; pub(crate) use model_selector::*; +pub use prompts::PromptBuilder; use semantic_index::{CloudEmbeddingProvider, SemanticIndex}; use serde::{Deserialize, Serialize}; use settings::{update_settings_file, Settings, SettingsStore}; @@ -163,7 +164,7 @@ impl Assistant { } } -pub fn init(fs: Arc, client: Arc, cx: &mut AppContext) { +pub fn init(fs: Arc, client: Arc, cx: &mut AppContext) -> Arc { cx.set_global(Assistant::default()); AssistantSettings::register(cx); @@ -196,19 +197,25 @@ pub fn init(fs: Arc, client: Arc, cx: &mut AppContext) { prompt_library::init(cx); init_language_model_settings(cx); assistant_slash_command::init(cx); - register_slash_commands(cx); assistant_panel::init(cx); - if let Some(prompt_builder) = prompts::PromptBuilder::new(Some((fs.clone(), cx))).log_err() { - let prompt_builder = Arc::new(prompt_builder); - inline_assistant::init( - fs.clone(), - prompt_builder.clone(), - client.telemetry().clone(), - cx, - ); - terminal_inline_assistant::init(fs.clone(), prompt_builder, client.telemetry().clone(), cx); - } + let prompt_builder = prompts::PromptBuilder::new(Some((fs.clone(), cx))) + .log_err() + .map(Arc::new) + .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap())); + register_slash_commands(Some(prompt_builder.clone()), cx); + inline_assistant::init( + fs.clone(), + prompt_builder.clone(), + client.telemetry().clone(), + cx, + ); + terminal_inline_assistant::init( + fs.clone(), + prompt_builder.clone(), + client.telemetry().clone(), + cx, + ); IndexedDocsRegistry::init_global(cx); CommandPaletteFilter::update_global(cx, |filter, _cx| { @@ -226,6 +233,8 @@ pub fn init(fs: Arc, client: Arc, cx: &mut AppContext) { }); }) .detach(); + + prompt_builder } fn init_language_model_settings(cx: &mut AppContext) { @@ -256,7 +265,7 @@ fn update_active_language_model_from_settings(cx: &mut AppContext) { }); } -fn register_slash_commands(cx: &mut AppContext) { +fn register_slash_commands(prompt_builder: Option>, cx: &mut AppContext) { let slash_command_registry = SlashCommandRegistry::global(cx); slash_command_registry.register_command(file_command::FileSlashCommand, true); slash_command_registry.register_command(active_command::ActiveSlashCommand, true); @@ -270,7 +279,12 @@ fn register_slash_commands(cx: &mut AppContext) { slash_command_registry.register_command(now_command::NowSlashCommand, true); slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true); slash_command_registry.register_command(docs_command::DocsSlashCommand, true); - slash_command_registry.register_command(workflow_command::WorkflowSlashCommand, true); + if let Some(prompt_builder) = prompt_builder { + slash_command_registry.register_command( + workflow_command::WorkflowSlashCommand::new(prompt_builder), + true, + ); + } slash_command_registry.register_command(fetch_command::FetchSlashCommand, false); } diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 9ccd8356fcddabdfe05cf9a50ee841f098c7d135..c5c504c4018023307cbcdf1975fac0841099f146 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -2,6 +2,7 @@ use crate::{ assistant_settings::{AssistantDockPosition, AssistantSettings}, humanize_token_count, prompt_library::open_prompt_library, + prompts::PromptBuilder, slash_command::{ default_command::DefaultSlashCommand, docs_command::{DocsSlashCommand, DocsSlashCommandArgs}, @@ -315,14 +316,17 @@ impl PickerDelegate for SavedContextPickerDelegate { impl AssistantPanel { pub fn load( workspace: WeakView, + prompt_builder: Arc, cx: AsyncWindowContext, ) -> Task>> { cx.spawn(|mut cx| async move { let context_store = workspace .update(&mut cx, |workspace, cx| { - ContextStore::new(workspace.project().clone(), cx) + let project = workspace.project().clone(); + ContextStore::new(project, prompt_builder.clone(), cx) })? .await?; + workspace.update(&mut cx, |workspace, cx| { // TODO: deserialize state. cx.new_view(|cx| Self::new(workspace, context_store, cx)) diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index 5b4e4a980c9fddc9993eab51a1ffa4ab507d35e1..b001794d9ed497a330a524256c4562568c94e62a 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -1,5 +1,5 @@ use crate::{ - prompt_library::PromptStore, slash_command::SlashCommandLine, AssistantPanel, InitialInsertion, + prompts::PromptBuilder, slash_command::SlashCommandLine, AssistantPanel, InitialInsertion, InlineAssistId, InlineAssistant, MessageId, MessageStatus, }; use anyhow::{anyhow, Context as _, Result}; @@ -611,6 +611,7 @@ pub struct Context { language_registry: Arc, workflow_steps: Vec, project: Option>, + prompt_builder: Arc, } impl EventEmitter for Context {} @@ -620,6 +621,7 @@ impl Context { language_registry: Arc, project: Option>, telemetry: Option>, + prompt_builder: Arc, cx: &mut ModelContext, ) -> Self { Self::new( @@ -627,17 +629,20 @@ impl Context { ReplicaId::default(), language::Capability::ReadWrite, language_registry, + prompt_builder, project, telemetry, cx, ) } + #[allow(clippy::too_many_arguments)] pub fn new( id: ContextId, replica_id: ReplicaId, capability: language::Capability, language_registry: Arc, + prompt_builder: Arc, project: Option>, telemetry: Option>, cx: &mut ModelContext, @@ -680,6 +685,7 @@ impl Context { project, language_registry, workflow_steps: Vec::new(), + prompt_builder, }; let first_message_id = MessageId(clock::Lamport { @@ -749,6 +755,7 @@ impl Context { saved_context: SavedContext, path: PathBuf, language_registry: Arc, + prompt_builder: Arc, project: Option>, telemetry: Option>, cx: &mut ModelContext, @@ -759,6 +766,7 @@ impl Context { ReplicaId::default(), language::Capability::ReadWrite, language_registry, + prompt_builder, project, telemetry, cx, @@ -1246,9 +1254,9 @@ impl Context { cx.spawn(|this, mut cx| { async move { - let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?; - - let mut prompt = prompt_store.step_resolution_prompt()?; + let mut prompt = this.update(&mut cx, |this, _| { + this.prompt_builder.generate_step_resolution_prompt() + })??; prompt.push_str(&step_text); request.messages.push(LanguageModelRequestMessage { @@ -2448,8 +2456,9 @@ mod tests { cx.set_global(settings_store); assistant_panel::init(cx); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - - let context = cx.new_model(|cx| Context::local(registry, None, None, cx)); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let context = + cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx)); let buffer = context.read(cx).buffer.clone(); let message_1 = context.read(cx).message_anchors[0].clone(); @@ -2580,7 +2589,9 @@ mod tests { assistant_panel::init(cx); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - let context = cx.new_model(|cx| Context::local(registry, None, None, cx)); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let context = + cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx)); let buffer = context.read(cx).buffer.clone(); let message_1 = context.read(cx).message_anchors[0].clone(); @@ -2673,7 +2684,9 @@ mod tests { cx.set_global(settings_store); assistant_panel::init(cx); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - let context = cx.new_model(|cx| Context::local(registry, None, None, cx)); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let context = + cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx)); let buffer = context.read(cx).buffer.clone(); let message_1 = context.read(cx).message_anchors[0].clone(); @@ -2778,7 +2791,10 @@ mod tests { slash_command_registry.register_command(active_command::ActiveSlashCommand, false); let registry = Arc::new(LanguageRegistry::test(cx.executor())); - let context = cx.new_model(|cx| Context::local(registry.clone(), None, None, cx)); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let context = cx.new_model(|cx| { + Context::local(registry.clone(), None, None, prompt_builder.clone(), cx) + }); let output_ranges = Rc::new(RefCell::new(HashSet::default())); context.update(cx, |_, cx| { @@ -2905,7 +2921,16 @@ mod tests { let registry = Arc::new(LanguageRegistry::test(cx.executor())); // Create a new context - let context = cx.new_model(|cx| Context::local(registry.clone(), Some(project), None, cx)); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let context = cx.new_model(|cx| { + Context::local( + registry.clone(), + Some(project), + None, + prompt_builder.clone(), + cx, + ) + }); let buffer = context.read_with(cx, |context, _| context.buffer.clone()); // Simulate user input @@ -3070,7 +3095,10 @@ mod tests { cx.update(LanguageModelRegistry::test); cx.update(assistant_panel::init); let registry = Arc::new(LanguageRegistry::test(cx.executor())); - let context = cx.new_model(|cx| Context::local(registry.clone(), None, None, cx)); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let context = cx.new_model(|cx| { + Context::local(registry.clone(), None, None, prompt_builder.clone(), cx) + }); let buffer = context.read_with(cx, |context, _| context.buffer.clone()); let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id); let message_1 = context.update(cx, |context, cx| { @@ -3109,6 +3137,7 @@ mod tests { serialized_context, Default::default(), registry.clone(), + prompt_builder.clone(), None, None, cx, @@ -3158,6 +3187,7 @@ mod tests { let num_peers = rng.gen_range(min_peers..=max_peers); let context_id = ContextId::new(); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); for i in 0..num_peers { let context = cx.new_model(|cx| { Context::new( @@ -3165,6 +3195,7 @@ mod tests { i as ReplicaId, language::Capability::ReadWrite, registry.clone(), + prompt_builder.clone(), None, None, cx, diff --git a/crates/assistant/src/context_store.rs b/crates/assistant/src/context_store.rs index dab709bd20fa1c157496dd8548edff8a32a78d11..ce82b5eca72951a7571ff416112bf844f6e85d8a 100644 --- a/crates/assistant/src/context_store.rs +++ b/crates/assistant/src/context_store.rs @@ -1,6 +1,6 @@ use crate::{ - Context, ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext, - SavedContextMetadata, + prompts::PromptBuilder, Context, ContextEvent, ContextId, ContextOperation, ContextVersion, + SavedContext, SavedContextMetadata, }; use anyhow::{anyhow, Context as _, Result}; use client::{proto, telemetry::Telemetry, Client, TypedEnvelope}; @@ -52,6 +52,7 @@ pub struct ContextStore { project_is_shared: bool, client_subscription: Option, _project_subscriptions: Vec, + prompt_builder: Arc, } pub enum ContextStoreEvent { @@ -82,7 +83,11 @@ impl ContextHandle { } impl ContextStore { - pub fn new(project: Model, cx: &mut AppContext) -> Task>> { + pub fn new( + project: Model, + prompt_builder: Arc, + cx: &mut AppContext, + ) -> Task>> { let fs = project.read(cx).fs().clone(); let languages = project.read(cx).languages().clone(); let telemetry = project.read(cx).client().telemetry().clone(); @@ -117,6 +122,7 @@ impl ContextStore { project_is_shared: false, client: project.read(cx).client(), project: project.clone(), + prompt_builder, }; this.handle_project_changed(project, cx); this.synchronize_contexts(cx); @@ -334,6 +340,7 @@ impl ContextStore { self.languages.clone(), Some(self.project.clone()), Some(self.telemetry.clone()), + self.prompt_builder.clone(), cx, ) }); @@ -358,6 +365,7 @@ impl ContextStore { let language_registry = self.languages.clone(); let project = self.project.clone(); let telemetry = self.telemetry.clone(); + let prompt_builder = self.prompt_builder.clone(); let request = self.client.request(proto::CreateContext { project_id }); cx.spawn(|this, mut cx| async move { let response = request.await?; @@ -369,6 +377,7 @@ impl ContextStore { replica_id, capability, language_registry, + prompt_builder, Some(project), Some(telemetry), cx, @@ -417,6 +426,7 @@ impl ContextStore { SavedContext::from_json(&saved_context) } }); + let prompt_builder = self.prompt_builder.clone(); cx.spawn(|this, mut cx| async move { let saved_context = load.await?; @@ -425,6 +435,7 @@ impl ContextStore { saved_context, path.clone(), languages, + prompt_builder, Some(project), Some(telemetry), cx, @@ -493,6 +504,7 @@ impl ContextStore { project_id, context_id: context_id.to_proto(), }); + let prompt_builder = self.prompt_builder.clone(); cx.spawn(|this, mut cx| async move { let response = request.await?; let context_proto = response.context.context("invalid context")?; @@ -502,6 +514,7 @@ impl ContextStore { replica_id, capability, language_registry, + prompt_builder, Some(project), Some(telemetry), cx, diff --git a/crates/assistant/src/prompt_library.rs b/crates/assistant/src/prompt_library.rs index 23f76177f7020055656d10fe21542373df398c33..a0b25bf679e95df56330770791f61c3b368c6809 100644 --- a/crates/assistant/src/prompt_library.rs +++ b/crates/assistant/src/prompt_library.rs @@ -2,7 +2,6 @@ use crate::{ slash_command::SlashCommandCompletionProvider, AssistantPanel, InlineAssist, InlineAssistant, }; use anyhow::{anyhow, Result}; -use assets::Assets; use chrono::{DateTime, Utc}; use collections::{HashMap, HashSet}; use editor::{actions::Tab, CurrentLineHighlight, Editor, EditorElement, EditorEvent, EditorStyle}; @@ -12,8 +11,8 @@ use futures::{ }; use fuzzy::StringMatchCandidate; use gpui::{ - actions, point, size, transparent_black, AppContext, AssetSource, BackgroundExecutor, Bounds, - EventEmitter, Global, HighlightStyle, PromptLevel, ReadGlobal, Subscription, Task, TextStyle, + actions, point, size, transparent_black, AppContext, BackgroundExecutor, Bounds, EventEmitter, + Global, HighlightStyle, PromptLevel, ReadGlobal, Subscription, Task, TextStyle, TitlebarOptions, UpdateGlobal, View, WindowBounds, WindowHandle, WindowOptions, }; use heed::{ @@ -1466,17 +1465,6 @@ impl PromptStore { fn first(&self) -> Option { self.metadata_cache.read().metadata.first().cloned() } - - pub fn step_resolution_prompt(&self) -> Result { - let path = "prompts/step_resolution.md"; - - Ok(String::from_utf8( - Assets - .load(path)? - .ok_or_else(|| anyhow!("{path} not found"))? - .to_vec(), - )?) - } } /// Wraps a shared future to a prompt store so it can be assigned as a context global. diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index d61d1278802ffcaf7b1c978664d1f1f1ec698f52..f6fb203e96076c395a9f2d5f9a7b94bc712b0017 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -125,22 +125,20 @@ impl PromptBuilder { } fn register_templates(handlebars: &mut Handlebars) -> Result<(), Box> { - let content_prompt = Assets::get("prompts/content_prompt.hbs") - .expect("Content prompt template not found") - .data; - let terminal_assistant_prompt = Assets::get("prompts/terminal_assistant_prompt.hbs") - .expect("Terminal assistant prompt template not found") - .data; - - handlebars - .register_template_string("content_prompt", String::from_utf8_lossy(&content_prompt)) - .map_err(Box::new)?; - handlebars - .register_template_string( - "terminal_assistant_prompt", - String::from_utf8_lossy(&terminal_assistant_prompt), - ) - .map_err(Box::new)?; + let mut register_template = |id: &str| { + let prompt = Assets::get(&format!("prompts/{}.hbs", id)) + .unwrap_or_else(|| panic!("{} prompt template not found", id)) + .data; + handlebars + .register_template_string(id, String::from_utf8_lossy(&prompt)) + .map_err(Box::new) + }; + + register_template("content_prompt")?; + register_template("terminal_assistant_prompt")?; + register_template("edit_workflow")?; + register_template("step_resolution")?; + Ok(()) } @@ -236,4 +234,12 @@ impl PromptBuilder { .lock() .render("terminal_assistant_prompt", &context) } + + pub fn generate_workflow_prompt(&self) -> Result { + self.handlebars.lock().render("edit_workflow", &()) + } + + pub fn generate_step_resolution_prompt(&self) -> Result { + self.handlebars.lock().render("step_resolution", &()) + } } diff --git a/crates/assistant/src/slash_command/workflow_command.rs b/crates/assistant/src/slash_command/workflow_command.rs index f55275f0114905f53cb185ed825f4c119bc7e3d7..d2708c38d2746b97cc903fde0d9043f721763aa0 100644 --- a/crates/assistant/src/slash_command/workflow_command.rs +++ b/crates/assistant/src/slash_command/workflow_command.rs @@ -1,18 +1,27 @@ -use std::sync::atomic::AtomicBool; +use crate::prompts::PromptBuilder; use std::sync::Arc; -use anyhow::{Context as _, Result}; -use assets::Assets; +use std::sync::atomic::AtomicBool; + +use anyhow::Result; use assistant_slash_command::{ ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection, }; -use gpui::{AppContext, AssetSource, Task, WeakView}; +use gpui::{AppContext, Task, WeakView}; use language::LspAdapterDelegate; -use text::LineEnding; use ui::prelude::*; + use workspace::Workspace; -pub(crate) struct WorkflowSlashCommand; +pub(crate) struct WorkflowSlashCommand { + prompt_builder: Arc, +} + +impl WorkflowSlashCommand { + pub fn new(prompt_builder: Arc) -> Self { + Self { prompt_builder } + } +} impl SlashCommand for WorkflowSlashCommand { fn name(&self) -> String { @@ -46,26 +55,22 @@ impl SlashCommand for WorkflowSlashCommand { _argument: Option<&str>, _workspace: WeakView, _delegate: Option>, - _cx: &mut WindowContext, + cx: &mut WindowContext, ) -> Task> { - let mut text = match Assets - .load("prompts/edit_workflow.md") - .and_then(|prompt| prompt.context("prompts/edit_workflow.md not found")) - { - Ok(prompt) => String::from_utf8_lossy(&prompt).into_owned(), - Err(error) => return Task::ready(Err(error)), - }; - LineEnding::normalize(&mut text); - let range = 0..text.len(); + let prompt_builder = self.prompt_builder.clone(); + cx.spawn(|_cx| async move { + let text = prompt_builder.generate_workflow_prompt()?; + let range = 0..text.len(); - Task::ready(Ok(SlashCommandOutput { - text, - sections: vec![SlashCommandOutputSection { - range, - icon: IconName::Route, - label: "Workflow".into(), - }], - run_commands_in_text: false, - })) + Ok(SlashCommandOutput { + text, + sections: vec![SlashCommandOutputSection { + range, + icon: IconName::Route, + label: "Workflow".into(), + }], + run_commands_in_text: false, + }) + }) } } diff --git a/crates/collab/src/tests/integration_tests.rs b/crates/collab/src/tests/integration_tests.rs index f5b7b8903b968bdac7c70f1b3c5058624068fe64..3e95ca76595660d47c0100fa3564299b4acaf6ee 100644 --- a/crates/collab/src/tests/integration_tests.rs +++ b/crates/collab/src/tests/integration_tests.rs @@ -6,7 +6,7 @@ use crate::{ }, }; use anyhow::{anyhow, Result}; -use assistant::ContextStore; +use assistant::{ContextStore, PromptBuilder}; use call::{room, ActiveCall, ParticipantLocation, Room}; use client::{User, RECEIVE_TIMEOUT}; use collections::{HashMap, HashSet}; @@ -6485,12 +6485,13 @@ async fn test_context_collaboration_with_reconnect( assert_eq!(project.collaborators().len(), 1); }); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); let context_store_a = cx_a - .update(|cx| ContextStore::new(project_a.clone(), cx)) + .update(|cx| ContextStore::new(project_a.clone(), prompt_builder.clone(), cx)) .await .unwrap(); let context_store_b = cx_b - .update(|cx| ContextStore::new(project_b.clone(), cx)) + .update(|cx| ContextStore::new(project_b.clone(), prompt_builder.clone(), cx)) .await .unwrap(); diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 1570d75b40c93b731871eeb830173392869d67cd..e56bfe5b9251cffd4fae4f4d1128e051bc1cd0f4 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -7,6 +7,7 @@ mod reliability; mod zed; use anyhow::{anyhow, Context as _, Result}; +use assistant::PromptBuilder; use clap::{command, Parser}; use cli::FORCE_CLI_MODE_ENV_VAR_NAME; use client::{parse_zed_link, Client, DevServerToken, UserStore}; @@ -161,7 +162,7 @@ fn init_headless( } // init_common is called for both headless and normal mode. -fn init_common(app_state: Arc, cx: &mut AppContext) { +fn init_common(app_state: Arc, cx: &mut AppContext) -> Arc { SystemAppearance::init(cx); theme::init(theme::LoadThemes::All(Box::new(Assets)), cx); command_palette::init(cx); @@ -182,7 +183,7 @@ fn init_common(app_state: Arc, cx: &mut AppContext) { ); snippet_provider::init(cx); inline_completion_registry::init(app_state.client.telemetry().clone(), cx); - assistant::init(app_state.fs.clone(), app_state.client.clone(), cx); + let prompt_builder = assistant::init(app_state.fs.clone(), app_state.client.clone(), cx); repl::init( app_state.fs.clone(), app_state.client.telemetry().clone(), @@ -196,9 +197,14 @@ fn init_common(app_state: Arc, cx: &mut AppContext) { ThemeRegistry::global(cx), cx, ); + prompt_builder } -fn init_ui(app_state: Arc, cx: &mut AppContext) -> Result<()> { +fn init_ui( + app_state: Arc, + prompt_builder: Arc, + cx: &mut AppContext, +) -> Result<()> { match cx.try_global::() { Some(AppMode::Headless(_)) => { return Err(anyhow!( @@ -289,7 +295,7 @@ fn init_ui(app_state: Arc, cx: &mut AppContext) -> Result<()> { watch_file_types(fs.clone(), cx); cx.set_menus(app_menus()); - initialize_workspace(app_state.clone(), cx); + initialize_workspace(app_state.clone(), prompt_builder, cx); cx.activate(true); @@ -467,7 +473,7 @@ fn main() { auto_update::init(client.http_client(), cx); reliability::init(client.http_client(), installation_id, cx); - init_common(app_state.clone(), cx); + let prompt_builder = init_common(app_state.clone(), cx); let args = Args::parse(); let urls: Vec<_> = args @@ -487,7 +493,7 @@ fn main() { .and_then(|urls| OpenRequest::parse(urls, cx).log_err()) { Some(request) => { - handle_open_request(request, app_state.clone(), cx); + handle_open_request(request, app_state.clone(), prompt_builder.clone(), cx); } None => { if let Some(dev_server_token) = args.dev_server_token { @@ -503,7 +509,7 @@ fn main() { }) .detach(); } else { - init_ui(app_state.clone(), cx).unwrap(); + init_ui(app_state.clone(), prompt_builder.clone(), cx).unwrap(); cx.spawn({ let app_state = app_state.clone(); |mut cx| async move { @@ -518,11 +524,12 @@ fn main() { } let app_state = app_state.clone(); + let prompt_builder = prompt_builder.clone(); cx.spawn(move |cx| async move { while let Some(urls) = open_rx.next().await { cx.update(|cx| { if let Some(request) = OpenRequest::parse(urls, cx).log_err() { - handle_open_request(request, app_state.clone(), cx); + handle_open_request(request, app_state.clone(), prompt_builder.clone(), cx); } }) .ok(); @@ -532,15 +539,20 @@ fn main() { }); } -fn handle_open_request(request: OpenRequest, app_state: Arc, cx: &mut AppContext) { +fn handle_open_request( + request: OpenRequest, + app_state: Arc, + prompt_builder: Arc, + cx: &mut AppContext, +) { if let Some(connection) = request.cli_connection { let app_state = app_state.clone(); - cx.spawn(move |cx| handle_cli_connection(connection, app_state, cx)) + cx.spawn(move |cx| handle_cli_connection(connection, app_state, prompt_builder, cx)) .detach(); return; } - if let Err(e) = init_ui(app_state.clone(), cx) { + if let Err(e) = init_ui(app_state.clone(), prompt_builder, cx) { fail_to_open_window(e, cx); return; }; diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index e727f0e170cf3d8fca248462bd4c7e917f29f089..667e8c66bc2533881188cbec3b1be6984e3be55c 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -7,6 +7,7 @@ pub(crate) mod only_instance; mod open_listener; pub use app_menus::*; +use assistant::PromptBuilder; use breadcrumbs::Breadcrumbs; use client::ZED_URL_SCHEME; use collections::VecDeque; @@ -119,7 +120,11 @@ pub fn build_window_options(display_uuid: Option, cx: &mut AppContext) -> } } -pub fn initialize_workspace(app_state: Arc, cx: &mut AppContext) { +pub fn initialize_workspace( + app_state: Arc, + prompt_builder: Arc, + cx: &mut AppContext, +) { cx.observe_new_views(move |workspace: &mut Workspace, cx| { let workspace_handle = cx.view().clone(); let center_pane = workspace.active_pane().clone(); @@ -238,9 +243,10 @@ pub fn initialize_workspace(app_state: Arc, cx: &mut AppContext) { }); } + let prompt_builder = prompt_builder.clone(); cx.spawn(|workspace_handle, mut cx| async move { let assistant_panel = - assistant::AssistantPanel::load(workspace_handle.clone(), cx.clone()); + assistant::AssistantPanel::load(workspace_handle.clone(), prompt_builder, cx.clone()); let project_panel = ProjectPanel::load(workspace_handle.clone(), cx.clone()); let outline_panel = OutlinePanel::load(workspace_handle.clone(), cx.clone()); @@ -3474,14 +3480,15 @@ mod tests { app_state.fs.clone(), cx, ); - assistant::init(app_state.fs.clone(), app_state.client.clone(), cx); + let prompt_builder = + assistant::init(app_state.fs.clone(), app_state.client.clone(), cx); repl::init( app_state.fs.clone(), app_state.client.telemetry().clone(), cx, ); tasks_ui::init(cx); - initialize_workspace(app_state.clone(), cx); + initialize_workspace(app_state.clone(), prompt_builder, cx); app_state }) } diff --git a/crates/zed/src/zed/open_listener.rs b/crates/zed/src/zed/open_listener.rs index fb94f2430e2973cea93837c5c55596fc5f4d9a67..a9c741f7a0476431a1dccfbb552252d14a6b2528 100644 --- a/crates/zed/src/zed/open_listener.rs +++ b/crates/zed/src/zed/open_listener.rs @@ -1,6 +1,7 @@ use crate::restorable_workspace_locations; use crate::{handle_open_request, init_headless, init_ui}; use anyhow::{anyhow, Context, Result}; +use assistant::PromptBuilder; use cli::{ipc, IpcHandshake}; use cli::{ipc::IpcSender, CliRequest, CliResponse}; use client::parse_zed_link; @@ -245,6 +246,7 @@ pub async fn open_paths_with_positions( pub async fn handle_cli_connection( (mut requests, responses): (mpsc::Receiver, IpcSender), app_state: Arc, + prompt_builder: Arc, mut cx: AsyncAppContext, ) { if let Some(request) = requests.next().await { @@ -289,7 +291,12 @@ pub async fn handle_cli_connection( cx.update(|cx| { match OpenRequest::parse(urls, cx) { Ok(open_request) => { - handle_open_request(open_request, app_state.clone(), cx); + handle_open_request( + open_request, + app_state.clone(), + prompt_builder.clone(), + cx, + ); responses.send(CliResponse::Exit { status: 0 }).log_err(); } Err(e) => { @@ -307,7 +314,7 @@ pub async fn handle_cli_connection( } if let Err(e) = cx - .update(|cx| init_ui(app_state.clone(), cx)) + .update(|cx| init_ui(app_state.clone(), prompt_builder.clone(), cx)) .and_then(|r| r) { responses diff --git a/docs/src/language-model-integration.md b/docs/src/language-model-integration.md index bb8d0f2f533de59ac8e2a8efb9f9611e76564abe..f2a2fb7b7c38ef88fdac6c6ba089a08c25a5e8a3 100644 --- a/docs/src/language-model-integration.md +++ b/docs/src/language-model-integration.md @@ -221,6 +221,10 @@ Zed allows you to override the default prompts used for various assistant featur given system information and latest terminal output if relevant. ``` -You can customize these templates to better suit your needs while maintaining the core structure and variables used by Zed. Zed will automatically reload your prompt overrides when they change on disk. +3. `edit_workflow.hbs`: Used for generating the edit workflow prompt. + +4. `step_resolution.hbs`: Used for generating the step resolution prompt. + +You can customize these templates to better suit your needs while maintaining the core structure and variables used by Zed. Zed will automatically reload your prompt overrides when they change on disk. Consult Zed's assets/prompts directory for current versions you can play with. Be sure you want to override these, as you'll miss out on iteration on our built in features. This should be primarily used when developing Zed.