diff --git a/assets/prompts/content_prompt_v2.hbs b/assets/prompts/content_prompt_v2.hbs
index 87376f49f12f0e27cc61e9f9747d9de6bfde43cb..826aada8c04863c21d756cf99beb64e582ed4906 100644
--- a/assets/prompts/content_prompt_v2.hbs
+++ b/assets/prompts/content_prompt_v2.hbs
@@ -14,7 +14,6 @@ The section you'll need to rewrite is marked with
The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.
{{/if}}
-{{#if rewrite_section}}
And here's the section to rewrite based on that prompt again for reference:
@@ -33,8 +32,6 @@ Below are the diagnostic errors visible to the user. If the user requests probl
{{/each}}
{{/if}}
-{{/if}}
-
Only make changes that are necessary to fulfill the prompt, leave everything else as-is. All surrounding {{content_type}} will be preserved.
Start at the indentation level in the original file in the rewritten {{content_type}}.
diff --git a/crates/agent_ui/src/buffer_codegen.rs b/crates/agent_ui/src/buffer_codegen.rs
index 87ce6d386b38f31a0d7b550aab00bb766ce75010..a296d4d20918fba6eb32bfcf7fcc657f9db2b3ac 100644
--- a/crates/agent_ui/src/buffer_codegen.rs
+++ b/crates/agent_ui/src/buffer_codegen.rs
@@ -75,6 +75,9 @@ pub struct BufferCodegen {
session_id: Uuid,
}
+pub const REWRITE_SECTION_TOOL_NAME: &str = "rewrite_section";
+pub const FAILURE_MESSAGE_TOOL_NAME: &str = "failure_message";
+
impl BufferCodegen {
pub fn new(
buffer: Entity,
@@ -522,12 +525,12 @@ impl CodegenAlternative {
let tools = vec![
LanguageModelRequestTool {
- name: "rewrite_section".to_string(),
+ name: REWRITE_SECTION_TOOL_NAME.to_string(),
description: "Replaces text in tags with your replacement_text.".to_string(),
input_schema: language_model::tool_schema::root_schema_for::(tool_input_format).to_value(),
},
LanguageModelRequestTool {
- name: "failure_message".to_string(),
+ name: FAILURE_MESSAGE_TOOL_NAME.to_string(),
description: "Use this tool to provide a message to the user when you're unable to complete a task.".to_string(),
input_schema: language_model::tool_schema::root_schema_for::(tool_input_format).to_value(),
},
@@ -1167,7 +1170,7 @@ impl CodegenAlternative {
let process_tool_use = move |tool_use: LanguageModelToolUse| -> Option {
let mut chars_read_so_far = chars_read_so_far.lock();
match tool_use.name.as_ref() {
- "rewrite_section" => {
+ REWRITE_SECTION_TOOL_NAME => {
let Ok(input) =
serde_json::from_value::(tool_use.input)
else {
@@ -1180,7 +1183,7 @@ impl CodegenAlternative {
description: None,
})
}
- "failure_message" => {
+ FAILURE_MESSAGE_TOOL_NAME => {
let Ok(mut input) =
serde_json::from_value::(tool_use.input)
else {
@@ -1493,7 +1496,10 @@ mod tests {
use indoc::indoc;
use language::{Buffer, Point};
use language_model::fake_provider::FakeLanguageModel;
- use language_model::{LanguageModelRegistry, TokenUsage};
+ use language_model::{
+ LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRegistry,
+ LanguageModelToolUse, StopReason, TokenUsage,
+ };
use languages::rust_lang;
use rand::prelude::*;
use settings::SettingsStore;
@@ -1805,6 +1811,51 @@ mod tests {
);
}
+ // When not streaming tool calls, we strip backticks as part of parsing the model's
+ // plain text response. This is a regression test for a bug where we stripped
+ // backticks incorrectly.
+ #[gpui::test]
+ async fn test_allows_model_to_output_backticks(cx: &mut TestAppContext) {
+ init_test(cx);
+ let text = "- Improved; `cmd+click` behavior. Now requires `cmd` to be pressed before the click starts or it doesn't run. ([#44579](https://github.com/zed-industries/zed/pull/44579); thanks [Zachiah](https://github.com/Zachiah))";
+ let buffer = cx.new(|cx| Buffer::local("", cx));
+ let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
+ let range = buffer.read_with(cx, |buffer, cx| {
+ let snapshot = buffer.snapshot(cx);
+ snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(0, 0))
+ });
+ let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let codegen = cx.new(|cx| {
+ CodegenAlternative::new(
+ buffer.clone(),
+ range.clone(),
+ true,
+ prompt_builder,
+ Uuid::new_v4(),
+ cx,
+ )
+ });
+
+ let events_tx = simulate_tool_based_completion(&codegen, cx);
+ let chunk_len = text.find('`').unwrap();
+ events_tx
+ .unbounded_send(rewrite_tool_use("tool_1", &text[..chunk_len], false))
+ .unwrap();
+ events_tx
+ .unbounded_send(rewrite_tool_use("tool_2", &text, true))
+ .unwrap();
+ events_tx
+ .unbounded_send(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))
+ .unwrap();
+ drop(events_tx);
+ cx.run_until_parked();
+
+ assert_eq!(
+ buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
+ text
+ );
+ }
+
#[gpui::test]
async fn test_strip_invalid_spans_from_codeblock() {
assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
@@ -1870,4 +1921,39 @@ mod tests {
});
chunks_tx
}
+
+ fn simulate_tool_based_completion(
+ codegen: &Entity,
+ cx: &mut TestAppContext,
+ ) -> mpsc::UnboundedSender {
+ let (events_tx, events_rx) = mpsc::unbounded();
+ let model = Arc::new(FakeLanguageModel::default());
+ codegen.update(cx, |codegen, cx| {
+ let completion_stream = Task::ready(Ok(events_rx.map(Ok).boxed()
+ as BoxStream<
+ 'static,
+ Result,
+ >));
+ codegen.generation = codegen.handle_completion(model, completion_stream, cx);
+ });
+ events_tx
+ }
+
+ fn rewrite_tool_use(
+ id: &str,
+ replacement_text: &str,
+ is_complete: bool,
+ ) -> LanguageModelCompletionEvent {
+ let input = RewriteSectionInput {
+ replacement_text: replacement_text.into(),
+ };
+ LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
+ id: id.into(),
+ name: REWRITE_SECTION_TOOL_NAME.into(),
+ raw_input: serde_json::to_string(&input).unwrap(),
+ input: serde_json::to_value(&input).unwrap(),
+ is_input_complete: is_complete,
+ thought_signature: None,
+ })
+ }
}
diff --git a/crates/agent_ui/src/inline_assistant.rs b/crates/agent_ui/src/inline_assistant.rs
index 052d8598a76d1044c6d97b5378041b5cd12e23b3..671579f9ef018b495b7993279a852595c78d3e02 100644
--- a/crates/agent_ui/src/inline_assistant.rs
+++ b/crates/agent_ui/src/inline_assistant.rs
@@ -2271,6 +2271,36 @@ pub mod evals {
);
}
+ #[test]
+ #[cfg_attr(not(feature = "unit-eval"), ignore)]
+ fn eval_empty_buffer() {
+ run_eval(
+ 20,
+ 1.0,
+ "Write a Python hello, world program".to_string(),
+ "ˇ".to_string(),
+ |output| match output {
+ InlineAssistantOutput::Success {
+ full_buffer_text, ..
+ } => {
+ if full_buffer_text.is_empty() {
+ EvalOutput::failed("expected some output".to_string())
+ } else {
+ EvalOutput::passed(format!("Produced {full_buffer_text}"))
+ }
+ }
+ o @ InlineAssistantOutput::Failure { .. } => EvalOutput::failed(format!(
+ "Assistant output does not match expected output: {:?}",
+ o
+ )),
+ o @ InlineAssistantOutput::Malformed { .. } => EvalOutput::failed(format!(
+ "Assistant output does not match expected output: {:?}",
+ o
+ )),
+ },
+ );
+ }
+
fn run_eval(
iterations: usize,
expected_pass_ratio: f32,
diff --git a/crates/prompt_store/src/prompts.rs b/crates/prompt_store/src/prompts.rs
index 674d4869e9825fd700dde3db510fbf68c6b4d5cc..6a845bb8dd394f8a1ff26a8a0e130156a2a158bd 100644
--- a/crates/prompt_store/src/prompts.rs
+++ b/crates/prompt_store/src/prompts.rs
@@ -112,7 +112,7 @@ pub struct ContentPromptContextV2 {
pub language_name: Option,
pub is_truncated: bool,
pub document_content: String,
- pub rewrite_section: Option,
+ pub rewrite_section: String,
pub diagnostic_errors: Vec,
}
@@ -310,7 +310,6 @@ impl PromptBuilder {
};
const MAX_CTX: usize = 50000;
- let is_insert = range.is_empty();
let mut is_truncated = false;
let before_range = 0..range.start;
@@ -335,28 +334,19 @@ impl PromptBuilder {
for chunk in buffer.text_for_range(truncated_before) {
document_content.push_str(chunk);
}
- if is_insert {
- document_content.push_str("");
- } else {
- document_content.push_str("\n");
- for chunk in buffer.text_for_range(range.clone()) {
- document_content.push_str(chunk);
- }
- document_content.push_str("\n");
+
+ document_content.push_str("\n");
+ for chunk in buffer.text_for_range(range.clone()) {
+ document_content.push_str(chunk);
}
+ document_content.push_str("\n");
+
for chunk in buffer.text_for_range(truncated_after) {
document_content.push_str(chunk);
}
- let rewrite_section = if !is_insert {
- let mut section = String::new();
- for chunk in buffer.text_for_range(range.clone()) {
- section.push_str(chunk);
- }
- Some(section)
- } else {
- None
- };
+ let rewrite_section: String = buffer.text_for_range(range.clone()).collect();
+
let diagnostics = buffer.diagnostics_in_range::<_, Point>(range, false);
let diagnostic_errors: Vec = diagnostics
.map(|entry| {