cloud_llm_client.rs

  1use std::str::FromStr;
  2use std::sync::Arc;
  3
  4use anyhow::Context as _;
  5use serde::{Deserialize, Serialize};
  6use strum::{Display, EnumIter, EnumString};
  7use uuid::Uuid;
  8
  9/// The name of the header used to indicate which version of Zed the client is running.
 10pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
 11
 12/// The name of the header used to indicate when a request failed due to an
 13/// expired LLM token.
 14///
 15/// The client may use this as a signal to refresh the token.
 16pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
 17
 18/// The name of the header used to indicate what plan the user is currently on.
 19pub const CURRENT_PLAN_HEADER_NAME: &str = "x-zed-plan";
 20
 21/// The name of the header used to indicate the usage limit for model requests.
 22pub const MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-model-requests-usage-limit";
 23
 24/// The name of the header used to indicate the usage amount for model requests.
 25pub const MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-model-requests-usage-amount";
 26
 27/// The name of the header used to indicate the usage limit for edit predictions.
 28pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
 29
 30/// The name of the header used to indicate the usage amount for edit predictions.
 31pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
 32
 33/// The name of the header used to indicate the resource for which the subscription limit has been reached.
 34pub const SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME: &str = "x-zed-subscription-limit-resource";
 35
 36pub const MODEL_REQUESTS_RESOURCE_HEADER_VALUE: &str = "model_requests";
 37pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
 38
 39/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached.
 40pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached";
 41
 42/// The name of the header used to indicate the the minimum required Zed version.
 43///
 44/// This can be used to force a Zed upgrade in order to continue communicating
 45/// with the LLM service.
 46pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
 47
 48/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
 49pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 50    "x-zed-client-supports-status-messages";
 51
 52/// The name of the header used by the server to indicate to the client that it supports sending status messages.
 53pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 54    "x-zed-server-supports-status-messages";
 55
 56#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 57#[serde(rename_all = "snake_case")]
 58pub enum UsageLimit {
 59    Limited(i32),
 60    Unlimited,
 61}
 62
 63impl FromStr for UsageLimit {
 64    type Err = anyhow::Error;
 65
 66    fn from_str(value: &str) -> Result<Self, Self::Err> {
 67        match value {
 68            "unlimited" => Ok(Self::Unlimited),
 69            limit => limit
 70                .parse::<i32>()
 71                .map(Self::Limited)
 72                .context("failed to parse limit"),
 73        }
 74    }
 75}
 76
 77#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
 78#[serde(rename_all = "snake_case")]
 79pub enum Plan {
 80    #[default]
 81    #[serde(alias = "Free")]
 82    ZedFree,
 83    #[serde(alias = "ZedPro")]
 84    ZedPro,
 85    #[serde(alias = "ZedProTrial")]
 86    ZedProTrial,
 87}
 88
 89impl Plan {
 90    pub fn as_str(&self) -> &'static str {
 91        match self {
 92            Plan::ZedFree => "zed_free",
 93            Plan::ZedPro => "zed_pro",
 94            Plan::ZedProTrial => "zed_pro_trial",
 95        }
 96    }
 97
 98    pub fn model_requests_limit(&self) -> UsageLimit {
 99        match self {
100            Plan::ZedPro => UsageLimit::Limited(500),
101            Plan::ZedProTrial => UsageLimit::Limited(150),
102            Plan::ZedFree => UsageLimit::Limited(50),
103        }
104    }
105
106    pub fn edit_predictions_limit(&self) -> UsageLimit {
107        match self {
108            Plan::ZedPro => UsageLimit::Unlimited,
109            Plan::ZedProTrial => UsageLimit::Unlimited,
110            Plan::ZedFree => UsageLimit::Limited(2_000),
111        }
112    }
113}
114
115impl FromStr for Plan {
116    type Err = anyhow::Error;
117
118    fn from_str(value: &str) -> Result<Self, Self::Err> {
119        match value {
120            "zed_free" => Ok(Plan::ZedFree),
121            "zed_pro" => Ok(Plan::ZedPro),
122            "zed_pro_trial" => Ok(Plan::ZedProTrial),
123            plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
124        }
125    }
126}
127
128#[derive(
129    Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
130)]
131#[serde(rename_all = "snake_case")]
132#[strum(serialize_all = "snake_case")]
133pub enum LanguageModelProvider {
134    Anthropic,
135    OpenAi,
136    Google,
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct PredictEditsBody {
141    pub input_events: String,
142    pub input_excerpt: String,
143    /// Whether the user provided consent for sampling this interaction.
144    #[serde(default, alias = "data_collection_permission")]
145    pub can_collect_data: bool,
146    /// Note that this is no longer sent, in favor of `PredictEditsAdditionalContext`.
147    #[serde(skip_serializing_if = "Option::is_none", default)]
148    pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
149    /// Info about the git repository state, only present when can_collect_data is true. Note that
150    /// this is no longer sent, in favor of `PredictEditsAdditionalContext`.
151    #[serde(skip_serializing_if = "Option::is_none", default)]
152    pub git_info: Option<PredictEditsGitInfo>,
153}
154
155/// Additional context only stored when can_collect_data is true for the corresponding edit
156/// predictions request.
157#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct PredictEditsAdditionalContext {
159    /// Path to the file in the repository that contains the input excerpt.
160    pub input_path: String,
161    /// Cursor position within the file that contains the input excerpt.
162    pub cursor_point: Point,
163    /// Cursor offset in bytes within the file that contains the input excerpt.
164    pub cursor_offset: usize,
165    #[serde(flatten)]
166    pub git_info: PredictEditsGitInfo,
167    /// Diagnostic near the cursor position.
168    #[serde(skip_serializing_if = "Vec::is_empty", default)]
169    pub diagnostic_groups: Vec<(String, Box<serde_json::value::RawValue>)>,
170    /// True if the diagnostics were truncated.
171    pub diagnostic_groups_truncated: bool,
172    /// Recently active files that may be within this repository.
173    #[serde(skip_serializing_if = "Vec::is_empty", default)]
174    pub recent_files: Vec<PredictEditsRecentFile>,
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct PredictEditsGitInfo {
179    /// SHA of git HEAD commit at time of prediction.
180    #[serde(skip_serializing_if = "Option::is_none", default)]
181    pub head_sha: Option<String>,
182    /// URL of the remote called `origin`.
183    #[serde(skip_serializing_if = "Option::is_none", default)]
184    pub remote_origin_url: Option<String>,
185    /// URL of the remote called `upstream`.
186    #[serde(skip_serializing_if = "Option::is_none", default)]
187    pub remote_upstream_url: Option<String>,
188}
189
190/// A zero-indexed point in a text buffer consisting of a row and column.
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct Point {
193    pub row: u32,
194    pub column: u32,
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct PredictEditsRecentFile {
199    /// Path to a file within the repository.
200    pub path: String,
201    /// Most recent cursor position with the file.
202    pub cursor_point: Point,
203    /// Milliseconds between the editor for this file being active and the request time.
204    pub active_to_now_ms: u32,
205    /// Number of times the editor for this file was activated.
206    pub activation_count: u32,
207    /// Rough estimate of milliseconds the user was editing the file.
208    pub cumulative_time_editing_ms: u32,
209    /// Rough estimate of milliseconds the user was navigating within the file.
210    pub cumulative_time_navigating_ms: u32,
211    /// Whether the file is a multibuffer.
212    #[serde(skip_serializing_if = "is_default", default)]
213    pub is_multibuffer: bool,
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct PredictEditsResponse {
218    pub request_id: Uuid,
219    pub output_excerpt: String,
220}
221
222#[derive(Debug, Clone, Serialize, Deserialize)]
223pub struct AcceptEditPredictionBody {
224    pub request_id: Uuid,
225}
226
227#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
228#[serde(rename_all = "snake_case")]
229pub enum CompletionMode {
230    Normal,
231    Max,
232}
233
234#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
235#[serde(rename_all = "snake_case")]
236pub enum CompletionIntent {
237    UserPrompt,
238    ToolResults,
239    ThreadSummarization,
240    ThreadContextSummarization,
241    CreateFile,
242    EditFile,
243    InlineAssist,
244    TerminalInlineAssist,
245    GenerateGitCommitMessage,
246}
247
248#[derive(Debug, Serialize, Deserialize)]
249pub struct CompletionBody {
250    #[serde(skip_serializing_if = "Option::is_none", default)]
251    pub thread_id: Option<String>,
252    #[serde(skip_serializing_if = "Option::is_none", default)]
253    pub prompt_id: Option<String>,
254    #[serde(skip_serializing_if = "Option::is_none", default)]
255    pub intent: Option<CompletionIntent>,
256    #[serde(skip_serializing_if = "Option::is_none", default)]
257    pub mode: Option<CompletionMode>,
258    pub provider: LanguageModelProvider,
259    pub model: String,
260    pub provider_request: serde_json::Value,
261}
262
263#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
264#[serde(rename_all = "snake_case")]
265pub enum CompletionRequestStatus {
266    Queued {
267        position: usize,
268    },
269    Started,
270    Failed {
271        code: String,
272        message: String,
273        request_id: Uuid,
274        /// Retry duration in seconds.
275        retry_after: Option<f64>,
276    },
277    UsageUpdated {
278        amount: usize,
279        limit: UsageLimit,
280    },
281    ToolUseLimitReached,
282}
283
284#[derive(Serialize, Deserialize)]
285#[serde(rename_all = "snake_case")]
286pub enum CompletionEvent<T> {
287    Status(CompletionRequestStatus),
288    Event(T),
289}
290
291impl<T> CompletionEvent<T> {
292    pub fn into_status(self) -> Option<CompletionRequestStatus> {
293        match self {
294            Self::Status(status) => Some(status),
295            Self::Event(_) => None,
296        }
297    }
298
299    pub fn into_event(self) -> Option<T> {
300        match self {
301            Self::Event(event) => Some(event),
302            Self::Status(_) => None,
303        }
304    }
305}
306
307#[derive(Serialize, Deserialize)]
308pub struct WebSearchBody {
309    pub query: String,
310}
311
312#[derive(Debug, Serialize, Deserialize, Clone)]
313pub struct WebSearchResponse {
314    pub results: Vec<WebSearchResult>,
315}
316
317#[derive(Debug, Serialize, Deserialize, Clone)]
318pub struct WebSearchResult {
319    pub title: String,
320    pub url: String,
321    pub text: String,
322}
323
324#[derive(Serialize, Deserialize)]
325pub struct CountTokensBody {
326    pub provider: LanguageModelProvider,
327    pub model: String,
328    pub provider_request: serde_json::Value,
329}
330
331#[derive(Serialize, Deserialize)]
332pub struct CountTokensResponse {
333    pub tokens: usize,
334}
335
336#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
337pub struct LanguageModelId(pub Arc<str>);
338
339impl std::fmt::Display for LanguageModelId {
340    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341        write!(f, "{}", self.0)
342    }
343}
344
345#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
346pub struct LanguageModel {
347    pub provider: LanguageModelProvider,
348    pub id: LanguageModelId,
349    pub display_name: String,
350    pub max_token_count: usize,
351    pub max_token_count_in_max_mode: Option<usize>,
352    pub max_output_tokens: usize,
353    pub supports_tools: bool,
354    pub supports_images: bool,
355    pub supports_thinking: bool,
356    pub supports_max_mode: bool,
357}
358
359#[derive(Debug, Serialize, Deserialize)]
360pub struct ListModelsResponse {
361    pub models: Vec<LanguageModel>,
362    pub default_model: LanguageModelId,
363    pub default_fast_model: LanguageModelId,
364    pub recommended_models: Vec<LanguageModelId>,
365}
366
367#[derive(Debug, Serialize, Deserialize)]
368pub struct GetSubscriptionResponse {
369    pub plan: Plan,
370    pub usage: Option<CurrentUsage>,
371}
372
373#[derive(Debug, PartialEq, Serialize, Deserialize)]
374pub struct CurrentUsage {
375    pub model_requests: UsageData,
376    pub edit_predictions: UsageData,
377}
378
379#[derive(Debug, PartialEq, Serialize, Deserialize)]
380pub struct UsageData {
381    pub used: u32,
382    pub limit: UsageLimit,
383}
384
385fn is_default<T: Default + PartialEq>(value: &T) -> bool {
386    *value == T::default()
387}
388
389#[cfg(test)]
390mod tests {
391    use pretty_assertions::assert_eq;
392    use serde_json::json;
393
394    use super::*;
395
396    #[test]
397    fn test_plan_deserialize_snake_case() {
398        let plan = serde_json::from_value::<Plan>(json!("zed_free")).unwrap();
399        assert_eq!(plan, Plan::ZedFree);
400
401        let plan = serde_json::from_value::<Plan>(json!("zed_pro")).unwrap();
402        assert_eq!(plan, Plan::ZedPro);
403
404        let plan = serde_json::from_value::<Plan>(json!("zed_pro_trial")).unwrap();
405        assert_eq!(plan, Plan::ZedProTrial);
406    }
407
408    #[test]
409    fn test_plan_deserialize_aliases() {
410        let plan = serde_json::from_value::<Plan>(json!("Free")).unwrap();
411        assert_eq!(plan, Plan::ZedFree);
412
413        let plan = serde_json::from_value::<Plan>(json!("ZedPro")).unwrap();
414        assert_eq!(plan, Plan::ZedPro);
415
416        let plan = serde_json::from_value::<Plan>(json!("ZedProTrial")).unwrap();
417        assert_eq!(plan, Plan::ZedProTrial);
418    }
419
420    #[test]
421    fn test_usage_limit_from_str() {
422        let limit = UsageLimit::from_str("unlimited").unwrap();
423        assert!(matches!(limit, UsageLimit::Unlimited));
424
425        let limit = UsageLimit::from_str(&0.to_string()).unwrap();
426        assert!(matches!(limit, UsageLimit::Limited(0)));
427
428        let limit = UsageLimit::from_str(&50.to_string()).unwrap();
429        assert!(matches!(limit, UsageLimit::Limited(50)));
430
431        for value in ["not_a_number", "50xyz"] {
432            let limit = UsageLimit::from_str(value);
433            assert!(limit.is_err());
434        }
435    }
436}