From fcb1efdf21ecda17a2029542d8959a11ef67aa23 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Wed, 16 Apr 2025 16:22:44 -0400 Subject: [PATCH] rpc: Remove `llm` module in favor of `zed_llm_client` (#28900) This PR removes the `llm` module of the `rpc` crate in favor of using the types from the `zed_llm_client`. Release Notes: - N/A --- Cargo.lock | 43 ++++++++++--------- Cargo.toml | 2 +- crates/collab/Cargo.toml | 1 + crates/collab/src/api/billing.rs | 6 ++- crates/collab/src/llm/db.rs | 2 +- .../collab/src/llm/db/tests/provider_tests.rs | 2 +- crates/language_models/src/provider/cloud.rs | 37 +++++++--------- crates/rpc/src/llm.rs | 35 --------------- crates/rpc/src/rpc.rs | 2 - 9 files changed, 45 insertions(+), 85 deletions(-) delete mode 100644 crates/rpc/src/llm.rs diff --git a/Cargo.lock b/Cargo.lock index c49a71c5eb4dd6ebf3c6b27d66f3bbd15dd70fec..52398176d7b2d4c241784646fe184171c22a9e63 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -324,7 +324,7 @@ dependencies = [ "schemars", "serde", "serde_json", - "strum 0.26.3", + "strum 0.27.1", "thiserror 2.0.12", "workspace-hack", ] @@ -567,7 +567,7 @@ dependencies = [ "settings", "smallvec", "smol", - "strum 0.26.3", + "strum 0.27.1", "telemetry_events", "text", "theme", @@ -1884,7 +1884,7 @@ dependencies = [ "schemars", "serde", "serde_json", - "strum 0.26.3", + "strum 0.27.1", "thiserror 2.0.12", "tokio", "workspace-hack", @@ -3031,7 +3031,7 @@ dependencies = [ "settings", "sha2", "sqlx", - "strum 0.26.3", + "strum 0.27.1", "subtle", "supermaven_api", "telemetry_events", @@ -3051,6 +3051,7 @@ dependencies = [ "workspace", "workspace-hack", "worktree", + "zed_llm_client", ] [[package]] @@ -3363,7 +3364,7 @@ dependencies = [ "serde", "serde_json", "settings", - "strum 0.26.3", + "strum 0.27.1", "task", "theme", "ui", @@ -5125,7 +5126,7 @@ dependencies = [ "serde", "settings", "smallvec", - "strum 0.26.3", + "strum 0.27.1", "telemetry", "theme", "ui", @@ -5976,7 +5977,7 @@ dependencies = [ "serde_derive", "serde_json", "settings", - "strum 0.26.3", + "strum 0.27.1", "telemetry", "theme", "time", @@ -6069,7 +6070,7 @@ dependencies = [ "schemars", "serde", "serde_json", - "strum 0.26.3", + "strum 0.27.1", "workspace-hack", ] @@ -6175,7 +6176,7 @@ dependencies = [ "slotmap", "smallvec", "smol", - "strum 0.26.3", + "strum 0.27.1", "sum_tree", "taffy", "thiserror 2.0.12", @@ -6823,7 +6824,7 @@ name = "icons" version = "0.1.0" dependencies = [ "serde", - "strum 0.26.3", + "strum 0.27.1", "workspace-hack", ] @@ -7091,7 +7092,7 @@ dependencies = [ "paths", "pretty_assertions", "serde", - "strum 0.26.3", + "strum 0.27.1", "util", "workspace-hack", ] @@ -7677,7 +7678,7 @@ dependencies = [ "serde", "serde_json", "smol", - "strum 0.26.3", + "strum 0.27.1", "telemetry_events", "thiserror 2.0.12", "util", @@ -7737,7 +7738,7 @@ dependencies = [ "serde_json", "settings", "smol", - "strum 0.26.3", + "strum 0.27.1", "theme", "thiserror 2.0.12", "tiktoken-rs", @@ -8710,7 +8711,7 @@ dependencies = [ "schemars", "serde", "serde_json", - "strum 0.26.3", + "strum 0.27.1", "workspace-hack", ] @@ -9557,7 +9558,7 @@ dependencies = [ "schemars", "serde", "serde_json", - "strum 0.26.3", + "strum 0.27.1", "workspace-hack", ] @@ -12136,7 +12137,7 @@ dependencies = [ "serde", "serde_json", "sha2", - "strum 0.26.3", + "strum 0.27.1", "tracing", "util", "workspace-hack", @@ -13709,7 +13710,7 @@ dependencies = [ "settings", "simplelog", "story", - "strum 0.26.3", + "strum 0.27.1", "theme", "title_bar", "ui", @@ -14444,7 +14445,7 @@ dependencies = [ "serde_json_lenient", "serde_repr", "settings", - "strum 0.26.3", + "strum 0.27.1", "thiserror 2.0.12", "util", "uuid", @@ -14478,7 +14479,7 @@ dependencies = [ "serde_json", "serde_json_lenient", "simplelog", - "strum 0.26.3", + "strum 0.27.1", "theme", "vscode_theme", "workspace-hack", @@ -15479,7 +15480,7 @@ dependencies = [ "settings", "smallvec", "story", - "strum 0.26.3", + "strum 0.27.1", "theme", "ui_macros", "util", @@ -17680,7 +17681,7 @@ dependencies = [ "settings", "smallvec", "sqlez", - "strum 0.26.3", + "strum 0.27.1", "task", "telemetry", "tempfile", diff --git a/Cargo.toml b/Cargo.toml index b0028553379853faa910589034993f2e8fe7dfff..099ea3ddad35f740e7f1437ccf3d81747ce56192 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -540,7 +540,7 @@ smol = "2.0" sqlformat = "0.2" streaming-iterator = "0.1" strsim = "0.11" -strum = { version = "0.26.0", features = ["derive"] } +strum = { version = "0.27.0", features = ["derive"] } subtle = "2.5.0" syn = { version = "1.0.72", features = ["full", "extra-traits"] } sys-locale = "0.3.1" diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index c4aa90e2c2041ee402286aa662bd561cf7d66afe..4d6787f8577ff7a5599794095a710784b69d6e32 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -75,6 +75,7 @@ tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json", "re util.workspace = true uuid.workspace = true workspace-hack.workspace = true +zed_llm_client.workspace = true [dev-dependencies] assistant = { workspace = true, features = ["test-support"] } diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 36843ced568a5f1fb0139e6f4ef669136f28142d..bc2b508e844023ab2548016f9f24482cfe5cf757 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -330,8 +330,10 @@ async fn create_billing_subscription( .await? } None => { - let default_model = - llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-7-sonnet")?; + let default_model = llm_db.model( + zed_llm_client::LanguageModelProvider::Anthropic, + "claude-3-7-sonnet", + )?; let stripe_model = stripe_billing.register_model(default_model).await?; stripe_billing .checkout(customer_id, &user.github_login, &stripe_model, &success_url) diff --git a/crates/collab/src/llm/db.rs b/crates/collab/src/llm/db.rs index f56e9e61e3fcd4c4b06f4a2ba1342c1ae86517d0..e445450ff42336edfb644b949f71ff539df8a3ea 100644 --- a/crates/collab/src/llm/db.rs +++ b/crates/collab/src/llm/db.rs @@ -8,9 +8,9 @@ mod tests; use collections::HashMap; pub use ids::*; -use rpc::LanguageModelProvider; pub use seed::*; pub use tables::*; +use zed_llm_client::LanguageModelProvider; #[cfg(test)] pub use tests::TestLlmDb; diff --git a/crates/collab/src/llm/db/tests/provider_tests.rs b/crates/collab/src/llm/db/tests/provider_tests.rs index 0bb55ee4b69a6c571582e9b9b0bc0f4ea5161993..7d52964b939e7b17ca8ec9f986756c00bd0dad55 100644 --- a/crates/collab/src/llm/db/tests/provider_tests.rs +++ b/crates/collab/src/llm/db/tests/provider_tests.rs @@ -1,5 +1,5 @@ use pretty_assertions::assert_eq; -use rpc::LanguageModelProvider; +use zed_llm_client::LanguageModelProvider; use crate::llm::db::LlmDatabase; use crate::test_llm_db; diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 32dc5f3f99384f4314d1d95206413b5643051eec..ee4c8540d8d7d0c460bbc2d77f74f86fe25dd4f5 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -1,9 +1,6 @@ use anthropic::{AnthropicError, AnthropicModelMode, parse_prompt_too_long}; use anyhow::{Result, anyhow}; -use client::{ - Client, EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, - PerformCompletionParams, UserStore, zed_urls, -}; +use client::{Client, UserStore, zed_urls}; use collections::BTreeMap; use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro}; use futures::{ @@ -26,7 +23,6 @@ use language_model::{ use proto::Plan; use schemars::JsonSchema; use serde::{Deserialize, Serialize, de::DeserializeOwned}; -use serde_json::value::RawValue; use settings::{Settings, SettingsStore}; use smol::Timer; use smol::io::{AsyncReadExt, BufReader}; @@ -38,7 +34,10 @@ use std::{ use strum::IntoEnumIterator; use thiserror::Error; use ui::{TintColor, prelude::*}; -use zed_llm_client::{CURRENT_PLAN_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME}; +use zed_llm_client::{ + CURRENT_PLAN_HEADER_NAME, CompletionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, + MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, +}; use crate::AllLanguageModelSettings; use crate::provider::anthropic::{count_anthropic_tokens, into_anthropic}; @@ -517,7 +516,7 @@ impl CloudLanguageModel { async fn perform_llm_completion( client: Arc, llm_api_token: LlmApiToken, - body: PerformCompletionParams, + body: CompletionBody, ) -> Result> { let http_client = &client.http_client(); @@ -724,12 +723,10 @@ impl LanguageModel for CloudLanguageModel { let response = Self::perform_llm_completion( client.clone(), llm_api_token, - PerformCompletionParams { - provider: client::LanguageModelProvider::Anthropic, + CompletionBody { + provider: zed_llm_client::LanguageModelProvider::Anthropic, model: request.model.clone(), - provider_request: RawValue::from_string(serde_json::to_string( - &request, - )?)?, + provider_request: serde_json::to_value(&request)?, }, ) .await @@ -765,12 +762,10 @@ impl LanguageModel for CloudLanguageModel { let response = Self::perform_llm_completion( client.clone(), llm_api_token, - PerformCompletionParams { - provider: client::LanguageModelProvider::OpenAi, + CompletionBody { + provider: zed_llm_client::LanguageModelProvider::OpenAi, model: request.model.clone(), - provider_request: RawValue::from_string(serde_json::to_string( - &request, - )?)?, + provider_request: serde_json::to_value(&request)?, }, ) .await?; @@ -790,12 +785,10 @@ impl LanguageModel for CloudLanguageModel { let response = Self::perform_llm_completion( client.clone(), llm_api_token, - PerformCompletionParams { - provider: client::LanguageModelProvider::Google, + CompletionBody { + provider: zed_llm_client::LanguageModelProvider::Google, model: request.model.clone(), - provider_request: RawValue::from_string(serde_json::to_string( - &request, - )?)?, + provider_request: serde_json::to_value(&request)?, }, ) .await?; diff --git a/crates/rpc/src/llm.rs b/crates/rpc/src/llm.rs deleted file mode 100644 index 0a7510d891d3522c8794fb106fe168df10fc5aab..0000000000000000000000000000000000000000 --- a/crates/rpc/src/llm.rs +++ /dev/null @@ -1,35 +0,0 @@ -use serde::{Deserialize, Serialize}; -use strum::{Display, EnumIter, EnumString}; - -pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token"; - -pub const MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME: &str = "x-zed-llm-max-monthly-spend-reached"; - -#[derive( - Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display, -)] -#[serde(rename_all = "snake_case")] -#[strum(serialize_all = "snake_case")] -pub enum LanguageModelProvider { - Anthropic, - OpenAi, - Google, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct LanguageModel { - pub provider: LanguageModelProvider, - pub name: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ListModelsResponse { - pub models: Vec, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct PerformCompletionParams { - pub provider: LanguageModelProvider, - pub model: String, - pub provider_request: Box, -} diff --git a/crates/rpc/src/rpc.rs b/crates/rpc/src/rpc.rs index c10ee9d6c85904be50b3dd72e5483d1dd192f36d..ad1ebb757c22783658d0ece51156b3864234aba1 100644 --- a/crates/rpc/src/rpc.rs +++ b/crates/rpc/src/rpc.rs @@ -1,14 +1,12 @@ pub mod auth; mod conn; mod extension; -mod llm; mod message_stream; mod notification; mod peer; pub use conn::Connection; pub use extension::*; -pub use llm::*; pub use notification::*; pub use peer::*; pub use proto;