bedrock.rs

  1mod models;
  2
  3use anyhow::{Context, Error, Result, anyhow};
  4use aws_sdk_bedrockruntime as bedrock;
  5pub use aws_sdk_bedrockruntime as bedrock_client;
  6pub use aws_sdk_bedrockruntime::types::{
  7    AnyToolChoice as BedrockAnyToolChoice, AutoToolChoice as BedrockAutoToolChoice,
  8    ContentBlock as BedrockInnerContent, Tool as BedrockTool, ToolChoice as BedrockToolChoice,
  9    ToolConfiguration as BedrockToolConfig, ToolInputSchema as BedrockToolInputSchema,
 10    ToolSpecification as BedrockToolSpec,
 11};
 12pub use aws_smithy_types::Blob as BedrockBlob;
 13use aws_smithy_types::{Document, Number as AwsNumber};
 14pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest;
 15pub use bedrock::types::{
 16    ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole,
 17    ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse,
 18    ImageBlock as BedrockImageBlock, Message as BedrockMessage,
 19    ReasoningContentBlock as BedrockThinkingBlock, ReasoningTextBlock as BedrockThinkingTextBlock,
 20    ResponseStream as BedrockResponseStream, ToolResultBlock as BedrockToolResultBlock,
 21    ToolResultContentBlock as BedrockToolResultContentBlock,
 22    ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock,
 23};
 24use futures::stream::{self, BoxStream};
 25use serde::{Deserialize, Serialize};
 26use serde_json::{Number, Value};
 27use std::collections::HashMap;
 28use thiserror::Error;
 29
 30pub use crate::models::*;
 31
 32pub async fn stream_completion(
 33    client: bedrock::Client,
 34    request: Request,
 35) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
 36    let mut response = bedrock::Client::converse_stream(&client)
 37        .model_id(request.model.clone())
 38        .set_messages(request.messages.into());
 39
 40    if let Some(Thinking::Enabled {
 41        budget_tokens: Some(budget_tokens),
 42    }) = request.thinking
 43    {
 44        let thinking_config = HashMap::from([
 45            ("type".to_string(), Document::String("enabled".to_string())),
 46            (
 47                "budget_tokens".to_string(),
 48                Document::Number(AwsNumber::PosInt(budget_tokens)),
 49            ),
 50        ]);
 51        response = response.additional_model_request_fields(Document::Object(HashMap::from([(
 52            "thinking".to_string(),
 53            Document::from(thinking_config),
 54        )])));
 55    }
 56
 57    if request
 58        .tools
 59        .as_ref()
 60        .map_or(false, |t| !t.tools.is_empty())
 61    {
 62        response = response.set_tool_config(request.tools);
 63    }
 64
 65    let output = response
 66        .send()
 67        .await
 68        .context("Failed to send API request to Bedrock");
 69
 70    let stream = Box::pin(stream::unfold(
 71        output?.stream,
 72        move |mut stream| async move {
 73            match stream.recv().await {
 74                Ok(Some(output)) => Some((Ok(output), stream)),
 75                Ok(None) => None,
 76                Err(err) => Some((
 77                    Err(BedrockError::ClientError(anyhow!(
 78                        "{:?}",
 79                        aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
 80                    ))),
 81                    stream,
 82                )),
 83            }
 84        },
 85    ));
 86
 87    Ok(stream)
 88}
 89
 90pub fn aws_document_to_value(document: &Document) -> Value {
 91    match document {
 92        Document::Null => Value::Null,
 93        Document::Bool(value) => Value::Bool(*value),
 94        Document::Number(value) => match *value {
 95            AwsNumber::PosInt(value) => Value::Number(Number::from(value)),
 96            AwsNumber::NegInt(value) => Value::Number(Number::from(value)),
 97            AwsNumber::Float(value) => Value::Number(Number::from_f64(value).unwrap()),
 98        },
 99        Document::String(value) => Value::String(value.clone()),
100        Document::Array(array) => Value::Array(array.iter().map(aws_document_to_value).collect()),
101        Document::Object(map) => Value::Object(
102            map.iter()
103                .map(|(key, value)| (key.clone(), aws_document_to_value(value)))
104                .collect(),
105        ),
106    }
107}
108
109pub fn value_to_aws_document(value: &Value) -> Document {
110    match value {
111        Value::Null => Document::Null,
112        Value::Bool(value) => Document::Bool(*value),
113        Value::Number(value) => {
114            if let Some(value) = value.as_u64() {
115                Document::Number(AwsNumber::PosInt(value))
116            } else if let Some(value) = value.as_i64() {
117                Document::Number(AwsNumber::NegInt(value))
118            } else if let Some(value) = value.as_f64() {
119                Document::Number(AwsNumber::Float(value))
120            } else {
121                Document::Null
122            }
123        }
124        Value::String(value) => Document::String(value.clone()),
125        Value::Array(array) => Document::Array(array.iter().map(value_to_aws_document).collect()),
126        Value::Object(map) => Document::Object(
127            map.iter()
128                .map(|(key, value)| (key.clone(), value_to_aws_document(value)))
129                .collect(),
130        ),
131    }
132}
133
134#[derive(Debug, Serialize, Deserialize)]
135pub enum Thinking {
136    Enabled { budget_tokens: Option<u64> },
137}
138
139#[derive(Debug)]
140pub struct Request {
141    pub model: String,
142    pub max_tokens: u64,
143    pub messages: Vec<BedrockMessage>,
144    pub tools: Option<BedrockToolConfig>,
145    pub thinking: Option<Thinking>,
146    pub system: Option<String>,
147    pub metadata: Option<Metadata>,
148    pub stop_sequences: Vec<String>,
149    pub temperature: Option<f32>,
150    pub top_k: Option<u32>,
151    pub top_p: Option<f32>,
152}
153
154#[derive(Debug, Serialize, Deserialize)]
155pub struct Metadata {
156    pub user_id: Option<String>,
157}
158
159#[derive(Error, Debug)]
160pub enum BedrockError {
161    #[error("client error: {0}")]
162    ClientError(anyhow::Error),
163    #[error("extension error: {0}")]
164    ExtensionError(anyhow::Error),
165    #[error(transparent)]
166    Other(#[from] anyhow::Error),
167}