bedrock.rs

  1mod models;
  2
  3use std::collections::HashMap;
  4use std::pin::Pin;
  5
  6use anyhow::{Error, Result, anyhow};
  7use aws_sdk_bedrockruntime as bedrock;
  8pub use aws_sdk_bedrockruntime as bedrock_client;
  9pub use aws_sdk_bedrockruntime::types::{
 10    AutoToolChoice as BedrockAutoToolChoice, ContentBlock as BedrockInnerContent,
 11    Tool as BedrockTool, ToolChoice as BedrockToolChoice, ToolConfiguration as BedrockToolConfig,
 12    ToolInputSchema as BedrockToolInputSchema, ToolSpecification as BedrockToolSpec,
 13};
 14use aws_smithy_types::{Document, Number as AwsNumber};
 15pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest;
 16pub use bedrock::types::{
 17    ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole,
 18    ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse,
 19    ImageBlock as BedrockImageBlock, Message as BedrockMessage,
 20    ResponseStream as BedrockResponseStream, ToolResultBlock as BedrockToolResultBlock,
 21    ToolResultContentBlock as BedrockToolResultContentBlock,
 22    ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock,
 23};
 24use futures::stream::{self, BoxStream, Stream};
 25use serde::{Deserialize, Serialize};
 26use serde_json::{Number, Value};
 27use thiserror::Error;
 28
 29pub use crate::models::*;
 30
 31pub async fn stream_completion(
 32    client: bedrock::Client,
 33    request: Request,
 34    handle: tokio::runtime::Handle,
 35) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
 36    handle
 37        .spawn(async move {
 38            let mut response = bedrock::Client::converse_stream(&client)
 39                .model_id(request.model.clone())
 40                .set_messages(request.messages.into());
 41
 42            if let Some(Thinking::Enabled {
 43                budget_tokens: Some(budget_tokens),
 44            }) = request.thinking
 45            {
 46                response =
 47                    response.additional_model_request_fields(Document::Object(HashMap::from([(
 48                        "thinking".to_string(),
 49                        Document::from(HashMap::from([
 50                            ("type".to_string(), Document::String("enabled".to_string())),
 51                            (
 52                                "budget_tokens".to_string(),
 53                                Document::Number(AwsNumber::PosInt(budget_tokens)),
 54                            ),
 55                        ])),
 56                    )])));
 57            }
 58
 59            if request.tools.is_some() && !request.tools.as_ref().unwrap().tools.is_empty() {
 60                response = response.set_tool_config(request.tools);
 61            }
 62
 63            let response = response.send().await;
 64
 65            match response {
 66                Ok(output) => {
 67                    let stream: Pin<
 68                        Box<
 69                            dyn Stream<Item = Result<BedrockStreamingResponse, BedrockError>>
 70                                + Send,
 71                        >,
 72                    > = Box::pin(stream::unfold(output.stream, |mut stream| async move {
 73                        match stream.recv().await {
 74                            Ok(Some(output)) => Some(({ Ok(output) }, stream)),
 75                            Ok(None) => None,
 76                            Err(err) => {
 77                                Some((
 78                                    // TODO: Figure out how we can capture Throttling Exceptions
 79                                    Err(BedrockError::ClientError(anyhow!(
 80                                        "{:?}",
 81                                        aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
 82                                    ))),
 83                                    stream,
 84                                ))
 85                            }
 86                        }
 87                    }));
 88                    Ok(stream)
 89                }
 90                Err(err) => Err(anyhow!(
 91                    "{:?}",
 92                    aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
 93                )),
 94            }
 95        })
 96        .await
 97        .map_err(|err| anyhow!("failed to spawn task: {err:?}"))?
 98}
 99
100pub fn aws_document_to_value(document: &Document) -> Value {
101    match document {
102        Document::Null => Value::Null,
103        Document::Bool(value) => Value::Bool(*value),
104        Document::Number(value) => match *value {
105            AwsNumber::PosInt(value) => Value::Number(Number::from(value)),
106            AwsNumber::NegInt(value) => Value::Number(Number::from(value)),
107            AwsNumber::Float(value) => Value::Number(Number::from_f64(value).unwrap()),
108        },
109        Document::String(value) => Value::String(value.clone()),
110        Document::Array(array) => Value::Array(array.iter().map(aws_document_to_value).collect()),
111        Document::Object(map) => Value::Object(
112            map.iter()
113                .map(|(key, value)| (key.clone(), aws_document_to_value(value)))
114                .collect(),
115        ),
116    }
117}
118
119pub fn value_to_aws_document(value: &Value) -> Document {
120    match value {
121        Value::Null => Document::Null,
122        Value::Bool(value) => Document::Bool(*value),
123        Value::Number(value) => {
124            if let Some(value) = value.as_u64() {
125                Document::Number(AwsNumber::PosInt(value))
126            } else if let Some(value) = value.as_i64() {
127                Document::Number(AwsNumber::NegInt(value))
128            } else if let Some(value) = value.as_f64() {
129                Document::Number(AwsNumber::Float(value))
130            } else {
131                Document::Null
132            }
133        }
134        Value::String(value) => Document::String(value.clone()),
135        Value::Array(array) => Document::Array(array.iter().map(value_to_aws_document).collect()),
136        Value::Object(map) => Document::Object(
137            map.iter()
138                .map(|(key, value)| (key.clone(), value_to_aws_document(value)))
139                .collect(),
140        ),
141    }
142}
143
144#[derive(Debug, Serialize, Deserialize)]
145pub enum Thinking {
146    Enabled { budget_tokens: Option<u64> },
147}
148
149#[derive(Debug)]
150pub struct Request {
151    pub model: String,
152    pub max_tokens: u32,
153    pub messages: Vec<BedrockMessage>,
154    pub tools: Option<BedrockToolConfig>,
155    pub thinking: Option<Thinking>,
156    pub system: Option<String>,
157    pub metadata: Option<Metadata>,
158    pub stop_sequences: Vec<String>,
159    pub temperature: Option<f32>,
160    pub top_k: Option<u32>,
161    pub top_p: Option<f32>,
162}
163
164#[derive(Debug, Serialize, Deserialize)]
165pub struct Metadata {
166    pub user_id: Option<String>,
167}
168
169#[derive(Error, Debug)]
170pub enum BedrockError {
171    #[error("client error: {0}")]
172    ClientError(anyhow::Error),
173    #[error("extension error: {0}")]
174    ExtensionError(anyhow::Error),
175    #[error(transparent)]
176    Other(#[from] anyhow::Error),
177}