A couple new inline assistant tests (#45049)

Michael Benfield and Max Brunsfeld created

Also adjust the code for streaming tool use to always use a
rewrite_section; remove insert_here entirely.

Release Notes:

- N/A

Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>

Change summary

assets/prompts/content_prompt_v2.hbs    |  3 
crates/agent_ui/src/buffer_codegen.rs   | 96 +++++++++++++++++++++++++-
crates/agent_ui/src/inline_assistant.rs | 30 ++++++++
crates/prompt_store/src/prompts.rs      | 28 ++-----
4 files changed, 130 insertions(+), 27 deletions(-)

Detailed changes

assets/prompts/content_prompt_v2.hbs 🔗

@@ -14,7 +14,6 @@ The section you'll need to rewrite is marked with <rewrite_this></rewrite_this>
 The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.
 {{/if}}
 
-{{#if rewrite_section}}
 And here's the section to rewrite based on that prompt again for reference:
 
 <rewrite_this>
@@ -33,8 +32,6 @@ Below are the diagnostic errors visible to the user.  If the user requests probl
 {{/each}}
 {{/if}}
 
-{{/if}}
-
 Only make changes that are necessary to fulfill the prompt, leave everything else as-is. All surrounding {{content_type}} will be preserved.
 
 Start at the indentation level in the original file in the rewritten {{content_type}}.

crates/agent_ui/src/buffer_codegen.rs 🔗

@@ -75,6 +75,9 @@ pub struct BufferCodegen {
     session_id: Uuid,
 }
 
+pub const REWRITE_SECTION_TOOL_NAME: &str = "rewrite_section";
+pub const FAILURE_MESSAGE_TOOL_NAME: &str = "failure_message";
+
 impl BufferCodegen {
     pub fn new(
         buffer: Entity<MultiBuffer>,
@@ -522,12 +525,12 @@ impl CodegenAlternative {
 
             let tools = vec![
                 LanguageModelRequestTool {
-                    name: "rewrite_section".to_string(),
+                    name: REWRITE_SECTION_TOOL_NAME.to_string(),
                     description: "Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.".to_string(),
                     input_schema: language_model::tool_schema::root_schema_for::<RewriteSectionInput>(tool_input_format).to_value(),
                 },
                 LanguageModelRequestTool {
-                    name: "failure_message".to_string(),
+                    name: FAILURE_MESSAGE_TOOL_NAME.to_string(),
                     description: "Use this tool to provide a message to the user when you're unable to complete a task.".to_string(),
                     input_schema: language_model::tool_schema::root_schema_for::<FailureMessageInput>(tool_input_format).to_value(),
                 },
@@ -1167,7 +1170,7 @@ impl CodegenAlternative {
             let process_tool_use = move |tool_use: LanguageModelToolUse| -> Option<ToolUseOutput> {
                 let mut chars_read_so_far = chars_read_so_far.lock();
                 match tool_use.name.as_ref() {
-                    "rewrite_section" => {
+                    REWRITE_SECTION_TOOL_NAME => {
                         let Ok(input) =
                             serde_json::from_value::<RewriteSectionInput>(tool_use.input)
                         else {
@@ -1180,7 +1183,7 @@ impl CodegenAlternative {
                             description: None,
                         })
                     }
-                    "failure_message" => {
+                    FAILURE_MESSAGE_TOOL_NAME => {
                         let Ok(mut input) =
                             serde_json::from_value::<FailureMessageInput>(tool_use.input)
                         else {
@@ -1493,7 +1496,10 @@ mod tests {
     use indoc::indoc;
     use language::{Buffer, Point};
     use language_model::fake_provider::FakeLanguageModel;
-    use language_model::{LanguageModelRegistry, TokenUsage};
+    use language_model::{
+        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRegistry,
+        LanguageModelToolUse, StopReason, TokenUsage,
+    };
     use languages::rust_lang;
     use rand::prelude::*;
     use settings::SettingsStore;
@@ -1805,6 +1811,51 @@ mod tests {
         );
     }
 
+    // When not streaming tool calls, we strip backticks as part of parsing the model's
+    // plain text response. This is a regression test for a bug where we stripped
+    // backticks incorrectly.
+    #[gpui::test]
+    async fn test_allows_model_to_output_backticks(cx: &mut TestAppContext) {
+        init_test(cx);
+        let text = "- Improved; `cmd+click` behavior. Now requires `cmd` to be pressed before the click starts or it doesn't run. ([#44579](https://github.com/zed-industries/zed/pull/44579); thanks [Zachiah](https://github.com/Zachiah))";
+        let buffer = cx.new(|cx| Buffer::local("", cx));
+        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
+        let range = buffer.read_with(cx, |buffer, cx| {
+            let snapshot = buffer.snapshot(cx);
+            snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(0, 0))
+        });
+        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+        let codegen = cx.new(|cx| {
+            CodegenAlternative::new(
+                buffer.clone(),
+                range.clone(),
+                true,
+                prompt_builder,
+                Uuid::new_v4(),
+                cx,
+            )
+        });
+
+        let events_tx = simulate_tool_based_completion(&codegen, cx);
+        let chunk_len = text.find('`').unwrap();
+        events_tx
+            .unbounded_send(rewrite_tool_use("tool_1", &text[..chunk_len], false))
+            .unwrap();
+        events_tx
+            .unbounded_send(rewrite_tool_use("tool_2", &text, true))
+            .unwrap();
+        events_tx
+            .unbounded_send(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))
+            .unwrap();
+        drop(events_tx);
+        cx.run_until_parked();
+
+        assert_eq!(
+            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
+            text
+        );
+    }
+
     #[gpui::test]
     async fn test_strip_invalid_spans_from_codeblock() {
         assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
@@ -1870,4 +1921,39 @@ mod tests {
         });
         chunks_tx
     }
+
+    fn simulate_tool_based_completion(
+        codegen: &Entity<CodegenAlternative>,
+        cx: &mut TestAppContext,
+    ) -> mpsc::UnboundedSender<LanguageModelCompletionEvent> {
+        let (events_tx, events_rx) = mpsc::unbounded();
+        let model = Arc::new(FakeLanguageModel::default());
+        codegen.update(cx, |codegen, cx| {
+            let completion_stream = Task::ready(Ok(events_rx.map(Ok).boxed()
+                as BoxStream<
+                    'static,
+                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
+                >));
+            codegen.generation = codegen.handle_completion(model, completion_stream, cx);
+        });
+        events_tx
+    }
+
+    fn rewrite_tool_use(
+        id: &str,
+        replacement_text: &str,
+        is_complete: bool,
+    ) -> LanguageModelCompletionEvent {
+        let input = RewriteSectionInput {
+            replacement_text: replacement_text.into(),
+        };
+        LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
+            id: id.into(),
+            name: REWRITE_SECTION_TOOL_NAME.into(),
+            raw_input: serde_json::to_string(&input).unwrap(),
+            input: serde_json::to_value(&input).unwrap(),
+            is_input_complete: is_complete,
+            thought_signature: None,
+        })
+    }
 }

crates/agent_ui/src/inline_assistant.rs 🔗

@@ -2271,6 +2271,36 @@ pub mod evals {
         );
     }
 
+    #[test]
+    #[cfg_attr(not(feature = "unit-eval"), ignore)]
+    fn eval_empty_buffer() {
+        run_eval(
+            20,
+            1.0,
+            "Write a Python hello, world program".to_string(),
+            "ˇ".to_string(),
+            |output| match output {
+                InlineAssistantOutput::Success {
+                    full_buffer_text, ..
+                } => {
+                    if full_buffer_text.is_empty() {
+                        EvalOutput::failed("expected some output".to_string())
+                    } else {
+                        EvalOutput::passed(format!("Produced {full_buffer_text}"))
+                    }
+                }
+                o @ InlineAssistantOutput::Failure { .. } => EvalOutput::failed(format!(
+                    "Assistant output does not match expected output: {:?}",
+                    o
+                )),
+                o @ InlineAssistantOutput::Malformed { .. } => EvalOutput::failed(format!(
+                    "Assistant output does not match expected output: {:?}",
+                    o
+                )),
+            },
+        );
+    }
+
     fn run_eval(
         iterations: usize,
         expected_pass_ratio: f32,

crates/prompt_store/src/prompts.rs 🔗

@@ -112,7 +112,7 @@ pub struct ContentPromptContextV2 {
     pub language_name: Option<String>,
     pub is_truncated: bool,
     pub document_content: String,
-    pub rewrite_section: Option<String>,
+    pub rewrite_section: String,
     pub diagnostic_errors: Vec<ContentPromptDiagnosticContext>,
 }
 
@@ -310,7 +310,6 @@ impl PromptBuilder {
         };
 
         const MAX_CTX: usize = 50000;
-        let is_insert = range.is_empty();
         let mut is_truncated = false;
 
         let before_range = 0..range.start;
@@ -335,28 +334,19 @@ impl PromptBuilder {
         for chunk in buffer.text_for_range(truncated_before) {
             document_content.push_str(chunk);
         }
-        if is_insert {
-            document_content.push_str("<insert_here></insert_here>");
-        } else {
-            document_content.push_str("<rewrite_this>\n");
-            for chunk in buffer.text_for_range(range.clone()) {
-                document_content.push_str(chunk);
-            }
-            document_content.push_str("\n</rewrite_this>");
+
+        document_content.push_str("<rewrite_this>\n");
+        for chunk in buffer.text_for_range(range.clone()) {
+            document_content.push_str(chunk);
         }
+        document_content.push_str("\n</rewrite_this>");
+
         for chunk in buffer.text_for_range(truncated_after) {
             document_content.push_str(chunk);
         }
 
-        let rewrite_section = if !is_insert {
-            let mut section = String::new();
-            for chunk in buffer.text_for_range(range.clone()) {
-                section.push_str(chunk);
-            }
-            Some(section)
-        } else {
-            None
-        };
+        let rewrite_section: String = buffer.text_for_range(range.clone()).collect();
+
         let diagnostics = buffer.diagnostics_in_range::<_, Point>(range, false);
         let diagnostic_errors: Vec<ContentPromptDiagnosticContext> = diagnostics
             .map(|entry| {