bedrock.rs

  1mod models;
  2
  3use std::collections::HashMap;
  4use std::pin::Pin;
  5
  6use anyhow::{Error, Result, anyhow};
  7use aws_sdk_bedrockruntime as bedrock;
  8pub use aws_sdk_bedrockruntime as bedrock_client;
  9pub use aws_sdk_bedrockruntime::types::{
 10    AutoToolChoice as BedrockAutoToolChoice, ContentBlock as BedrockInnerContent,
 11    Tool as BedrockTool, ToolChoice as BedrockToolChoice, ToolConfiguration as BedrockToolConfig,
 12    ToolInputSchema as BedrockToolInputSchema, ToolSpecification as BedrockToolSpec,
 13};
 14pub use aws_smithy_types::Blob as BedrockBlob;
 15use aws_smithy_types::{Document, Number as AwsNumber};
 16pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest;
 17pub use bedrock::types::{
 18    ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole,
 19    ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse,
 20    ImageBlock as BedrockImageBlock, Message as BedrockMessage,
 21    ReasoningContentBlock as BedrockThinkingBlock, ReasoningTextBlock as BedrockThinkingTextBlock,
 22    ResponseStream as BedrockResponseStream, ToolResultBlock as BedrockToolResultBlock,
 23    ToolResultContentBlock as BedrockToolResultContentBlock,
 24    ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock,
 25};
 26use futures::stream::{self, BoxStream, Stream};
 27use serde::{Deserialize, Serialize};
 28use serde_json::{Number, Value};
 29use thiserror::Error;
 30
 31pub use crate::models::*;
 32
 33pub async fn stream_completion(
 34    client: bedrock::Client,
 35    request: Request,
 36    handle: tokio::runtime::Handle,
 37) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
 38    handle
 39        .spawn(async move {
 40            let mut response = bedrock::Client::converse_stream(&client)
 41                .model_id(request.model.clone())
 42                .set_messages(request.messages.into());
 43
 44            if let Some(Thinking::Enabled {
 45                budget_tokens: Some(budget_tokens),
 46            }) = request.thinking
 47            {
 48                response =
 49                    response.additional_model_request_fields(Document::Object(HashMap::from([(
 50                        "thinking".to_string(),
 51                        Document::from(HashMap::from([
 52                            ("type".to_string(), Document::String("enabled".to_string())),
 53                            (
 54                                "budget_tokens".to_string(),
 55                                Document::Number(AwsNumber::PosInt(budget_tokens)),
 56                            ),
 57                        ])),
 58                    )])));
 59            }
 60
 61            if request.tools.is_some() && !request.tools.as_ref().unwrap().tools.is_empty() {
 62                response = response.set_tool_config(request.tools);
 63            }
 64
 65            let response = response.send().await;
 66
 67            match response {
 68                Ok(output) => {
 69                    let stream: Pin<
 70                        Box<
 71                            dyn Stream<Item = Result<BedrockStreamingResponse, BedrockError>>
 72                                + Send,
 73                        >,
 74                    > = Box::pin(stream::unfold(output.stream, |mut stream| async move {
 75                        match stream.recv().await {
 76                            Ok(Some(output)) => Some(({ Ok(output) }, stream)),
 77                            Ok(None) => None,
 78                            Err(err) => {
 79                                Some((
 80                                    // TODO: Figure out how we can capture Throttling Exceptions
 81                                    Err(BedrockError::ClientError(anyhow!(
 82                                        "{:?}",
 83                                        aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
 84                                    ))),
 85                                    stream,
 86                                ))
 87                            }
 88                        }
 89                    }));
 90                    Ok(stream)
 91                }
 92                Err(err) => Err(anyhow!(
 93                    "{:?}",
 94                    aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
 95                )),
 96            }
 97        })
 98        .await
 99        .map_err(|err| anyhow!("failed to spawn task: {err:?}"))?
100}
101
102pub fn aws_document_to_value(document: &Document) -> Value {
103    match document {
104        Document::Null => Value::Null,
105        Document::Bool(value) => Value::Bool(*value),
106        Document::Number(value) => match *value {
107            AwsNumber::PosInt(value) => Value::Number(Number::from(value)),
108            AwsNumber::NegInt(value) => Value::Number(Number::from(value)),
109            AwsNumber::Float(value) => Value::Number(Number::from_f64(value).unwrap()),
110        },
111        Document::String(value) => Value::String(value.clone()),
112        Document::Array(array) => Value::Array(array.iter().map(aws_document_to_value).collect()),
113        Document::Object(map) => Value::Object(
114            map.iter()
115                .map(|(key, value)| (key.clone(), aws_document_to_value(value)))
116                .collect(),
117        ),
118    }
119}
120
121pub fn value_to_aws_document(value: &Value) -> Document {
122    match value {
123        Value::Null => Document::Null,
124        Value::Bool(value) => Document::Bool(*value),
125        Value::Number(value) => {
126            if let Some(value) = value.as_u64() {
127                Document::Number(AwsNumber::PosInt(value))
128            } else if let Some(value) = value.as_i64() {
129                Document::Number(AwsNumber::NegInt(value))
130            } else if let Some(value) = value.as_f64() {
131                Document::Number(AwsNumber::Float(value))
132            } else {
133                Document::Null
134            }
135        }
136        Value::String(value) => Document::String(value.clone()),
137        Value::Array(array) => Document::Array(array.iter().map(value_to_aws_document).collect()),
138        Value::Object(map) => Document::Object(
139            map.iter()
140                .map(|(key, value)| (key.clone(), value_to_aws_document(value)))
141                .collect(),
142        ),
143    }
144}
145
146#[derive(Debug, Serialize, Deserialize)]
147pub enum Thinking {
148    Enabled { budget_tokens: Option<u64> },
149}
150
151#[derive(Debug)]
152pub struct Request {
153    pub model: String,
154    pub max_tokens: u32,
155    pub messages: Vec<BedrockMessage>,
156    pub tools: Option<BedrockToolConfig>,
157    pub thinking: Option<Thinking>,
158    pub system: Option<String>,
159    pub metadata: Option<Metadata>,
160    pub stop_sequences: Vec<String>,
161    pub temperature: Option<f32>,
162    pub top_k: Option<u32>,
163    pub top_p: Option<f32>,
164}
165
166#[derive(Debug, Serialize, Deserialize)]
167pub struct Metadata {
168    pub user_id: Option<String>,
169}
170
171#[derive(Error, Debug)]
172pub enum BedrockError {
173    #[error("client error: {0}")]
174    ClientError(anyhow::Error),
175    #[error("extension error: {0}")]
176    ExtensionError(anyhow::Error),
177    #[error(transparent)]
178    Other(#[from] anyhow::Error),
179}