1mod models;
2
3use std::collections::HashMap;
4use std::pin::Pin;
5
6use anyhow::{Error, Result, anyhow};
7use aws_sdk_bedrockruntime as bedrock;
8pub use aws_sdk_bedrockruntime as bedrock_client;
9pub use aws_sdk_bedrockruntime::types::{
10 AutoToolChoice as BedrockAutoToolChoice, ContentBlock as BedrockInnerContent,
11 Tool as BedrockTool, ToolChoice as BedrockToolChoice, ToolConfiguration as BedrockToolConfig,
12 ToolInputSchema as BedrockToolInputSchema, ToolSpecification as BedrockToolSpec,
13};
14use aws_smithy_types::{Document, Number as AwsNumber};
15pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest;
16pub use bedrock::types::{
17 ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole,
18 ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse,
19 ImageBlock as BedrockImageBlock, Message as BedrockMessage,
20 ResponseStream as BedrockResponseStream, ToolResultBlock as BedrockToolResultBlock,
21 ToolResultContentBlock as BedrockToolResultContentBlock,
22 ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock,
23};
24use futures::stream::{self, BoxStream, Stream};
25use serde::{Deserialize, Serialize};
26use serde_json::{Number, Value};
27use thiserror::Error;
28
29pub use crate::models::*;
30
31pub async fn stream_completion(
32 client: bedrock::Client,
33 request: Request,
34 handle: tokio::runtime::Handle,
35) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
36 handle
37 .spawn(async move {
38 let mut response = bedrock::Client::converse_stream(&client)
39 .model_id(request.model.clone())
40 .set_messages(request.messages.into());
41
42 if let Some(Thinking::Enabled {
43 budget_tokens: Some(budget_tokens),
44 }) = request.thinking
45 {
46 response =
47 response.additional_model_request_fields(Document::Object(HashMap::from([(
48 "thinking".to_string(),
49 Document::from(HashMap::from([
50 ("type".to_string(), Document::String("enabled".to_string())),
51 (
52 "budget_tokens".to_string(),
53 Document::Number(AwsNumber::PosInt(budget_tokens)),
54 ),
55 ])),
56 )])));
57 }
58
59 if request.tools.is_some() && !request.tools.as_ref().unwrap().tools.is_empty() {
60 response = response.set_tool_config(request.tools);
61 }
62
63 let response = response.send().await;
64
65 match response {
66 Ok(output) => {
67 let stream: Pin<
68 Box<
69 dyn Stream<Item = Result<BedrockStreamingResponse, BedrockError>>
70 + Send,
71 >,
72 > = Box::pin(stream::unfold(output.stream, |mut stream| async move {
73 match stream.recv().await {
74 Ok(Some(output)) => Some(({ Ok(output) }, stream)),
75 Ok(None) => None,
76 Err(err) => {
77 Some((
78 // TODO: Figure out how we can capture Throttling Exceptions
79 Err(BedrockError::ClientError(anyhow!(
80 "{:?}",
81 aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
82 ))),
83 stream,
84 ))
85 }
86 }
87 }));
88 Ok(stream)
89 }
90 Err(err) => Err(anyhow!(
91 "{:?}",
92 aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
93 )),
94 }
95 })
96 .await
97 .map_err(|err| anyhow!("failed to spawn task: {err:?}"))?
98}
99
100pub fn aws_document_to_value(document: &Document) -> Value {
101 match document {
102 Document::Null => Value::Null,
103 Document::Bool(value) => Value::Bool(*value),
104 Document::Number(value) => match *value {
105 AwsNumber::PosInt(value) => Value::Number(Number::from(value)),
106 AwsNumber::NegInt(value) => Value::Number(Number::from(value)),
107 AwsNumber::Float(value) => Value::Number(Number::from_f64(value).unwrap()),
108 },
109 Document::String(value) => Value::String(value.clone()),
110 Document::Array(array) => Value::Array(array.iter().map(aws_document_to_value).collect()),
111 Document::Object(map) => Value::Object(
112 map.iter()
113 .map(|(key, value)| (key.clone(), aws_document_to_value(value)))
114 .collect(),
115 ),
116 }
117}
118
119pub fn value_to_aws_document(value: &Value) -> Document {
120 match value {
121 Value::Null => Document::Null,
122 Value::Bool(value) => Document::Bool(*value),
123 Value::Number(value) => {
124 if let Some(value) = value.as_u64() {
125 Document::Number(AwsNumber::PosInt(value))
126 } else if let Some(value) = value.as_i64() {
127 Document::Number(AwsNumber::NegInt(value))
128 } else if let Some(value) = value.as_f64() {
129 Document::Number(AwsNumber::Float(value))
130 } else {
131 Document::Null
132 }
133 }
134 Value::String(value) => Document::String(value.clone()),
135 Value::Array(array) => Document::Array(array.iter().map(value_to_aws_document).collect()),
136 Value::Object(map) => Document::Object(
137 map.iter()
138 .map(|(key, value)| (key.clone(), value_to_aws_document(value)))
139 .collect(),
140 ),
141 }
142}
143
144#[derive(Debug, Serialize, Deserialize)]
145pub enum Thinking {
146 Enabled { budget_tokens: Option<u64> },
147}
148
149#[derive(Debug)]
150pub struct Request {
151 pub model: String,
152 pub max_tokens: u32,
153 pub messages: Vec<BedrockMessage>,
154 pub tools: Option<BedrockToolConfig>,
155 pub thinking: Option<Thinking>,
156 pub system: Option<String>,
157 pub metadata: Option<Metadata>,
158 pub stop_sequences: Vec<String>,
159 pub temperature: Option<f32>,
160 pub top_k: Option<u32>,
161 pub top_p: Option<f32>,
162}
163
164#[derive(Debug, Serialize, Deserialize)]
165pub struct Metadata {
166 pub user_id: Option<String>,
167}
168
169#[derive(Error, Debug)]
170pub enum BedrockError {
171 #[error("client error: {0}")]
172 ClientError(anyhow::Error),
173 #[error("extension error: {0}")]
174 ExtensionError(anyhow::Error),
175 #[error(transparent)]
176 Other(#[from] anyhow::Error),
177}