1mod models;
2
3use anyhow::{Context, Error, Result, anyhow};
4use aws_sdk_bedrockruntime as bedrock;
5pub use aws_sdk_bedrockruntime as bedrock_client;
6pub use aws_sdk_bedrockruntime::types::{
7 AnyToolChoice as BedrockAnyToolChoice, AutoToolChoice as BedrockAutoToolChoice,
8 ContentBlock as BedrockInnerContent, Tool as BedrockTool, ToolChoice as BedrockToolChoice,
9 ToolConfiguration as BedrockToolConfig, ToolInputSchema as BedrockToolInputSchema,
10 ToolSpecification as BedrockToolSpec,
11};
12pub use aws_smithy_types::Blob as BedrockBlob;
13use aws_smithy_types::{Document, Number as AwsNumber};
14pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest;
15pub use bedrock::types::{
16 ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole,
17 ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse,
18 ImageBlock as BedrockImageBlock, Message as BedrockMessage,
19 ReasoningContentBlock as BedrockThinkingBlock, ReasoningTextBlock as BedrockThinkingTextBlock,
20 ResponseStream as BedrockResponseStream, ToolResultBlock as BedrockToolResultBlock,
21 ToolResultContentBlock as BedrockToolResultContentBlock,
22 ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock,
23};
24use futures::stream::{self, BoxStream};
25use serde::{Deserialize, Serialize};
26use serde_json::{Number, Value};
27use std::collections::HashMap;
28use thiserror::Error;
29
30pub use crate::models::*;
31
32pub async fn stream_completion(
33 client: bedrock::Client,
34 request: Request,
35) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
36 let mut response = bedrock::Client::converse_stream(&client)
37 .model_id(request.model.clone())
38 .set_messages(request.messages.into());
39
40 if let Some(Thinking::Enabled {
41 budget_tokens: Some(budget_tokens),
42 }) = request.thinking
43 {
44 let thinking_config = HashMap::from([
45 ("type".to_string(), Document::String("enabled".to_string())),
46 (
47 "budget_tokens".to_string(),
48 Document::Number(AwsNumber::PosInt(budget_tokens)),
49 ),
50 ]);
51 response = response.additional_model_request_fields(Document::Object(HashMap::from([(
52 "thinking".to_string(),
53 Document::from(thinking_config),
54 )])));
55 }
56
57 if request.tools.as_ref().is_some_and(|t| !t.tools.is_empty()) {
58 response = response.set_tool_config(request.tools);
59 }
60
61 let output = response
62 .send()
63 .await
64 .context("Failed to send API request to Bedrock");
65
66 let stream = Box::pin(stream::unfold(
67 output?.stream,
68 move |mut stream| async move {
69 match stream.recv().await {
70 Ok(Some(output)) => Some((Ok(output), stream)),
71 Ok(None) => None,
72 Err(err) => Some((
73 Err(BedrockError::ClientError(anyhow!(
74 "{:?}",
75 aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
76 ))),
77 stream,
78 )),
79 }
80 },
81 ));
82
83 Ok(stream)
84}
85
86pub fn aws_document_to_value(document: &Document) -> Value {
87 match document {
88 Document::Null => Value::Null,
89 Document::Bool(value) => Value::Bool(*value),
90 Document::Number(value) => match *value {
91 AwsNumber::PosInt(value) => Value::Number(Number::from(value)),
92 AwsNumber::NegInt(value) => Value::Number(Number::from(value)),
93 AwsNumber::Float(value) => Value::Number(Number::from_f64(value).unwrap()),
94 },
95 Document::String(value) => Value::String(value.clone()),
96 Document::Array(array) => Value::Array(array.iter().map(aws_document_to_value).collect()),
97 Document::Object(map) => Value::Object(
98 map.iter()
99 .map(|(key, value)| (key.clone(), aws_document_to_value(value)))
100 .collect(),
101 ),
102 }
103}
104
105pub fn value_to_aws_document(value: &Value) -> Document {
106 match value {
107 Value::Null => Document::Null,
108 Value::Bool(value) => Document::Bool(*value),
109 Value::Number(value) => {
110 if let Some(value) = value.as_u64() {
111 Document::Number(AwsNumber::PosInt(value))
112 } else if let Some(value) = value.as_i64() {
113 Document::Number(AwsNumber::NegInt(value))
114 } else if let Some(value) = value.as_f64() {
115 Document::Number(AwsNumber::Float(value))
116 } else {
117 Document::Null
118 }
119 }
120 Value::String(value) => Document::String(value.clone()),
121 Value::Array(array) => Document::Array(array.iter().map(value_to_aws_document).collect()),
122 Value::Object(map) => Document::Object(
123 map.iter()
124 .map(|(key, value)| (key.clone(), value_to_aws_document(value)))
125 .collect(),
126 ),
127 }
128}
129
130#[derive(Debug, Serialize, Deserialize)]
131pub enum Thinking {
132 Enabled { budget_tokens: Option<u64> },
133}
134
135#[derive(Debug)]
136pub struct Request {
137 pub model: String,
138 pub max_tokens: u64,
139 pub messages: Vec<BedrockMessage>,
140 pub tools: Option<BedrockToolConfig>,
141 pub thinking: Option<Thinking>,
142 pub system: Option<String>,
143 pub metadata: Option<Metadata>,
144 pub stop_sequences: Vec<String>,
145 pub temperature: Option<f32>,
146 pub top_k: Option<u32>,
147 pub top_p: Option<f32>,
148}
149
150#[derive(Debug, Serialize, Deserialize)]
151pub struct Metadata {
152 pub user_id: Option<String>,
153}
154
155#[derive(Error, Debug)]
156pub enum BedrockError {
157 #[error("client error: {0}")]
158 ClientError(anyhow::Error),
159 #[error("extension error: {0}")]
160 ExtensionError(anyhow::Error),
161 #[error(transparent)]
162 Other(#[from] anyhow::Error),
163}