From 821b8bb93fcb79d8ef077e022e28a1854f0ecb28 Mon Sep 17 00:00:00 2001 From: Jakub Konka Date: Tue, 31 Mar 2026 22:53:37 +0200 Subject: [PATCH] language_model: Create provider module --- crates/agent/src/native_agent_server.rs | 4 +- crates/agent/src/thread.rs | 2 +- crates/agent/src/tools/web_search_tool.rs | 2 +- crates/agent_ui/src/agent_configuration.rs | 2 +- crates/agent_ui/src/agent_panel.rs | 4 +- .../src/agent_api_keys_onboarding.rs | 2 +- .../src/agent_panel_onboarding_content.rs | 2 +- crates/language_model/src/language_model.rs | 190 +----------------- crates/language_model/src/provider.rs | 12 ++ .../language_model/src/provider/anthropic.rs | 80 ++++++++ crates/language_model/src/provider/google.rs | 5 + crates/language_model/src/provider/open_ai.rs | 28 +++ .../src/provider/open_router.rs | 69 +++++++ crates/language_model/src/provider/x_ai.rs | 4 + crates/language_model/src/provider/zed.rs | 5 + crates/language_model/src/registry.rs | 2 +- .../language_models/src/provider/anthropic.rs | 5 +- .../src/provider/anthropic/telemetry.rs | 2 +- crates/language_models/src/provider/cloud.rs | 35 ++-- crates/language_models/src/provider/google.rs | 5 +- .../language_models/src/provider/open_ai.rs | 5 +- 21 files changed, 250 insertions(+), 215 deletions(-) create mode 100644 crates/language_model/src/provider.rs create mode 100644 crates/language_model/src/provider/anthropic.rs create mode 100644 crates/language_model/src/provider/google.rs create mode 100644 crates/language_model/src/provider/open_ai.rs create mode 100644 crates/language_model/src/provider/open_router.rs create mode 100644 crates/language_model/src/provider/x_ai.rs create mode 100644 crates/language_model/src/provider/zed.rs diff --git a/crates/agent/src/native_agent_server.rs b/crates/agent/src/native_agent_server.rs index 7f19f9005e3ff54e361f57075b7af06508476564..88f41117fefff2d06091e1c0411398ca0e6c87f1 100644 --- a/crates/agent/src/native_agent_server.rs +++ b/crates/agent/src/native_agent_server.rs @@ -112,7 +112,7 @@ mod tests { prompt_store::init(cx); let registry = language_model::LanguageModelRegistry::read_global(cx); let auth = registry - .provider(&language_model::ANTHROPIC_PROVIDER_ID) + .provider(&language_model::provider::ANTHROPIC_PROVIDER_ID) .unwrap() .authenticate(cx); @@ -127,7 +127,7 @@ mod tests { registry.update(cx, |registry, cx| { registry.select_default_model( Some(&language_model::SelectedModel { - provider: language_model::ANTHROPIC_PROVIDER_ID, + provider: language_model::provider::ANTHROPIC_PROVIDER_ID, model: language_model::LanguageModelId("claude-sonnet-4-latest".into()), }), cx, diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index b61df1b8af84d312d7f186fb85e5a1d04ab59dfd..e979c329defa071b382eaf720b435a7ff31990a6 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -39,7 +39,7 @@ use language_model::{ LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, SelectedModel, Speed, StopReason, - TokenUsage, ZED_CLOUD_PROVIDER_ID, + TokenUsage, provider::ZED_CLOUD_PROVIDER_ID, }; use project::Project; use prompt_store::ProjectContext; diff --git a/crates/agent/src/tools/web_search_tool.rs b/crates/agent/src/tools/web_search_tool.rs index c697a5b78f1fe8c84d6ed58db13f651a493ae8c3..d265bee6f9c4140a24c07ef126ddf1ee3388c4b4 100644 --- a/crates/agent/src/tools/web_search_tool.rs +++ b/crates/agent/src/tools/web_search_tool.rs @@ -11,7 +11,7 @@ use cloud_llm_client::WebSearchResponse; use futures::FutureExt as _; use gpui::{App, Task}; use language_model::{ - LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID, + LanguageModelProviderId, LanguageModelToolResultContent, provider::ZED_CLOUD_PROVIDER_ID, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index fda3cb9907b2f02cce29ff0ae8c4762e6efa625a..1bf6384cc3ff182645f39bb7a7ebd6dd902b6be4 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -23,7 +23,7 @@ use itertools::Itertools; use language::LanguageRegistry; use language_model::{ IconOrSvg, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry, - ZED_CLOUD_PROVIDER_ID, + provider::ZED_CLOUD_PROVIDER_ID, }; use language_models::AllLanguageModelSettings; use notifications::status_toast::{StatusToast, ToastIcon}; diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index e6ef267a95110e745534010bae32b1b1fd6c0f0c..18083595f8acb25522a35593ca375a0c0ea04d22 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -3726,7 +3726,7 @@ impl AgentPanel { .read(cx) .default_model() .is_some_and(|model| { - model.provider.id() != language_model::ZED_CLOUD_PROVIDER_ID + model.provider.id() != language_model::provider::ZED_CLOUD_PROVIDER_ID }) { return false; @@ -3767,7 +3767,7 @@ impl AgentPanel { .iter() .any(|provider| { provider.is_authenticated(cx) - && provider.id() != language_model::ZED_CLOUD_PROVIDER_ID + && provider.id() != language_model::provider::ZED_CLOUD_PROVIDER_ID }); match &self.active_view { diff --git a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs index 47197ec2331b97dd4d7561d9f14c91c7f91c9fa0..aa9488fd5ad6fc9085f2f0aa16cea07f65737845 100644 --- a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs +++ b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs @@ -1,5 +1,5 @@ use gpui::{Action, IntoElement, ParentElement, RenderOnce, point}; -use language_model::{IconOrSvg, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID}; +use language_model::{IconOrSvg, LanguageModelRegistry, provider::ZED_CLOUD_PROVIDER_ID}; use ui::{Divider, List, ListBulletItem, prelude::*}; pub struct ApiKeysWithProviders { diff --git a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs index cc60a35e501329b0ca089e2f218ab1551ca35d93..4b6163d9f69776d448cad0e695bfe0a046a5b9c5 100644 --- a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs +++ b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use client::{Client, UserStore}; use cloud_api_types::Plan; use gpui::{Entity, IntoElement, ParentElement}; -use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID}; +use language_model::{LanguageModelRegistry, provider::ZED_CLOUD_PROVIDER_ID}; use ui::prelude::*; use crate::{AgentPanelOnboardingCard, ApiKeysWithoutProviders, ZedAiOnboarding}; diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 1fae41094e6a6176eeb18842841f4fd1ebb2a9eb..ae78bd300c5f857d368c3150c07c7a7aee4f6f55 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -1,5 +1,6 @@ mod api_key; mod model; +pub mod provider; mod rate_limiter; mod registry; mod request; @@ -9,7 +10,6 @@ pub mod tool_schema; #[cfg(any(test, feature = "test-support"))] pub mod fake_provider; -use anthropic::{AnthropicError, parse_prompt_too_long}; use anyhow::{Result, anyhow}; use client::Client; use client::UserStore; @@ -19,8 +19,8 @@ use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{AnyView, App, AsyncApp, Entity, SharedString, Task, Window}; use http_client::{StatusCode, http}; use icons::IconName; -use open_router::OpenRouterError; use parking_lot::Mutex; +use provider::parse_prompt_too_long; use serde::{Deserialize, Serialize}; pub use settings::LanguageModelCacheConfiguration; use std::ops::{Add, Sub}; @@ -40,26 +40,6 @@ pub use crate::role::*; pub use crate::tool_schema::LanguageModelToolSchemaFormat; pub use zed_env_vars::{EnvVar, env_var}; -pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId = - LanguageModelProviderId::new("anthropic"); -pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName = - LanguageModelProviderName::new("Anthropic"); - -pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google"); -pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName = - LanguageModelProviderName::new("Google AI"); - -pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai"); -pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName = - LanguageModelProviderName::new("OpenAI"); - -pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai"); -pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI"); - -pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev"); -pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName = - LanguageModelProviderName::new("Zed"); - pub fn init(user_store: Entity, client: Arc, cx: &mut App) { init_settings(cx); RefreshLlmTokenListener::register(client, user_store, cx); @@ -266,7 +246,12 @@ impl LanguageModelCompletionError { .strip_prefix("http_") .and_then(|code| StatusCode::from_str(code).ok()) { - Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after) + Self::from_http_status( + provider::ZED_CLOUD_PROVIDER_NAME, + status_code, + message, + retry_after, + ) } else { anyhow!("completion request failed, code: {code}, message: {message}").into() } @@ -308,165 +293,6 @@ impl LanguageModelCompletionError { } } -impl From for LanguageModelCompletionError { - fn from(error: AnthropicError) -> Self { - let provider = ANTHROPIC_PROVIDER_NAME; - match error { - AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error }, - AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error }, - AnthropicError::HttpSend(error) => Self::HttpSend { provider, error }, - AnthropicError::DeserializeResponse(error) => { - Self::DeserializeResponse { provider, error } - } - AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error }, - AnthropicError::HttpResponseError { - status_code, - message, - } => Self::HttpResponseError { - provider, - status_code, - message, - }, - AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded { - provider, - retry_after: Some(retry_after), - }, - AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded { - provider, - retry_after, - }, - AnthropicError::ApiError(api_error) => api_error.into(), - } - } -} - -impl From for LanguageModelCompletionError { - fn from(error: anthropic::ApiError) -> Self { - use anthropic::ApiErrorCode::*; - let provider = ANTHROPIC_PROVIDER_NAME; - match error.code() { - Some(code) => match code { - InvalidRequestError => Self::BadRequestFormat { - provider, - message: error.message, - }, - AuthenticationError => Self::AuthenticationError { - provider, - message: error.message, - }, - PermissionError => Self::PermissionError { - provider, - message: error.message, - }, - NotFoundError => Self::ApiEndpointNotFound { provider }, - RequestTooLarge => Self::PromptTooLarge { - tokens: parse_prompt_too_long(&error.message), - }, - RateLimitError => Self::RateLimitExceeded { - provider, - retry_after: None, - }, - ApiError => Self::ApiInternalServerError { - provider, - message: error.message, - }, - OverloadedError => Self::ServerOverloaded { - provider, - retry_after: None, - }, - }, - None => Self::Other(error.into()), - } - } -} - -impl From for LanguageModelCompletionError { - fn from(error: open_ai::RequestError) -> Self { - match error { - open_ai::RequestError::HttpResponseError { - provider, - status_code, - body, - headers, - } => { - let retry_after = headers - .get(http::header::RETRY_AFTER) - .and_then(|val| val.to_str().ok()?.parse::().ok()) - .map(Duration::from_secs); - - Self::from_http_status(provider.into(), status_code, body, retry_after) - } - open_ai::RequestError::Other(e) => Self::Other(e), - } - } -} - -impl From for LanguageModelCompletionError { - fn from(error: OpenRouterError) -> Self { - let provider = LanguageModelProviderName::new("OpenRouter"); - match error { - OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error }, - OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error }, - OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error }, - OpenRouterError::DeserializeResponse(error) => { - Self::DeserializeResponse { provider, error } - } - OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error }, - OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded { - provider, - retry_after: Some(retry_after), - }, - OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded { - provider, - retry_after, - }, - OpenRouterError::ApiError(api_error) => api_error.into(), - } - } -} - -impl From for LanguageModelCompletionError { - fn from(error: open_router::ApiError) -> Self { - use open_router::ApiErrorCode::*; - let provider = LanguageModelProviderName::new("OpenRouter"); - match error.code { - InvalidRequestError => Self::BadRequestFormat { - provider, - message: error.message, - }, - AuthenticationError => Self::AuthenticationError { - provider, - message: error.message, - }, - PaymentRequiredError => Self::AuthenticationError { - provider, - message: format!("Payment required: {}", error.message), - }, - PermissionError => Self::PermissionError { - provider, - message: error.message, - }, - RequestTimedOut => Self::HttpResponseError { - provider, - status_code: StatusCode::REQUEST_TIMEOUT, - message: error.message, - }, - RateLimitError => Self::RateLimitExceeded { - provider, - retry_after: None, - }, - ApiError => Self::ApiInternalServerError { - provider, - message: error.message, - }, - OverloadedError => Self::ServerOverloaded { - provider, - retry_after: None, - }, - } - } -} - #[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum StopReason { diff --git a/crates/language_model/src/provider.rs b/crates/language_model/src/provider.rs new file mode 100644 index 0000000000000000000000000000000000000000..707d8e2d618894e2898e253450dbfbb5e9483bba --- /dev/null +++ b/crates/language_model/src/provider.rs @@ -0,0 +1,12 @@ +pub mod anthropic; +pub mod google; +pub mod open_ai; +pub mod open_router; +pub mod x_ai; +pub mod zed; + +pub use anthropic::*; +pub use google::*; +pub use open_ai::*; +pub use x_ai::*; +pub use zed::*; diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs new file mode 100644 index 0000000000000000000000000000000000000000..0878be2070fdbb9e57145684f59c962a32bb9fd2 --- /dev/null +++ b/crates/language_model/src/provider/anthropic.rs @@ -0,0 +1,80 @@ +use crate::{LanguageModelCompletionError, LanguageModelProviderId, LanguageModelProviderName}; +use anthropic::AnthropicError; +pub use anthropic::parse_prompt_too_long; + +pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId = + LanguageModelProviderId::new("anthropic"); +pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName = + LanguageModelProviderName::new("Anthropic"); + +impl From for LanguageModelCompletionError { + fn from(error: AnthropicError) -> Self { + let provider = ANTHROPIC_PROVIDER_NAME; + match error { + AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error }, + AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error }, + AnthropicError::HttpSend(error) => Self::HttpSend { provider, error }, + AnthropicError::DeserializeResponse(error) => { + Self::DeserializeResponse { provider, error } + } + AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error }, + AnthropicError::HttpResponseError { + status_code, + message, + } => Self::HttpResponseError { + provider, + status_code, + message, + }, + AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded { + provider, + retry_after: Some(retry_after), + }, + AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded { + provider, + retry_after, + }, + AnthropicError::ApiError(api_error) => api_error.into(), + } + } +} + +impl From for LanguageModelCompletionError { + fn from(error: anthropic::ApiError) -> Self { + use anthropic::ApiErrorCode::*; + let provider = ANTHROPIC_PROVIDER_NAME; + match error.code() { + Some(code) => match code { + InvalidRequestError => Self::BadRequestFormat { + provider, + message: error.message, + }, + AuthenticationError => Self::AuthenticationError { + provider, + message: error.message, + }, + PermissionError => Self::PermissionError { + provider, + message: error.message, + }, + NotFoundError => Self::ApiEndpointNotFound { provider }, + RequestTooLarge => Self::PromptTooLarge { + tokens: parse_prompt_too_long(&error.message), + }, + RateLimitError => Self::RateLimitExceeded { + provider, + retry_after: None, + }, + ApiError => Self::ApiInternalServerError { + provider, + message: error.message, + }, + OverloadedError => Self::ServerOverloaded { + provider, + retry_after: None, + }, + }, + None => Self::Other(error.into()), + } + } +} diff --git a/crates/language_model/src/provider/google.rs b/crates/language_model/src/provider/google.rs new file mode 100644 index 0000000000000000000000000000000000000000..1caee496b519f395dd10744b127bc29ee893849f --- /dev/null +++ b/crates/language_model/src/provider/google.rs @@ -0,0 +1,5 @@ +use crate::{LanguageModelProviderId, LanguageModelProviderName}; + +pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google"); +pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName = + LanguageModelProviderName::new("Google AI"); diff --git a/crates/language_model/src/provider/open_ai.rs b/crates/language_model/src/provider/open_ai.rs new file mode 100644 index 0000000000000000000000000000000000000000..3796eb9a3aef78628c52d92e92fabb3812249e04 --- /dev/null +++ b/crates/language_model/src/provider/open_ai.rs @@ -0,0 +1,28 @@ +use crate::{LanguageModelCompletionError, LanguageModelProviderId, LanguageModelProviderName}; +use http_client::http; +use std::time::Duration; + +pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai"); +pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName = + LanguageModelProviderName::new("OpenAI"); + +impl From for LanguageModelCompletionError { + fn from(error: open_ai::RequestError) -> Self { + match error { + open_ai::RequestError::HttpResponseError { + provider, + status_code, + body, + headers, + } => { + let retry_after = headers + .get(http::header::RETRY_AFTER) + .and_then(|val| val.to_str().ok()?.parse::().ok()) + .map(Duration::from_secs); + + Self::from_http_status(provider.into(), status_code, body, retry_after) + } + open_ai::RequestError::Other(e) => Self::Other(e), + } + } +} diff --git a/crates/language_model/src/provider/open_router.rs b/crates/language_model/src/provider/open_router.rs new file mode 100644 index 0000000000000000000000000000000000000000..809e22f1fec0f2d205caa3ebbcb0baaf129b062c --- /dev/null +++ b/crates/language_model/src/provider/open_router.rs @@ -0,0 +1,69 @@ +use crate::{LanguageModelCompletionError, LanguageModelProviderName}; +use http_client::StatusCode; +use open_router::OpenRouterError; + +impl From for LanguageModelCompletionError { + fn from(error: OpenRouterError) -> Self { + let provider = LanguageModelProviderName::new("OpenRouter"); + match error { + OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error }, + OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error }, + OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error }, + OpenRouterError::DeserializeResponse(error) => { + Self::DeserializeResponse { provider, error } + } + OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error }, + OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded { + provider, + retry_after: Some(retry_after), + }, + OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded { + provider, + retry_after, + }, + OpenRouterError::ApiError(api_error) => api_error.into(), + } + } +} + +impl From for LanguageModelCompletionError { + fn from(error: open_router::ApiError) -> Self { + use open_router::ApiErrorCode::*; + let provider = LanguageModelProviderName::new("OpenRouter"); + match error.code { + InvalidRequestError => Self::BadRequestFormat { + provider, + message: error.message, + }, + AuthenticationError => Self::AuthenticationError { + provider, + message: error.message, + }, + PaymentRequiredError => Self::AuthenticationError { + provider, + message: format!("Payment required: {}", error.message), + }, + PermissionError => Self::PermissionError { + provider, + message: error.message, + }, + RequestTimedOut => Self::HttpResponseError { + provider, + status_code: StatusCode::REQUEST_TIMEOUT, + message: error.message, + }, + RateLimitError => Self::RateLimitExceeded { + provider, + retry_after: None, + }, + ApiError => Self::ApiInternalServerError { + provider, + message: error.message, + }, + OverloadedError => Self::ServerOverloaded { + provider, + retry_after: None, + }, + } + } +} diff --git a/crates/language_model/src/provider/x_ai.rs b/crates/language_model/src/provider/x_ai.rs new file mode 100644 index 0000000000000000000000000000000000000000..3d0f794fa4087a4beeb4a9b6253d016a9b592f0e --- /dev/null +++ b/crates/language_model/src/provider/x_ai.rs @@ -0,0 +1,4 @@ +use crate::{LanguageModelProviderId, LanguageModelProviderName}; + +pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai"); +pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI"); diff --git a/crates/language_model/src/provider/zed.rs b/crates/language_model/src/provider/zed.rs new file mode 100644 index 0000000000000000000000000000000000000000..0ba793e99aad1caa25f049a96faf02c16e8970fa --- /dev/null +++ b/crates/language_model/src/provider/zed.rs @@ -0,0 +1,5 @@ +use crate::{LanguageModelProviderId, LanguageModelProviderName}; + +pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev"); +pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName = + LanguageModelProviderName::new("Zed"); diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index cf7718f7b102010cc0c8a981a0425583436176b7..bf14fbb0b5804505b33074e6e4cbcc36ddf21fab 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -101,7 +101,7 @@ impl ConfiguredModel { } pub fn is_provided_by_zed(&self) -> bool { - self.provider.id() == crate::ZED_CLOUD_PROVIDER_ID + self.provider.id() == crate::provider::ZED_CLOUD_PROVIDER_ID } } diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index f9ec204e5a65dd31f8b0280e91d4beb4004d29c8..7d05202c4d180c1082f06bc91e2d714f30197371 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -16,6 +16,7 @@ use language_model::{ LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, env_var, + provider::{ANTHROPIC_PROVIDER_ID, ANTHROPIC_PROVIDER_NAME}, }; use settings::{Settings, SettingsStore}; use std::pin::Pin; @@ -30,8 +31,8 @@ use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; pub use settings::AnthropicAvailableModel as AvailableModel; -const PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID; -const PROVIDER_NAME: LanguageModelProviderName = language_model::ANTHROPIC_PROVIDER_NAME; +const PROVIDER_ID: LanguageModelProviderId = ANTHROPIC_PROVIDER_ID; +const PROVIDER_NAME: LanguageModelProviderName = ANTHROPIC_PROVIDER_NAME; #[derive(Default, Clone, Debug, PartialEq)] pub struct AnthropicSettings { diff --git a/crates/language_models/src/provider/anthropic/telemetry.rs b/crates/language_models/src/provider/anthropic/telemetry.rs index 75fb11a81b479635ea02db77a2df8a769e795e01..6c017d7c65a8da47c9344f240aa0c2bf38a4bac5 100644 --- a/crates/language_models/src/provider/anthropic/telemetry.rs +++ b/crates/language_models/src/provider/anthropic/telemetry.rs @@ -2,7 +2,7 @@ use anthropic::ANTHROPIC_API_URL; use anyhow::{Context as _, anyhow}; use gpui::BackgroundExecutor; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; -use language_model::{ANTHROPIC_PROVIDER_ID, LanguageModel}; +use language_model::{LanguageModel, provider::ANTHROPIC_PROVIDER_ID}; use std::env; use std::sync::Arc; use util::ResultExt; diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 161ee6e9abd5283dfbe10c4e7c9dc5597fc4b5b9..a451a5c27973e643c16939f39dd42d6a9a17773b 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -25,6 +25,11 @@ use language_model::{ LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, NeedsLlmTokenRefresh, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener, + provider::{ + ANTHROPIC_PROVIDER_ID, ANTHROPIC_PROVIDER_NAME, GOOGLE_PROVIDER_ID, GOOGLE_PROVIDER_NAME, + OPEN_AI_PROVIDER_ID, OPEN_AI_PROVIDER_NAME, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME, + ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME, + }, }; use release_channel::AppVersion; use schemars::JsonSchema; @@ -53,8 +58,8 @@ use crate::provider::open_ai::{ }; use crate::provider::x_ai::count_xai_tokens; -const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID; -const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME; +const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID; +const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME; #[derive(Default, Clone, Debug, PartialEq)] pub struct ZedDotDevSettings { @@ -568,20 +573,20 @@ impl LanguageModel for CloudLanguageModel { fn upstream_provider_id(&self) -> LanguageModelProviderId { use cloud_llm_client::LanguageModelProvider::*; match self.model.provider { - Anthropic => language_model::ANTHROPIC_PROVIDER_ID, - OpenAi => language_model::OPEN_AI_PROVIDER_ID, - Google => language_model::GOOGLE_PROVIDER_ID, - XAi => language_model::X_AI_PROVIDER_ID, + Anthropic => ANTHROPIC_PROVIDER_ID, + OpenAi => OPEN_AI_PROVIDER_ID, + Google => GOOGLE_PROVIDER_ID, + XAi => X_AI_PROVIDER_ID, } } fn upstream_provider_name(&self) -> LanguageModelProviderName { use cloud_llm_client::LanguageModelProvider::*; match self.model.provider { - Anthropic => language_model::ANTHROPIC_PROVIDER_NAME, - OpenAi => language_model::OPEN_AI_PROVIDER_NAME, - Google => language_model::GOOGLE_PROVIDER_NAME, - XAi => language_model::X_AI_PROVIDER_NAME, + Anthropic => ANTHROPIC_PROVIDER_NAME, + OpenAi => OPEN_AI_PROVIDER_NAME, + Google => GOOGLE_PROVIDER_NAME, + XAi => X_AI_PROVIDER_NAME, } } @@ -1047,12 +1052,10 @@ where fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName { match provider { - cloud_llm_client::LanguageModelProvider::Anthropic => { - language_model::ANTHROPIC_PROVIDER_NAME - } - cloud_llm_client::LanguageModelProvider::OpenAi => language_model::OPEN_AI_PROVIDER_NAME, - cloud_llm_client::LanguageModelProvider::Google => language_model::GOOGLE_PROVIDER_NAME, - cloud_llm_client::LanguageModelProvider::XAi => language_model::X_AI_PROVIDER_NAME, + cloud_llm_client::LanguageModelProvider::Anthropic => ANTHROPIC_PROVIDER_NAME, + cloud_llm_client::LanguageModelProvider::OpenAi => OPEN_AI_PROVIDER_NAME, + cloud_llm_client::LanguageModelProvider::Google => GOOGLE_PROVIDER_NAME, + cloud_llm_client::LanguageModelProvider::XAi => X_AI_PROVIDER_NAME, } } diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 334a5cbe64e6cdefbaa7c15c309ca4632109e323..ea2134234995273fa574b08052390c4a9f07eacc 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -16,6 +16,7 @@ use language_model::{ IconOrSvg, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, + provider::{GOOGLE_PROVIDER_ID, GOOGLE_PROVIDER_NAME}, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -33,8 +34,8 @@ use util::ResultExt; use language_model::ApiKeyState; -const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID; -const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME; +const PROVIDER_ID: LanguageModelProviderId = GOOGLE_PROVIDER_ID; +const PROVIDER_NAME: LanguageModelProviderName = GOOGLE_PROVIDER_NAME; #[derive(Default, Clone, Debug, PartialEq)] pub struct GoogleSettings { diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 8de1eaaf8465cf48838c49f6b24d3eb16d6e3487..266ef69e3d6fd9ebd22ad6997d8679c758509907 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -11,6 +11,7 @@ use language_model::{ LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, TokenUsage, env_var, + provider::{OPEN_AI_PROVIDER_ID, OPEN_AI_PROVIDER_NAME}, }; use menu; use open_ai::responses::{ @@ -35,8 +36,8 @@ use util::ResultExt; use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; -const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID; -const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME; +const PROVIDER_ID: LanguageModelProviderId = OPEN_AI_PROVIDER_ID; +const PROVIDER_NAME: LanguageModelProviderName = OPEN_AI_PROVIDER_NAME; const API_KEY_ENV_VAR_NAME: &str = "OPENAI_API_KEY"; static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME);