Add `aws_http_client` and `bedrock` crates (#25490)

Marshall Bowers , Shardul Vaidya , and Anthony Eid created

This PR adds new `aws_http_client` and `bedrock` crates for supporting
AWS Bedrock.

Pulling out of https://github.com/zed-industries/zed/pull/21092 to make
it easier to land.

Release Notes:

- N/A

---------

Co-authored-by: Shardul Vaidya <cam.v737@gmail.com>
Co-authored-by: Anthony Eid <hello@anthonyeid.me>

Change summary

Cargo.lock                                    |  51 +++++
Cargo.toml                                    |   9 
crates/aws_http_client/Cargo.toml             |  22 ++
crates/aws_http_client/LICENSE-GPL            |   1 
crates/aws_http_client/src/aws_http_client.rs | 118 ++++++++++++
crates/bedrock/Cargo.toml                     |  28 ++
crates/bedrock/LICENSE-GPL                    |   1 
crates/bedrock/src/bedrock.rs                 | 166 +++++++++++++++++
crates/bedrock/src/models.rs                  | 199 +++++++++++++++++++++
9 files changed, 595 insertions(+)

Detailed changes

Cargo.lock 🔗

@@ -1269,6 +1269,30 @@ dependencies = [
  "uuid",
 ]
 
+[[package]]
+name = "aws-sdk-bedrockruntime"
+version = "1.74.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6938541d1948a543bca23303fec4cff9c36bf0e63b8fa3ae1b337bcb9d5b81af"
+dependencies = [
+ "aws-credential-types",
+ "aws-runtime",
+ "aws-smithy-async",
+ "aws-smithy-eventstream",
+ "aws-smithy-http",
+ "aws-smithy-json",
+ "aws-smithy-runtime",
+ "aws-smithy-runtime-api",
+ "aws-smithy-types",
+ "aws-types",
+ "bytes 1.10.0",
+ "fastrand 2.3.0",
+ "http 0.2.12",
+ "once_cell",
+ "regex-lite",
+ "tracing",
+]
+
 [[package]]
 name = "aws-sdk-kinesis"
 version = "1.61.0"
@@ -1598,6 +1622,17 @@ dependencies = [
  "tracing",
 ]
 
+[[package]]
+name = "aws_http_client"
+version = "0.1.0"
+dependencies = [
+ "aws-smithy-runtime-api",
+ "aws-smithy-types",
+ "futures 0.3.31",
+ "http_client",
+ "tokio",
+]
+
 [[package]]
 name = "axum"
 version = "0.6.20"
@@ -1727,6 +1762,22 @@ version = "1.6.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
 
+[[package]]
+name = "bedrock"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "aws-sdk-bedrockruntime",
+ "aws-smithy-types",
+ "futures 0.3.31",
+ "schemars",
+ "serde",
+ "serde_json",
+ "strum",
+ "thiserror 1.0.69",
+ "tokio",
+]
+
 [[package]]
 name = "bigdecimal"
 version = "0.4.7"

Cargo.toml 🔗

@@ -15,6 +15,8 @@ members = [
     "crates/audio",
     "crates/auto_update",
     "crates/auto_update_ui",
+    "crates/aws_http_client",
+    "crates/bedrock",
     "crates/breadcrumbs",
     "crates/buffer_diff",
     "crates/call",
@@ -218,6 +220,8 @@ assistant_tools = { path = "crates/assistant_tools" }
 audio = { path = "crates/audio" }
 auto_update = { path = "crates/auto_update" }
 auto_update_ui = { path = "crates/auto_update_ui" }
+aws_http_client = { path = "crates/aws_http_client" }
+bedrock = { path = "crates/bedrock" }
 breadcrumbs = { path = "crates/breadcrumbs" }
 call = { path = "crates/call" }
 channel = { path = "crates/channel" }
@@ -382,6 +386,11 @@ async-trait = "0.1"
 async-tungstenite = "0.28"
 async-watch = "0.3.1"
 async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] }
+aws-config = { version = "1.5.16", features = ["behavior-version-latest"] }
+aws-credential-types = { version = "1.2.1", features = ["hardcoded-credentials"] }
+aws-sdk-bedrockruntime = { version = "1.73.0", features = ["behavior-version-latest"] }
+aws-smithy-runtime-api = { version = "1.7.3", features = ["http-1x", "client"] }
+aws-smithy-types = { version = "1.2.13", features = ["http-body-1-x"] }
 base64 = "0.22"
 bitflags = "2.6.0"
 blade-graphics = { git = "https://github.com/kvark/blade", rev = "b16f5c7bd873c7126f48c82c39e7ae64602ae74f" }

