1mod models;
2
3use anyhow::{Result, anyhow};
4use aws_sdk_bedrockruntime as bedrock;
5pub use aws_sdk_bedrockruntime as bedrock_client;
6use aws_sdk_bedrockruntime::types::InferenceConfiguration;
7pub use aws_sdk_bedrockruntime::types::{
8 AnyToolChoice as BedrockAnyToolChoice, AutoToolChoice as BedrockAutoToolChoice,
9 ContentBlock as BedrockInnerContent, Tool as BedrockTool, ToolChoice as BedrockToolChoice,
10 ToolConfiguration as BedrockToolConfig, ToolInputSchema as BedrockToolInputSchema,
11 ToolSpecification as BedrockToolSpec,
12};
13pub use aws_smithy_types::Blob as BedrockBlob;
14use aws_smithy_types::{Document, Number as AwsNumber};
15pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest;
16pub use bedrock::types::{
17 ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole,
18 ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse,
19 ImageBlock as BedrockImageBlock, ImageFormat as BedrockImageFormat,
20 ImageSource as BedrockImageSource, Message as BedrockMessage,
21 ReasoningContentBlock as BedrockThinkingBlock, ReasoningTextBlock as BedrockThinkingTextBlock,
22 ResponseStream as BedrockResponseStream, SystemContentBlock as BedrockSystemContentBlock,
23 ToolResultBlock as BedrockToolResultBlock,
24 ToolResultContentBlock as BedrockToolResultContentBlock,
25 ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock,
26};
27use futures::stream::{self, BoxStream};
28use serde::{Deserialize, Serialize};
29use serde_json::{Number, Value};
30use std::collections::HashMap;
31use thiserror::Error;
32
33pub use crate::models::*;
34
35pub const CONTEXT_1M_BETA_HEADER: &str = "context-1m-2025-08-07";
36
37pub async fn stream_completion(
38 client: bedrock::Client,
39 request: Request,
40) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, anyhow::Error>>, BedrockError> {
41 let mut response = bedrock::Client::converse_stream(&client)
42 .model_id(request.model.clone())
43 .set_messages(request.messages.into());
44
45 let mut additional_fields: HashMap<String, Document> = HashMap::new();
46
47 match request.thinking {
48 Some(Thinking::Enabled {
49 budget_tokens: Some(budget_tokens),
50 }) => {
51 let thinking_config = HashMap::from([
52 ("type".to_string(), Document::String("enabled".to_string())),
53 (
54 "budget_tokens".to_string(),
55 Document::Number(AwsNumber::PosInt(budget_tokens)),
56 ),
57 ]);
58 additional_fields.insert("thinking".to_string(), Document::from(thinking_config));
59 }
60 Some(Thinking::Adaptive { effort: _ }) => {
61 let thinking_config =
62 HashMap::from([("type".to_string(), Document::String("adaptive".to_string()))]);
63 additional_fields.insert("thinking".to_string(), Document::from(thinking_config));
64 }
65 _ => {}
66 }
67
68 if request.allow_extended_context {
69 additional_fields.insert(
70 "anthropic_beta".to_string(),
71 Document::Array(vec![Document::String(CONTEXT_1M_BETA_HEADER.to_string())]),
72 );
73 }
74
75 if !additional_fields.is_empty() {
76 response = response.additional_model_request_fields(Document::Object(additional_fields));
77 }
78
79 if request.tools.as_ref().is_some_and(|t| !t.tools.is_empty()) {
80 response = response.set_tool_config(request.tools);
81 }
82
83 let inference_config = InferenceConfiguration::builder()
84 .max_tokens(request.max_tokens as i32)
85 .set_temperature(request.temperature)
86 .set_top_p(request.top_p)
87 .build();
88
89 response = response.inference_config(inference_config);
90
91 if let Some(system) = request.system {
92 if !system.is_empty() {
93 response = response.system(BedrockSystemContentBlock::Text(system));
94 }
95 }
96
97 let output = response.send().await.map_err(|err| match err {
98 bedrock::error::SdkError::ServiceError(ctx) => {
99 use bedrock::operation::converse_stream::ConverseStreamError;
100 let err = ctx.into_err();
101 match &err {
102 ConverseStreamError::ValidationException(e) => {
103 BedrockError::Validation(e.message().unwrap_or("validation error").to_string())
104 }
105 ConverseStreamError::ThrottlingException(_) => BedrockError::RateLimited,
106 ConverseStreamError::ServiceUnavailableException(_)
107 | ConverseStreamError::ModelNotReadyException(_) => {
108 BedrockError::ServiceUnavailable
109 }
110 ConverseStreamError::AccessDeniedException(e) => {
111 BedrockError::AccessDenied(e.message().unwrap_or("access denied").to_string())
112 }
113 ConverseStreamError::InternalServerException(e) => BedrockError::InternalServer(
114 e.message().unwrap_or("internal server error").to_string(),
115 ),
116 _ => BedrockError::Other(err.into()),
117 }
118 }
119 other => BedrockError::Other(other.into()),
120 });
121
122 let stream = Box::pin(stream::unfold(
123 output?.stream,
124 move |mut stream| async move {
125 match stream.recv().await {
126 Ok(Some(output)) => Some((Ok(output), stream)),
127 Ok(None) => None,
128 Err(err) => Some((
129 Err(anyhow!(
130 "{}",
131 aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
132 )),
133 stream,
134 )),
135 }
136 },
137 ));
138
139 Ok(stream)
140}
141
142pub fn aws_document_to_value(document: &Document) -> Value {
143 match document {
144 Document::Null => Value::Null,
145 Document::Bool(value) => Value::Bool(*value),
146 Document::Number(value) => match *value {
147 AwsNumber::PosInt(value) => Value::Number(Number::from(value)),
148 AwsNumber::NegInt(value) => Value::Number(Number::from(value)),
149 AwsNumber::Float(value) => Value::Number(Number::from_f64(value).unwrap()),
150 },
151 Document::String(value) => Value::String(value.clone()),
152 Document::Array(array) => Value::Array(array.iter().map(aws_document_to_value).collect()),
153 Document::Object(map) => Value::Object(
154 map.iter()
155 .map(|(key, value)| (key.clone(), aws_document_to_value(value)))
156 .collect(),
157 ),
158 }
159}
160
161pub fn value_to_aws_document(value: &Value) -> Document {
162 match value {
163 Value::Null => Document::Null,
164 Value::Bool(value) => Document::Bool(*value),
165 Value::Number(value) => {
166 if let Some(value) = value.as_u64() {
167 Document::Number(AwsNumber::PosInt(value))
168 } else if let Some(value) = value.as_i64() {
169 Document::Number(AwsNumber::NegInt(value))
170 } else if let Some(value) = value.as_f64() {
171 Document::Number(AwsNumber::Float(value))
172 } else {
173 Document::Null
174 }
175 }
176 Value::String(value) => Document::String(value.clone()),
177 Value::Array(array) => Document::Array(array.iter().map(value_to_aws_document).collect()),
178 Value::Object(map) => Document::Object(
179 map.iter()
180 .map(|(key, value)| (key.clone(), value_to_aws_document(value)))
181 .collect(),
182 ),
183 }
184}
185
186#[derive(Debug, Serialize, Deserialize)]
187pub enum Thinking {
188 Enabled {
189 budget_tokens: Option<u64>,
190 },
191 Adaptive {
192 effort: BedrockAdaptiveThinkingEffort,
193 },
194}
195
196#[derive(Debug)]
197pub struct Request {
198 pub model: String,
199 pub max_tokens: u64,
200 pub messages: Vec<BedrockMessage>,
201 pub tools: Option<BedrockToolConfig>,
202 pub thinking: Option<Thinking>,
203 pub system: Option<String>,
204 pub metadata: Option<Metadata>,
205 pub stop_sequences: Vec<String>,
206 pub temperature: Option<f32>,
207 pub top_k: Option<u32>,
208 pub top_p: Option<f32>,
209 pub allow_extended_context: bool,
210}
211
212#[derive(Debug, Serialize, Deserialize)]
213pub struct Metadata {
214 pub user_id: Option<String>,
215}
216
217#[derive(Error, Debug)]
218pub enum BedrockError {
219 #[error("{0}")]
220 Validation(String),
221 #[error("rate limited")]
222 RateLimited,
223 #[error("service unavailable")]
224 ServiceUnavailable,
225 #[error("{0}")]
226 AccessDenied(String),
227 #[error("{0}")]
228 InternalServer(String),
229 #[error(transparent)]
230 Other(#[from] anyhow::Error),
231}