bedrock.rs

  1mod models;
  2
  3use anyhow::{Context, Error, Result, anyhow};
  4use aws_sdk_bedrockruntime as bedrock;
  5pub use aws_sdk_bedrockruntime as bedrock_client;
  6use aws_sdk_bedrockruntime::types::InferenceConfiguration;
  7pub use aws_sdk_bedrockruntime::types::{
  8    AnyToolChoice as BedrockAnyToolChoice, AutoToolChoice as BedrockAutoToolChoice,
  9    ContentBlock as BedrockInnerContent, Tool as BedrockTool, ToolChoice as BedrockToolChoice,
 10    ToolConfiguration as BedrockToolConfig, ToolInputSchema as BedrockToolInputSchema,
 11    ToolSpecification as BedrockToolSpec,
 12};
 13pub use aws_smithy_types::Blob as BedrockBlob;
 14use aws_smithy_types::{Document, Number as AwsNumber};
 15pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest;
 16pub use bedrock::types::{
 17    ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole,
 18    ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse,
 19    ImageBlock as BedrockImageBlock, ImageFormat as BedrockImageFormat,
 20    ImageSource as BedrockImageSource, Message as BedrockMessage,
 21    ReasoningContentBlock as BedrockThinkingBlock, ReasoningTextBlock as BedrockThinkingTextBlock,
 22    ResponseStream as BedrockResponseStream, SystemContentBlock as BedrockSystemContentBlock,
 23    ToolResultBlock as BedrockToolResultBlock,
 24    ToolResultContentBlock as BedrockToolResultContentBlock,
 25    ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock,
 26};
 27use futures::stream::{self, BoxStream};
 28use serde::{Deserialize, Serialize};
 29use serde_json::{Number, Value};
 30use std::collections::HashMap;
 31use thiserror::Error;
 32
 33pub use crate::models::*;
 34
 35pub const CONTEXT_1M_BETA_HEADER: &str = "context-1m-2025-08-07";
 36
 37pub async fn stream_completion(
 38    client: bedrock::Client,
 39    request: Request,
 40) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
 41    let mut response = bedrock::Client::converse_stream(&client)
 42        .model_id(request.model.clone())
 43        .set_messages(request.messages.into());
 44
 45    let mut additional_fields: HashMap<String, Document> = HashMap::new();
 46
 47    match request.thinking {
 48        Some(Thinking::Enabled {
 49            budget_tokens: Some(budget_tokens),
 50        }) => {
 51            let thinking_config = HashMap::from([
 52                ("type".to_string(), Document::String("enabled".to_string())),
 53                (
 54                    "budget_tokens".to_string(),
 55                    Document::Number(AwsNumber::PosInt(budget_tokens)),
 56                ),
 57            ]);
 58            additional_fields.insert("thinking".to_string(), Document::from(thinking_config));
 59        }
 60        Some(Thinking::Adaptive { effort: _ }) => {
 61            let thinking_config =
 62                HashMap::from([("type".to_string(), Document::String("adaptive".to_string()))]);
 63            additional_fields.insert("thinking".to_string(), Document::from(thinking_config));
 64        }
 65        _ => {}
 66    }
 67
 68    if request.allow_extended_context {
 69        additional_fields.insert(
 70            "anthropic_beta".to_string(),
 71            Document::Array(vec![Document::String(CONTEXT_1M_BETA_HEADER.to_string())]),
 72        );
 73    }
 74
 75    if !additional_fields.is_empty() {
 76        response = response.additional_model_request_fields(Document::Object(additional_fields));
 77    }
 78
 79    if request.tools.as_ref().is_some_and(|t| !t.tools.is_empty()) {
 80        response = response.set_tool_config(request.tools);
 81    }
 82
 83    let inference_config = InferenceConfiguration::builder()
 84        .max_tokens(request.max_tokens as i32)
 85        .set_temperature(request.temperature)
 86        .set_top_p(request.top_p)
 87        .build();
 88
 89    response = response.inference_config(inference_config);
 90
 91    if let Some(system) = request.system {
 92        if !system.is_empty() {
 93            response = response.system(BedrockSystemContentBlock::Text(system));
 94        }
 95    }
 96
 97    let output = response
 98        .send()
 99        .await
