From 4f91fab1900efd8a00582cb8b8a6c629bed015c6 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Thu, 25 Sep 2025 23:19:12 -0400 Subject: [PATCH] language_models: Add xAI support to Zed Cloud provider (#38928) This PR adds xAI support to the Zed Cloud provider. Release Notes: - N/A --- .../cloud_llm_client/src/cloud_llm_client.rs | 1 + crates/language_model/src/language_model.rs | 3 + crates/language_models/src/provider/cloud.rs | 64 ++++++++++++++++++- 3 files changed, 67 insertions(+), 1 deletion(-) diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs index e0cc42af76156466c31ead17d6421f3634d3ad7c..aeace10e6d1e10233484a6bda1aa64db7aa7b79a 100644 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -144,6 +144,7 @@ pub enum LanguageModelProvider { Anthropic, OpenAi, Google, + XAi, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 418a864648e10bf36073f00392a189f45d474534..38f2b0959072599900cb8a13c16f4e2f8e9c55db 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -50,6 +50,9 @@ pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenAI"); +pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai"); +pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI"); + pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev"); pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Zed"); diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index cf73da02c68b168097c3a4690849a42d5b993bc7..c6567fe01551397807c08e0e907fcab8c1b4c211 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -46,6 +46,7 @@ use util::{ResultExt as _, maybe}; use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic}; use crate::provider::google::{GoogleEventMapper, into_google}; use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai}; +use crate::provider::x_ai::count_xai_tokens; const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID; const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME; @@ -579,6 +580,7 @@ impl LanguageModel for CloudLanguageModel { Anthropic => language_model::ANTHROPIC_PROVIDER_ID, OpenAi => language_model::OPEN_AI_PROVIDER_ID, Google => language_model::GOOGLE_PROVIDER_ID, + XAi => language_model::X_AI_PROVIDER_ID, } } @@ -588,6 +590,7 @@ impl LanguageModel for CloudLanguageModel { Anthropic => language_model::ANTHROPIC_PROVIDER_NAME, OpenAi => language_model::OPEN_AI_PROVIDER_NAME, Google => language_model::GOOGLE_PROVIDER_NAME, + XAi => language_model::X_AI_PROVIDER_NAME, } } @@ -618,7 +621,8 @@ impl LanguageModel for CloudLanguageModel { fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { match self.model.provider { cloud_llm_client::LanguageModelProvider::Anthropic - | cloud_llm_client::LanguageModelProvider::OpenAi => { + | cloud_llm_client::LanguageModelProvider::OpenAi + | cloud_llm_client::LanguageModelProvider::XAi => { LanguageModelToolSchemaFormat::JsonSchema } cloud_llm_client::LanguageModelProvider::Google => { @@ -648,6 +652,7 @@ impl LanguageModel for CloudLanguageModel { }) } cloud_llm_client::LanguageModelProvider::OpenAi + | cloud_llm_client::LanguageModelProvider::XAi | cloud_llm_client::LanguageModelProvider::Google => None, } } @@ -668,6 +673,13 @@ impl LanguageModel for CloudLanguageModel { }; count_open_ai_tokens(request, model, cx) } + cloud_llm_client::LanguageModelProvider::XAi => { + let model = match x_ai::Model::from_id(&self.model.id.0) { + Ok(model) => model, + Err(err) => return async move { Err(anyhow!(err)) }.boxed(), + }; + count_xai_tokens(request, model, cx) + } cloud_llm_client::LanguageModelProvider::Google => { let client = self.client.clone(); let llm_api_token = self.llm_api_token.clone(); @@ -845,6 +857,56 @@ impl LanguageModel for CloudLanguageModel { }); async move { Ok(future.await?.boxed()) }.boxed() } + cloud_llm_client::LanguageModelProvider::XAi => { + let client = self.client.clone(); + let model = match x_ai::Model::from_id(&self.model.id.0) { + Ok(model) => model, + Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(), + }; + let request = into_open_ai( + request, + model.id(), + model.supports_parallel_tool_calls(), + model.supports_prompt_cache_key(), + None, + None, + ); + let llm_api_token = self.llm_api_token.clone(); + let future = self.request_limiter.stream(async move { + let PerformLlmCompletionResponse { + response, + usage, + includes_status_messages, + tool_use_limit_reached, + } = Self::perform_llm_completion( + client.clone(), + llm_api_token, + app_version, + CompletionBody { + thread_id, + prompt_id, + intent, + mode, + provider: cloud_llm_client::LanguageModelProvider::XAi, + model: request.model.clone(), + provider_request: serde_json::to_value(&request) + .map_err(|e| anyhow!(e))?, + }, + ) + .await?; + + let mut mapper = OpenAiEventMapper::new(); + Ok(map_cloud_completion_events( + Box::pin( + response_lines(response, includes_status_messages) + .chain(usage_updated_event(usage)) + .chain(tool_use_limit_reached_event(tool_use_limit_reached)), + ), + move |event| mapper.map_event(event), + )) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } cloud_llm_client::LanguageModelProvider::Google => { let client = self.client.clone(); let request =