bedrock.rs

  1mod models;
  2
  3use anyhow::{Result, anyhow};
  4use aws_sdk_bedrockruntime as bedrock;
  5pub use aws_sdk_bedrockruntime as bedrock_client;
  6use aws_sdk_bedrockruntime::types::InferenceConfiguration;
  7pub use aws_sdk_bedrockruntime::types::{
  8    AnyToolChoice as BedrockAnyToolChoice, AutoToolChoice as BedrockAutoToolChoice,
  9    ContentBlock as BedrockInnerContent, Tool as BedrockTool, ToolChoice as BedrockToolChoice,
 10    ToolConfiguration as BedrockToolConfig, ToolInputSchema as BedrockToolInputSchema,
 11    ToolSpecification as BedrockToolSpec,
 12};
 13pub use aws_smithy_types::Blob as BedrockBlob;
 14use aws_smithy_types::{Document, Number as AwsNumber};
 15pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest;
 16pub use bedrock::types::{
 17    ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole,
 18    ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse,
 19    ImageBlock as BedrockImageBlock, ImageFormat as BedrockImageFormat,
 20    ImageSource as BedrockImageSource, Message as BedrockMessage,
 21    ReasoningContentBlock as BedrockThinkingBlock, ReasoningTextBlock as BedrockThinkingTextBlock,
 22    ResponseStream as BedrockResponseStream, SystemContentBlock as BedrockSystemContentBlock,
 23    ToolResultBlock as BedrockToolResultBlock,
 24    ToolResultContentBlock as BedrockToolResultContentBlock,
 25    ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock,
 26};
 27use futures::stream::{self, BoxStream};
 28use serde::{Deserialize, Serialize};
 29use serde_json::{Number, Value};
 30use std::collections::HashMap;
 31use thiserror::Error;
 32
 33pub use crate::models::*;
 34
 35pub const CONTEXT_1M_BETA_HEADER: &str = "context-1m-2025-08-07";
 36
 37pub async fn stream_completion(
 38    client: bedrock::Client,
 39    request: Request,
 40) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, anyhow::Error>>, BedrockError> {
 41    let mut response = bedrock::Client::converse_stream(&client)
 42        .model_id(request.model.clone())
 43        .set_messages(request.messages.into());
 44
 45    let mut additional_fields: HashMap<String, Document> = HashMap::new();
 46
 47    match request.thinking {
 48        Some(Thinking::Enabled {
 49            budget_tokens: Some(budget_tokens),
 50        }) => {
 51            let thinking_config = HashMap::from([
 52                ("type".to_string(), Document::String("enabled".to_string())),
 53                (
 54                    "budget_tokens".to_string(),
 55                    Document::Number(AwsNumber::PosInt(budget_tokens)),
 56                ),
 57            ]);
 58            additional_fields.insert("thinking".to_string(), Document::from(thinking_config));
 59        }
 60        Some(Thinking::Adaptive { effort: _ }) => {
 61            let thinking_config =
 62                HashMap::from([("type".to_string(), Document::String("adaptive".to_string()))]);
 63            additional_fields.insert("thinking".to_string(), Document::from(thinking_config));
 64        }
 65        _ => {}
 66    }
 67
 68    if request.allow_extended_context {
 69        additional_fields.insert(
 70            "anthropic_beta".to_string(),
 71            Document::Array(vec![Document::String(CONTEXT_1M_BETA_HEADER.to_string())]),
 72        );
 73    }
 74
 75    if !additional_fields.is_empty() {
 76        response = response.additional_model_request_fields(Document::Object(additional_fields));
 77    }
 78
 79    if request.tools.as_ref().is_some_and(|t| !t.tools.is_empty()) {
 80        response = response.set_tool_config(request.tools);
 81    }
 82
 83    let inference_config = InferenceConfiguration::builder()
 84        .max_tokens(request.max_tokens as i32)
 85        .set_temperature(request.temperature)
 86        .set_top_p(request.top_p)
 87        .build();
 88
 89    response = response.inference_config(inference_config);
 90
 91    if let Some(system) = request.system {
 92        if !system.is_empty() {
 93            response = response.system(BedrockSystemContentBlock::Text(system));
 94        }
 95    }
 96
 97    let output = response.send().await.map_err(|err| match err {
 98        bedrock::error::SdkError::ServiceError(ctx) => {
 99            use bedrock::operation::converse_stream::ConverseStreamError;
100            let err = ctx.into_err();
101            match &err {
102                ConverseStreamError::ValidationException(e) => {
103                    BedrockError::Validation(e.message().unwrap_or("validation error").to_string())
104                }
105                ConverseStreamError::ThrottlingException(_) => BedrockError::RateLimited,
106                ConverseStreamError::ServiceUnavailableException(_)
107                | ConverseStreamError::ModelNotReadyException(_) => {
108                    BedrockError::ServiceUnavailable
109                }
110                ConverseStreamError::AccessDeniedException(e) => {
111                    BedrockError::AccessDenied(e.message().unwrap_or("access denied").to_string())
112                }
113                ConverseStreamError::InternalServerException(e) => BedrockError::InternalServer(
114                    e.message().unwrap_or("internal server error").to_string(),
115                ),
116                _ => BedrockError::Other(err.into()),
117            }
118        }
119        other => BedrockError::Other(other.into()),
120    });
121
122    let stream = Box::pin(stream::unfold(
123        output?.stream,
124        move |mut stream| async move {
125            match stream.recv().await {
126                Ok(Some(output)) => Some((Ok(output), stream)),
127                Ok(None) => None,
128                Err(err) => Some((
129                    Err(anyhow!(
130                        "{}",
131                        aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
132                    )),
133                    stream,
134                )),
135            }
136        },
137    ));
138
139    Ok(stream)
140}
141
142pub fn aws_document_to_value(document: &Document) -> Value {
143    match document {
144        Document::Null => Value::Null,
145        Document::Bool(value) => Value::Bool(*value),
146        Document::Number(value) => match *value {
147            AwsNumber::PosInt(value) => Value::Number(Number::from(value)),
148            AwsNumber::NegInt(value) => Value::Number(Number::from(value)),
149            AwsNumber::Float(value) => Value::Number(Number::from_f64(value).unwrap()),
150        },
151        Document::String(value) => Value::String(value.clone()),
152        Document::Array(array) => Value::Array(array.iter().map(aws_document_to_value).collect()),
153        Document::Object(map) => Value::Object(
154            map.iter()
155                .map(|(key, value)| (key.clone(), aws_document_to_value(value)))
156                .collect(),
157        ),
158    }
159}
160
161pub fn value_to_aws_document(value: &Value) -> Document {
162    match value {
163        Value::Null => Document::Null,
164        Value::Bool(value) => Document::Bool(*value),
165        Value::Number(value) => {
166            if let Some(value) = value.as_u64() {
167                Document::Number(AwsNumber::PosInt(value))
168            } else if let Some(value) = value.as_i64() {
169                Document::Number(AwsNumber::NegInt(value))
170            } else if let Some(value) = value.as_f64() {
171                Document::Number(AwsNumber::Float(value))
172            } else {
173                Document::Null
174            }
175        }
176        Value::String(value) => Document::String(value.clone()),
177        Value::Array(array) => Document::Array(array.iter().map(value_to_aws_document).collect()),
178        Value::Object(map) => Document::Object(
179            map.iter()
180                .map(|(key, value)| (key.clone(), value_to_aws_document(value)))
181                .collect(),
182        ),
183    }
184}
185
186#[derive(Debug, Serialize, Deserialize)]
187pub enum Thinking {
188    Enabled {
189        budget_tokens: Option<u64>,
190    },
191    Adaptive {
192        effort: BedrockAdaptiveThinkingEffort,
193    },
194}
195
196#[derive(Debug)]
197pub struct Request {
198    pub model: String,
199    pub max_tokens: u64,
200    pub messages: Vec<BedrockMessage>,
201    pub tools: Option<BedrockToolConfig>,
202    pub thinking: Option<Thinking>,
203    pub system: Option<String>,
204    pub metadata: Option<Metadata>,
205    pub stop_sequences: Vec<String>,
206    pub temperature: Option<f32>,
207    pub top_k: Option<u32>,
208    pub top_p: Option<f32>,
209    pub allow_extended_context: bool,
210}
211
212#[derive(Debug, Serialize, Deserialize)]
213pub struct Metadata {
214    pub user_id: Option<String>,
215}
216
217#[derive(Error, Debug)]
218pub enum BedrockError {
219    #[error("{0}")]
220    Validation(String),
221    #[error("rate limited")]
222    RateLimited,
223    #[error("service unavailable")]
224    ServiceUnavailable,
225    #[error("{0}")]
226    AccessDenied(String),
227    #[error("{0}")]
228    InternalServer(String),
229    #[error(transparent)]
230    Other(#[from] anyhow::Error),
231}