100        .context("Failed to send API request to Bedrock");
101
102    let stream = Box::pin(stream::unfold(
103        output?.stream,
104        move |mut stream| async move {
105            match stream.recv().await {
106                Ok(Some(output)) => Some((Ok(output), stream)),
107                Ok(None) => None,
108                Err(err) => Some((
109                    Err(BedrockError::ClientError(anyhow!(
110                        "{}",
111                        aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
112                    ))),
113                    stream,
114                )),
115            }
116        },
117    ));
118
119    Ok(stream)
120}
121
122pub fn aws_document_to_value(document: &Document) -> Value {
123    match document {
124        Document::Null => Value::Null,
125        Document::Bool(value) => Value::Bool(*value),
126        Document::Number(value) => match *value {
127            AwsNumber::PosInt(value) => Value::Number(Number::from(value)),
128            AwsNumber::NegInt(value) => Value::Number(Number::from(value)),
129            AwsNumber::Float(value) => Value::Number(Number::from_f64(value).unwrap()),
130        },
131        Document::String(value) => Value::String(value.clone()),
132        Document::Array(array) => Value::Array(array.iter().map(aws_document_to_value).collect()),
133        Document::Object(map) => Value::Object(
134            map.iter()
135                .map(|(key, value)| (key.clone(), aws_document_to_value(value)))
136                .collect(),
137        ),
138    }
139}
140
141pub fn value_to_aws_document(value: &Value) -> Document {
142    match value {
143        Value::Null => Document::Null,
144        Value::Bool(value) => Document::Bool(*value),
145        Value::Number(value) => {
146            if let Some(value) = value.as_u64() {
147                Document::Number(AwsNumber::PosInt(value))
148            } else if let Some(value) = value.as_i64() {
149                Document::Number(AwsNumber::NegInt(value))
150            } else if let Some(value) = value.as_f64() {
151                Document::Number(AwsNumber::Float(value))
152            } else {
153                Document::Null
154            }
155        }
156        Value::String(value) => Document::String(value.clone()),
157        Value::Array(array) => Document::Array(array.iter().map(value_to_aws_document).collect()),
158        Value::Object(map) => Document::Object(
159            map.iter()
160                .map(|(key, value)| (key.clone(), value_to_aws_document(value)))
161                .collect(),
162        ),
163    }
164}
165
166#[derive(Debug, Serialize, Deserialize)]
167pub enum Thinking {
168    Enabled {
169        budget_tokens: Option<u64>,
170    },
171    Adaptive {
172        effort: BedrockAdaptiveThinkingEffort,
173    },
174}
175
176#[derive(Debug)]
177pub struct Request {
178    pub model: String,
179    pub max_tokens: u64,
180    pub messages: Vec<BedrockMessage>,
181    pub tools: Option<BedrockToolConfig>,
182    pub thinking: Option<Thinking>,
183    pub system: Option<String>,
184    pub metadata: Option<Metadata>,
185    pub stop_sequences: Vec<String>,
186    pub temperature: Option<f32>,
187    pub top_k: Option<u32>,
188    pub top_p: Option<f32>,
189    pub allow_extended_context: bool,
190}
191
192#[derive(Debug, Serialize, Deserialize)]
193pub struct Metadata {
194    pub user_id: Option<String>,
195}
196
197#[derive(Error, Debug)]
198pub enum BedrockError {
199    #[error("client error: {0}")]
200    ClientError(anyhow::Error),
201    #[error("extension error: {0}")]
202    ExtensionError(anyhow::Error),
203    #[error(transparent)]
204    Other(#[from] anyhow::Error),
205}