crates/aws_http_client/Cargo.toml 🔗

@@ -0,0 +1,22 @@
+[package]
+name = "aws_http_client"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/aws_http_client.rs"
+
+[features]
+default = []
+
+[dependencies]
+aws-smithy-runtime-api.workspace = true
+aws-smithy-types.workspace = true
+futures.workspace = true
+http_client.workspace = true
+tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }

crates/aws_http_client/src/aws_http_client.rs 🔗

@@ -0,0 +1,118 @@
+use std::fmt;
+use std::sync::Arc;
+
+use aws_smithy_runtime_api::client::http::{
+    HttpClient as AwsClient, HttpConnector as AwsConnector,
+    HttpConnectorFuture as AwsConnectorFuture, HttpConnectorFuture, HttpConnectorSettings,
+    SharedHttpConnector,
+};
+use aws_smithy_runtime_api::client::orchestrator::{HttpRequest as AwsHttpRequest, HttpResponse};
+use aws_smithy_runtime_api::client::result::ConnectorError;
+use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
+use aws_smithy_runtime_api::http::StatusCode;
+use aws_smithy_types::body::SdkBody;
+use futures::AsyncReadExt;
+use http_client::{AsyncBody, Inner};
+use http_client::{HttpClient, Request};
+use tokio::runtime::Handle;
+
+struct AwsHttpConnector {
+    client: Arc<dyn HttpClient>,
+    handle: Handle,
+}
+
+impl std::fmt::Debug for AwsHttpConnector {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result {
+        f.debug_struct("AwsHttpConnector").finish()
+    }
+}
+
+impl AwsConnector for AwsHttpConnector {
+    fn call(&self, request: AwsHttpRequest) -> AwsConnectorFuture {
+        let req = match request.try_into_http1x() {
+            Ok(req) => req,
+            Err(err) => {
+                return HttpConnectorFuture::ready(Err(ConnectorError::other(err.into(), None)))
+            }
+        };
+
+        let (parts, body) = req.into_parts();
+
+        let response = self
+            .client
+            .send(Request::from_parts(parts, convert_to_async_body(body)));
+
+        let handle = self.handle.clone();
+
+        HttpConnectorFuture::new(async move {
+            let response = match response.await {
+                Ok(response) => response,
+                Err(err) => return Err(ConnectorError::other(err.into(), None)),
+            };
+            let (parts, body) = response.into_parts();
+            let body = convert_to_sdk_body(body, handle).await;
+
+            Ok(HttpResponse::new(
+                StatusCode::try_from(parts.status.as_u16()).unwrap(),
+                body,
+            ))
+        })
+    }
+}
+
+#[derive(Clone)]
+pub struct AwsHttpClient {
+    client: Arc<dyn HttpClient>,
+    handler: Handle,
+}
+
+impl std::fmt::Debug for AwsHttpClient {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result {
+        f.debug_struct("AwsHttpClient").finish()
+    }
+}
+
+impl AwsHttpClient {
+    pub fn new(client: Arc<dyn HttpClient>, handle: Handle) -> Self {
+        Self {
+            client,
+            handler: handle,
+        }
+    }
+}
+
+impl AwsClient for AwsHttpClient {
+    fn http_connector(
+        &self,
+        _settings: &HttpConnectorSettings,
+        _components: &RuntimeComponents,
+    ) -> SharedHttpConnector {
+        SharedHttpConnector::new(AwsHttpConnector {
+            client: self.client.clone(),
+            handle: self.handler.clone(),
+        })
+    }
+}
+
+pub async fn convert_to_sdk_body(body: AsyncBody, handle: Handle) -> SdkBody {
+    match body.0 {
+        Inner::Empty => SdkBody::empty(),
+        Inner::Bytes(bytes) => SdkBody::from(bytes.into_inner()),
+        Inner::AsyncReader(mut reader) => {
+            let buffer = handle.spawn(async move {
+                let mut buffer = Vec::new();
+                let _ = reader.read_to_end(&mut buffer).await;
+                buffer
+            });
+
+            SdkBody::from(buffer.await.unwrap_or_default())
+        }
+    }
+}
+
+pub fn convert_to_async_body(body: SdkBody) -> AsyncBody {
+    match body.bytes() {
+        Some(bytes) => AsyncBody::from((*bytes).to_vec()),
+        None => AsyncBody::empty(),
+    }
+}

