1mod models;
2
3use anyhow::{Context, Error, Result, anyhow};
4use aws_sdk_bedrockruntime as bedrock;
5pub use aws_sdk_bedrockruntime as bedrock_client;
6use aws_sdk_bedrockruntime::types::InferenceConfiguration;
7pub use aws_sdk_bedrockruntime::types::{
8 AnyToolChoice as BedrockAnyToolChoice, AutoToolChoice as BedrockAutoToolChoice,
9 ContentBlock as BedrockInnerContent, Tool as BedrockTool, ToolChoice as BedrockToolChoice,
10 ToolConfiguration as BedrockToolConfig, ToolInputSchema as BedrockToolInputSchema,
11 ToolSpecification as BedrockToolSpec,
12};
13pub use aws_smithy_types::Blob as BedrockBlob;
14use aws_smithy_types::{Document, Number as AwsNumber};
15pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest;
16pub use bedrock::types::{
17 ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole,
18 ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse,
19 ImageBlock as BedrockImageBlock, Message as BedrockMessage,
20 ReasoningContentBlock as BedrockThinkingBlock, ReasoningTextBlock as BedrockThinkingTextBlock,
21 ResponseStream as BedrockResponseStream, SystemContentBlock as BedrockSystemContentBlock,
22 ToolResultBlock as BedrockToolResultBlock,
23 ToolResultContentBlock as BedrockToolResultContentBlock,
24 ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock,
25};
26use futures::stream::{self, BoxStream};
27use serde::{Deserialize, Serialize};
28use serde_json::{Number, Value};
29use std::collections::HashMap;
30use thiserror::Error;
31
32pub use crate::models::*;
33
34pub async fn stream_completion(
35 client: bedrock::Client,
36 request: Request,
37) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
38 let mut response = bedrock::Client::converse_stream(&client)
39 .model_id(request.model.clone())
40 .set_messages(request.messages.into());
41
42 match request.thinking {
43 Some(Thinking::Enabled {
44 budget_tokens: Some(budget_tokens),
45 }) => {
46 let thinking_config = HashMap::from([
47 ("type".to_string(), Document::String("enabled".to_string())),
48 (
49 "budget_tokens".to_string(),
50 Document::Number(AwsNumber::PosInt(budget_tokens)),
51 ),
52 ]);
53 response =
54 response.additional_model_request_fields(Document::Object(HashMap::from([(
55 "thinking".to_string(),
56 Document::from(thinking_config),
57 )])));
58 }
59 Some(Thinking::Adaptive { effort: _ }) => {
60 let thinking_config =
61 HashMap::from([("type".to_string(), Document::String("adaptive".to_string()))]);
62 response =
63 response.additional_model_request_fields(Document::Object(HashMap::from([(
64 "thinking".to_string(),
65 Document::from(thinking_config),
66 )])));
67 }
68 _ => {}
69 }
70
71 if request.tools.as_ref().is_some_and(|t| !t.tools.is_empty()) {
72 response = response.set_tool_config(request.tools);
73 }
74
75 let inference_config = InferenceConfiguration::builder()
76 .max_tokens(request.max_tokens as i32)
77 .set_temperature(request.temperature)
78 .set_top_p(request.top_p)
79 .build();
80
81 response = response.inference_config(inference_config);
82
83 if let Some(system) = request.system {
84 if !system.is_empty() {
85 response = response.system(BedrockSystemContentBlock::Text(system));
86 }
87 }
88
89 let output = response
90 .send()
91 .await
92 .context("Failed to send API request to Bedrock");
93
94 let stream = Box::pin(stream::unfold(
95 output?.stream,
96 move |mut stream| async move {
97 match stream.recv().await {
98 Ok(Some(output)) => Some((Ok(output), stream)),
99 Ok(None) => None,
100 Err(err) => Some((
101 Err(BedrockError::ClientError(anyhow!(
102 "{}",
103 aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
104 ))),
105 stream,
106 )),
107 }
108 },
109 ));
110
111 Ok(stream)
112}
113
114pub fn aws_document_to_value(document: &Document) -> Value {
115 match document {
116 Document::Null => Value::Null,
117 Document::Bool(value) => Value::Bool(*value),
118 Document::Number(value) => match *value {
119 AwsNumber::PosInt(value) => Value::Number(Number::from(value)),
120 AwsNumber::NegInt(value) => Value::Number(Number::from(value)),
121 AwsNumber::Float(value) => Value::Number(Number::from_f64(value).unwrap()),
122 },
123 Document::String(value) => Value::String(value.clone()),
124 Document::Array(array) => Value::Array(array.iter().map(aws_document_to_value).collect()),
125 Document::Object(map) => Value::Object(
126 map.iter()
127 .map(|(key, value)| (key.clone(), aws_document_to_value(value)))
128 .collect(),
129 ),
130 }
131}
132
133pub fn value_to_aws_document(value: &Value) -> Document {
134 match value {
135 Value::Null => Document::Null,
136 Value::Bool(value) => Document::Bool(*value),
137 Value::Number(value) => {
138 if let Some(value) = value.as_u64() {
139 Document::Number(AwsNumber::PosInt(value))
140 } else if let Some(value) = value.as_i64() {
141 Document::Number(AwsNumber::NegInt(value))
142 } else if let Some(value) = value.as_f64() {
143 Document::Number(AwsNumber::Float(value))
144 } else {
145 Document::Null
146 }
147 }
148 Value::String(value) => Document::String(value.clone()),
149 Value::Array(array) => Document::Array(array.iter().map(value_to_aws_document).collect()),
150 Value::Object(map) => Document::Object(
151 map.iter()
152 .map(|(key, value)| (key.clone(), value_to_aws_document(value)))
153 .collect(),
154 ),
155 }
156}
157
158#[derive(Debug, Serialize, Deserialize)]
159pub enum Thinking {
160 Enabled {
161 budget_tokens: Option<u64>,
162 },
163 Adaptive {
164 effort: BedrockAdaptiveThinkingEffort,
165 },
166}
167
168#[derive(Debug)]
169pub struct Request {
170 pub model: String,
171 pub max_tokens: u64,
172 pub messages: Vec<BedrockMessage>,
173 pub tools: Option<BedrockToolConfig>,
174 pub thinking: Option<Thinking>,
175 pub system: Option<String>,
176 pub metadata: Option<Metadata>,
177 pub stop_sequences: Vec<String>,
178 pub temperature: Option<f32>,
179 pub top_k: Option<u32>,
180 pub top_p: Option<f32>,
181}
182
183#[derive(Debug, Serialize, Deserialize)]
184pub struct Metadata {
185 pub user_id: Option<String>,
186}
187
188#[derive(Error, Debug)]
189pub enum BedrockError {
190 #[error("client error: {0}")]
191 ClientError(anyhow::Error),
192 #[error("extension error: {0}")]
193 ExtensionError(anyhow::Error),
194 #[error(transparent)]
195 Other(#[from] anyhow::Error),
196}