@@ -1,8 +1,6 @@
use anthropic::{AnthropicModelMode, parse_prompt_too_long};
-use anyhow::{Result, anyhow};
+use anyhow::{Context as _, Result, anyhow};
use client::{Client, UserStore, zed_urls};
-use collections::BTreeMap;
-use feature_flags::{FeatureFlagAppExt, LlmClosedBetaFeatureFlag};
use futures::{
AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
};
@@ -11,7 +9,7 @@ use gpui::{
};
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
use language_model::{
- AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
+ AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolChoice,
@@ -26,45 +24,30 @@ use proto::Plan;
use release_channel::AppVersion;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
-use settings::{Settings, SettingsStore};
+use settings::SettingsStore;
use smol::Timer;
use smol::io::{AsyncReadExt, BufReader};
use std::pin::Pin;
use std::str::FromStr as _;
-use std::{
- sync::{Arc, LazyLock},
- time::Duration,
-};
-use strum::IntoEnumIterator;
+use std::sync::Arc;
+use std::time::Duration;
use thiserror::Error;
use ui::{TintColor, prelude::*};
+use util::{ResultExt as _, maybe};
use zed_llm_client::{
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
- MODEL_REQUESTS_RESOURCE_HEADER_VALUE, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME,
- SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, TOOL_USE_LIMIT_REACHED_HEADER_NAME,
- ZED_VERSION_HEADER_NAME,
+ ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
+ SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
+ TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
};
-use crate::AllLanguageModelSettings;
use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
use crate::provider::google::{GoogleEventMapper, into_google};
use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
pub const PROVIDER_NAME: &str = "Zed";
-const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
- option_env!("ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON");
-
-fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
- static ADDITIONAL_MODELS: LazyLock<Vec<AvailableModel>> = LazyLock::new(|| {
- ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON
- .map(|json| serde_json::from_str(json).unwrap())
- .unwrap_or_default()
- });
- ADDITIONAL_MODELS.as_slice()
-}
-
#[derive(Default, Clone, Debug, PartialEq)]
pub struct ZedDotDevSettings {
pub available_models: Vec<AvailableModel>,
@@ -137,6 +120,11 @@ pub struct State {
user_store: Entity<UserStore>,
status: client::Status,
accept_terms: Option<Task<Result<()>>>,
+ models: Vec<Arc<zed_llm_client::LanguageModel>>,
+ default_model: Option<Arc<zed_llm_client::LanguageModel>>,
+ default_fast_model: Option<Arc<zed_llm_client::LanguageModel>>,
+ recommended_models: Vec<Arc<zed_llm_client::LanguageModel>>,
+ _fetch_models_task: Task<()>,
_settings_subscription: Subscription,
_llm_token_subscription: Subscription,
}
@@ -156,6 +144,72 @@ impl State {
user_store,
status,
accept_terms: None,
+ models: Vec::new(),
+ default_model: None,
+ default_fast_model: None,
+ recommended_models: Vec::new(),
+ _fetch_models_task: cx.spawn(async move |this, cx| {
+ maybe!(async move {
+ let (client, llm_api_token) = this
+ .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
+
+ loop {
+ let status = this.read_with(cx, |this, _cx| this.status)?;
+ if matches!(status, client::Status::Connected { .. }) {
+ break;
+ }
+
+ cx.background_executor()
+ .timer(Duration::from_millis(100))
+ .await;
+ }
+
+ let response = Self::fetch_models(client, llm_api_token).await?;
+ cx.update(|cx| {
+ this.update(cx, |this, cx| {
+ let mut models = Vec::new();
+
+ for model in response.models {
+ models.push(Arc::new(model.clone()));
+
+ // Right now we represent thinking variants of models as separate models on the client,
+ // so we need to insert variants for any model that supports thinking.
+ if model.supports_thinking {
+ models.push(Arc::new(zed_llm_client::LanguageModel {
+ id: zed_llm_client::LanguageModelId(
+ format!("{}-thinking", model.id).into(),
+ ),
+ display_name: format!("{} Thinking", model.display_name),
+ ..model
+ }));
+ }
+ }
+
+ this.default_model = models
+ .iter()
+ .find(|model| model.id == response.default_model)
+ .cloned();
+ this.default_fast_model = models
+ .iter()
+ .find(|model| model.id == response.default_fast_model)
+ .cloned();
+ this.recommended_models = response
+ .recommended_models
+ .iter()
+ .filter_map(|id| models.iter().find(|model| &model.id == id))
+ .cloned()
+ .collect();
+ this.models = models;
+ cx.notify();
+ })
+ })??;
+
+ anyhow::Ok(())
+ })
+ .await
+ .context("failed to fetch Zed models")
+ .log_err();
+ }),
_settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
cx.notify();
}),
@@ -208,6 +262,37 @@ impl State {
})
}));
}
+
+ async fn fetch_models(
+ client: Arc<Client>,
+ llm_api_token: LlmApiToken,
+ ) -> Result<ListModelsResponse> {
+ let http_client = &client.http_client();
+ let token = llm_api_token.acquire(&client).await?;
+
+ let request = http_client::Request::builder()
+ .method(Method::GET)
+ .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
+ .header("Authorization", format!("Bearer {token}"))
+ .body(AsyncBody::empty())?;
+ let mut response = http_client
+ .send(request)
+ .await
+ .context("failed to send list models request")?;
+
+ if response.status().is_success() {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ return Ok(serde_json::from_str(&body)?);
+ } else {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ anyhow::bail!(
+ "error listing models.\nStatus: {:?}\nBody: {body}",
+ response.status(),
+ );
+ }
+ }
}
impl CloudLanguageModelProvider {
@@ -242,11 +327,11 @@ impl CloudLanguageModelProvider {
fn create_language_model(
&self,
- model: CloudModel,
+ model: Arc<zed_llm_client::LanguageModel>,
llm_api_token: LlmApiToken,
) -> Arc<dyn LanguageModel> {
Arc::new(CloudLanguageModel {
- id: LanguageModelId::from(model.id().to_string()),
+ id: LanguageModelId(SharedString::from(model.id.0.clone())),
model,
llm_api_token: llm_api_token.clone(),
client: self.client.clone(),
@@ -277,121 +362,35 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
}
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ let default_model = self.state.read(cx).default_model.clone()?;
let llm_api_token = self.state.read(cx).llm_api_token.clone();
- let model = CloudModel::Anthropic(anthropic::Model::ClaudeSonnet4);
- Some(self.create_language_model(model, llm_api_token))
+ Some(self.create_language_model(default_model, llm_api_token))
}
fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
let llm_api_token = self.state.read(cx).llm_api_token.clone();
- let model = CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet);
- Some(self.create_language_model(model, llm_api_token))
+ Some(self.create_language_model(default_fast_model, llm_api_token))
}
fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let llm_api_token = self.state.read(cx).llm_api_token.clone();
- [
- CloudModel::Anthropic(anthropic::Model::ClaudeSonnet4),
- CloudModel::Anthropic(anthropic::Model::ClaudeSonnet4Thinking),
- ]
- .into_iter()
- .map(|model| self.create_language_model(model, llm_api_token.clone()))
- .collect()
+ self.state
+ .read(cx)
+ .recommended_models
+ .iter()
+ .cloned()
+ .map(|model| self.create_language_model(model, llm_api_token.clone()))
+ .collect()
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
- let mut models = BTreeMap::default();
-
- if cx.is_staff() {
- for model in anthropic::Model::iter() {
- if !matches!(model, anthropic::Model::Custom { .. }) {
- models.insert(model.id().to_string(), CloudModel::Anthropic(model));
- }
- }
- for model in open_ai::Model::iter() {
- if !matches!(model, open_ai::Model::Custom { .. }) {
- models.insert(model.id().to_string(), CloudModel::OpenAi(model));
- }
- }
- for model in google_ai::Model::iter() {
- if !matches!(model, google_ai::Model::Custom { .. }) {
- models.insert(model.id().to_string(), CloudModel::Google(model));
- }
- }
- } else {
- models.insert(
- anthropic::Model::Claude3_5Sonnet.id().to_string(),
- CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
- );
- models.insert(
- anthropic::Model::Claude3_7Sonnet.id().to_string(),
- CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
- );
- models.insert(
- anthropic::Model::Claude3_7SonnetThinking.id().to_string(),
- CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
- );
- models.insert(
- anthropic::Model::ClaudeSonnet4.id().to_string(),
- CloudModel::Anthropic(anthropic::Model::ClaudeSonnet4),
- );
- models.insert(
- anthropic::Model::ClaudeSonnet4Thinking.id().to_string(),
- CloudModel::Anthropic(anthropic::Model::ClaudeSonnet4Thinking),
- );
- }
-
- let llm_closed_beta_models = if cx.has_flag::<LlmClosedBetaFeatureFlag>() {
- zed_cloud_provider_additional_models()
- } else {
- &[]
- };
-
- // Override with available models from settings
- for model in AllLanguageModelSettings::get_global(cx)
- .zed_dot_dev
- .available_models
+ let llm_api_token = self.state.read(cx).llm_api_token.clone();
+ self.state
+ .read(cx)
+ .models
.iter()
- .chain(llm_closed_beta_models)
.cloned()
- {
- let model = match model.provider {
- AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
- name: model.name.clone(),
- display_name: model.display_name.clone(),
- max_tokens: model.max_tokens,
- tool_override: model.tool_override.clone(),
- cache_configuration: model.cache_configuration.as_ref().map(|config| {
- anthropic::AnthropicModelCacheConfiguration {
- max_cache_anchors: config.max_cache_anchors,
- should_speculate: config.should_speculate,
- min_total_token: config.min_total_token,
- }
- }),
- default_temperature: model.default_temperature,
- max_output_tokens: model.max_output_tokens,
- extra_beta_headers: model.extra_beta_headers.clone(),
- mode: model.mode.unwrap_or_default().into(),
- }),
- AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
- name: model.name.clone(),
- display_name: model.display_name.clone(),
- max_tokens: model.max_tokens,
- max_output_tokens: model.max_output_tokens,
- max_completion_tokens: model.max_completion_tokens,
- }),
- AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
- name: model.name.clone(),
- display_name: model.display_name.clone(),
- max_tokens: model.max_tokens,
- }),
- };
- models.insert(model.id().to_string(), model.clone());
- }
-
- let llm_api_token = self.state.read(cx).llm_api_token.clone();
- models
- .into_values()
.map(|model| self.create_language_model(model, llm_api_token.clone()))
.collect()
}
@@ -522,7 +521,7 @@ fn render_accept_terms(
pub struct CloudLanguageModel {
id: LanguageModelId,
- model: CloudModel,
+ model: Arc<zed_llm_client::LanguageModel>,
llm_api_token: LlmApiToken,
client: Arc<Client>,
request_limiter: RateLimiter,
@@ -668,7 +667,7 @@ impl LanguageModel for CloudLanguageModel {
}
fn name(&self) -> LanguageModelName {
- LanguageModelName::from(self.model.display_name().to_string())
+ LanguageModelName::from(self.model.display_name.clone())
}
fn provider_id(&self) -> LanguageModelProviderId {
@@ -680,19 +679,11 @@ impl LanguageModel for CloudLanguageModel {
}
fn supports_tools(&self) -> bool {
- match self.model {
- CloudModel::Anthropic(_) => true,
- CloudModel::Google(_) => true,
- CloudModel::OpenAi(_) => true,
- }
+ self.model.supports_tools
}
fn supports_images(&self) -> bool {
- match self.model {
- CloudModel::Anthropic(_) => true,
- CloudModel::Google(_) => true,
- CloudModel::OpenAi(_) => false,
- }
+ self.model.supports_images
}
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
@@ -703,30 +694,41 @@ impl LanguageModel for CloudLanguageModel {
}
}
+ fn supports_max_mode(&self) -> bool {
+ self.model.supports_max_mode
+ }
+
fn telemetry_id(&self) -> String {
- format!("zed.dev/{}", self.model.id())
+ format!("zed.dev/{}", self.model.id)
}
fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
- self.model.tool_input_format()
+ match self.model.provider {
+ zed_llm_client::LanguageModelProvider::Anthropic
+ | zed_llm_client::LanguageModelProvider::OpenAi => {
+ LanguageModelToolSchemaFormat::JsonSchema
+ }
+ zed_llm_client::LanguageModelProvider::Google => {
+ LanguageModelToolSchemaFormat::JsonSchemaSubset
+ }
+ }
}
fn max_token_count(&self) -> usize {
- self.model.max_token_count()
+ self.model.max_token_count
}
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
- match &self.model {
- CloudModel::Anthropic(model) => {
- model
- .cache_configuration()
- .map(|cache| LanguageModelCacheConfiguration {
- max_cache_anchors: cache.max_cache_anchors,
- should_speculate: cache.should_speculate,
- min_total_token: cache.min_total_token,
- })
+ match &self.model.provider {
+ zed_llm_client::LanguageModelProvider::Anthropic => {
+ Some(LanguageModelCacheConfiguration {
+ min_total_token: 2_048,
+ should_speculate: true,
+ max_cache_anchors: 4,
+ })
}
- CloudModel::OpenAi(_) | CloudModel::Google(_) => None,
+ zed_llm_client::LanguageModelProvider::OpenAi
+ | zed_llm_client::LanguageModelProvider::Google => None,
}
}
@@ -735,13 +737,19 @@ impl LanguageModel for CloudLanguageModel {
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<usize>> {
- match self.model.clone() {
- CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
- CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
- CloudModel::Google(model) => {
+ match self.model.provider {
+ zed_llm_client::LanguageModelProvider::Anthropic => count_anthropic_tokens(request, cx),
+ zed_llm_client::LanguageModelProvider::OpenAi => {
+ let model = match open_ai::Model::from_id(&self.model.id.0) {
+ Ok(model) => model,
+ Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
+ };
+ count_open_ai_tokens(request, model, cx)
+ }
+ zed_llm_client::LanguageModelProvider::Google => {
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
- let model_id = model.id().to_string();
+ let model_id = self.model.id.to_string();
let generate_content_request = into_google(request, model_id.clone());
async move {
let http_client = &client.http_client();
@@ -803,14 +811,20 @@ impl LanguageModel for CloudLanguageModel {
let prompt_id = request.prompt_id.clone();
let mode = request.mode;
let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
- match &self.model {
- CloudModel::Anthropic(model) => {
+ match self.model.provider {
+ zed_llm_client::LanguageModelProvider::Anthropic => {
let request = into_anthropic(
request,
- model.request_id().into(),
- model.default_temperature(),
- model.max_output_tokens(),
- model.mode(),
+ self.model.id.to_string(),
+ 1.0,
+ self.model.max_output_tokens as u32,
+ if self.model.id.0.ends_with("-thinking") {
+ AnthropicModelMode::Thinking {
+ budget_tokens: Some(4_096),
+ }
+ } else {
+ AnthropicModelMode::Default
+ },
);
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
@@ -862,9 +876,13 @@ impl LanguageModel for CloudLanguageModel {
});
async move { Ok(future.await?.boxed()) }.boxed()
}
- CloudModel::OpenAi(model) => {
+ zed_llm_client::LanguageModelProvider::OpenAi => {
let client = self.client.clone();
- let request = into_open_ai(request, model, model.max_output_tokens());
+ let model = match open_ai::Model::from_id(&self.model.id.0) {
+ Ok(model) => model,
+ Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
+ };
+ let request = into_open_ai(request, &model, None);
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move {
let PerformLlmCompletionResponse {
@@ -899,9 +917,9 @@ impl LanguageModel for CloudLanguageModel {
});
async move { Ok(future.await?.boxed()) }.boxed()
}
- CloudModel::Google(model) => {
+ zed_llm_client::LanguageModelProvider::Google => {
let client = self.client.clone();
- let request = into_google(request, model.id().into());
+ let request = into_google(request, self.model.id.to_string());
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move {
let PerformLlmCompletionResponse {