1mod models;
2
3use anyhow::{Context, Error, Result, anyhow};
4use aws_sdk_bedrockruntime as bedrock;
5pub use aws_sdk_bedrockruntime as bedrock_client;
6use aws_sdk_bedrockruntime::types::InferenceConfiguration;
7pub use aws_sdk_bedrockruntime::types::{
8 AnyToolChoice as BedrockAnyToolChoice, AutoToolChoice as BedrockAutoToolChoice,
9 ContentBlock as BedrockInnerContent, Tool as BedrockTool, ToolChoice as BedrockToolChoice,
10 ToolConfiguration as BedrockToolConfig, ToolInputSchema as BedrockToolInputSchema,
11 ToolSpecification as BedrockToolSpec,
12};
13pub use aws_smithy_types::Blob as BedrockBlob;
14use aws_smithy_types::{Document, Number as AwsNumber};
15pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest;
16pub use bedrock::types::{
17 ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole,
18 ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse,
19 ImageBlock as BedrockImageBlock, ImageFormat as BedrockImageFormat,
20 ImageSource as BedrockImageSource, Message as BedrockMessage,
21 ReasoningContentBlock as BedrockThinkingBlock, ReasoningTextBlock as BedrockThinkingTextBlock,
22 ResponseStream as BedrockResponseStream, SystemContentBlock as BedrockSystemContentBlock,
23 ToolResultBlock as BedrockToolResultBlock,
24 ToolResultContentBlock as BedrockToolResultContentBlock,
25 ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock,
26};
27use futures::stream::{self, BoxStream};
28use serde::{Deserialize, Serialize};
29use serde_json::{Number, Value};
30use std::collections::HashMap;
31use thiserror::Error;
32
33pub use crate::models::*;
34
35pub const CONTEXT_1M_BETA_HEADER: &str = "context-1m-2025-08-07";
36
37pub async fn stream_completion(
38 client: bedrock::Client,
39 request: Request,
40) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
41 let mut response = bedrock::Client::converse_stream(&client)
42 .model_id(request.model.clone())
43 .set_messages(request.messages.into());
44
45 let mut additional_fields: HashMap<String, Document> = HashMap::new();
46
47 match request.thinking {
48 Some(Thinking::Enabled {
49 budget_tokens: Some(budget_tokens),
50 }) => {
51 let thinking_config = HashMap::from([
52 ("type".to_string(), Document::String("enabled".to_string())),
53 (
54 "budget_tokens".to_string(),
55 Document::Number(AwsNumber::PosInt(budget_tokens)),
56 ),
57 ]);
58 additional_fields.insert("thinking".to_string(), Document::from(thinking_config));
59 }
60 Some(Thinking::Adaptive { effort: _ }) => {
61 let thinking_config =
62 HashMap::from([("type".to_string(), Document::String("adaptive".to_string()))]);
63 additional_fields.insert("thinking".to_string(), Document::from(thinking_config));
64 }
65 _ => {}
66 }
67
68 if request.allow_extended_context {
69 additional_fields.insert(
70 "anthropic_beta".to_string(),
71 Document::Array(vec![Document::String(CONTEXT_1M_BETA_HEADER.to_string())]),
72 );
73 }
74
75 if !additional_fields.is_empty() {
76 response = response.additional_model_request_fields(Document::Object(additional_fields));
77 }
78
79 if request.tools.as_ref().is_some_and(|t| !t.tools.is_empty()) {
80 response = response.set_tool_config(request.tools);
81 }
82
83 let inference_config = InferenceConfiguration::builder()
84 .max_tokens(request.max_tokens as i32)
85 .set_temperature(request.temperature)
86 .set_top_p(request.top_p)
87 .build();
88
89 response = response.inference_config(inference_config);
90
91 if let Some(system) = request.system {
92 if !system.is_empty() {
93 response = response.system(BedrockSystemContentBlock::Text(system));
94 }
95 }
96
97 let output = response
98 .send()
99 .await
100 .context("Failed to send API request to Bedrock");
101
102 let stream = Box::pin(stream::unfold(
103 output?.stream,
104 move |mut stream| async move {
105 match stream.recv().await {
106 Ok(Some(output)) => Some((Ok(output), stream)),
107 Ok(None) => None,
108 Err(err) => Some((
109 Err(BedrockError::ClientError(anyhow!(
110 "{}",
111 aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
112 ))),
113 stream,
114 )),
115 }
116 },
117 ));
118
119 Ok(stream)
120}
121
122pub fn aws_document_to_value(document: &Document) -> Value {
123 match document {
124 Document::Null => Value::Null,
125 Document::Bool(value) => Value::Bool(*value),
126 Document::Number(value) => match *value {
127 AwsNumber::PosInt(value) => Value::Number(Number::from(value)),
128 AwsNumber::NegInt(value) => Value::Number(Number::from(value)),
129 AwsNumber::Float(value) => Value::Number(Number::from_f64(value).unwrap()),
130 },
131 Document::String(value) => Value::String(value.clone()),
132 Document::Array(array) => Value::Array(array.iter().map(aws_document_to_value).collect()),
133 Document::Object(map) => Value::Object(
134 map.iter()
135 .map(|(key, value)| (key.clone(), aws_document_to_value(value)))
136 .collect(),
137 ),
138 }
139}
140
141pub fn value_to_aws_document(value: &Value) -> Document {
142 match value {
143 Value::Null => Document::Null,
144 Value::Bool(value) => Document::Bool(*value),
145 Value::Number(value) => {
146 if let Some(value) = value.as_u64() {
147 Document::Number(AwsNumber::PosInt(value))
148 } else if let Some(value) = value.as_i64() {
149 Document::Number(AwsNumber::NegInt(value))
150 } else if let Some(value) = value.as_f64() {
151 Document::Number(AwsNumber::Float(value))
152 } else {
153 Document::Null
154 }
155 }
156 Value::String(value) => Document::String(value.clone()),
157 Value::Array(array) => Document::Array(array.iter().map(value_to_aws_document).collect()),
158 Value::Object(map) => Document::Object(
159 map.iter()
160 .map(|(key, value)| (key.clone(), value_to_aws_document(value)))
161 .collect(),
162 ),
163 }
164}
165
166#[derive(Debug, Serialize, Deserialize)]
167pub enum Thinking {
168 Enabled {
169 budget_tokens: Option<u64>,
170 },
171 Adaptive {
172 effort: BedrockAdaptiveThinkingEffort,
173 },
174}
175
176#[derive(Debug)]
177pub struct Request {
178 pub model: String,
179 pub max_tokens: u64,
180 pub messages: Vec<BedrockMessage>,
181 pub tools: Option<BedrockToolConfig>,
182 pub thinking: Option<Thinking>,
183 pub system: Option<String>,
184 pub metadata: Option<Metadata>,
185 pub stop_sequences: Vec<String>,
186 pub temperature: Option<f32>,
187 pub top_k: Option<u32>,
188 pub top_p: Option<f32>,
189 pub allow_extended_context: bool,
190}
191
192#[derive(Debug, Serialize, Deserialize)]
193pub struct Metadata {
194 pub user_id: Option<String>,
195}
196
197#[derive(Error, Debug)]
198pub enum BedrockError {
199 #[error("client error: {0}")]
200 ClientError(anyhow::Error),
201 #[error("extension error: {0}")]
202 ExtensionError(anyhow::Error),
203 #[error(transparent)]
204 Other(#[from] anyhow::Error),
205}