1mod models;
2
3use anyhow::{Context, Error, Result, anyhow};
4use aws_sdk_bedrockruntime as bedrock;
5pub use aws_sdk_bedrockruntime as bedrock_client;
6use aws_sdk_bedrockruntime::types::InferenceConfiguration;
7pub use aws_sdk_bedrockruntime::types::{
8 AnyToolChoice as BedrockAnyToolChoice, AutoToolChoice as BedrockAutoToolChoice,
9 ContentBlock as BedrockInnerContent, Tool as BedrockTool, ToolChoice as BedrockToolChoice,
10 ToolConfiguration as BedrockToolConfig, ToolInputSchema as BedrockToolInputSchema,
11 ToolSpecification as BedrockToolSpec,
12};
13pub use aws_smithy_types::Blob as BedrockBlob;
14use aws_smithy_types::{Document, Number as AwsNumber};
15pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest;
16pub use bedrock::types::{
17 ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole,
18 ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse,
19 ImageBlock as BedrockImageBlock, Message as BedrockMessage,
20 ReasoningContentBlock as BedrockThinkingBlock, ReasoningTextBlock as BedrockThinkingTextBlock,
21 ResponseStream as BedrockResponseStream, SystemContentBlock as BedrockSystemContentBlock,
22 ToolResultBlock as BedrockToolResultBlock,
23 ToolResultContentBlock as BedrockToolResultContentBlock,
24 ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock,
25};
26use futures::stream::{self, BoxStream};
27use serde::{Deserialize, Serialize};
28use serde_json::{Number, Value};
29use std::collections::HashMap;
30use thiserror::Error;
31
32pub use crate::models::*;
33
34pub async fn stream_completion(
35 client: bedrock::Client,
36 request: Request,
37) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
38 let mut response = bedrock::Client::converse_stream(&client)
39 .model_id(request.model.clone())
40 .set_messages(request.messages.into());
41
42 if let Some(Thinking::Enabled {
43 budget_tokens: Some(budget_tokens),
44 }) = request.thinking
45 {
46 let thinking_config = HashMap::from([
47 ("type".to_string(), Document::String("enabled".to_string())),
48 (
49 "budget_tokens".to_string(),
50 Document::Number(AwsNumber::PosInt(budget_tokens)),
51 ),
52 ]);
53 response = response.additional_model_request_fields(Document::Object(HashMap::from([(
54 "thinking".to_string(),
55 Document::from(thinking_config),
56 )])));
57 }
58
59 if request.tools.as_ref().is_some_and(|t| !t.tools.is_empty()) {
60 response = response.set_tool_config(request.tools);
61 }
62
63 let inference_config = InferenceConfiguration::builder()
64 .max_tokens(request.max_tokens as i32)
65 .set_temperature(request.temperature)
66 .set_top_p(request.top_p)
67 .build();
68
69 response = response.inference_config(inference_config);
70
71 if let Some(system) = request.system {
72 if !system.is_empty() {
73 response = response.system(BedrockSystemContentBlock::Text(system));
74 }
75 }
76
77 let output = response
78 .send()
79 .await
80 .context("Failed to send API request to Bedrock");
81
82 let stream = Box::pin(stream::unfold(
83 output?.stream,
84 move |mut stream| async move {
85 match stream.recv().await {
86 Ok(Some(output)) => Some((Ok(output), stream)),
87 Ok(None) => None,
88 Err(err) => Some((
89 Err(BedrockError::ClientError(anyhow!(
90 "{:?}",
91 aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
92 ))),
93 stream,
94 )),
95 }
96 },
97 ));
98
99 Ok(stream)
100}
101
102pub fn aws_document_to_value(document: &Document) -> Value {
103 match document {
104 Document::Null => Value::Null,
105 Document::Bool(value) => Value::Bool(*value),
106 Document::Number(value) => match *value {
107 AwsNumber::PosInt(value) => Value::Number(Number::from(value)),
108 AwsNumber::NegInt(value) => Value::Number(Number::from(value)),
109 AwsNumber::Float(value) => Value::Number(Number::from_f64(value).unwrap()),
110 },
111 Document::String(value) => Value::String(value.clone()),
112 Document::Array(array) => Value::Array(array.iter().map(aws_document_to_value).collect()),
113 Document::Object(map) => Value::Object(
114 map.iter()
115 .map(|(key, value)| (key.clone(), aws_document_to_value(value)))
116 .collect(),
117 ),
118 }
119}
120
121pub fn value_to_aws_document(value: &Value) -> Document {
122 match value {
123 Value::Null => Document::Null,
124 Value::Bool(value) => Document::Bool(*value),
125 Value::Number(value) => {
126 if let Some(value) = value.as_u64() {
127 Document::Number(AwsNumber::PosInt(value))
128 } else if let Some(value) = value.as_i64() {
129 Document::Number(AwsNumber::NegInt(value))
130 } else if let Some(value) = value.as_f64() {
131 Document::Number(AwsNumber::Float(value))
132 } else {
133 Document::Null
134 }
135 }
136 Value::String(value) => Document::String(value.clone()),
137 Value::Array(array) => Document::Array(array.iter().map(value_to_aws_document).collect()),
138 Value::Object(map) => Document::Object(
139 map.iter()
140 .map(|(key, value)| (key.clone(), value_to_aws_document(value)))
141 .collect(),
142 ),
143 }
144}
145
146#[derive(Debug, Serialize, Deserialize)]
147pub enum Thinking {
148 Enabled { budget_tokens: Option<u64> },
149}
150
151#[derive(Debug)]
152pub struct Request {
153 pub model: String,
154 pub max_tokens: u64,
155 pub messages: Vec<BedrockMessage>,
156 pub tools: Option<BedrockToolConfig>,
157 pub thinking: Option<Thinking>,
158 pub system: Option<String>,
159 pub metadata: Option<Metadata>,
160 pub stop_sequences: Vec<String>,
161 pub temperature: Option<f32>,
162 pub top_k: Option<u32>,
163 pub top_p: Option<f32>,
164}
165
166#[derive(Debug, Serialize, Deserialize)]
167pub struct Metadata {
168 pub user_id: Option<String>,
169}
170
171#[derive(Error, Debug)]
172pub enum BedrockError {
173 #[error("client error: {0}")]
174 ClientError(anyhow::Error),
175 #[error("extension error: {0}")]
176 ExtensionError(anyhow::Error),
177 #[error(transparent)]
178 Other(#[from] anyhow::Error),
179}