1mod models;
2
3use anyhow::{Context, Error, Result, anyhow};
4use aws_sdk_bedrockruntime as bedrock;
5pub use aws_sdk_bedrockruntime as bedrock_client;
6pub use aws_sdk_bedrockruntime::types::{
7 AnyToolChoice as BedrockAnyToolChoice, AutoToolChoice as BedrockAutoToolChoice,
8 ContentBlock as BedrockInnerContent, Tool as BedrockTool, ToolChoice as BedrockToolChoice,
9 ToolConfiguration as BedrockToolConfig, ToolInputSchema as BedrockToolInputSchema,
10 ToolSpecification as BedrockToolSpec,
11};
12pub use aws_smithy_types::Blob as BedrockBlob;
13use aws_smithy_types::{Document, Number as AwsNumber};
14pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest;
15pub use bedrock::types::{
16 ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole,
17 ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse,
18 ImageBlock as BedrockImageBlock, Message as BedrockMessage,
19 ReasoningContentBlock as BedrockThinkingBlock, ReasoningTextBlock as BedrockThinkingTextBlock,
20 ResponseStream as BedrockResponseStream, ToolResultBlock as BedrockToolResultBlock,
21 ToolResultContentBlock as BedrockToolResultContentBlock,
22 ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock,
23};
24use futures::stream::{self, BoxStream};
25use serde::{Deserialize, Serialize};
26use serde_json::{Number, Value};
27use std::collections::HashMap;
28use thiserror::Error;
29
30pub use crate::models::*;
31
32pub async fn stream_completion(
33 client: bedrock::Client,
34 request: Request,
35) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
36 let mut response = bedrock::Client::converse_stream(&client)
37 .model_id(request.model.clone())
38 .set_messages(request.messages.into());
39
40 if let Some(Thinking::Enabled {
41 budget_tokens: Some(budget_tokens),
42 }) = request.thinking
43 {
44 let thinking_config = HashMap::from([
45 ("type".to_string(), Document::String("enabled".to_string())),
46 (
47 "budget_tokens".to_string(),
48 Document::Number(AwsNumber::PosInt(budget_tokens)),
49 ),
50 ]);
51 response = response.additional_model_request_fields(Document::Object(HashMap::from([(
52 "thinking".to_string(),
53 Document::from(thinking_config),
54 )])));
55 }
56
57 if request
58 .tools
59 .as_ref()
60 .map_or(false, |t| !t.tools.is_empty())
61 {
62 response = response.set_tool_config(request.tools);
63 }
64
65 let output = response
66 .send()
67 .await
68 .context("Failed to send API request to Bedrock");
69
70 let stream = Box::pin(stream::unfold(
71 output?.stream,
72 move |mut stream| async move {
73 match stream.recv().await {
74 Ok(Some(output)) => Some((Ok(output), stream)),
75 Ok(None) => None,
76 Err(err) => Some((
77 Err(BedrockError::ClientError(anyhow!(
78 "{:?}",
79 aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
80 ))),
81 stream,
82 )),
83 }
84 },
85 ));
86
87 Ok(stream)
88}
89
90pub fn aws_document_to_value(document: &Document) -> Value {
91 match document {
92 Document::Null => Value::Null,
93 Document::Bool(value) => Value::Bool(*value),
94 Document::Number(value) => match *value {
95 AwsNumber::PosInt(value) => Value::Number(Number::from(value)),
96 AwsNumber::NegInt(value) => Value::Number(Number::from(value)),
97 AwsNumber::Float(value) => Value::Number(Number::from_f64(value).unwrap()),
98 },
99 Document::String(value) => Value::String(value.clone()),
100 Document::Array(array) => Value::Array(array.iter().map(aws_document_to_value).collect()),
101 Document::Object(map) => Value::Object(
102 map.iter()
103 .map(|(key, value)| (key.clone(), aws_document_to_value(value)))
104 .collect(),
105 ),
106 }
107}
108
109pub fn value_to_aws_document(value: &Value) -> Document {
110 match value {
111 Value::Null => Document::Null,
112 Value::Bool(value) => Document::Bool(*value),
113 Value::Number(value) => {
114 if let Some(value) = value.as_u64() {
115 Document::Number(AwsNumber::PosInt(value))
116 } else if let Some(value) = value.as_i64() {
117 Document::Number(AwsNumber::NegInt(value))
118 } else if let Some(value) = value.as_f64() {
119 Document::Number(AwsNumber::Float(value))
120 } else {
121 Document::Null
122 }
123 }
124 Value::String(value) => Document::String(value.clone()),
125 Value::Array(array) => Document::Array(array.iter().map(value_to_aws_document).collect()),
126 Value::Object(map) => Document::Object(
127 map.iter()
128 .map(|(key, value)| (key.clone(), value_to_aws_document(value)))
129 .collect(),
130 ),
131 }
132}
133
134#[derive(Debug, Serialize, Deserialize)]
135pub enum Thinking {
136 Enabled { budget_tokens: Option<u64> },
137}
138
139#[derive(Debug)]
140pub struct Request {
141 pub model: String,
142 pub max_tokens: u64,
143 pub messages: Vec<BedrockMessage>,
144 pub tools: Option<BedrockToolConfig>,
145 pub thinking: Option<Thinking>,
146 pub system: Option<String>,
147 pub metadata: Option<Metadata>,
148 pub stop_sequences: Vec<String>,
149 pub temperature: Option<f32>,
150 pub top_k: Option<u32>,
151 pub top_p: Option<f32>,
152}
153
154#[derive(Debug, Serialize, Deserialize)]
155pub struct Metadata {
156 pub user_id: Option<String>,
157}
158
159#[derive(Error, Debug)]
160pub enum BedrockError {
161 #[error("client error: {0}")]
162 ClientError(anyhow::Error),
163 #[error("extension error: {0}")]
164 ExtensionError(anyhow::Error),
165 #[error(transparent)]
166 Other(#[from] anyhow::Error),
167}