diff --git a/assets/prompts/content_prompt_v2.hbs b/assets/prompts/content_prompt_v2.hbs index 87376f49f12f0e27cc61e9f9747d9de6bfde43cb..826aada8c04863c21d756cf99beb64e582ed4906 100644 --- a/assets/prompts/content_prompt_v2.hbs +++ b/assets/prompts/content_prompt_v2.hbs @@ -14,7 +14,6 @@ The section you'll need to rewrite is marked with The context around the relevant section has been truncated (possibly in the middle of a line) for brevity. {{/if}} -{{#if rewrite_section}} And here's the section to rewrite based on that prompt again for reference: @@ -33,8 +32,6 @@ Below are the diagnostic errors visible to the user. If the user requests probl {{/each}} {{/if}} -{{/if}} - Only make changes that are necessary to fulfill the prompt, leave everything else as-is. All surrounding {{content_type}} will be preserved. Start at the indentation level in the original file in the rewritten {{content_type}}. diff --git a/crates/agent_ui/src/buffer_codegen.rs b/crates/agent_ui/src/buffer_codegen.rs index 87ce6d386b38f31a0d7b550aab00bb766ce75010..a296d4d20918fba6eb32bfcf7fcc657f9db2b3ac 100644 --- a/crates/agent_ui/src/buffer_codegen.rs +++ b/crates/agent_ui/src/buffer_codegen.rs @@ -75,6 +75,9 @@ pub struct BufferCodegen { session_id: Uuid, } +pub const REWRITE_SECTION_TOOL_NAME: &str = "rewrite_section"; +pub const FAILURE_MESSAGE_TOOL_NAME: &str = "failure_message"; + impl BufferCodegen { pub fn new( buffer: Entity, @@ -522,12 +525,12 @@ impl CodegenAlternative { let tools = vec![ LanguageModelRequestTool { - name: "rewrite_section".to_string(), + name: REWRITE_SECTION_TOOL_NAME.to_string(), description: "Replaces text in tags with your replacement_text.".to_string(), input_schema: language_model::tool_schema::root_schema_for::(tool_input_format).to_value(), }, LanguageModelRequestTool { - name: "failure_message".to_string(), + name: FAILURE_MESSAGE_TOOL_NAME.to_string(), description: "Use this tool to provide a message to the user when you're unable to complete a task.".to_string(), input_schema: language_model::tool_schema::root_schema_for::(tool_input_format).to_value(), }, @@ -1167,7 +1170,7 @@ impl CodegenAlternative { let process_tool_use = move |tool_use: LanguageModelToolUse| -> Option { let mut chars_read_so_far = chars_read_so_far.lock(); match tool_use.name.as_ref() { - "rewrite_section" => { + REWRITE_SECTION_TOOL_NAME => { let Ok(input) = serde_json::from_value::(tool_use.input) else { @@ -1180,7 +1183,7 @@ impl CodegenAlternative { description: None, }) } - "failure_message" => { + FAILURE_MESSAGE_TOOL_NAME => { let Ok(mut input) = serde_json::from_value::(tool_use.input) else { @@ -1493,7 +1496,10 @@ mod tests { use indoc::indoc; use language::{Buffer, Point}; use language_model::fake_provider::FakeLanguageModel; - use language_model::{LanguageModelRegistry, TokenUsage}; + use language_model::{ + LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRegistry, + LanguageModelToolUse, StopReason, TokenUsage, + }; use languages::rust_lang; use rand::prelude::*; use settings::SettingsStore; @@ -1805,6 +1811,51 @@ mod tests { ); } + // When not streaming tool calls, we strip backticks as part of parsing the model's + // plain text response. This is a regression test for a bug where we stripped + // backticks incorrectly. + #[gpui::test] + async fn test_allows_model_to_output_backticks(cx: &mut TestAppContext) { + init_test(cx); + let text = "- Improved; `cmd+click` behavior. Now requires `cmd` to be pressed before the click starts or it doesn't run. ([#44579](https://github.com/zed-industries/zed/pull/44579); thanks [Zachiah](https://github.com/Zachiah))"; + let buffer = cx.new(|cx| Buffer::local("", cx)); + let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); + let range = buffer.read_with(cx, |buffer, cx| { + let snapshot = buffer.snapshot(cx); + snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(0, 0)) + }); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let codegen = cx.new(|cx| { + CodegenAlternative::new( + buffer.clone(), + range.clone(), + true, + prompt_builder, + Uuid::new_v4(), + cx, + ) + }); + + let events_tx = simulate_tool_based_completion(&codegen, cx); + let chunk_len = text.find('`').unwrap(); + events_tx + .unbounded_send(rewrite_tool_use("tool_1", &text[..chunk_len], false)) + .unwrap(); + events_tx + .unbounded_send(rewrite_tool_use("tool_2", &text, true)) + .unwrap(); + events_tx + .unbounded_send(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)) + .unwrap(); + drop(events_tx); + cx.run_until_parked(); + + assert_eq!( + buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), + text + ); + } + #[gpui::test] async fn test_strip_invalid_spans_from_codeblock() { assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await; @@ -1870,4 +1921,39 @@ mod tests { }); chunks_tx } + + fn simulate_tool_based_completion( + codegen: &Entity, + cx: &mut TestAppContext, + ) -> mpsc::UnboundedSender { + let (events_tx, events_rx) = mpsc::unbounded(); + let model = Arc::new(FakeLanguageModel::default()); + codegen.update(cx, |codegen, cx| { + let completion_stream = Task::ready(Ok(events_rx.map(Ok).boxed() + as BoxStream< + 'static, + Result, + >)); + codegen.generation = codegen.handle_completion(model, completion_stream, cx); + }); + events_tx + } + + fn rewrite_tool_use( + id: &str, + replacement_text: &str, + is_complete: bool, + ) -> LanguageModelCompletionEvent { + let input = RewriteSectionInput { + replacement_text: replacement_text.into(), + }; + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + id: id.into(), + name: REWRITE_SECTION_TOOL_NAME.into(), + raw_input: serde_json::to_string(&input).unwrap(), + input: serde_json::to_value(&input).unwrap(), + is_input_complete: is_complete, + thought_signature: None, + }) + } } diff --git a/crates/agent_ui/src/inline_assistant.rs b/crates/agent_ui/src/inline_assistant.rs index 052d8598a76d1044c6d97b5378041b5cd12e23b3..671579f9ef018b495b7993279a852595c78d3e02 100644 --- a/crates/agent_ui/src/inline_assistant.rs +++ b/crates/agent_ui/src/inline_assistant.rs @@ -2271,6 +2271,36 @@ pub mod evals { ); } + #[test] + #[cfg_attr(not(feature = "unit-eval"), ignore)] + fn eval_empty_buffer() { + run_eval( + 20, + 1.0, + "Write a Python hello, world program".to_string(), + "ˇ".to_string(), + |output| match output { + InlineAssistantOutput::Success { + full_buffer_text, .. + } => { + if full_buffer_text.is_empty() { + EvalOutput::failed("expected some output".to_string()) + } else { + EvalOutput::passed(format!("Produced {full_buffer_text}")) + } + } + o @ InlineAssistantOutput::Failure { .. } => EvalOutput::failed(format!( + "Assistant output does not match expected output: {:?}", + o + )), + o @ InlineAssistantOutput::Malformed { .. } => EvalOutput::failed(format!( + "Assistant output does not match expected output: {:?}", + o + )), + }, + ); + } + fn run_eval( iterations: usize, expected_pass_ratio: f32, diff --git a/crates/prompt_store/src/prompts.rs b/crates/prompt_store/src/prompts.rs index 674d4869e9825fd700dde3db510fbf68c6b4d5cc..6a845bb8dd394f8a1ff26a8a0e130156a2a158bd 100644 --- a/crates/prompt_store/src/prompts.rs +++ b/crates/prompt_store/src/prompts.rs @@ -112,7 +112,7 @@ pub struct ContentPromptContextV2 { pub language_name: Option, pub is_truncated: bool, pub document_content: String, - pub rewrite_section: Option, + pub rewrite_section: String, pub diagnostic_errors: Vec, } @@ -310,7 +310,6 @@ impl PromptBuilder { }; const MAX_CTX: usize = 50000; - let is_insert = range.is_empty(); let mut is_truncated = false; let before_range = 0..range.start; @@ -335,28 +334,19 @@ impl PromptBuilder { for chunk in buffer.text_for_range(truncated_before) { document_content.push_str(chunk); } - if is_insert { - document_content.push_str(""); - } else { - document_content.push_str("\n"); - for chunk in buffer.text_for_range(range.clone()) { - document_content.push_str(chunk); - } - document_content.push_str("\n"); + + document_content.push_str("\n"); + for chunk in buffer.text_for_range(range.clone()) { + document_content.push_str(chunk); } + document_content.push_str("\n"); + for chunk in buffer.text_for_range(truncated_after) { document_content.push_str(chunk); } - let rewrite_section = if !is_insert { - let mut section = String::new(); - for chunk in buffer.text_for_range(range.clone()) { - section.push_str(chunk); - } - Some(section) - } else { - None - }; + let rewrite_section: String = buffer.text_for_range(range.clone()).collect(); + let diagnostics = buffer.diagnostics_in_range::<_, Point>(range, false); let diagnostic_errors: Vec = diagnostics .map(|entry| {