crates/bedrock/Cargo.toml 🔗

@@ -0,0 +1,28 @@
+[package]
+name = "bedrock"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/bedrock.rs"
+
+[features]
+default = []
+schemars = ["dep:schemars"]
+
+[dependencies]
+anyhow.workspace = true
+aws-sdk-bedrockruntime = { workspace = true, features = ["behavior-version-latest"] }
+aws-smithy-types = {workspace = true}
+futures.workspace = true
+schemars = { workspace = true, optional = true }
+serde.workspace = true
+serde_json.workspace = true
+strum.workspace = true
+thiserror.workspace = true
+tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }

crates/bedrock/src/bedrock.rs 🔗

@@ -0,0 +1,166 @@
+mod models;
+
+use std::pin::Pin;
+
+use anyhow::{anyhow, Context, Error, Result};
+use aws_sdk_bedrockruntime as bedrock;
+pub use aws_sdk_bedrockruntime as bedrock_client;
+pub use aws_sdk_bedrockruntime::types::{
+    ContentBlock as BedrockInnerContent, SpecificToolChoice as BedrockSpecificTool,
+    ToolChoice as BedrockToolChoice, ToolInputSchema as BedrockToolInputSchema,
+    ToolSpecification as BedrockTool,
+};
+use aws_smithy_types::{Document, Number as AwsNumber};
+pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest;
+pub use bedrock::types::{
+    ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole,
+    ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse,
+    Message as BedrockMessage, ResponseStream as BedrockResponseStream,
+};
+use futures::stream::{self, BoxStream, Stream};
+use serde::{Deserialize, Serialize};
+use serde_json::{Number, Value};
+use thiserror::Error;
+
+pub use crate::models::*;
+
+pub async fn complete(
+    client: &bedrock::Client,
+    request: Request,
+) -> Result<BedrockResponse, BedrockError> {
+    let response = bedrock::Client::converse(client)
+        .model_id(request.model.clone())
+        .set_messages(request.messages.into())
+        .send()
+        .await
+        .context("failed to send request to Bedrock");
+
+    match response {
+        Ok(output) => output
+            .output
+            .ok_or_else(|| BedrockError::Other(anyhow!("no output"))),
+        Err(err) => Err(BedrockError::Other(err)),
+    }
+}
+
+pub async fn stream_completion(
+    client: bedrock::Client,
+    request: Request,
+    handle: tokio::runtime::Handle,
+) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
+    handle
+        .spawn(async move {
+            let response = bedrock::Client::converse_stream(&client)
+                .model_id(request.model.clone())
+                .set_messages(request.messages.into())
+                .send()
+                .await;
+
+            match response {
+                Ok(output) => {
+                    let stream: Pin<
+                        Box<
+                            dyn Stream<Item = Result<BedrockStreamingResponse, BedrockError>>
+                                + Send,
+                        >,
+                    > = Box::pin(stream::unfold(output.stream, |mut stream| async move {
+                        match stream.recv().await {
+                            Ok(Some(output)) => Some((Ok(output), stream)),
+                            Ok(None) => None,
+                            Err(err) => {
+                                Some((
+                                    // TODO: Figure out how we can capture Throttling Exceptions
+                                    Err(BedrockError::ClientError(anyhow!(
+                                        "{:?}",
+                                        aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
+                                    ))),
+                                    stream,
+                                ))
+                            }
+                        }
+                    }));
+                    Ok(stream)
+                }
+                Err(err) => Err(anyhow!(
+                    "{:?}",
+                    aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
+                )),
+            }
+        })
+        .await
+        .map_err(|err| anyhow!("failed to spawn task: {err:?}"))?
+}
+
+pub fn aws_document_to_value(document: &Document) -> Value {
+    match document {
+        Document::Null => Value::Null,
+        Document::Bool(value) => Value::Bool(*value),
+        Document::Number(value) => match *value {
+            AwsNumber::PosInt(value) => Value::Number(Number::from(value)),
+            AwsNumber::NegInt(value) => Value::Number(Number::from(value)),
+            AwsNumber::Float(value) => Value::Number(Number::from_f64(value).unwrap()),
+        },
+        Document::String(value) => Value::String(value.clone()),
+        Document::Array(array) => Value::Array(array.iter().map(aws_document_to_value).collect()),
+        Document::Object(map) => Value::Object(
+            map.iter()
+                .map(|(key, value)| (key.clone(), aws_document_to_value(value)))
+                .collect(),
+        ),
+    }
+}
+
+pub fn value_to_aws_document(value: &Value) -> Document {
+    match value {
+        Value::Null => Document::Null,
+        Value::Bool(value) => Document::Bool(*value),
+        Value::Number(value) => {
+            if let Some(value) = value.as_u64() {
+                Document::Number(AwsNumber::PosInt(value))
+            } else if let Some(value) = value.as_i64() {
+                Document::Number(AwsNumber::NegInt(value))
+            } else if let Some(value) = value.as_f64() {
+                Document::Number(AwsNumber::Float(value))
+            } else {
+                Document::Null
+            }
+        }
+        Value::String(value) => Document::String(value.clone()),
+        Value::Array(array) => Document::Array(array.iter().map(value_to_aws_document).collect()),
+        Value::Object(map) => Document::Object(
+            map.iter()
+                .map(|(key, value)| (key.clone(), value_to_aws_document(value)))
+                .collect(),
+        ),
+    }
+}
+
+#[derive(Debug)]
+pub struct Request {
+    pub model: String,
+    pub max_tokens: u32,
+    pub messages: Vec<BedrockMessage>,
+    pub tools: Vec<BedrockTool>,
+    pub tool_choice: Option<BedrockToolChoice>,
+    pub system: Option<String>,
+    pub metadata: Option<Metadata>,
+    pub stop_sequences: Vec<String>,
+    pub temperature: Option<f32>,
+    pub top_k: Option<u32>,
+    pub top_p: Option<f32>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct Metadata {
+    pub user_id: Option<String>,
+}
+
+#[derive(Error, Debug)]
+pub enum BedrockError {
+    #[error("client error: {0}")]
+    ClientError(anyhow::Error),
+    #[error("extension error: {0}")]
+    ExtensionError(anyhow::Error),
+    #[error(transparent)]
+    Other(#[from] anyhow::Error),
+}

crates/bedrock/src/models.rs 🔗

@@ -0,0 +1,199 @@
+use anyhow::anyhow;
+use serde::{Deserialize, Serialize};
+use strum::EnumIter;
+
+#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
+#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
+pub enum Model {
+    // Anthropic models (already included)
+    #[default]
+    #[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-latest")]
+    Claude3_5Sonnet,
+    #[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")]
+    Claude3Opus,
+    #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-latest")]
+    Claude3Sonnet,
+    #[serde(rename = "claude-3-5-haiku", alias = "claude-3-5-haiku-latest")]
+    Claude3_5Haiku,
+    // Amazon Nova Models
+    AmazonNovaLite,
+    AmazonNovaMicro,
+    AmazonNovaPro,
+    // AI21 models
+    AI21J2GrandeInstruct,
+    AI21J2JumboInstruct,
+    AI21J2Mid,
+    AI21J2MidV1,
+    AI21J2Ultra,
+    AI21J2UltraV1_8k,
+    AI21J2UltraV1,
+    AI21JambaInstructV1,
+    AI21Jamba15LargeV1,
+    AI21Jamba15MiniV1,
+    // Cohere models
+    CohereCommandTextV14_4k,
+    CohereCommandRV1,
+    CohereCommandRPlusV1,
+    CohereCommandLightTextV14_4k,
+    // Meta models
+    MetaLlama38BInstructV1,
+    MetaLlama370BInstructV1,
+    MetaLlama318BInstructV1_128k,
+    MetaLlama318BInstructV1,
+    MetaLlama3170BInstructV1_128k,
+    MetaLlama3170BInstructV1,
+    MetaLlama3211BInstructV1,
+    MetaLlama3290BInstructV1,
+    MetaLlama321BInstructV1,
+    MetaLlama323BInstructV1,
+    // Mistral models
+    MistralMistral7BInstructV0,
+    MistralMixtral8x7BInstructV0,
+    MistralMistralLarge2402V1,
+    MistralMistralSmall2402V1,
+    #[serde(rename = "custom")]
+    Custom {
+        name: String,
+        max_tokens: usize,
+        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
+        display_name: Option<String>,
+        max_output_tokens: Option<u32>,
+        default_temperature: Option<f32>,
+    },
+}
+
+impl Model {
+    pub fn from_id(id: &str) -> anyhow::Result<Self> {
+        if id.starts_with("claude-3-5-sonnet") {
+            Ok(Self::Claude3_5Sonnet)
+        } else if id.starts_with("claude-3-opus") {
+            Ok(Self::Claude3Opus)
+        } else if id.starts_with("claude-3-sonnet") {
+            Ok(Self::Claude3Sonnet)
+        } else if id.starts_with("claude-3-5-haiku") {
+            Ok(Self::Claude3_5Haiku)
+        } else {
+            Err(anyhow!("invalid model id"))
+        }
+    }
+
+    pub fn id(&self) -> &str {
+        match self {
+            Model::Claude3_5Sonnet => "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
+            Model::Claude3Opus => "us.anthropic.claude-3-opus-20240229-v1:0",
+            Model::Claude3Sonnet => "us.anthropic.claude-3-sonnet-20240229-v1:0",
+            Model::Claude3_5Haiku => "us.anthropic.claude-3-5-haiku-20241022-v1:0",
+            Model::AmazonNovaLite => "us.amazon.nova-lite-v1:0",
+            Model::AmazonNovaMicro => "us.amazon.nova-micro-v1:0",
+            Model::AmazonNovaPro => "us.amazon.nova-pro-v1:0",
+            Model::AI21J2GrandeInstruct => "ai21.j2-grande-instruct",
+            Model::AI21J2JumboInstruct => "ai21.j2-jumbo-instruct",
+            Model::AI21J2Mid => "ai21.j2-mid",
+            Model::AI21J2MidV1 => "ai21.j2-mid-v1",
+            Model::AI21J2Ultra => "ai21.j2-ultra",
+            Model::AI21J2UltraV1_8k => "ai21.j2-ultra-v1:0:8k",
+            Model::AI21J2UltraV1 => "ai21.j2-ultra-v1",
+            Model::AI21JambaInstructV1 => "ai21.jamba-instruct-v1:0",
+            Model::AI21Jamba15LargeV1 => "ai21.jamba-1-5-large-v1:0",
+            Model::AI21Jamba15MiniV1 => "ai21.jamba-1-5-mini-v1:0",
+            Model::CohereCommandTextV14_4k => "cohere.command-text-v14:7:4k",
+            Model::CohereCommandRV1 => "cohere.command-r-v1:0",
+            Model::CohereCommandRPlusV1 => "cohere.command-r-plus-v1:0",
+            Model::CohereCommandLightTextV14_4k => "cohere.command-light-text-v14:7:4k",
+            Model::MetaLlama38BInstructV1 => "meta.llama3-8b-instruct-v1:0",
+            Model::MetaLlama370BInstructV1 => "meta.llama3-70b-instruct-v1:0",
+            Model::MetaLlama318BInstructV1_128k => "meta.llama3-1-8b-instruct-v1:0:128k",
+            Model::MetaLlama318BInstructV1 => "meta.llama3-1-8b-instruct-v1:0",
+            Model::MetaLlama3170BInstructV1_128k => "meta.llama3-1-70b-instruct-v1:0:128k",
+            Model::MetaLlama3170BInstructV1 => "meta.llama3-1-70b-instruct-v1:0",
+            Model::MetaLlama3211BInstructV1 => "meta.llama3-2-11b-instruct-v1:0",
+            Model::MetaLlama3290BInstructV1 => "meta.llama3-2-90b-instruct-v1:0",
+            Model::MetaLlama321BInstructV1 => "meta.llama3-2-1b-instruct-v1:0",
+            Model::MetaLlama323BInstructV1 => "meta.llama3-2-3b-instruct-v1:0",
+            Model::MistralMistral7BInstructV0 => "mistral.mistral-7b-instruct-v0:2",
+            Model::MistralMixtral8x7BInstructV0 => "mistral.mixtral-8x7b-instruct-v0:1",
+            Model::MistralMistralLarge2402V1 => "mistral.mistral-large-2402-v1:0",
+            Model::MistralMistralSmall2402V1 => "mistral.mistral-small-2402-v1:0",
+            Self::Custom { name, .. } => name,
+        }
+    }
+
+    pub fn display_name(&self) -> &str {
+        match self {
+            Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
+            Self::Claude3Opus => "Claude 3 Opus",
+            Self::Claude3Sonnet => "Claude 3 Sonnet",
+            Self::Claude3_5Haiku => "Claude 3.5 Haiku",
+            Self::AmazonNovaLite => "Amazon Nova Lite",
+            Self::AmazonNovaMicro => "Amazon Nova Micro",
+            Self::AmazonNovaPro => "Amazon Nova Pro",
+            Self::AI21J2GrandeInstruct => "AI21 Jurassic2 Grande Instruct",
+            Self::AI21J2JumboInstruct => "AI21 Jurassic2 Jumbo Instruct",
+            Self::AI21J2Mid => "AI21 Jurassic2 Mid",
+            Self::AI21J2MidV1 => "AI21 Jurassic2 Mid V1",
+            Self::AI21J2Ultra => "AI21 Jurassic2 Ultra",
+            Self::AI21J2UltraV1_8k => "AI21 Jurassic2 Ultra V1 8K",
+            Self::AI21J2UltraV1 => "AI21 Jurassic2 Ultra V1",
+            Self::AI21JambaInstructV1 => "AI21 Jamba Instruct",
+            Self::AI21Jamba15LargeV1 => "AI21 Jamba 1.5 Large",
+            Self::AI21Jamba15MiniV1 => "AI21 Jamba 1.5 Mini",
+            Self::CohereCommandTextV14_4k => "Cohere Command Text V14 4K",
+            Self::CohereCommandRV1 => "Cohere Command R V1",
+            Self::CohereCommandRPlusV1 => "Cohere Command R Plus V1",
+            Self::CohereCommandLightTextV14_4k => "Cohere Command Light Text V14 4K",
+            Self::MetaLlama38BInstructV1 => "Meta Llama 3 8B Instruct V1",
+            Self::MetaLlama370BInstructV1 => "Meta Llama 3 70B Instruct V1",
+            Self::MetaLlama318BInstructV1_128k => "Meta Llama 3 1.8B Instruct V1 128K",
+            Self::MetaLlama318BInstructV1 => "Meta Llama 3 1.8B Instruct V1",
+            Self::MetaLlama3170BInstructV1_128k => "Meta Llama 3 1 70B Instruct V1 128K",
+            Self::MetaLlama3170BInstructV1 => "Meta Llama 3 1 70B Instruct V1",
+            Self::MetaLlama3211BInstructV1 => "Meta Llama 3 2 11B Instruct V1",
+            Self::MetaLlama3290BInstructV1 => "Meta Llama 3 2 90B Instruct V1",
+            Self::MetaLlama321BInstructV1 => "Meta Llama 3 2 1B Instruct V1",
+            Self::MetaLlama323BInstructV1 => "Meta Llama 3 2 3B Instruct V1",
+            Self::MistralMistral7BInstructV0 => "Mistral 7B Instruct V0",
+            Self::MistralMixtral8x7BInstructV0 => "Mistral Mixtral 8x7B Instruct V0",
+            Self::MistralMistralLarge2402V1 => "Mistral Large 2402 V1",
+            Self::MistralMistralSmall2402V1 => "Mistral Small 2402 V1",
+            Self::Custom {
+                display_name, name, ..
+            } => display_name.as_deref().unwrap_or(name),
+        }
+    }
+
+    pub fn max_token_count(&self) -> usize {
+        match self {
+            Self::Claude3_5Sonnet
+            | Self::Claude3Opus
+            | Self::Claude3Sonnet
+            | Self::Claude3_5Haiku => 200_000,
+            Self::Custom { max_tokens, .. } => *max_tokens,
+            _ => 200_000,
+        }
+    }
+
+    pub fn max_output_tokens(&self) -> u32 {
+        match self {
+            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3_5Haiku => 4_096,
+            Self::Claude3_5Sonnet => 8_192,
+            Self::Custom {
+                max_output_tokens, ..
+            } => max_output_tokens.unwrap_or(4_096),
+            _ => 4_096,
+        }
+    }
+
+    pub fn default_temperature(&self) -> f32 {
+        match self {
+            Self::Claude3_5Sonnet
+            | Self::Claude3Opus
+            | Self::Claude3Sonnet
+            | Self::Claude3_5Haiku => 1.0,
+            Self::Custom {
+                default_temperature,
+                ..
+            } => default_temperature.unwrap_or(1.0),
+            _ => 1.0,
+        }
+    }
+}