bedrock.rs

  1mod models;
  2
  3use anyhow::{Context, Error, Result, anyhow};
  4use aws_sdk_bedrockruntime as bedrock;
  5pub use aws_sdk_bedrockruntime as bedrock_client;
  6use aws_sdk_bedrockruntime::types::InferenceConfiguration;
  7pub use aws_sdk_bedrockruntime::types::{
  8    AnyToolChoice as BedrockAnyToolChoice, AutoToolChoice as BedrockAutoToolChoice,
  9    ContentBlock as BedrockInnerContent, Tool as BedrockTool, ToolChoice as BedrockToolChoice,
 10    ToolConfiguration as BedrockToolConfig, ToolInputSchema as BedrockToolInputSchema,
 11    ToolSpecification as BedrockToolSpec,
 12};
 13pub use aws_smithy_types::Blob as BedrockBlob;
 14use aws_smithy_types::{Document, Number as AwsNumber};
 15pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest;
 16pub use bedrock::types::{
 17    ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole,
 18    ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse,
 19    ImageBlock as BedrockImageBlock, Message as BedrockMessage,
 20    ReasoningContentBlock as BedrockThinkingBlock, ReasoningTextBlock as BedrockThinkingTextBlock,
 21    ResponseStream as BedrockResponseStream, SystemContentBlock as BedrockSystemContentBlock,
 22    ToolResultBlock as BedrockToolResultBlock,
 23    ToolResultContentBlock as BedrockToolResultContentBlock,
 24    ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock,
 25};
 26use futures::stream::{self, BoxStream};
 27use serde::{Deserialize, Serialize};
 28use serde_json::{Number, Value};
 29use std::collections::HashMap;
 30use thiserror::Error;
 31
 32pub use crate::models::*;
 33
 34pub async fn stream_completion(
 35    client: bedrock::Client,
 36    request: Request,
 37) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
 38    let mut response = bedrock::Client::converse_stream(&client)
 39        .model_id(request.model.clone())
 40        .set_messages(request.messages.into());
 41
 42    match request.thinking {
 43        Some(Thinking::Enabled {
 44            budget_tokens: Some(budget_tokens),
 45        }) => {
 46            let thinking_config = HashMap::from([
 47                ("type".to_string(), Document::String("enabled".to_string())),
 48                (
 49                    "budget_tokens".to_string(),
 50                    Document::Number(AwsNumber::PosInt(budget_tokens)),
 51                ),
 52            ]);
 53            response =
 54                response.additional_model_request_fields(Document::Object(HashMap::from([(
 55                    "thinking".to_string(),
 56                    Document::from(thinking_config),
 57                )])));
 58        }
 59        Some(Thinking::Adaptive { effort: _ }) => {
 60            let thinking_config =
 61                HashMap::from([("type".to_string(), Document::String("adaptive".to_string()))]);
 62            response =
 63                response.additional_model_request_fields(Document::Object(HashMap::from([(
 64                    "thinking".to_string(),
 65                    Document::from(thinking_config),
 66                )])));
 67        }
 68        _ => {}
 69    }
 70
 71    if request.tools.as_ref().is_some_and(|t| !t.tools.is_empty()) {
 72        response = response.set_tool_config(request.tools);
 73    }
 74
 75    let inference_config = InferenceConfiguration::builder()
 76        .max_tokens(request.max_tokens as i32)
 77        .set_temperature(request.temperature)
 78        .set_top_p(request.top_p)
 79        .build();
 80
 81    response = response.inference_config(inference_config);
 82
 83    if let Some(system) = request.system {
 84        if !system.is_empty() {
 85            response = response.system(BedrockSystemContentBlock::Text(system));
 86        }
 87    }
 88
 89    let output = response
 90        .send()
 91        .await
 92        .context("Failed to send API request to Bedrock");
 93
 94    let stream = Box::pin(stream::unfold(
 95        output?.stream,
 96        move |mut stream| async move {
 97            match stream.recv().await {
 98                Ok(Some(output)) => Some((Ok(output), stream)),
 99                Ok(None) => None,
100                Err(err) => Some((
101                    Err(BedrockError::ClientError(anyhow!(
102                        "{}",
103                        aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
104                    ))),
105                    stream,
106                )),
107            }
108        },
109    ));
110
111    Ok(stream)
112}
113
114pub fn aws_document_to_value(document: &Document) -> Value {
115    match document {
116        Document::Null => Value::Null,
117        Document::Bool(value) => Value::Bool(*value),
118        Document::Number(value) => match *value {
119            AwsNumber::PosInt(value) => Value::Number(Number::from(value)),
120            AwsNumber::NegInt(value) => Value::Number(Number::from(value)),
121            AwsNumber::Float(value) => Value::Number(Number::from_f64(value).unwrap()),
122        },
123        Document::String(value) => Value::String(value.clone()),
124        Document::Array(array) => Value::Array(array.iter().map(aws_document_to_value).collect()),
125        Document::Object(map) => Value::Object(
126            map.iter()
127                .map(|(key, value)| (key.clone(), aws_document_to_value(value)))
128                .collect(),
129        ),
130    }
131}
132
133pub fn value_to_aws_document(value: &Value) -> Document {
134    match value {
135        Value::Null => Document::Null,
136        Value::Bool(value) => Document::Bool(*value),
137        Value::Number(value) => {
138            if let Some(value) = value.as_u64() {
139                Document::Number(AwsNumber::PosInt(value))
140            } else if let Some(value) = value.as_i64() {
141                Document::Number(AwsNumber::NegInt(value))
142            } else if let Some(value) = value.as_f64() {
143                Document::Number(AwsNumber::Float(value))
144            } else {
145                Document::Null
146            }
147        }
148        Value::String(value) => Document::String(value.clone()),
149        Value::Array(array) => Document::Array(array.iter().map(value_to_aws_document).collect()),
150        Value::Object(map) => Document::Object(
151            map.iter()
152                .map(|(key, value)| (key.clone(), value_to_aws_document(value)))
153                .collect(),
154        ),
155    }
156}
157
158#[derive(Debug, Serialize, Deserialize)]
159pub enum Thinking {
160    Enabled {
161        budget_tokens: Option<u64>,
162    },
163    Adaptive {
164        effort: BedrockAdaptiveThinkingEffort,
165    },
166}
167
168#[derive(Debug)]
169pub struct Request {
170    pub model: String,
171    pub max_tokens: u64,
172    pub messages: Vec<BedrockMessage>,
173    pub tools: Option<BedrockToolConfig>,
174    pub thinking: Option<Thinking>,
175    pub system: Option<String>,
176    pub metadata: Option<Metadata>,
177    pub stop_sequences: Vec<String>,
178    pub temperature: Option<f32>,
179    pub top_k: Option<u32>,
180    pub top_p: Option<f32>,
181}
182
183#[derive(Debug, Serialize, Deserialize)]
184pub struct Metadata {
185    pub user_id: Option<String>,
186}
187
188#[derive(Error, Debug)]
189pub enum BedrockError {
190    #[error("client error: {0}")]
191    ClientError(anyhow::Error),
192    #[error("extension error: {0}")]
193    ExtensionError(anyhow::Error),
194    #[error(transparent)]
195    Other(#[from] anyhow::Error),
196}