bedrock.rs

  1mod models;
  2
  3use std::collections::HashMap;
  4use std::pin::Pin;
  5
  6use anyhow::{Context as _, Error, Result, anyhow};
  7use aws_sdk_bedrockruntime as bedrock;
  8pub use aws_sdk_bedrockruntime as bedrock_client;
  9pub use aws_sdk_bedrockruntime::types::{
 10    AnyToolChoice as BedrockAnyToolChoice, AutoToolChoice as BedrockAutoToolChoice,
 11    ContentBlock as BedrockInnerContent, Tool as BedrockTool, ToolChoice as BedrockToolChoice,
 12    ToolConfiguration as BedrockToolConfig, ToolInputSchema as BedrockToolInputSchema,
 13    ToolSpecification as BedrockToolSpec,
 14};
 15pub use aws_smithy_types::Blob as BedrockBlob;
 16use aws_smithy_types::{Document, Number as AwsNumber};
 17pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest;
 18pub use bedrock::types::{
 19    ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole,
 20    ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse,
 21    ImageBlock as BedrockImageBlock, Message as BedrockMessage,
 22    ReasoningContentBlock as BedrockThinkingBlock, ReasoningTextBlock as BedrockThinkingTextBlock,
 23    ResponseStream as BedrockResponseStream, ToolResultBlock as BedrockToolResultBlock,
 24    ToolResultContentBlock as BedrockToolResultContentBlock,
 25    ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock,
 26};
 27use futures::stream::{self, BoxStream, Stream};
 28use serde::{Deserialize, Serialize};
 29use serde_json::{Number, Value};
 30use thiserror::Error;
 31
 32pub use crate::models::*;
 33
 34pub async fn stream_completion(
 35    client: bedrock::Client,
 36    request: Request,
 37    handle: tokio::runtime::Handle,
 38) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
 39    handle
 40        .spawn(async move {
 41            let mut response = bedrock::Client::converse_stream(&client)
 42                .model_id(request.model.clone())
 43                .set_messages(request.messages.into());
 44
 45            if let Some(Thinking::Enabled {
 46                budget_tokens: Some(budget_tokens),
 47            }) = request.thinking
 48            {
 49                response =
 50                    response.additional_model_request_fields(Document::Object(HashMap::from([(
 51                        "thinking".to_string(),
 52                        Document::from(HashMap::from([
 53                            ("type".to_string(), Document::String("enabled".to_string())),
 54                            (
 55                                "budget_tokens".to_string(),
 56                                Document::Number(AwsNumber::PosInt(budget_tokens)),
 57                            ),
 58                        ])),
 59                    )])));
 60            }
 61
 62            if request.tools.is_some() && !request.tools.as_ref().unwrap().tools.is_empty() {
 63                response = response.set_tool_config(request.tools);
 64            }
 65
 66            let response = response.send().await;
 67
 68            match response {
 69                Ok(output) => {
 70                    let stream: Pin<
 71                        Box<
 72                            dyn Stream<Item = Result<BedrockStreamingResponse, BedrockError>>
 73                                + Send,
 74                        >,
 75                    > = Box::pin(stream::unfold(output.stream, |mut stream| async move {
 76                        match stream.recv().await {
 77                            Ok(Some(output)) => Some(({ Ok(output) }, stream)),
 78                            Ok(None) => None,
 79                            Err(err) => {
 80                                Some((
 81                                    // TODO: Figure out how we can capture Throttling Exceptions
 82                                    Err(BedrockError::ClientError(anyhow!(
 83                                        "{:?}",
 84                                        aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
 85                                    ))),
 86                                    stream,
 87                                ))
 88                            }
 89                        }
 90                    }));
 91                    Ok(stream)
 92                }
 93                Err(err) => Err(anyhow!(
 94                    "{:?}",
 95                    aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
 96                )),
 97            }
 98        })
 99        .await
100        .context("spawning a task")?
101}
102
103pub fn aws_document_to_value(document: &Document) -> Value {
104    match document {
105        Document::Null => Value::Null,
106        Document::Bool(value) => Value::Bool(*value),
107        Document::Number(value) => match *value {
108            AwsNumber::PosInt(value) => Value::Number(Number::from(value)),
109            AwsNumber::NegInt(value) => Value::Number(Number::from(value)),
110            AwsNumber::Float(value) => Value::Number(Number::from_f64(value).unwrap()),
111        },
112        Document::String(value) => Value::String(value.clone()),
113        Document::Array(array) => Value::Array(array.iter().map(aws_document_to_value).collect()),
114        Document::Object(map) => Value::Object(
115            map.iter()
116                .map(|(key, value)| (key.clone(), aws_document_to_value(value)))
117                .collect(),
118        ),
119    }
120}
121
122pub fn value_to_aws_document(value: &Value) -> Document {
123    match value {
124        Value::Null => Document::Null,
125        Value::Bool(value) => Document::Bool(*value),
126        Value::Number(value) => {
127            if let Some(value) = value.as_u64() {
128                Document::Number(AwsNumber::PosInt(value))
129            } else if let Some(value) = value.as_i64() {
130                Document::Number(AwsNumber::NegInt(value))
131            } else if let Some(value) = value.as_f64() {
132                Document::Number(AwsNumber::Float(value))
133            } else {
134                Document::Null
135            }
136        }
137        Value::String(value) => Document::String(value.clone()),
138        Value::Array(array) => Document::Array(array.iter().map(value_to_aws_document).collect()),
139        Value::Object(map) => Document::Object(
140            map.iter()
141                .map(|(key, value)| (key.clone(), value_to_aws_document(value)))
142                .collect(),
143        ),
144    }
145}
146
147#[derive(Debug, Serialize, Deserialize)]
148pub enum Thinking {
149    Enabled { budget_tokens: Option<u64> },
150}
151
152#[derive(Debug)]
153pub struct Request {
154    pub model: String,
155    pub max_tokens: u32,
156    pub messages: Vec<BedrockMessage>,
157    pub tools: Option<BedrockToolConfig>,
158    pub thinking: Option<Thinking>,
159    pub system: Option<String>,
160    pub metadata: Option<Metadata>,
161    pub stop_sequences: Vec<String>,
162    pub temperature: Option<f32>,
163    pub top_k: Option<u32>,
164    pub top_p: Option<f32>,
165}
166
167#[derive(Debug, Serialize, Deserialize)]
168pub struct Metadata {
169    pub user_id: Option<String>,
170}
171
172#[derive(Error, Debug)]
173pub enum BedrockError {
174    #[error("client error: {0}")]
175    ClientError(anyhow::Error),
176    #[error("extension error: {0}")]
177    ExtensionError(anyhow::Error),
178    #[error(transparent)]
179    Other(#[from] anyhow::Error),
180}