1mod models;
2
3use std::pin::Pin;
4
5use anyhow::{Context, Error, Result, anyhow};
6use aws_sdk_bedrockruntime as bedrock;
7pub use aws_sdk_bedrockruntime as bedrock_client;
8pub use aws_sdk_bedrockruntime::types::{
9 ContentBlock as BedrockInnerContent, SpecificToolChoice as BedrockSpecificTool,
10 ToolChoice as BedrockToolChoice, ToolInputSchema as BedrockToolInputSchema,
11 ToolSpecification as BedrockTool,
12};
13use aws_smithy_types::{Document, Number as AwsNumber};
14pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest;
15pub use bedrock::types::{
16 ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole,
17 ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse,
18 Message as BedrockMessage, ResponseStream as BedrockResponseStream,
19};
20use futures::stream::{self, BoxStream, Stream};
21use serde::{Deserialize, Serialize};
22use serde_json::{Number, Value};
23use thiserror::Error;
24
25pub use crate::models::*;
26
27pub async fn complete(
28 client: &bedrock::Client,
29 request: Request,
30) -> Result<BedrockResponse, BedrockError> {
31 let response = bedrock::Client::converse(client)
32 .model_id(request.model.clone())
33 .set_messages(request.messages.into())
34 .send()
35 .await
36 .context("failed to send request to Bedrock");
37
38 match response {
39 Ok(output) => output
40 .output
41 .ok_or_else(|| BedrockError::Other(anyhow!("no output"))),
42 Err(err) => Err(BedrockError::Other(err)),
43 }
44}
45
46pub async fn stream_completion(
47 client: bedrock::Client,
48 request: Request,
49 handle: tokio::runtime::Handle,
50) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
51 handle
52 .spawn(async move {
53 let response = bedrock::Client::converse_stream(&client)
54 .model_id(request.model.clone())
55 .set_messages(request.messages.into())
56 .send()
57 .await;
58
59 match response {
60 Ok(output) => {
61 let stream: Pin<
62 Box<
63 dyn Stream<Item = Result<BedrockStreamingResponse, BedrockError>>
64 + Send,
65 >,
66 > = Box::pin(stream::unfold(output.stream, |mut stream| async move {
67 match stream.recv().await {
68 Ok(Some(output)) => Some((Ok(output), stream)),
69 Ok(None) => None,
70 Err(err) => {
71 Some((
72 // TODO: Figure out how we can capture Throttling Exceptions
73 Err(BedrockError::ClientError(anyhow!(
74 "{:?}",
75 aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
76 ))),
77 stream,
78 ))
79 }
80 }
81 }));
82 Ok(stream)
83 }
84 Err(err) => Err(anyhow!(
85 "{:?}",
86 aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
87 )),
88 }
89 })
90 .await
91 .map_err(|err| anyhow!("failed to spawn task: {err:?}"))?
92}
93
94pub fn aws_document_to_value(document: &Document) -> Value {
95 match document {
96 Document::Null => Value::Null,
97 Document::Bool(value) => Value::Bool(*value),
98 Document::Number(value) => match *value {
99 AwsNumber::PosInt(value) => Value::Number(Number::from(value)),
100 AwsNumber::NegInt(value) => Value::Number(Number::from(value)),
101 AwsNumber::Float(value) => Value::Number(Number::from_f64(value).unwrap()),
102 },
103 Document::String(value) => Value::String(value.clone()),
104 Document::Array(array) => Value::Array(array.iter().map(aws_document_to_value).collect()),
105 Document::Object(map) => Value::Object(
106 map.iter()
107 .map(|(key, value)| (key.clone(), aws_document_to_value(value)))
108 .collect(),
109 ),
110 }
111}
112
113pub fn value_to_aws_document(value: &Value) -> Document {
114 match value {
115 Value::Null => Document::Null,
116 Value::Bool(value) => Document::Bool(*value),
117 Value::Number(value) => {
118 if let Some(value) = value.as_u64() {
119 Document::Number(AwsNumber::PosInt(value))
120 } else if let Some(value) = value.as_i64() {
121 Document::Number(AwsNumber::NegInt(value))
122 } else if let Some(value) = value.as_f64() {
123 Document::Number(AwsNumber::Float(value))
124 } else {
125 Document::Null
126 }
127 }
128 Value::String(value) => Document::String(value.clone()),
129 Value::Array(array) => Document::Array(array.iter().map(value_to_aws_document).collect()),
130 Value::Object(map) => Document::Object(
131 map.iter()
132 .map(|(key, value)| (key.clone(), value_to_aws_document(value)))
133 .collect(),
134 ),
135 }
136}
137
138#[derive(Debug)]
139pub struct Request {
140 pub model: String,
141 pub max_tokens: u32,
142 pub messages: Vec<BedrockMessage>,
143 pub tools: Vec<BedrockTool>,
144 pub tool_choice: Option<BedrockToolChoice>,
145 pub system: Option<String>,
146 pub metadata: Option<Metadata>,
147 pub stop_sequences: Vec<String>,
148 pub temperature: Option<f32>,
149 pub top_k: Option<u32>,
150 pub top_p: Option<f32>,
151}
152
153#[derive(Debug, Serialize, Deserialize)]
154pub struct Metadata {
155 pub user_id: Option<String>,
156}
157
158#[derive(Error, Debug)]
159pub enum BedrockError {
160 #[error("client error: {0}")]
161 ClientError(anyhow::Error),
162 #[error("extension error: {0}")]
163 ExtensionError(anyhow::Error),
164 #[error(transparent)]
165 Other(#[from] anyhow::Error),
166}