1mod models;
2
3use std::collections::HashMap;
4use std::pin::Pin;
5
6use anyhow::{Context as _, Error, Result, anyhow};
7use aws_sdk_bedrockruntime as bedrock;
8pub use aws_sdk_bedrockruntime as bedrock_client;
9pub use aws_sdk_bedrockruntime::types::{
10 AnyToolChoice as BedrockAnyToolChoice, AutoToolChoice as BedrockAutoToolChoice,
11 ContentBlock as BedrockInnerContent, Tool as BedrockTool, ToolChoice as BedrockToolChoice,
12 ToolConfiguration as BedrockToolConfig, ToolInputSchema as BedrockToolInputSchema,
13 ToolSpecification as BedrockToolSpec,
14};
15pub use aws_smithy_types::Blob as BedrockBlob;
16use aws_smithy_types::{Document, Number as AwsNumber};
17pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest;
18pub use bedrock::types::{
19 ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole,
20 ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse,
21 ImageBlock as BedrockImageBlock, Message as BedrockMessage,
22 ReasoningContentBlock as BedrockThinkingBlock, ReasoningTextBlock as BedrockThinkingTextBlock,
23 ResponseStream as BedrockResponseStream, ToolResultBlock as BedrockToolResultBlock,
24 ToolResultContentBlock as BedrockToolResultContentBlock,
25 ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock,
26};
27use futures::stream::{self, BoxStream, Stream};
28use serde::{Deserialize, Serialize};
29use serde_json::{Number, Value};
30use thiserror::Error;
31
32pub use crate::models::*;
33
34pub async fn stream_completion(
35 client: bedrock::Client,
36 request: Request,
37 handle: tokio::runtime::Handle,
38) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
39 handle
40 .spawn(async move {
41 let mut response = bedrock::Client::converse_stream(&client)
42 .model_id(request.model.clone())
43 .set_messages(request.messages.into());
44
45 if let Some(Thinking::Enabled {
46 budget_tokens: Some(budget_tokens),
47 }) = request.thinking
48 {
49 response =
50 response.additional_model_request_fields(Document::Object(HashMap::from([(
51 "thinking".to_string(),
52 Document::from(HashMap::from([
53 ("type".to_string(), Document::String("enabled".to_string())),
54 (
55 "budget_tokens".to_string(),
56 Document::Number(AwsNumber::PosInt(budget_tokens)),
57 ),
58 ])),
59 )])));
60 }
61
62 if request.tools.is_some() && !request.tools.as_ref().unwrap().tools.is_empty() {
63 response = response.set_tool_config(request.tools);
64 }
65
66 let response = response.send().await;
67
68 match response {
69 Ok(output) => {
70 let stream: Pin<
71 Box<
72 dyn Stream<Item = Result<BedrockStreamingResponse, BedrockError>>
73 + Send,
74 >,
75 > = Box::pin(stream::unfold(output.stream, |mut stream| async move {
76 match stream.recv().await {
77 Ok(Some(output)) => Some(({ Ok(output) }, stream)),
78 Ok(None) => None,
79 Err(err) => {
80 Some((
81 // TODO: Figure out how we can capture Throttling Exceptions
82 Err(BedrockError::ClientError(anyhow!(
83 "{:?}",
84 aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
85 ))),
86 stream,
87 ))
88 }
89 }
90 }));
91 Ok(stream)
92 }
93 Err(err) => Err(anyhow!(
94 "{:?}",
95 aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
96 )),
97 }
98 })
99 .await
100 .context("spawning a task")?
101}
102
103pub fn aws_document_to_value(document: &Document) -> Value {
104 match document {
105 Document::Null => Value::Null,
106 Document::Bool(value) => Value::Bool(*value),
107 Document::Number(value) => match *value {
108 AwsNumber::PosInt(value) => Value::Number(Number::from(value)),
109 AwsNumber::NegInt(value) => Value::Number(Number::from(value)),
110 AwsNumber::Float(value) => Value::Number(Number::from_f64(value).unwrap()),
111 },
112 Document::String(value) => Value::String(value.clone()),
113 Document::Array(array) => Value::Array(array.iter().map(aws_document_to_value).collect()),
114 Document::Object(map) => Value::Object(
115 map.iter()
116 .map(|(key, value)| (key.clone(), aws_document_to_value(value)))
117 .collect(),
118 ),
119 }
120}
121
122pub fn value_to_aws_document(value: &Value) -> Document {
123 match value {
124 Value::Null => Document::Null,
125 Value::Bool(value) => Document::Bool(*value),
126 Value::Number(value) => {
127 if let Some(value) = value.as_u64() {
128 Document::Number(AwsNumber::PosInt(value))
129 } else if let Some(value) = value.as_i64() {
130 Document::Number(AwsNumber::NegInt(value))
131 } else if let Some(value) = value.as_f64() {
132 Document::Number(AwsNumber::Float(value))
133 } else {
134 Document::Null
135 }
136 }
137 Value::String(value) => Document::String(value.clone()),
138 Value::Array(array) => Document::Array(array.iter().map(value_to_aws_document).collect()),
139 Value::Object(map) => Document::Object(
140 map.iter()
141 .map(|(key, value)| (key.clone(), value_to_aws_document(value)))
142 .collect(),
143 ),
144 }
145}
146
147#[derive(Debug, Serialize, Deserialize)]
148pub enum Thinking {
149 Enabled { budget_tokens: Option<u64> },
150}
151
152#[derive(Debug)]
153pub struct Request {
154 pub model: String,
155 pub max_tokens: u32,
156 pub messages: Vec<BedrockMessage>,
157 pub tools: Option<BedrockToolConfig>,
158 pub thinking: Option<Thinking>,
159 pub system: Option<String>,
160 pub metadata: Option<Metadata>,
161 pub stop_sequences: Vec<String>,
162 pub temperature: Option<f32>,
163 pub top_k: Option<u32>,
164 pub top_p: Option<f32>,
165}
166
167#[derive(Debug, Serialize, Deserialize)]
168pub struct Metadata {
169 pub user_id: Option<String>,
170}
171
172#[derive(Error, Debug)]
173pub enum BedrockError {
174 #[error("client error: {0}")]
175 ClientError(anyhow::Error),
176 #[error("extension error: {0}")]
177 ExtensionError(anyhow::Error),
178 #[error(transparent)]
179 Other(#[from] anyhow::Error),
180}