bedrock.rs

  1mod models;
  2
  3use anyhow::{Context, Error, Result, anyhow};
  4use aws_sdk_bedrockruntime as bedrock;
  5pub use aws_sdk_bedrockruntime as bedrock_client;
  6pub use aws_sdk_bedrockruntime::types::{
  7    AnyToolChoice as BedrockAnyToolChoice, AutoToolChoice as BedrockAutoToolChoice,
  8    ContentBlock as BedrockInnerContent, Tool as BedrockTool, ToolChoice as BedrockToolChoice,
  9    ToolConfiguration as BedrockToolConfig, ToolInputSchema as BedrockToolInputSchema,
 10    ToolSpecification as BedrockToolSpec,
 11};
 12pub use aws_smithy_types::Blob as BedrockBlob;
 13use aws_smithy_types::{Document, Number as AwsNumber};
 14pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest;
 15pub use bedrock::types::{
 16    ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole,
 17    ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse,
 18    ImageBlock as BedrockImageBlock, Message as BedrockMessage,
 19    ReasoningContentBlock as BedrockThinkingBlock, ReasoningTextBlock as BedrockThinkingTextBlock,
 20    ResponseStream as BedrockResponseStream, ToolResultBlock as BedrockToolResultBlock,
 21    ToolResultContentBlock as BedrockToolResultContentBlock,
 22    ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock,
 23};
 24use futures::stream::{self, BoxStream};
 25use serde::{Deserialize, Serialize};
 26use serde_json::{Number, Value};
 27use std::collections::HashMap;
 28use thiserror::Error;
 29
 30pub use crate::models::*;
 31
 32pub async fn stream_completion(
 33    client: bedrock::Client,
 34    request: Request,
 35) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
 36    let mut response = bedrock::Client::converse_stream(&client)
 37        .model_id(request.model.clone())
 38        .set_messages(request.messages.into());
 39
 40    if let Some(Thinking::Enabled {
 41        budget_tokens: Some(budget_tokens),
 42    }) = request.thinking
 43    {
 44        let thinking_config = HashMap::from([
 45            ("type".to_string(), Document::String("enabled".to_string())),
 46            (
 47                "budget_tokens".to_string(),
 48                Document::Number(AwsNumber::PosInt(budget_tokens)),
 49            ),
 50        ]);
 51        response = response.additional_model_request_fields(Document::Object(HashMap::from([(
 52            "thinking".to_string(),
 53            Document::from(thinking_config),
 54        )])));
 55    }
 56
 57    if request.tools.as_ref().is_some_and(|t| !t.tools.is_empty()) {
 58        response = response.set_tool_config(request.tools);
 59    }
 60
 61    let output = response
 62        .send()
 63        .await
 64        .context("Failed to send API request to Bedrock");
 65
 66    let stream = Box::pin(stream::unfold(
 67        output?.stream,
 68        move |mut stream| async move {
 69            match stream.recv().await {
 70                Ok(Some(output)) => Some((Ok(output), stream)),
 71                Ok(None) => None,
 72                Err(err) => Some((
 73                    Err(BedrockError::ClientError(anyhow!(
 74                        "{:?}",
 75                        aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
 76                    ))),
 77                    stream,
 78                )),
 79            }
 80        },
 81    ));
 82
 83    Ok(stream)
 84}
 85
 86pub fn aws_document_to_value(document: &Document) -> Value {
 87    match document {
 88        Document::Null => Value::Null,
 89        Document::Bool(value) => Value::Bool(*value),
 90        Document::Number(value) => match *value {
 91            AwsNumber::PosInt(value) => Value::Number(Number::from(value)),
 92            AwsNumber::NegInt(value) => Value::Number(Number::from(value)),
 93            AwsNumber::Float(value) => Value::Number(Number::from_f64(value).unwrap()),
 94        },
 95        Document::String(value) => Value::String(value.clone()),
 96        Document::Array(array) => Value::Array(array.iter().map(aws_document_to_value).collect()),
 97        Document::Object(map) => Value::Object(
 98            map.iter()
 99                .map(|(key, value)| (key.clone(), aws_document_to_value(value)))
100                .collect(),
101        ),
102    }
103}
104
105pub fn value_to_aws_document(value: &Value) -> Document {
106    match value {
107        Value::Null => Document::Null,
108        Value::Bool(value) => Document::Bool(*value),
109        Value::Number(value) => {
110            if let Some(value) = value.as_u64() {
111                Document::Number(AwsNumber::PosInt(value))
112            } else if let Some(value) = value.as_i64() {
113                Document::Number(AwsNumber::NegInt(value))
114            } else if let Some(value) = value.as_f64() {
115                Document::Number(AwsNumber::Float(value))
116            } else {
117                Document::Null
118            }
119        }
120        Value::String(value) => Document::String(value.clone()),
121        Value::Array(array) => Document::Array(array.iter().map(value_to_aws_document).collect()),
122        Value::Object(map) => Document::Object(
123            map.iter()
124                .map(|(key, value)| (key.clone(), value_to_aws_document(value)))
125                .collect(),
126        ),
127    }
128}
129
130#[derive(Debug, Serialize, Deserialize)]
131pub enum Thinking {
132    Enabled { budget_tokens: Option<u64> },
133}
134
135#[derive(Debug)]
136pub struct Request {
137    pub model: String,
138    pub max_tokens: u64,
139    pub messages: Vec<BedrockMessage>,
140    pub tools: Option<BedrockToolConfig>,
141    pub thinking: Option<Thinking>,
142    pub system: Option<String>,
143    pub metadata: Option<Metadata>,
144    pub stop_sequences: Vec<String>,
145    pub temperature: Option<f32>,
146    pub top_k: Option<u32>,
147    pub top_p: Option<f32>,
148}
149
150#[derive(Debug, Serialize, Deserialize)]
151pub struct Metadata {
152    pub user_id: Option<String>,
153}
154
155#[derive(Error, Debug)]
156pub enum BedrockError {
157    #[error("client error: {0}")]
158    ClientError(anyhow::Error),
159    #[error("extension error: {0}")]
160    ExtensionError(anyhow::Error),
161    #[error(transparent)]
162    Other(#[from] anyhow::Error),
163}