Detailed changes
@@ -1269,6 +1269,30 @@ dependencies = [
"uuid",
]
+[[package]]
+name = "aws-sdk-bedrockruntime"
+version = "1.74.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6938541d1948a543bca23303fec4cff9c36bf0e63b8fa3ae1b337bcb9d5b81af"
+dependencies = [
+ "aws-credential-types",
+ "aws-runtime",
+ "aws-smithy-async",
+ "aws-smithy-eventstream",
+ "aws-smithy-http",
+ "aws-smithy-json",
+ "aws-smithy-runtime",
+ "aws-smithy-runtime-api",
+ "aws-smithy-types",
+ "aws-types",
+ "bytes 1.10.0",
+ "fastrand 2.3.0",
+ "http 0.2.12",
+ "once_cell",
+ "regex-lite",
+ "tracing",
+]
+
[[package]]
name = "aws-sdk-kinesis"
version = "1.61.0"
@@ -1598,6 +1622,17 @@ dependencies = [
"tracing",
]
+[[package]]
+name = "aws_http_client"
+version = "0.1.0"
+dependencies = [
+ "aws-smithy-runtime-api",
+ "aws-smithy-types",
+ "futures 0.3.31",
+ "http_client",
+ "tokio",
+]
+
[[package]]
name = "axum"
version = "0.6.20"
@@ -1727,6 +1762,22 @@ version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
+[[package]]
+name = "bedrock"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "aws-sdk-bedrockruntime",
+ "aws-smithy-types",
+ "futures 0.3.31",
+ "schemars",
+ "serde",
+ "serde_json",
+ "strum",
+ "thiserror 1.0.69",
+ "tokio",
+]
+
[[package]]
name = "bigdecimal"
version = "0.4.7"
@@ -15,6 +15,8 @@ members = [
"crates/audio",
"crates/auto_update",
"crates/auto_update_ui",
+ "crates/aws_http_client",
+ "crates/bedrock",
"crates/breadcrumbs",
"crates/buffer_diff",
"crates/call",
@@ -218,6 +220,8 @@ assistant_tools = { path = "crates/assistant_tools" }
audio = { path = "crates/audio" }
auto_update = { path = "crates/auto_update" }
auto_update_ui = { path = "crates/auto_update_ui" }
+aws_http_client = { path = "crates/aws_http_client" }
+bedrock = { path = "crates/bedrock" }
breadcrumbs = { path = "crates/breadcrumbs" }
call = { path = "crates/call" }
channel = { path = "crates/channel" }
@@ -382,6 +386,11 @@ async-trait = "0.1"
async-tungstenite = "0.28"
async-watch = "0.3.1"
async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] }
+aws-config = { version = "1.5.16", features = ["behavior-version-latest"] }
+aws-credential-types = { version = "1.2.1", features = ["hardcoded-credentials"] }
+aws-sdk-bedrockruntime = { version = "1.73.0", features = ["behavior-version-latest"] }
+aws-smithy-runtime-api = { version = "1.7.3", features = ["http-1x", "client"] }
+aws-smithy-types = { version = "1.2.13", features = ["http-body-1-x"] }
base64 = "0.22"
bitflags = "2.6.0"
blade-graphics = { git = "https://github.com/kvark/blade", rev = "b16f5c7bd873c7126f48c82c39e7ae64602ae74f" }
@@ -0,0 +1,22 @@
+[package]
+name = "aws_http_client"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/aws_http_client.rs"
+
+[features]
+default = []
+
+[dependencies]
+aws-smithy-runtime-api.workspace = true
+aws-smithy-types.workspace = true
+futures.workspace = true
+http_client.workspace = true
+tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,118 @@
+use std::fmt;
+use std::sync::Arc;
+
+use aws_smithy_runtime_api::client::http::{
+ HttpClient as AwsClient, HttpConnector as AwsConnector,
+ HttpConnectorFuture as AwsConnectorFuture, HttpConnectorFuture, HttpConnectorSettings,
+ SharedHttpConnector,
+};
+use aws_smithy_runtime_api::client::orchestrator::{HttpRequest as AwsHttpRequest, HttpResponse};
+use aws_smithy_runtime_api::client::result::ConnectorError;
+use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
+use aws_smithy_runtime_api::http::StatusCode;
+use aws_smithy_types::body::SdkBody;
+use futures::AsyncReadExt;
+use http_client::{AsyncBody, Inner};
+use http_client::{HttpClient, Request};
+use tokio::runtime::Handle;
+
+struct AwsHttpConnector {
+ client: Arc<dyn HttpClient>,
+ handle: Handle,
+}
+
+impl std::fmt::Debug for AwsHttpConnector {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("AwsHttpConnector").finish()
+ }
+}
+
+impl AwsConnector for AwsHttpConnector {
+ fn call(&self, request: AwsHttpRequest) -> AwsConnectorFuture {
+ let req = match request.try_into_http1x() {
+ Ok(req) => req,
+ Err(err) => {
+ return HttpConnectorFuture::ready(Err(ConnectorError::other(err.into(), None)))
+ }
+ };
+
+ let (parts, body) = req.into_parts();
+
+ let response = self
+ .client
+ .send(Request::from_parts(parts, convert_to_async_body(body)));
+
+ let handle = self.handle.clone();
+
+ HttpConnectorFuture::new(async move {
+ let response = match response.await {
+ Ok(response) => response,
+ Err(err) => return Err(ConnectorError::other(err.into(), None)),
+ };
+ let (parts, body) = response.into_parts();
+ let body = convert_to_sdk_body(body, handle).await;
+
+ Ok(HttpResponse::new(
+ StatusCode::try_from(parts.status.as_u16()).unwrap(),
+ body,
+ ))
+ })
+ }
+}
+
+#[derive(Clone)]
+pub struct AwsHttpClient {
+ client: Arc<dyn HttpClient>,
+ handler: Handle,
+}
+
+impl std::fmt::Debug for AwsHttpClient {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("AwsHttpClient").finish()
+ }
+}
+
+impl AwsHttpClient {
+ pub fn new(client: Arc<dyn HttpClient>, handle: Handle) -> Self {
+ Self {
+ client,
+ handler: handle,
+ }
+ }
+}
+
+impl AwsClient for AwsHttpClient {
+ fn http_connector(
+ &self,
+ _settings: &HttpConnectorSettings,
+ _components: &RuntimeComponents,
+ ) -> SharedHttpConnector {
+ SharedHttpConnector::new(AwsHttpConnector {
+ client: self.client.clone(),
+ handle: self.handler.clone(),
+ })
+ }
+}
+
+pub async fn convert_to_sdk_body(body: AsyncBody, handle: Handle) -> SdkBody {
+ match body.0 {
+ Inner::Empty => SdkBody::empty(),
+ Inner::Bytes(bytes) => SdkBody::from(bytes.into_inner()),
+ Inner::AsyncReader(mut reader) => {
+ let buffer = handle.spawn(async move {
+ let mut buffer = Vec::new();
+ let _ = reader.read_to_end(&mut buffer).await;
+ buffer
+ });
+
+ SdkBody::from(buffer.await.unwrap_or_default())
+ }
+ }
+}
+
+pub fn convert_to_async_body(body: SdkBody) -> AsyncBody {
+ match body.bytes() {
+ Some(bytes) => AsyncBody::from((*bytes).to_vec()),
+ None => AsyncBody::empty(),
+ }
+}
@@ -0,0 +1,28 @@
+[package]
+name = "bedrock"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/bedrock.rs"
+
+[features]
+default = []
+schemars = ["dep:schemars"]
+
+[dependencies]
+anyhow.workspace = true
+aws-sdk-bedrockruntime = { workspace = true, features = ["behavior-version-latest"] }
+aws-smithy-types = {workspace = true}
+futures.workspace = true
+schemars = { workspace = true, optional = true }
+serde.workspace = true
+serde_json.workspace = true
+strum.workspace = true
+thiserror.workspace = true
+tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,166 @@
+mod models;
+
+use std::pin::Pin;
+
+use anyhow::{anyhow, Context, Error, Result};
+use aws_sdk_bedrockruntime as bedrock;
+pub use aws_sdk_bedrockruntime as bedrock_client;
+pub use aws_sdk_bedrockruntime::types::{
+ ContentBlock as BedrockInnerContent, SpecificToolChoice as BedrockSpecificTool,
+ ToolChoice as BedrockToolChoice, ToolInputSchema as BedrockToolInputSchema,
+ ToolSpecification as BedrockTool,
+};
+use aws_smithy_types::{Document, Number as AwsNumber};
+pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest;
+pub use bedrock::types::{
+ ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole,
+ ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse,
+ Message as BedrockMessage, ResponseStream as BedrockResponseStream,
+};
+use futures::stream::{self, BoxStream, Stream};
+use serde::{Deserialize, Serialize};
+use serde_json::{Number, Value};
+use thiserror::Error;
+
+pub use crate::models::*;
+
+pub async fn complete(
+ client: &bedrock::Client,
+ request: Request,
+) -> Result<BedrockResponse, BedrockError> {
+ let response = bedrock::Client::converse(client)
+ .model_id(request.model.clone())
+ .set_messages(request.messages.into())
+ .send()
+ .await
+ .context("failed to send request to Bedrock");
+
+ match response {
+ Ok(output) => output
+ .output
+ .ok_or_else(|| BedrockError::Other(anyhow!("no output"))),
+ Err(err) => Err(BedrockError::Other(err)),
+ }
+}
+
+pub async fn stream_completion(
+ client: bedrock::Client,
+ request: Request,
+ handle: tokio::runtime::Handle,
+) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
+ handle
+ .spawn(async move {
+ let response = bedrock::Client::converse_stream(&client)
+ .model_id(request.model.clone())
+ .set_messages(request.messages.into())
+ .send()
+ .await;
+
+ match response {
+ Ok(output) => {
+ let stream: Pin<
+ Box<
+ dyn Stream<Item = Result<BedrockStreamingResponse, BedrockError>>
+ + Send,
+ >,
+ > = Box::pin(stream::unfold(output.stream, |mut stream| async move {
+ match stream.recv().await {
+ Ok(Some(output)) => Some((Ok(output), stream)),
+ Ok(None) => None,
+ Err(err) => {
+ Some((
+ // TODO: Figure out how we can capture Throttling Exceptions
+ Err(BedrockError::ClientError(anyhow!(
+ "{:?}",
+ aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
+ ))),
+ stream,
+ ))
+ }
+ }
+ }));
+ Ok(stream)
+ }
+ Err(err) => Err(anyhow!(
+ "{:?}",
+ aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
+ )),
+ }
+ })
+ .await
+ .map_err(|err| anyhow!("failed to spawn task: {err:?}"))?
+}
+
+pub fn aws_document_to_value(document: &Document) -> Value {
+ match document {
+ Document::Null => Value::Null,
+ Document::Bool(value) => Value::Bool(*value),
+ Document::Number(value) => match *value {
+ AwsNumber::PosInt(value) => Value::Number(Number::from(value)),
+ AwsNumber::NegInt(value) => Value::Number(Number::from(value)),
+ AwsNumber::Float(value) => Value::Number(Number::from_f64(value).unwrap()),
+ },
+ Document::String(value) => Value::String(value.clone()),
+ Document::Array(array) => Value::Array(array.iter().map(aws_document_to_value).collect()),
+ Document::Object(map) => Value::Object(
+ map.iter()
+ .map(|(key, value)| (key.clone(), aws_document_to_value(value)))
+ .collect(),
+ ),
+ }
+}
+
+pub fn value_to_aws_document(value: &Value) -> Document {
+ match value {
+ Value::Null => Document::Null,
+ Value::Bool(value) => Document::Bool(*value),
+ Value::Number(value) => {
+ if let Some(value) = value.as_u64() {
+ Document::Number(AwsNumber::PosInt(value))
+ } else if let Some(value) = value.as_i64() {
+ Document::Number(AwsNumber::NegInt(value))
+ } else if let Some(value) = value.as_f64() {
+ Document::Number(AwsNumber::Float(value))
+ } else {
+ Document::Null
+ }
+ }
+ Value::String(value) => Document::String(value.clone()),
+ Value::Array(array) => Document::Array(array.iter().map(value_to_aws_document).collect()),
+ Value::Object(map) => Document::Object(
+ map.iter()
+ .map(|(key, value)| (key.clone(), value_to_aws_document(value)))
+ .collect(),
+ ),
+ }
+}
+
+#[derive(Debug)]
+pub struct Request {
+ pub model: String,
+ pub max_tokens: u32,
+ pub messages: Vec<BedrockMessage>,
+ pub tools: Vec<BedrockTool>,
+ pub tool_choice: Option<BedrockToolChoice>,
+ pub system: Option<String>,
+ pub metadata: Option<Metadata>,
+ pub stop_sequences: Vec<String>,
+ pub temperature: Option<f32>,
+ pub top_k: Option<u32>,
+ pub top_p: Option<f32>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct Metadata {
+ pub user_id: Option<String>,
+}
+
+#[derive(Error, Debug)]
+pub enum BedrockError {
+ #[error("client error: {0}")]
+ ClientError(anyhow::Error),
+ #[error("extension error: {0}")]
+ ExtensionError(anyhow::Error),
+ #[error(transparent)]
+ Other(#[from] anyhow::Error),
+}
@@ -0,0 +1,199 @@
+use anyhow::anyhow;
+use serde::{Deserialize, Serialize};
+use strum::EnumIter;
+
+#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
+#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
+pub enum Model {
+ // Anthropic models (already included)
+ #[default]
+ #[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-latest")]
+ Claude3_5Sonnet,
+ #[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")]
+ Claude3Opus,
+ #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-latest")]
+ Claude3Sonnet,
+ #[serde(rename = "claude-3-5-haiku", alias = "claude-3-5-haiku-latest")]
+ Claude3_5Haiku,
+ // Amazon Nova Models
+ AmazonNovaLite,
+ AmazonNovaMicro,
+ AmazonNovaPro,
+ // AI21 models
+ AI21J2GrandeInstruct,
+ AI21J2JumboInstruct,
+ AI21J2Mid,
+ AI21J2MidV1,
+ AI21J2Ultra,
+ AI21J2UltraV1_8k,
+ AI21J2UltraV1,
+ AI21JambaInstructV1,
+ AI21Jamba15LargeV1,
+ AI21Jamba15MiniV1,
+ // Cohere models
+ CohereCommandTextV14_4k,
+ CohereCommandRV1,
+ CohereCommandRPlusV1,
+ CohereCommandLightTextV14_4k,
+ // Meta models
+ MetaLlama38BInstructV1,
+ MetaLlama370BInstructV1,
+ MetaLlama318BInstructV1_128k,
+ MetaLlama318BInstructV1,
+ MetaLlama3170BInstructV1_128k,
+ MetaLlama3170BInstructV1,
+ MetaLlama3211BInstructV1,
+ MetaLlama3290BInstructV1,
+ MetaLlama321BInstructV1,
+ MetaLlama323BInstructV1,
+ // Mistral models
+ MistralMistral7BInstructV0,
+ MistralMixtral8x7BInstructV0,
+ MistralMistralLarge2402V1,
+ MistralMistralSmall2402V1,
+ #[serde(rename = "custom")]
+ Custom {
+ name: String,
+ max_tokens: usize,
+ /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
+ display_name: Option<String>,
+ max_output_tokens: Option<u32>,
+ default_temperature: Option<f32>,
+ },
+}
+
+impl Model {
+ pub fn from_id(id: &str) -> anyhow::Result<Self> {
+ if id.starts_with("claude-3-5-sonnet") {
+ Ok(Self::Claude3_5Sonnet)
+ } else if id.starts_with("claude-3-opus") {
+ Ok(Self::Claude3Opus)
+ } else if id.starts_with("claude-3-sonnet") {
+ Ok(Self::Claude3Sonnet)
+ } else if id.starts_with("claude-3-5-haiku") {
+ Ok(Self::Claude3_5Haiku)
+ } else {
+ Err(anyhow!("invalid model id"))
+ }
+ }
+
+ pub fn id(&self) -> &str {
+ match self {
+ Model::Claude3_5Sonnet => "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
+ Model::Claude3Opus => "us.anthropic.claude-3-opus-20240229-v1:0",
+ Model::Claude3Sonnet => "us.anthropic.claude-3-sonnet-20240229-v1:0",
+ Model::Claude3_5Haiku => "us.anthropic.claude-3-5-haiku-20241022-v1:0",
+ Model::AmazonNovaLite => "us.amazon.nova-lite-v1:0",
+ Model::AmazonNovaMicro => "us.amazon.nova-micro-v1:0",
+ Model::AmazonNovaPro => "us.amazon.nova-pro-v1:0",
+ Model::AI21J2GrandeInstruct => "ai21.j2-grande-instruct",
+ Model::AI21J2JumboInstruct => "ai21.j2-jumbo-instruct",
+ Model::AI21J2Mid => "ai21.j2-mid",
+ Model::AI21J2MidV1 => "ai21.j2-mid-v1",
+ Model::AI21J2Ultra => "ai21.j2-ultra",
+ Model::AI21J2UltraV1_8k => "ai21.j2-ultra-v1:0:8k",
+ Model::AI21J2UltraV1 => "ai21.j2-ultra-v1",
+ Model::AI21JambaInstructV1 => "ai21.jamba-instruct-v1:0",
+ Model::AI21Jamba15LargeV1 => "ai21.jamba-1-5-large-v1:0",
+ Model::AI21Jamba15MiniV1 => "ai21.jamba-1-5-mini-v1:0",
+ Model::CohereCommandTextV14_4k => "cohere.command-text-v14:7:4k",
+ Model::CohereCommandRV1 => "cohere.command-r-v1:0",
+ Model::CohereCommandRPlusV1 => "cohere.command-r-plus-v1:0",
+ Model::CohereCommandLightTextV14_4k => "cohere.command-light-text-v14:7:4k",
+ Model::MetaLlama38BInstructV1 => "meta.llama3-8b-instruct-v1:0",
+ Model::MetaLlama370BInstructV1 => "meta.llama3-70b-instruct-v1:0",
+ Model::MetaLlama318BInstructV1_128k => "meta.llama3-1-8b-instruct-v1:0:128k",
+ Model::MetaLlama318BInstructV1 => "meta.llama3-1-8b-instruct-v1:0",
+ Model::MetaLlama3170BInstructV1_128k => "meta.llama3-1-70b-instruct-v1:0:128k",
+ Model::MetaLlama3170BInstructV1 => "meta.llama3-1-70b-instruct-v1:0",
+ Model::MetaLlama3211BInstructV1 => "meta.llama3-2-11b-instruct-v1:0",
+ Model::MetaLlama3290BInstructV1 => "meta.llama3-2-90b-instruct-v1:0",
+ Model::MetaLlama321BInstructV1 => "meta.llama3-2-1b-instruct-v1:0",
+ Model::MetaLlama323BInstructV1 => "meta.llama3-2-3b-instruct-v1:0",
+ Model::MistralMistral7BInstructV0 => "mistral.mistral-7b-instruct-v0:2",
+ Model::MistralMixtral8x7BInstructV0 => "mistral.mixtral-8x7b-instruct-v0:1",
+ Model::MistralMistralLarge2402V1 => "mistral.mistral-large-2402-v1:0",
+ Model::MistralMistralSmall2402V1 => "mistral.mistral-small-2402-v1:0",
+ Self::Custom { name, .. } => name,
+ }
+ }
+
+ pub fn display_name(&self) -> &str {
+ match self {
+ Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
+ Self::Claude3Opus => "Claude 3 Opus",
+ Self::Claude3Sonnet => "Claude 3 Sonnet",
+ Self::Claude3_5Haiku => "Claude 3.5 Haiku",
+ Self::AmazonNovaLite => "Amazon Nova Lite",
+ Self::AmazonNovaMicro => "Amazon Nova Micro",
+ Self::AmazonNovaPro => "Amazon Nova Pro",
+ Self::AI21J2GrandeInstruct => "AI21 Jurassic2 Grande Instruct",
+ Self::AI21J2JumboInstruct => "AI21 Jurassic2 Jumbo Instruct",
+ Self::AI21J2Mid => "AI21 Jurassic2 Mid",
+ Self::AI21J2MidV1 => "AI21 Jurassic2 Mid V1",
+ Self::AI21J2Ultra => "AI21 Jurassic2 Ultra",
+ Self::AI21J2UltraV1_8k => "AI21 Jurassic2 Ultra V1 8K",
+ Self::AI21J2UltraV1 => "AI21 Jurassic2 Ultra V1",
+ Self::AI21JambaInstructV1 => "AI21 Jamba Instruct",
+ Self::AI21Jamba15LargeV1 => "AI21 Jamba 1.5 Large",
+ Self::AI21Jamba15MiniV1 => "AI21 Jamba 1.5 Mini",
+ Self::CohereCommandTextV14_4k => "Cohere Command Text V14 4K",
+ Self::CohereCommandRV1 => "Cohere Command R V1",
+ Self::CohereCommandRPlusV1 => "Cohere Command R Plus V1",
+ Self::CohereCommandLightTextV14_4k => "Cohere Command Light Text V14 4K",
+ Self::MetaLlama38BInstructV1 => "Meta Llama 3 8B Instruct V1",
+ Self::MetaLlama370BInstructV1 => "Meta Llama 3 70B Instruct V1",
+ Self::MetaLlama318BInstructV1_128k => "Meta Llama 3 1.8B Instruct V1 128K",
+ Self::MetaLlama318BInstructV1 => "Meta Llama 3 1.8B Instruct V1",
+ Self::MetaLlama3170BInstructV1_128k => "Meta Llama 3 1 70B Instruct V1 128K",
+ Self::MetaLlama3170BInstructV1 => "Meta Llama 3 1 70B Instruct V1",
+ Self::MetaLlama3211BInstructV1 => "Meta Llama 3 2 11B Instruct V1",
+ Self::MetaLlama3290BInstructV1 => "Meta Llama 3 2 90B Instruct V1",
+ Self::MetaLlama321BInstructV1 => "Meta Llama 3 2 1B Instruct V1",
+ Self::MetaLlama323BInstructV1 => "Meta Llama 3 2 3B Instruct V1",
+ Self::MistralMistral7BInstructV0 => "Mistral 7B Instruct V0",
+ Self::MistralMixtral8x7BInstructV0 => "Mistral Mixtral 8x7B Instruct V0",
+ Self::MistralMistralLarge2402V1 => "Mistral Large 2402 V1",
+ Self::MistralMistralSmall2402V1 => "Mistral Small 2402 V1",
+ Self::Custom {
+ display_name, name, ..
+ } => display_name.as_deref().unwrap_or(name),
+ }
+ }
+
+ pub fn max_token_count(&self) -> usize {
+ match self {
+ Self::Claude3_5Sonnet
+ | Self::Claude3Opus
+ | Self::Claude3Sonnet
+ | Self::Claude3_5Haiku => 200_000,
+ Self::Custom { max_tokens, .. } => *max_tokens,
+ _ => 200_000,
+ }
+ }
+
+ pub fn max_output_tokens(&self) -> u32 {
+ match self {
+ Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3_5Haiku => 4_096,
+ Self::Claude3_5Sonnet => 8_192,
+ Self::Custom {
+ max_output_tokens, ..
+ } => max_output_tokens.unwrap_or(4_096),
+ _ => 4_096,
+ }
+ }
+
+ pub fn default_temperature(&self) -> f32 {
+ match self {
+ Self::Claude3_5Sonnet
+ | Self::Claude3Opus
+ | Self::Claude3Sonnet
+ | Self::Claude3_5Haiku => 1.0,
+ Self::Custom {
+ default_temperature,
+ ..
+ } => default_temperature.unwrap_or(1.0),
+ _ => 1.0,
+ }
+ }
+}