Cargo.lock 🔗
@@ -1911,7 +1911,6 @@ dependencies = [
"serde_json",
"strum 0.27.1",
"thiserror 2.0.12",
- "tokio",
"workspace-hack",
]
Shardul Vaidya and Peter Tripp created
Closes #26030
Release Notes:
- Fixed Bedrock bug causing streaming responses to return as one big
chunk
---------
Co-authored-by: Peter Tripp <peter@zed.dev>
Cargo.lock | 1
crates/bedrock/Cargo.toml | 1
crates/bedrock/src/bedrock.rs | 111 ++---
crates/language_models/src/provider/bedrock.rs | 335 +++++++------------
4 files changed, 177 insertions(+), 271 deletions(-)
@@ -1911,7 +1911,6 @@ dependencies = [
"serde_json",
"strum 0.27.1",
"thiserror 2.0.12",
- "tokio",
"workspace-hack",
]
@@ -25,5 +25,4 @@ serde.workspace = true
serde_json.workspace = true
strum.workspace = true
thiserror.workspace = true
-tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
workspace-hack.workspace = true
@@ -1,9 +1,6 @@
mod models;
-use std::collections::HashMap;
-use std::pin::Pin;
-
-use anyhow::{Context as _, Error, Result, anyhow};
+use anyhow::{Context, Error, Result, anyhow};
use aws_sdk_bedrockruntime as bedrock;
pub use aws_sdk_bedrockruntime as bedrock_client;
pub use aws_sdk_bedrockruntime::types::{
@@ -24,9 +21,10 @@ pub use bedrock::types::{
ToolResultContentBlock as BedrockToolResultContentBlock,
ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock,
};
-use futures::stream::{self, BoxStream, Stream};
+use futures::stream::{self, BoxStream};
use serde::{Deserialize, Serialize};
use serde_json::{Number, Value};
+use std::collections::HashMap;
use thiserror::Error;
pub use crate::models::*;
@@ -34,70 +32,59 @@ pub use crate::models::*;
pub async fn stream_completion(
client: bedrock::Client,
request: Request,
- handle: tokio::runtime::Handle,
) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
- handle
- .spawn(async move {
- let mut response = bedrock::Client::converse_stream(&client)
- .model_id(request.model.clone())
- .set_messages(request.messages.into());
+ let mut response = bedrock::Client::converse_stream(&client)
+ .model_id(request.model.clone())
+ .set_messages(request.messages.into());
- if let Some(Thinking::Enabled {
- budget_tokens: Some(budget_tokens),
- }) = request.thinking
- {
- response =
- response.additional_model_request_fields(Document::Object(HashMap::from([(
- "thinking".to_string(),
- Document::from(HashMap::from([
- ("type".to_string(), Document::String("enabled".to_string())),
- (
- "budget_tokens".to_string(),
- Document::Number(AwsNumber::PosInt(budget_tokens)),
- ),
- ])),
- )])));
- }
+ if let Some(Thinking::Enabled {
+ budget_tokens: Some(budget_tokens),
+ }) = request.thinking
+ {
+ let thinking_config = HashMap::from([
+ ("type".to_string(), Document::String("enabled".to_string())),
+ (
+ "budget_tokens".to_string(),
+ Document::Number(AwsNumber::PosInt(budget_tokens)),
+ ),
+ ]);
+ response = response.additional_model_request_fields(Document::Object(HashMap::from([(
+ "thinking".to_string(),
+ Document::from(thinking_config),
+ )])));
+ }
- if request.tools.is_some() && !request.tools.as_ref().unwrap().tools.is_empty() {
- response = response.set_tool_config(request.tools);
- }
+ if request
+ .tools
+ .as_ref()
+ .map_or(false, |t| !t.tools.is_empty())
+ {
+ response = response.set_tool_config(request.tools);
+ }
- let response = response.send().await;
+ let output = response
+ .send()
+ .await
+ .context("Failed to send API request to Bedrock");
- match response {
- Ok(output) => {
- let stream: Pin<
- Box<
- dyn Stream<Item = Result<BedrockStreamingResponse, BedrockError>>
- + Send,
- >,
- > = Box::pin(stream::unfold(output.stream, |mut stream| async move {
- match stream.recv().await {
- Ok(Some(output)) => Some(({ Ok(output) }, stream)),
- Ok(None) => None,
- Err(err) => {
- Some((
- // TODO: Figure out how we can capture Throttling Exceptions
- Err(BedrockError::ClientError(anyhow!(
- "{:?}",
- aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
- ))),
- stream,
- ))
- }
- }
- }));
- Ok(stream)
- }
- Err(err) => Err(anyhow!(
- "{:?}",
- aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
+ let stream = Box::pin(stream::unfold(
+ output?.stream,
+ move |mut stream| async move {
+ match stream.recv().await {
+ Ok(Some(output)) => Some((Ok(output), stream)),
+ Ok(None) => None,
+ Err(err) => Some((
+ Err(BedrockError::ClientError(anyhow!(
+ "{:?}",
+ aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
+ ))),
+ stream,
)),
}
- })
- .await
- .context("spawning a task")?
+ },
+ ));
+
+ Ok(stream)
}
pub fn aws_document_to_value(document: &Document) -> Value {
@@ -46,7 +46,6 @@ use settings::{Settings, SettingsStore};
use smol::lock::OnceCell;
use strum::{EnumIter, IntoEnumIterator, IntoStaticStr};
use theme::ThemeSettings;
-use tokio::runtime::Handle;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::ResultExt;
@@ -460,22 +459,22 @@ impl BedrockModel {
&self,
request: bedrock::Request,
cx: &AsyncApp,
- ) -> Result<
- BoxFuture<'static, BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>>,
+ ) -> BoxFuture<
+ 'static,
+ Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>>,
> {
- let runtime_client = self
- .get_or_init_client(cx)
+ let Ok(runtime_client) = self
+ .get_or_init_client(&cx)
.cloned()
- .context("Bedrock client not initialized")?;
- let owned_handle = self.handler.clone();
+ .context("Bedrock client not initialized")
+ else {
+ return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
+ };
- Ok(async move {
- let request = bedrock::stream_completion(runtime_client, request, owned_handle);
- request.await.unwrap_or_else(|e| {
- futures::stream::once(async move { Err(BedrockError::ClientError(e)) }).boxed()
- })
+ match Tokio::spawn(cx, bedrock::stream_completion(runtime_client, request)) {
+ Ok(res) => async { res.await.map_err(|err| anyhow!(err))? }.boxed(),
+ Err(err) => futures::future::ready(Err(anyhow!(err))).boxed(),
}
- .boxed())
}
}
@@ -570,12 +569,10 @@ impl LanguageModel for BedrockModel {
Err(err) => return futures::future::ready(Err(err.into())).boxed(),
};
- let owned_handle = self.handler.clone();
-
let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move {
- let response = request.map_err(|err| anyhow!(err))?.await;
- let events = map_to_language_model_completion_events(response, owned_handle);
+ let response = request.await.map_err(|err| anyhow!(err))?;
+ let events = map_to_language_model_completion_events(response);
if deny_tool_calls {
Ok(deny_tool_use_events(events).boxed())
@@ -879,7 +876,6 @@ pub fn get_bedrock_tokens(
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
- handle: Handle,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
struct RawToolUse {
id: String,
@@ -892,198 +888,123 @@ pub fn map_to_language_model_completion_events(
tool_uses_by_index: HashMap<i32, RawToolUse>,
}
- futures::stream::unfold(
- State {
- events,
- tool_uses_by_index: HashMap::default(),
- },
- move |mut state: State| {
- let inner_handle = handle.clone();
- async move {
- inner_handle
- .spawn(async {
- while let Some(event) = state.events.next().await {
- match event {
- Ok(event) => match event {
- ConverseStreamOutput::ContentBlockDelta(cb_delta) => {
- match cb_delta.delta {
- Some(ContentBlockDelta::Text(text_out)) => {
- let completion_event =
- LanguageModelCompletionEvent::Text(text_out);
- return Some((Some(Ok(completion_event)), state));
- }
-
- Some(ContentBlockDelta::ToolUse(text_out)) => {
- if let Some(tool_use) = state
- .tool_uses_by_index
- .get_mut(&cb_delta.content_block_index)
- {
- tool_use.input_json.push_str(text_out.input());
- }
- }
-
- Some(ContentBlockDelta::ReasoningContent(thinking)) => {
- match thinking {
- ReasoningContentBlockDelta::RedactedContent(
- redacted,
- ) => {
- let thinking_event =
- LanguageModelCompletionEvent::Thinking {
- text: String::from_utf8(
- redacted.into_inner(),
- )
- .unwrap_or("REDACTED".to_string()),
- signature: None,
- };
-
- return Some((
- Some(Ok(thinking_event)),
- state,
- ));
- }
- ReasoningContentBlockDelta::Signature(
- signature,
- ) => {
- return Some((
- Some(Ok(LanguageModelCompletionEvent::Thinking {
- text: "".to_string(),
- signature: Some(signature)
- })),
- state,
- ));
- }
- ReasoningContentBlockDelta::Text(thoughts) => {
- let thinking_event =
- LanguageModelCompletionEvent::Thinking {
- text: thoughts.to_string(),
- signature: None
- };
-
- return Some((
- Some(Ok(thinking_event)),
- state,
- ));
- }
- _ => {}
- }
- }
- _ => {}
- }
- }
- ConverseStreamOutput::ContentBlockStart(cb_start) => {
- if let Some(ContentBlockStart::ToolUse(text_out)) =
- cb_start.start
- {
- let tool_use = RawToolUse {
- id: text_out.tool_use_id,
- name: text_out.name,
- input_json: String::new(),
- };
-
- state
- .tool_uses_by_index
- .insert(cb_start.content_block_index, tool_use);
- }
- }
- ConverseStreamOutput::ContentBlockStop(cb_stop) => {
- if let Some(tool_use) = state
- .tool_uses_by_index
- .remove(&cb_stop.content_block_index)
- {
- let tool_use_event = LanguageModelToolUse {
- id: tool_use.id.into(),
- name: tool_use.name.into(),
- is_input_complete: true,
- raw_input: tool_use.input_json.clone(),
- input: if tool_use.input_json.is_empty() {
- Value::Null
- } else {
- serde_json::Value::from_str(
- &tool_use.input_json,
- )
- .map_err(|err| anyhow!(err))
- .unwrap()
- },
- };
-
- return Some((
- Some(Ok(LanguageModelCompletionEvent::ToolUse(
- tool_use_event,
- ))),
- state,
- ));
- }
- }
-
- ConverseStreamOutput::Metadata(cb_meta) => {
- if let Some(metadata) = cb_meta.usage {
- let completion_event =
- LanguageModelCompletionEvent::UsageUpdate(
- TokenUsage {
- input_tokens: metadata.input_tokens as u64,
- output_tokens: metadata.output_tokens as u64,
- cache_creation_input_tokens:
- metadata.cache_write_input_tokens.unwrap_or_default() as u64,
- cache_read_input_tokens:
- metadata.cache_read_input_tokens.unwrap_or_default() as u64,
- },
- );
- return Some((Some(Ok(completion_event)), state));
- }
- }
- ConverseStreamOutput::MessageStop(message_stop) => {
- let reason = match message_stop.stop_reason {
- StopReason::ContentFiltered => {
- LanguageModelCompletionEvent::Stop(
- language_model::StopReason::EndTurn,
- )
- }
- StopReason::EndTurn => {
- LanguageModelCompletionEvent::Stop(
- language_model::StopReason::EndTurn,
- )
- }
- StopReason::GuardrailIntervened => {
- LanguageModelCompletionEvent::Stop(
- language_model::StopReason::EndTurn,
- )
- }
- StopReason::MaxTokens => {
- LanguageModelCompletionEvent::Stop(
- language_model::StopReason::EndTurn,
- )
- }
- StopReason::StopSequence => {
- LanguageModelCompletionEvent::Stop(
- language_model::StopReason::EndTurn,
- )
- }
- StopReason::ToolUse => {
- LanguageModelCompletionEvent::Stop(
- language_model::StopReason::ToolUse,
- )
- }
- _ => LanguageModelCompletionEvent::Stop(
- language_model::StopReason::EndTurn,
- ),
- };
- return Some((Some(Ok(reason)), state));
- }
- _ => {}
- },
+ let initial_state = State {
+ events,
+ tool_uses_by_index: HashMap::default(),
+ };
- Err(err) => return Some((Some(Err(anyhow!(err).into())), state)),
+ futures::stream::unfold(initial_state, |mut state| async move {
+ match state.events.next().await {
+ Some(event_result) => match event_result {
+ Ok(event) => {
+ let result = match event {
+ ConverseStreamOutput::ContentBlockDelta(cb_delta) => match cb_delta.delta {
+ Some(ContentBlockDelta::Text(text)) => {
+ Some(Ok(LanguageModelCompletionEvent::Text(text)))
+ }
+ Some(ContentBlockDelta::ToolUse(tool_output)) => {
+ if let Some(tool_use) = state
+ .tool_uses_by_index
+ .get_mut(&cb_delta.content_block_index)
+ {
+ tool_use.input_json.push_str(tool_output.input());
+ }
+ None
}
+ Some(ContentBlockDelta::ReasoningContent(thinking)) => match thinking {
+ ReasoningContentBlockDelta::Text(thoughts) => {
+ Some(Ok(LanguageModelCompletionEvent::Thinking {
+ text: thoughts.clone(),
+ signature: None,
+ }))
+ }
+ ReasoningContentBlockDelta::Signature(sig) => {
+ Some(Ok(LanguageModelCompletionEvent::Thinking {
+ text: "".into(),
+ signature: Some(sig),
+ }))
+ }
+ ReasoningContentBlockDelta::RedactedContent(redacted) => {
+ let content = String::from_utf8(redacted.into_inner())
+ .unwrap_or("REDACTED".to_string());
+ Some(Ok(LanguageModelCompletionEvent::Thinking {
+ text: content,
+ signature: None,
+ }))
+ }
+ _ => None,
+ },
+ _ => None,
+ },
+ ConverseStreamOutput::ContentBlockStart(cb_start) => {
+ if let Some(ContentBlockStart::ToolUse(tool_start)) = cb_start.start {
+ state.tool_uses_by_index.insert(
+ cb_start.content_block_index,
+ RawToolUse {
+ id: tool_start.tool_use_id,
+ name: tool_start.name,
+ input_json: String::new(),
+ },
+ );
+ }
+ None
}
- None
- })
- .await
- .log_err()
- .flatten()
- }
- },
- )
- .filter_map(|event| async move { event })
+ ConverseStreamOutput::ContentBlockStop(cb_stop) => state
+ .tool_uses_by_index
+ .remove(&cb_stop.content_block_index)
+ .map(|tool_use| {
+ let input = if tool_use.input_json.is_empty() {
+ Value::Null
+ } else {
+ serde_json::Value::from_str(&tool_use.input_json)
+ .unwrap_or(Value::Null)
+ };
+
+ Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: tool_use.id.into(),
+ name: tool_use.name.into(),
+ is_input_complete: true,
+ raw_input: tool_use.input_json.clone(),
+ input,
+ },
+ ))
+ }),
+ ConverseStreamOutput::Metadata(cb_meta) => cb_meta.usage.map(|metadata| {
+ Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
+ input_tokens: metadata.input_tokens as u64,
+ output_tokens: metadata.output_tokens as u64,
+ cache_creation_input_tokens: metadata
+ .cache_write_input_tokens
+ .unwrap_or_default()
+ as u64,
+ cache_read_input_tokens: metadata
+ .cache_read_input_tokens
+ .unwrap_or_default()
+ as u64,
+ }))
+ }),
+ ConverseStreamOutput::MessageStop(message_stop) => {
+ let stop_reason = match message_stop.stop_reason {
+ StopReason::ToolUse => language_model::StopReason::ToolUse,
+ _ => language_model::StopReason::EndTurn,
+ };
+ Some(Ok(LanguageModelCompletionEvent::Stop(stop_reason)))
+ }
+ _ => None,
+ };
+
+ Some((result, state))
+ }
+ Err(err) => Some((
+ Some(Err(LanguageModelCompletionError::Other(anyhow!(err)))),
+ state,
+ )),
+ },
+ None => None,
+ }
+ })
+ .filter_map(|result| async move { result })
}
struct ConfigurationView {