crates/cloud_llm_client/src/cloud_llm_client.rs 🔗
@@ -144,6 +144,7 @@ pub enum LanguageModelProvider {
Anthropic,
OpenAi,
Google,
+ XAi,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
Marshall Bowers created
This PR adds xAI support to the Zed Cloud provider.
Release Notes:
- N/A
crates/cloud_llm_client/src/cloud_llm_client.rs | 1
crates/language_model/src/language_model.rs | 3
crates/language_models/src/provider/cloud.rs | 64 ++++++++++++++++++
3 files changed, 67 insertions(+), 1 deletion(-)
@@ -144,6 +144,7 @@ pub enum LanguageModelProvider {
Anthropic,
OpenAi,
Google,
+ XAi,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -50,6 +50,9 @@ pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId
pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("OpenAI");
+pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
+pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
+
pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("Zed");
@@ -46,6 +46,7 @@ use util::{ResultExt as _, maybe};
use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
use crate::provider::google::{GoogleEventMapper, into_google};
use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
+use crate::provider::x_ai::count_xai_tokens;
const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
@@ -579,6 +580,7 @@ impl LanguageModel for CloudLanguageModel {
Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
OpenAi => language_model::OPEN_AI_PROVIDER_ID,
Google => language_model::GOOGLE_PROVIDER_ID,
+ XAi => language_model::X_AI_PROVIDER_ID,
}
}
@@ -588,6 +590,7 @@ impl LanguageModel for CloudLanguageModel {
Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
Google => language_model::GOOGLE_PROVIDER_NAME,
+ XAi => language_model::X_AI_PROVIDER_NAME,
}
}
@@ -618,7 +621,8 @@ impl LanguageModel for CloudLanguageModel {
fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
match self.model.provider {
cloud_llm_client::LanguageModelProvider::Anthropic
- | cloud_llm_client::LanguageModelProvider::OpenAi => {
+ | cloud_llm_client::LanguageModelProvider::OpenAi
+ | cloud_llm_client::LanguageModelProvider::XAi => {
LanguageModelToolSchemaFormat::JsonSchema
}
cloud_llm_client::LanguageModelProvider::Google => {
@@ -648,6 +652,7 @@ impl LanguageModel for CloudLanguageModel {
})
}
cloud_llm_client::LanguageModelProvider::OpenAi
+ | cloud_llm_client::LanguageModelProvider::XAi
| cloud_llm_client::LanguageModelProvider::Google => None,
}
}
@@ -668,6 +673,13 @@ impl LanguageModel for CloudLanguageModel {
};
count_open_ai_tokens(request, model, cx)
}
+ cloud_llm_client::LanguageModelProvider::XAi => {
+ let model = match x_ai::Model::from_id(&self.model.id.0) {
+ Ok(model) => model,
+ Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
+ };
+ count_xai_tokens(request, model, cx)
+ }
cloud_llm_client::LanguageModelProvider::Google => {
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
@@ -845,6 +857,56 @@ impl LanguageModel for CloudLanguageModel {
});
async move { Ok(future.await?.boxed()) }.boxed()
}
+ cloud_llm_client::LanguageModelProvider::XAi => {
+ let client = self.client.clone();
+ let model = match x_ai::Model::from_id(&self.model.id.0) {
+ Ok(model) => model,
+ Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(),
+ };
+ let request = into_open_ai(
+ request,
+ model.id(),
+ model.supports_parallel_tool_calls(),
+ model.supports_prompt_cache_key(),
+ None,
+ None,
+ );
+ let llm_api_token = self.llm_api_token.clone();
+ let future = self.request_limiter.stream(async move {
+ let PerformLlmCompletionResponse {
+ response,
+ usage,
+ includes_status_messages,
+ tool_use_limit_reached,
+ } = Self::perform_llm_completion(
+ client.clone(),
+ llm_api_token,
+ app_version,
+ CompletionBody {
+ thread_id,
+ prompt_id,
+ intent,
+ mode,
+ provider: cloud_llm_client::LanguageModelProvider::XAi,
+ model: request.model.clone(),
+ provider_request: serde_json::to_value(&request)
+ .map_err(|e| anyhow!(e))?,
+ },
+ )
+ .await?;
+
+ let mut mapper = OpenAiEventMapper::new();
+ Ok(map_cloud_completion_events(
+ Box::pin(
+ response_lines(response, includes_status_messages)
+ .chain(usage_updated_event(usage))
+ .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
+ ),
+ move |event| mapper.map_event(event),
+ ))
+ });
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
cloud_llm_client::LanguageModelProvider::Google => {
let client = self.client.clone();
let request =