bedrock: Fix bedrock not streaming (#28281)

Shardul Vaidya and Peter Tripp created

Closes #26030 

Release Notes:

- Fixed Bedrock bug causing streaming responses to return as one big
chunk

---------

Co-authored-by: Peter Tripp <peter@zed.dev>

Change summary

Cargo.lock                                     |   1 
crates/bedrock/Cargo.toml                      |   1 
crates/bedrock/src/bedrock.rs                  | 111 ++---
crates/language_models/src/provider/bedrock.rs | 335 +++++++------------
4 files changed, 177 insertions(+), 271 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -1911,7 +1911,6 @@ dependencies = [
  "serde_json",
  "strum 0.27.1",
  "thiserror 2.0.12",
- "tokio",
  "workspace-hack",
 ]
 

crates/bedrock/Cargo.toml 🔗

@@ -25,5 +25,4 @@ serde.workspace = true
 serde_json.workspace = true
 strum.workspace = true
 thiserror.workspace = true
-tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
 workspace-hack.workspace = true

crates/bedrock/src/bedrock.rs 🔗

@@ -1,9 +1,6 @@
 mod models;
 
-use std::collections::HashMap;
-use std::pin::Pin;
-
-use anyhow::{Context as _, Error, Result, anyhow};
+use anyhow::{Context, Error, Result, anyhow};
 use aws_sdk_bedrockruntime as bedrock;
 pub use aws_sdk_bedrockruntime as bedrock_client;
 pub use aws_sdk_bedrockruntime::types::{
@@ -24,9 +21,10 @@ pub use bedrock::types::{
     ToolResultContentBlock as BedrockToolResultContentBlock,
     ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock,
 };
-use futures::stream::{self, BoxStream, Stream};
+use futures::stream::{self, BoxStream};
 use serde::{Deserialize, Serialize};
 use serde_json::{Number, Value};
+use std::collections::HashMap;
 use thiserror::Error;
 
 pub use crate::models::*;
@@ -34,70 +32,59 @@ pub use crate::models::*;
 pub async fn stream_completion(
     client: bedrock::Client,
     request: Request,
-    handle: tokio::runtime::Handle,
 ) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
-    handle
-        .spawn(async move {
-            let mut response = bedrock::Client::converse_stream(&client)
-                .model_id(request.model.clone())
-                .set_messages(request.messages.into());
+    let mut response = bedrock::Client::converse_stream(&client)
+        .model_id(request.model.clone())
+        .set_messages(request.messages.into());
 
-            if let Some(Thinking::Enabled {
-                budget_tokens: Some(budget_tokens),
-            }) = request.thinking
-            {
-                response =
-                    response.additional_model_request_fields(Document::Object(HashMap::from([(
-                        "thinking".to_string(),
-                        Document::from(HashMap::from([
-                            ("type".to_string(), Document::String("enabled".to_string())),
-                            (
-                                "budget_tokens".to_string(),
-                                Document::Number(AwsNumber::PosInt(budget_tokens)),
-                            ),
-                        ])),
-                    )])));
-            }
+    if let Some(Thinking::Enabled {
+        budget_tokens: Some(budget_tokens),
+    }) = request.thinking
+    {
+        let thinking_config = HashMap::from([
+            ("type".to_string(), Document::String("enabled".to_string())),
+            (
+                "budget_tokens".to_string(),
+                Document::Number(AwsNumber::PosInt(budget_tokens)),
+            ),
+        ]);
+        response = response.additional_model_request_fields(Document::Object(HashMap::from([(
+            "thinking".to_string(),
+            Document::from(thinking_config),
+        )])));
+    }
 
-            if request.tools.is_some() && !request.tools.as_ref().unwrap().tools.is_empty() {
-                response = response.set_tool_config(request.tools);
-            }
+    if request
+        .tools
+        .as_ref()
+        .map_or(false, |t| !t.tools.is_empty())
+    {
+        response = response.set_tool_config(request.tools);
+    }
 
-            let response = response.send().await;
+    let output = response
+        .send()
+        .await
+        .context("Failed to send API request to Bedrock");
 
-            match response {
-                Ok(output) => {
-                    let stream: Pin<
-                        Box<
-                            dyn Stream<Item = Result<BedrockStreamingResponse, BedrockError>>
-                                + Send,
-                        >,
-                    > = Box::pin(stream::unfold(output.stream, |mut stream| async move {
-                        match stream.recv().await {
-                            Ok(Some(output)) => Some(({ Ok(output) }, stream)),
-                            Ok(None) => None,
-                            Err(err) => {
-                                Some((
-                                    // TODO: Figure out how we can capture Throttling Exceptions
-                                    Err(BedrockError::ClientError(anyhow!(
-                                        "{:?}",
-                                        aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
-                                    ))),
-                                    stream,
-                                ))
-                            }
-                        }
-                    }));
-                    Ok(stream)
-                }
-                Err(err) => Err(anyhow!(
-                    "{:?}",
-                    aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
+    let stream = Box::pin(stream::unfold(
+        output?.stream,
+        move |mut stream| async move {
+            match stream.recv().await {
+                Ok(Some(output)) => Some((Ok(output), stream)),
+                Ok(None) => None,
+                Err(err) => Some((
+                    Err(BedrockError::ClientError(anyhow!(
+                        "{:?}",
+                        aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
+                    ))),
+                    stream,
                 )),
             }
-        })
-        .await
-        .context("spawning a task")?
+        },
+    ));
+
+    Ok(stream)
 }
 
 pub fn aws_document_to_value(document: &Document) -> Value {

crates/language_models/src/provider/bedrock.rs 🔗

@@ -46,7 +46,6 @@ use settings::{Settings, SettingsStore};
 use smol::lock::OnceCell;
 use strum::{EnumIter, IntoEnumIterator, IntoStaticStr};
 use theme::ThemeSettings;
-use tokio::runtime::Handle;
 use ui::{Icon, IconName, List, Tooltip, prelude::*};
 use util::ResultExt;
 
@@ -460,22 +459,22 @@ impl BedrockModel {
         &self,
         request: bedrock::Request,
         cx: &AsyncApp,
-    ) -> Result<
-        BoxFuture<'static, BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>>,
+    ) -> BoxFuture<
+        'static,
+        Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>>,
     > {
-        let runtime_client = self
-            .get_or_init_client(cx)
+        let Ok(runtime_client) = self
+            .get_or_init_client(&cx)
             .cloned()
-            .context("Bedrock client not initialized")?;
-        let owned_handle = self.handler.clone();
+            .context("Bedrock client not initialized")
+        else {
+            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
+        };
 
-        Ok(async move {
-            let request = bedrock::stream_completion(runtime_client, request, owned_handle);
-            request.await.unwrap_or_else(|e| {
-                futures::stream::once(async move { Err(BedrockError::ClientError(e)) }).boxed()
-            })
+        match Tokio::spawn(cx, bedrock::stream_completion(runtime_client, request)) {
+            Ok(res) => async { res.await.map_err(|err| anyhow!(err))? }.boxed(),
+            Err(err) => futures::future::ready(Err(anyhow!(err))).boxed(),
         }
-        .boxed())
     }
 }
 
@@ -570,12 +569,10 @@ impl LanguageModel for BedrockModel {
             Err(err) => return futures::future::ready(Err(err.into())).boxed(),
         };
 
-        let owned_handle = self.handler.clone();
-
         let request = self.stream_completion(request, cx);
         let future = self.request_limiter.stream(async move {
-            let response = request.map_err(|err| anyhow!(err))?.await;
-            let events = map_to_language_model_completion_events(response, owned_handle);
+            let response = request.await.map_err(|err| anyhow!(err))?;
+            let events = map_to_language_model_completion_events(response);
 
             if deny_tool_calls {
                 Ok(deny_tool_use_events(events).boxed())
@@ -879,7 +876,6 @@ pub fn get_bedrock_tokens(
 
 pub fn map_to_language_model_completion_events(
     events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
-    handle: Handle,
 ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
     struct RawToolUse {
         id: String,
@@ -892,198 +888,123 @@ pub fn map_to_language_model_completion_events(
         tool_uses_by_index: HashMap<i32, RawToolUse>,
     }
 
-    futures::stream::unfold(
-        State {
-            events,
-            tool_uses_by_index: HashMap::default(),
-        },
-        move |mut state: State| {
-            let inner_handle = handle.clone();
-            async move {
-                inner_handle
-                    .spawn(async {
-                        while let Some(event) = state.events.next().await {
-                            match event {
-                                Ok(event) => match event {
-                                    ConverseStreamOutput::ContentBlockDelta(cb_delta) => {
-                                        match cb_delta.delta {
-                                            Some(ContentBlockDelta::Text(text_out)) => {
-                                                let completion_event =
-                                                    LanguageModelCompletionEvent::Text(text_out);
-                                                return Some((Some(Ok(completion_event)), state));
-                                            }
-
-                                            Some(ContentBlockDelta::ToolUse(text_out)) => {
-                                                if let Some(tool_use) = state
-                                                    .tool_uses_by_index
-                                                    .get_mut(&cb_delta.content_block_index)
-                                                {
-                                                    tool_use.input_json.push_str(text_out.input());
-                                                }
-                                            }
-
-                                            Some(ContentBlockDelta::ReasoningContent(thinking)) => {
-                                                match thinking {
-                                                    ReasoningContentBlockDelta::RedactedContent(
-                                                        redacted,
-                                                    ) => {
-                                                        let thinking_event =
-                                                            LanguageModelCompletionEvent::Thinking {
-                                                                text: String::from_utf8(
-                                                                    redacted.into_inner(),
-                                                                )
-                                                                .unwrap_or("REDACTED".to_string()),
-                                                                signature: None,
-                                                            };
-
-                                                        return Some((
-                                                            Some(Ok(thinking_event)),
-                                                            state,
-                                                        ));
-                                                    }
-                                                    ReasoningContentBlockDelta::Signature(
-                                                        signature,
-                                                    ) => {
-                                                        return Some((
-                                                            Some(Ok(LanguageModelCompletionEvent::Thinking {
-                                                                text: "".to_string(),
-                                                                signature: Some(signature)
-                                                            })),
-                                                            state,
-                                                        ));
-                                                    }
-                                                    ReasoningContentBlockDelta::Text(thoughts) => {
-                                                        let thinking_event =
-                                                            LanguageModelCompletionEvent::Thinking {
-                                                                text: thoughts.to_string(),
-                                                                signature: None
-                                                            };
-
-                                                        return Some((
-                                                            Some(Ok(thinking_event)),
-                                                            state,
-                                                        ));
-                                                    }
-                                                    _ => {}
-                                                }
-                                            }
-                                            _ => {}
-                                        }
-                                    }
-                                    ConverseStreamOutput::ContentBlockStart(cb_start) => {
-                                        if let Some(ContentBlockStart::ToolUse(text_out)) =
-                                            cb_start.start
-                                        {
-                                            let tool_use = RawToolUse {
-                                                id: text_out.tool_use_id,
-                                                name: text_out.name,
-                                                input_json: String::new(),
-                                            };
-
-                                            state
-                                                .tool_uses_by_index
-                                                .insert(cb_start.content_block_index, tool_use);
-                                        }
-                                    }
-                                    ConverseStreamOutput::ContentBlockStop(cb_stop) => {
-                                        if let Some(tool_use) = state
-                                            .tool_uses_by_index
-                                            .remove(&cb_stop.content_block_index)
-                                        {
-                                            let tool_use_event = LanguageModelToolUse {
-                                                id: tool_use.id.into(),
-                                                name: tool_use.name.into(),
-                                                is_input_complete: true,
-                                                raw_input: tool_use.input_json.clone(),
-                                                input: if tool_use.input_json.is_empty() {
-                                                    Value::Null
-                                                } else {
-                                                    serde_json::Value::from_str(
-                                                        &tool_use.input_json,
-                                                    )
-                                                    .map_err(|err| anyhow!(err))
-                                                    .unwrap()
-                                                },
-                                            };
-
-                                            return Some((
-                                                Some(Ok(LanguageModelCompletionEvent::ToolUse(
-                                                    tool_use_event,
-                                                ))),
-                                                state,
-                                            ));
-                                        }
-                                    }
-
-                                    ConverseStreamOutput::Metadata(cb_meta) => {
-                                        if let Some(metadata) = cb_meta.usage {
-                                            let completion_event =
-                                                LanguageModelCompletionEvent::UsageUpdate(
-                                                    TokenUsage {
-                                                        input_tokens: metadata.input_tokens as u64,
-                                                        output_tokens: metadata.output_tokens as u64,
-                                                        cache_creation_input_tokens:
-                                                            metadata.cache_write_input_tokens.unwrap_or_default() as u64,
-                                                        cache_read_input_tokens:
-                                                            metadata.cache_read_input_tokens.unwrap_or_default() as u64,
-                                                    },
-                                                );
-                                            return Some((Some(Ok(completion_event)), state));
-                                        }
-                                    }
-                                    ConverseStreamOutput::MessageStop(message_stop) => {
-                                        let reason = match message_stop.stop_reason {
-                                            StopReason::ContentFiltered => {
-                                                LanguageModelCompletionEvent::Stop(
-                                                    language_model::StopReason::EndTurn,
-                                                )
-                                            }
-                                            StopReason::EndTurn => {
-                                                LanguageModelCompletionEvent::Stop(
-                                                    language_model::StopReason::EndTurn,
-                                                )
-                                            }
-                                            StopReason::GuardrailIntervened => {
-                                                LanguageModelCompletionEvent::Stop(
-                                                    language_model::StopReason::EndTurn,
-                                                )
-                                            }
-                                            StopReason::MaxTokens => {
-                                                LanguageModelCompletionEvent::Stop(
-                                                    language_model::StopReason::EndTurn,
-                                                )
-                                            }
-                                            StopReason::StopSequence => {
-                                                LanguageModelCompletionEvent::Stop(
-                                                    language_model::StopReason::EndTurn,
-                                                )
-                                            }
-                                            StopReason::ToolUse => {
-                                                LanguageModelCompletionEvent::Stop(
-                                                    language_model::StopReason::ToolUse,
-                                                )
-                                            }
-                                            _ => LanguageModelCompletionEvent::Stop(
-                                                language_model::StopReason::EndTurn,
-                                            ),
-                                        };
-                                        return Some((Some(Ok(reason)), state));
-                                    }
-                                    _ => {}
-                                },
+    let initial_state = State {
+        events,
+        tool_uses_by_index: HashMap::default(),
+    };
 
-                                Err(err) => return Some((Some(Err(anyhow!(err).into())), state)),
+    futures::stream::unfold(initial_state, |mut state| async move {
+        match state.events.next().await {
+            Some(event_result) => match event_result {
+                Ok(event) => {
+                    let result = match event {
+                        ConverseStreamOutput::ContentBlockDelta(cb_delta) => match cb_delta.delta {
+                            Some(ContentBlockDelta::Text(text)) => {
+                                Some(Ok(LanguageModelCompletionEvent::Text(text)))
+                            }
+                            Some(ContentBlockDelta::ToolUse(tool_output)) => {
+                                if let Some(tool_use) = state
+                                    .tool_uses_by_index
+                                    .get_mut(&cb_delta.content_block_index)
+                                {
+                                    tool_use.input_json.push_str(tool_output.input());
+                                }
+                                None
                             }
+                            Some(ContentBlockDelta::ReasoningContent(thinking)) => match thinking {
+                                ReasoningContentBlockDelta::Text(thoughts) => {
+                                    Some(Ok(LanguageModelCompletionEvent::Thinking {
+                                        text: thoughts.clone(),
+                                        signature: None,
+                                    }))
+                                }
+                                ReasoningContentBlockDelta::Signature(sig) => {
+                                    Some(Ok(LanguageModelCompletionEvent::Thinking {
+                                        text: "".into(),
+                                        signature: Some(sig),
+                                    }))
+                                }
+                                ReasoningContentBlockDelta::RedactedContent(redacted) => {
+                                    let content = String::from_utf8(redacted.into_inner())
+                                        .unwrap_or("REDACTED".to_string());
+                                    Some(Ok(LanguageModelCompletionEvent::Thinking {
+                                        text: content,
+                                        signature: None,
+                                    }))
+                                }
+                                _ => None,
+                            },
+                            _ => None,
+                        },
+                        ConverseStreamOutput::ContentBlockStart(cb_start) => {
+                            if let Some(ContentBlockStart::ToolUse(tool_start)) = cb_start.start {
+                                state.tool_uses_by_index.insert(
+                                    cb_start.content_block_index,
+                                    RawToolUse {
+                                        id: tool_start.tool_use_id,
+                                        name: tool_start.name,
+                                        input_json: String::new(),
+                                    },
+                                );
+                            }
+                            None
                         }
-                        None
-                    })
-                    .await
-                    .log_err()
-                    .flatten()
-            }
-        },
-    )
-    .filter_map(|event| async move { event })
+                        ConverseStreamOutput::ContentBlockStop(cb_stop) => state
+                            .tool_uses_by_index
+                            .remove(&cb_stop.content_block_index)
+                            .map(|tool_use| {
+                                let input = if tool_use.input_json.is_empty() {
+                                    Value::Null
+                                } else {
+                                    serde_json::Value::from_str(&tool_use.input_json)
+                                        .unwrap_or(Value::Null)
+                                };
+
+                                Ok(LanguageModelCompletionEvent::ToolUse(
+                                    LanguageModelToolUse {
+                                        id: tool_use.id.into(),
+                                        name: tool_use.name.into(),
+                                        is_input_complete: true,
+                                        raw_input: tool_use.input_json.clone(),
+                                        input,
+                                    },
+                                ))
+                            }),
+                        ConverseStreamOutput::Metadata(cb_meta) => cb_meta.usage.map(|metadata| {
+                            Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
+                                input_tokens: metadata.input_tokens as u64,
+                                output_tokens: metadata.output_tokens as u64,
+                                cache_creation_input_tokens: metadata
+                                    .cache_write_input_tokens
+                                    .unwrap_or_default()
+                                    as u64,
+                                cache_read_input_tokens: metadata
+                                    .cache_read_input_tokens
+                                    .unwrap_or_default()
+                                    as u64,
+                            }))
+                        }),
+                        ConverseStreamOutput::MessageStop(message_stop) => {
+                            let stop_reason = match message_stop.stop_reason {
+                                StopReason::ToolUse => language_model::StopReason::ToolUse,
+                                _ => language_model::StopReason::EndTurn,
+                            };
+                            Some(Ok(LanguageModelCompletionEvent::Stop(stop_reason)))
+                        }
+                        _ => None,
+                    };
+
+                    Some((result, state))
+                }
+                Err(err) => Some((
+                    Some(Err(LanguageModelCompletionError::Other(anyhow!(err)))),
+                    state,
+                )),
+            },
+            None => None,
+        }
+    })
+    .filter_map(|result| async move { result })
 }
 
 struct ConfigurationView {