@@ -441,7 +441,8 @@ impl CodegenAlternative {
})
.boxed_local()
};
- self.generation = self.handle_stream(model, stream, cx);
+ self.generation =
+ self.handle_stream(model, /* strip_invalid_spans: */ true, stream, cx);
}
Ok(())
@@ -629,6 +630,7 @@ impl CodegenAlternative {
pub fn handle_stream(
&mut self,
model: Arc<dyn LanguageModel>,
+ strip_invalid_spans: bool,
stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
cx: &mut Context<Self>,
) -> Task<()> {
@@ -713,10 +715,16 @@ impl CodegenAlternative {
let mut response_latency = None;
let request_start = Instant::now();
let diff = async {
- let chunks = StripInvalidSpans::new(
- stream?.stream.map_err(|error| error.into()),
- );
- futures::pin_mut!(chunks);
+ let raw_stream = stream?.stream.map_err(|error| error.into());
+
+ let stripped;
+ let mut chunks: Pin<Box<dyn Stream<Item = Result<String>> + Send>> =
+ if strip_invalid_spans {
+ stripped = StripInvalidSpans::new(raw_stream);
+ Box::pin(stripped)
+ } else {
+ Box::pin(raw_stream)
+ };
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut line_diff = LineDiff::default();
@@ -1307,7 +1315,12 @@ impl CodegenAlternative {
let Some(task) = codegen
.update(cx, move |codegen, cx| {
- codegen.handle_stream(model, async { Ok(language_model_text_stream) }, cx)
+ codegen.handle_stream(
+ model,
+ /* strip_invalid_spans: */ false,
+ async { Ok(language_model_text_stream) },
+ cx,
+ )
})
.ok()
else {
@@ -1846,6 +1859,7 @@ mod tests {
codegen.update(cx, |codegen, cx| {
codegen.generation = codegen.handle_stream(
model,
+ /* strip_invalid_spans: */ false,
future::ready(Ok(LanguageModelTextStream {
message_id: None,
stream: chunks_rx.map(Ok).boxed(),