Don't apply StripInvalidSpans for tool using inline assistant (#45040)

Michael Benfield created

It can occasionally mutilate the text when used with the tool format.

Release Notes:

- N/A

Change summary

crates/agent_ui/src/buffer_codegen.rs | 26 ++++++++++++++++++++------
1 file changed, 20 insertions(+), 6 deletions(-)

Detailed changes

crates/agent_ui/src/buffer_codegen.rs 🔗

@@ -441,7 +441,8 @@ impl CodegenAlternative {
                     })
                     .boxed_local()
                 };
-            self.generation = self.handle_stream(model, stream, cx);
+            self.generation =
+                self.handle_stream(model, /* strip_invalid_spans: */ true, stream, cx);
         }
 
         Ok(())
@@ -629,6 +630,7 @@ impl CodegenAlternative {
     pub fn handle_stream(
         &mut self,
         model: Arc<dyn LanguageModel>,
+        strip_invalid_spans: bool,
         stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
         cx: &mut Context<Self>,
     ) -> Task<()> {
@@ -713,10 +715,16 @@ impl CodegenAlternative {
                         let mut response_latency = None;
                         let request_start = Instant::now();
                         let diff = async {
-                            let chunks = StripInvalidSpans::new(
-                                stream?.stream.map_err(|error| error.into()),
-                            );
-                            futures::pin_mut!(chunks);
+                            let raw_stream = stream?.stream.map_err(|error| error.into());
+
+                            let stripped;
+                            let mut chunks: Pin<Box<dyn Stream<Item = Result<String>> + Send>> =
+                                if strip_invalid_spans {
+                                    stripped = StripInvalidSpans::new(raw_stream);
+                                    Box::pin(stripped)
+                                } else {
+                                    Box::pin(raw_stream)
+                                };
 
                             let mut diff = StreamingDiff::new(selected_text.to_string());
                             let mut line_diff = LineDiff::default();
@@ -1307,7 +1315,12 @@ impl CodegenAlternative {
 
             let Some(task) = codegen
                 .update(cx, move |codegen, cx| {
-                    codegen.handle_stream(model, async { Ok(language_model_text_stream) }, cx)
+                    codegen.handle_stream(
+                        model,
+                        /* strip_invalid_spans: */ false,
+                        async { Ok(language_model_text_stream) },
+                        cx,
+                    )
                 })
                 .ok()
             else {
@@ -1846,6 +1859,7 @@ mod tests {
         codegen.update(cx, |codegen, cx| {
             codegen.generation = codegen.handle_stream(
                 model,
+                /* strip_invalid_spans: */ false,
                 future::ready(Ok(LanguageModelTextStream {
                     message_id: None,
                     stream: chunks_rx.map(Ok).boxed(),