bedrock.rs

  1mod models;
  2
  3use std::pin::Pin;
  4
  5use anyhow::{Context, Error, Result, anyhow};
  6use aws_sdk_bedrockruntime as bedrock;
  7pub use aws_sdk_bedrockruntime as bedrock_client;
  8pub use aws_sdk_bedrockruntime::types::{
  9    ContentBlock as BedrockInnerContent, SpecificToolChoice as BedrockSpecificTool,
 10    ToolChoice as BedrockToolChoice, ToolInputSchema as BedrockToolInputSchema,
 11    ToolSpecification as BedrockTool,
 12};
 13use aws_smithy_types::{Document, Number as AwsNumber};
 14pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest;
 15pub use bedrock::types::{
 16    ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole,
 17    ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse,
 18    Message as BedrockMessage, ResponseStream as BedrockResponseStream,
 19};
 20use futures::stream::{self, BoxStream, Stream};
 21use serde::{Deserialize, Serialize};
 22use serde_json::{Number, Value};
 23use thiserror::Error;
 24
 25pub use crate::models::*;
 26
 27pub async fn complete(
 28    client: &bedrock::Client,
 29    request: Request,
 30) -> Result<BedrockResponse, BedrockError> {
 31    let response = bedrock::Client::converse(client)
 32        .model_id(request.model.clone())
 33        .set_messages(request.messages.into())
 34        .send()
 35        .await
 36        .context("failed to send request to Bedrock");
 37
 38    match response {
 39        Ok(output) => output
 40            .output
 41            .ok_or_else(|| BedrockError::Other(anyhow!("no output"))),
 42        Err(err) => Err(BedrockError::Other(err)),
 43    }
 44}
 45
 46pub async fn stream_completion(
 47    client: bedrock::Client,
 48    request: Request,
 49    handle: tokio::runtime::Handle,
 50) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
 51    handle
 52        .spawn(async move {
 53            let response = bedrock::Client::converse_stream(&client)
 54                .model_id(request.model.clone())
 55                .set_messages(request.messages.into())
 56                .send()
 57                .await;
 58
 59            match response {
 60                Ok(output) => {
 61                    let stream: Pin<
 62                        Box<
 63                            dyn Stream<Item = Result<BedrockStreamingResponse, BedrockError>>
 64                                + Send,
 65                        >,
 66                    > = Box::pin(stream::unfold(output.stream, |mut stream| async move {
 67                        match stream.recv().await {
 68                            Ok(Some(output)) => Some((Ok(output), stream)),
 69                            Ok(None) => None,
 70                            Err(err) => {
 71                                Some((
 72                                    // TODO: Figure out how we can capture Throttling Exceptions
 73                                    Err(BedrockError::ClientError(anyhow!(
 74                                        "{:?}",
 75                                        aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
 76                                    ))),
 77                                    stream,
 78                                ))
 79                            }
 80                        }
 81                    }));
 82                    Ok(stream)
 83                }
 84                Err(err) => Err(anyhow!(
 85                    "{:?}",
 86                    aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
 87                )),
 88            }
 89        })
 90        .await
 91        .map_err(|err| anyhow!("failed to spawn task: {err:?}"))?
 92}
 93
 94pub fn aws_document_to_value(document: &Document) -> Value {
 95    match document {
 96        Document::Null => Value::Null,
 97        Document::Bool(value) => Value::Bool(*value),
 98        Document::Number(value) => match *value {
 99            AwsNumber::PosInt(value) => Value::Number(Number::from(value)),
100            AwsNumber::NegInt(value) => Value::Number(Number::from(value)),
101            AwsNumber::Float(value) => Value::Number(Number::from_f64(value).unwrap()),
102        },
103        Document::String(value) => Value::String(value.clone()),
104        Document::Array(array) => Value::Array(array.iter().map(aws_document_to_value).collect()),
105        Document::Object(map) => Value::Object(
106            map.iter()
107                .map(|(key, value)| (key.clone(), aws_document_to_value(value)))
108                .collect(),
109        ),
110    }
111}
112
113pub fn value_to_aws_document(value: &Value) -> Document {
114    match value {
115        Value::Null => Document::Null,
116        Value::Bool(value) => Document::Bool(*value),
117        Value::Number(value) => {
118            if let Some(value) = value.as_u64() {
119                Document::Number(AwsNumber::PosInt(value))
120            } else if let Some(value) = value.as_i64() {
121                Document::Number(AwsNumber::NegInt(value))
122            } else if let Some(value) = value.as_f64() {
123                Document::Number(AwsNumber::Float(value))
124            } else {
125                Document::Null
126            }
127        }
128        Value::String(value) => Document::String(value.clone()),
129        Value::Array(array) => Document::Array(array.iter().map(value_to_aws_document).collect()),
130        Value::Object(map) => Document::Object(
131            map.iter()
132                .map(|(key, value)| (key.clone(), value_to_aws_document(value)))
133                .collect(),
134        ),
135    }
136}
137
138#[derive(Debug)]
139pub struct Request {
140    pub model: String,
141    pub max_tokens: u32,
142    pub messages: Vec<BedrockMessage>,
143    pub tools: Vec<BedrockTool>,
144    pub tool_choice: Option<BedrockToolChoice>,
145    pub system: Option<String>,
146    pub metadata: Option<Metadata>,
147    pub stop_sequences: Vec<String>,
148    pub temperature: Option<f32>,
149    pub top_k: Option<u32>,
150    pub top_p: Option<f32>,
151}
152
153#[derive(Debug, Serialize, Deserialize)]
154pub struct Metadata {
155    pub user_id: Option<String>,
156}
157
158#[derive(Error, Debug)]
159pub enum BedrockError {
160    #[error("client error: {0}")]
161    ClientError(anyhow::Error),
162    #[error("extension error: {0}")]
163    ExtensionError(anyhow::Error),
164    #[error(transparent)]
165    Other(#[from] anyhow::Error),
166}