language_models: Add xAI support to Zed Cloud provider (#38928)

Marshall Bowers created

This PR adds xAI support to the Zed Cloud provider.

Release Notes:

- N/A

Change summary

crates/cloud_llm_client/src/cloud_llm_client.rs |  1 
crates/language_model/src/language_model.rs     |  3 
crates/language_models/src/provider/cloud.rs    | 64 ++++++++++++++++++
3 files changed, 67 insertions(+), 1 deletion(-)

Detailed changes

crates/language_model/src/language_model.rs 🔗

@@ -50,6 +50,9 @@ pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId
 pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
     LanguageModelProviderName::new("OpenAI");
 
+pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
+pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
+
 pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
 pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
     LanguageModelProviderName::new("Zed");

crates/language_models/src/provider/cloud.rs 🔗

@@ -46,6 +46,7 @@ use util::{ResultExt as _, maybe};
 use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
 use crate::provider::google::{GoogleEventMapper, into_google};
 use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
+use crate::provider::x_ai::count_xai_tokens;
 
 const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
 const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
@@ -579,6 +580,7 @@ impl LanguageModel for CloudLanguageModel {
             Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
             OpenAi => language_model::OPEN_AI_PROVIDER_ID,
             Google => language_model::GOOGLE_PROVIDER_ID,
+            XAi => language_model::X_AI_PROVIDER_ID,
         }
     }
 
@@ -588,6 +590,7 @@ impl LanguageModel for CloudLanguageModel {
             Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
             OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
             Google => language_model::GOOGLE_PROVIDER_NAME,
+            XAi => language_model::X_AI_PROVIDER_NAME,
         }
     }
 
@@ -618,7 +621,8 @@ impl LanguageModel for CloudLanguageModel {
     fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
         match self.model.provider {
             cloud_llm_client::LanguageModelProvider::Anthropic
-            | cloud_llm_client::LanguageModelProvider::OpenAi => {
+            | cloud_llm_client::LanguageModelProvider::OpenAi
+            | cloud_llm_client::LanguageModelProvider::XAi => {
                 LanguageModelToolSchemaFormat::JsonSchema
             }
             cloud_llm_client::LanguageModelProvider::Google => {
@@ -648,6 +652,7 @@ impl LanguageModel for CloudLanguageModel {
                 })
             }
             cloud_llm_client::LanguageModelProvider::OpenAi
+            | cloud_llm_client::LanguageModelProvider::XAi
             | cloud_llm_client::LanguageModelProvider::Google => None,
         }
     }
@@ -668,6 +673,13 @@ impl LanguageModel for CloudLanguageModel {
                 };
                 count_open_ai_tokens(request, model, cx)
             }
+            cloud_llm_client::LanguageModelProvider::XAi => {
+                let model = match x_ai::Model::from_id(&self.model.id.0) {
+                    Ok(model) => model,
+                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
+                };
+                count_xai_tokens(request, model, cx)
+            }
             cloud_llm_client::LanguageModelProvider::Google => {
                 let client = self.client.clone();
                 let llm_api_token = self.llm_api_token.clone();
@@ -845,6 +857,56 @@ impl LanguageModel for CloudLanguageModel {
                 });
                 async move { Ok(future.await?.boxed()) }.boxed()
             }
+            cloud_llm_client::LanguageModelProvider::XAi => {
+                let client = self.client.clone();
+                let model = match x_ai::Model::from_id(&self.model.id.0) {
+                    Ok(model) => model,
+                    Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(),
+                };
+                let request = into_open_ai(
+                    request,
+                    model.id(),
+                    model.supports_parallel_tool_calls(),
+                    model.supports_prompt_cache_key(),
+                    None,
+                    None,
+                );
+                let llm_api_token = self.llm_api_token.clone();
+                let future = self.request_limiter.stream(async move {
+                    let PerformLlmCompletionResponse {
+                        response,
+                        usage,
+                        includes_status_messages,
+                        tool_use_limit_reached,
+                    } = Self::perform_llm_completion(
+                        client.clone(),
+                        llm_api_token,
+                        app_version,
+                        CompletionBody {
+                            thread_id,
+                            prompt_id,
+                            intent,
+                            mode,
+                            provider: cloud_llm_client::LanguageModelProvider::XAi,
+                            model: request.model.clone(),
+                            provider_request: serde_json::to_value(&request)
+                                .map_err(|e| anyhow!(e))?,
+                        },
+                    )
+                    .await?;
+
+                    let mut mapper = OpenAiEventMapper::new();
+                    Ok(map_cloud_completion_events(
+                        Box::pin(
+                            response_lines(response, includes_status_messages)
+                                .chain(usage_updated_event(usage))
+                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
+                        ),
+                        move |event| mapper.map_event(event),
+                    ))
+                });
+                async move { Ok(future.await?.boxed()) }.boxed()
+            }
             cloud_llm_client::LanguageModelProvider::Google => {
                 let client = self.client.clone();
                 let request =