Add `cloud_llm_client` crate (#35307)

Marshall Bowers created

This PR adds a `cloud_llm_client` crate to take the place of the
`zed_llm_client`.

Release Notes:

- N/A

Change summary

Cargo.lock                                      |  13 
Cargo.toml                                      |   3 
crates/cloud_llm_client/Cargo.toml              |  23 +
crates/cloud_llm_client/LICENSE-APACHE          |   1 
crates/cloud_llm_client/src/cloud_llm_client.rs | 370 +++++++++++++++++++
5 files changed, 409 insertions(+), 1 deletion(-)

Detailed changes

Cargo.lock 🔗

@@ -3031,6 +3031,19 @@ dependencies = [
  "workspace-hack",
 ]
 
+[[package]]
+name = "cloud_llm_client"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "pretty_assertions",
+ "serde",
+ "serde_json",
+ "strum 0.27.1",
+ "uuid",
+ "workspace-hack",
+]
+
 [[package]]
 name = "clru"
 version = "0.6.2"

Cargo.toml 🔗

@@ -29,6 +29,7 @@ members = [
     "crates/cli",
     "crates/client",
     "crates/clock",
+    "crates/cloud_llm_client",
     "crates/collab",
     "crates/collab_ui",
     "crates/collections",
@@ -70,7 +71,6 @@ members = [
     "crates/gpui",
     "crates/gpui_macros",
     "crates/gpui_tokio",
-
     "crates/html_to_markdown",
     "crates/http_client",
     "crates/http_client_tls",
@@ -251,6 +251,7 @@ channel = { path = "crates/channel" }
 cli = { path = "crates/cli" }
 client = { path = "crates/client" }
 clock = { path = "crates/clock" }
+cloud_llm_client = { path = "crates/cloud_llm_client" }
 collab = { path = "crates/collab" }
 collab_ui = { path = "crates/collab_ui" }
 collections = { path = "crates/collections" }

crates/cloud_llm_client/Cargo.toml 🔗

@@ -0,0 +1,23 @@
+[package]
+name = "cloud_llm_client"
+version = "0.1.0"
+publish.workspace = true
+edition.workspace = true
+license = "Apache-2.0"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/cloud_llm_client.rs"
+
+[dependencies]
+anyhow.workspace = true
+serde = { workspace = true, features = ["derive", "rc"] }
+serde_json.workspace = true
+strum = { workspace = true, features = ["derive"] }
+uuid = { workspace = true, features = ["serde"] }
+workspace-hack.workspace = true
+
+[dev-dependencies]
+pretty_assertions.workspace = true

crates/cloud_llm_client/src/cloud_llm_client.rs 🔗

@@ -0,0 +1,370 @@
+use std::str::FromStr;
+use std::sync::Arc;
+
+use anyhow::Context as _;
+use serde::{Deserialize, Serialize};
+use strum::{Display, EnumIter, EnumString};
+use uuid::Uuid;
+
+/// The name of the header used to indicate which version of Zed the client is running.
+pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
+
+/// The name of the header used to indicate when a request failed due to an
+/// expired LLM token.
+///
+/// The client may use this as a signal to refresh the token.
+pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
+
+/// The name of the header used to indicate what plan the user is currently on.
+pub const CURRENT_PLAN_HEADER_NAME: &str = "x-zed-plan";
+
+/// The name of the header used to indicate the usage limit for model requests.
+pub const MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-model-requests-usage-limit";
+
+/// The name of the header used to indicate the usage amount for model requests.
+pub const MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-model-requests-usage-amount";
+
+/// The name of the header used to indicate the usage limit for edit predictions.
+pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
+
+/// The name of the header used to indicate the usage amount for edit predictions.
+pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
+
+/// The name of the header used to indicate the resource for which the subscription limit has been reached.
+pub const SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME: &str = "x-zed-subscription-limit-resource";
+
+pub const MODEL_REQUESTS_RESOURCE_HEADER_VALUE: &str = "model_requests";
+pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
+
+/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached.
+pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached";
+
+/// The name of the header used to indicate the the minimum required Zed version.
+///
+/// This can be used to force a Zed upgrade in order to continue communicating
+/// with the LLM service.
+pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
+
+/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
+pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
+    "x-zed-client-supports-status-messages";
+
+/// The name of the header used by the server to indicate to the client that it supports sending status messages.
+pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
+    "x-zed-server-supports-status-messages";
+
+#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
+#[serde(rename_all = "snake_case")]
+pub enum UsageLimit {
+    Limited(i32),
+    Unlimited,
+}
+
+impl FromStr for UsageLimit {
+    type Err = anyhow::Error;
+
+    fn from_str(value: &str) -> Result<Self, Self::Err> {
+        match value {
+            "unlimited" => Ok(Self::Unlimited),
+            limit => limit
+                .parse::<i32>()
+                .map(Self::Limited)
+                .context("failed to parse limit"),
+        }
+    }
+}
+
+#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
+#[serde(rename_all = "snake_case")]
+pub enum Plan {
+    #[default]
+    #[serde(alias = "Free")]
+    ZedFree,
+    #[serde(alias = "ZedPro")]
+    ZedPro,
+    #[serde(alias = "ZedProTrial")]
+    ZedProTrial,
+}
+
+impl Plan {
+    pub fn as_str(&self) -> &'static str {
+        match self {
+            Plan::ZedFree => "zed_free",
+            Plan::ZedPro => "zed_pro",
+            Plan::ZedProTrial => "zed_pro_trial",
+        }
+    }
+
+    pub fn model_requests_limit(&self) -> UsageLimit {
+        match self {
+            Plan::ZedPro => UsageLimit::Limited(500),
+            Plan::ZedProTrial => UsageLimit::Limited(150),
+            Plan::ZedFree => UsageLimit::Limited(50),
+        }
+    }
+
+    pub fn edit_predictions_limit(&self) -> UsageLimit {
+        match self {
+            Plan::ZedPro => UsageLimit::Unlimited,
+            Plan::ZedProTrial => UsageLimit::Unlimited,
+            Plan::ZedFree => UsageLimit::Limited(2_000),
+        }
+    }
+}
+
+impl FromStr for Plan {
+    type Err = anyhow::Error;
+
+    fn from_str(value: &str) -> Result<Self, Self::Err> {
+        match value {
+            "zed_free" => Ok(Plan::ZedFree),
+            "zed_pro" => Ok(Plan::ZedPro),
+            "zed_pro_trial" => Ok(Plan::ZedProTrial),
+            plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
+        }
+    }
+}
+
+#[derive(
+    Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
+)]
+#[serde(rename_all = "snake_case")]
+#[strum(serialize_all = "snake_case")]
+pub enum LanguageModelProvider {
+    Anthropic,
+    OpenAi,
+    Google,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct PredictEditsBody {
+    #[serde(skip_serializing_if = "Option::is_none", default)]
+    pub outline: Option<String>,
+    pub input_events: String,
+    pub input_excerpt: String,
+    #[serde(skip_serializing_if = "Option::is_none", default)]
+    pub speculated_output: Option<String>,
+    /// Whether the user provided consent for sampling this interaction.
+    #[serde(default, alias = "data_collection_permission")]
+    pub can_collect_data: bool,
+    #[serde(skip_serializing_if = "Option::is_none", default)]
+    pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct PredictEditsResponse {
+    pub request_id: Uuid,
+    pub output_excerpt: String,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct AcceptEditPredictionBody {
+    pub request_id: Uuid,
+}
+
+#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
+#[serde(rename_all = "snake_case")]
+pub enum CompletionMode {
+    Normal,
+    Max,
+}
+
+#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
+#[serde(rename_all = "snake_case")]
+pub enum CompletionIntent {
+    UserPrompt,
+    ToolResults,
+    ThreadSummarization,
+    ThreadContextSummarization,
+    CreateFile,
+    EditFile,
+    InlineAssist,
+    TerminalInlineAssist,
+    GenerateGitCommitMessage,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct CompletionBody {
+    #[serde(skip_serializing_if = "Option::is_none", default)]
+    pub thread_id: Option<String>,
+    #[serde(skip_serializing_if = "Option::is_none", default)]
+    pub prompt_id: Option<String>,
+    #[serde(skip_serializing_if = "Option::is_none", default)]
+    pub intent: Option<CompletionIntent>,
+    #[serde(skip_serializing_if = "Option::is_none", default)]
+    pub mode: Option<CompletionMode>,
+    pub provider: LanguageModelProvider,
+    pub model: String,
+    pub provider_request: serde_json::Value,
+}
+
+#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
+#[serde(rename_all = "snake_case")]
+pub enum CompletionRequestStatus {
+    Queued {
+        position: usize,
+    },
+    Started,
+    Failed {
+        code: String,
+        message: String,
+        request_id: Uuid,
+        /// Retry duration in seconds.
+        retry_after: Option<f64>,
+    },
+    UsageUpdated {
+        amount: usize,
+        limit: UsageLimit,
+    },
+    ToolUseLimitReached,
+}
+
+#[derive(Serialize, Deserialize)]
+#[serde(rename_all = "snake_case")]
+pub enum CompletionEvent<T> {
+    Status(CompletionRequestStatus),
+    Event(T),
+}
+
+impl<T> CompletionEvent<T> {
+    pub fn into_status(self) -> Option<CompletionRequestStatus> {
+        match self {
+            Self::Status(status) => Some(status),
+            Self::Event(_) => None,
+        }
+    }
+
+    pub fn into_event(self) -> Option<T> {
+        match self {
+            Self::Event(event) => Some(event),
+            Self::Status(_) => None,
+        }
+    }
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct WebSearchBody {
+    pub query: String,
+}
+
+#[derive(Serialize, Deserialize, Clone)]
+pub struct WebSearchResponse {
+    pub results: Vec<WebSearchResult>,
+}
+
+#[derive(Serialize, Deserialize, Clone)]
+pub struct WebSearchResult {
+    pub title: String,
+    pub url: String,
+    pub text: String,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct CountTokensBody {
+    pub provider: LanguageModelProvider,
+    pub model: String,
+    pub provider_request: serde_json::Value,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct CountTokensResponse {
+    pub tokens: usize,
+}
+
+#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
+pub struct LanguageModelId(pub Arc<str>);
+
+impl std::fmt::Display for LanguageModelId {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        write!(f, "{}", self.0)
+    }
+}
+
+#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
+pub struct LanguageModel {
+    pub provider: LanguageModelProvider,
+    pub id: LanguageModelId,
+    pub display_name: String,
+    pub max_token_count: usize,
+    pub max_token_count_in_max_mode: Option<usize>,
+    pub max_output_tokens: usize,
+    pub supports_tools: bool,
+    pub supports_images: bool,
+    pub supports_thinking: bool,
+    pub supports_max_mode: bool,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct ListModelsResponse {
+    pub models: Vec<LanguageModel>,
+    pub default_model: LanguageModelId,
+    pub default_fast_model: LanguageModelId,
+    pub recommended_models: Vec<LanguageModelId>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct GetSubscriptionResponse {
+    pub plan: Plan,
+    pub usage: Option<CurrentUsage>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct CurrentUsage {
+    pub model_requests: UsageData,
+    pub edit_predictions: UsageData,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct UsageData {
+    pub used: u32,
+    pub limit: UsageLimit,
+}
+
+#[cfg(test)]
+mod tests {
+    use pretty_assertions::assert_eq;
+    use serde_json::json;
+
+    use super::*;
+
+    #[test]
+    fn test_plan_deserialize_snake_case() {
+        let plan = serde_json::from_value::<Plan>(json!("zed_free")).unwrap();
+        assert_eq!(plan, Plan::ZedFree);
+
+        let plan = serde_json::from_value::<Plan>(json!("zed_pro")).unwrap();
+        assert_eq!(plan, Plan::ZedPro);
+
+        let plan = serde_json::from_value::<Plan>(json!("zed_pro_trial")).unwrap();
+        assert_eq!(plan, Plan::ZedProTrial);
+    }
+
+    #[test]
+    fn test_plan_deserialize_aliases() {
+        let plan = serde_json::from_value::<Plan>(json!("Free")).unwrap();
+        assert_eq!(plan, Plan::ZedFree);
+
+        let plan = serde_json::from_value::<Plan>(json!("ZedPro")).unwrap();
+        assert_eq!(plan, Plan::ZedPro);
+
+        let plan = serde_json::from_value::<Plan>(json!("ZedProTrial")).unwrap();
+        assert_eq!(plan, Plan::ZedProTrial);
+    }
+
+    #[test]
+    fn test_usage_limit_from_str() {
+        let limit = UsageLimit::from_str("unlimited").unwrap();
+        assert!(matches!(limit, UsageLimit::Unlimited));
+
+        let limit = UsageLimit::from_str(&0.to_string()).unwrap();
+        assert!(matches!(limit, UsageLimit::Limited(0)));
+
+        let limit = UsageLimit::from_str(&50.to_string()).unwrap();
+        assert!(matches!(limit, UsageLimit::Limited(50)));
+
+        for value in ["not_a_number", "50xyz"] {
+            let limit = UsageLimit::from_str(value);
+            assert!(limit.is_err());
+        }
+    }
+}