cloud_llm_client.rs

  1pub mod predict_edits_v3;
  2
  3use std::str::FromStr;
  4use std::sync::Arc;
  5
  6use anyhow::Context as _;
  7use serde::{Deserialize, Serialize};
  8use strum::{Display, EnumIter, EnumString};
  9use uuid::Uuid;
 10
 11/// The name of the header used to indicate which version of Zed the client is running.
 12pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
 13
 14/// The name of the header used to indicate when a request failed due to an
 15/// expired LLM token.
 16///
 17/// The client may use this as a signal to refresh the token.
 18pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
 19
 20/// The name of the header used to indicate what plan the user is currently on.
 21pub const CURRENT_PLAN_HEADER_NAME: &str = "x-zed-plan";
 22
 23/// The name of the header used to indicate the usage limit for model requests.
 24pub const MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-model-requests-usage-limit";
 25
 26/// The name of the header used to indicate the usage amount for model requests.
 27pub const MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-model-requests-usage-amount";
 28
 29/// The name of the header used to indicate the usage limit for edit predictions.
 30pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
 31
 32/// The name of the header used to indicate the usage amount for edit predictions.
 33pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
 34
 35/// The name of the header used to indicate the resource for which the subscription limit has been reached.
 36pub const SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME: &str = "x-zed-subscription-limit-resource";
 37
 38pub const MODEL_REQUESTS_RESOURCE_HEADER_VALUE: &str = "model_requests";
 39pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
 40
 41/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached.
 42pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached";
 43
 44/// The name of the header used to indicate the minimum required Zed version.
 45///
 46/// This can be used to force a Zed upgrade in order to continue communicating
 47/// with the LLM service.
 48pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
 49
 50/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
 51pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 52    "x-zed-client-supports-status-messages";
 53
 54/// The name of the header used by the server to indicate to the client that it supports sending status messages.
 55pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 56    "x-zed-server-supports-status-messages";
 57
 58/// The name of the header used by the client to indicate that it supports receiving xAI models.
 59pub const CLIENT_SUPPORTS_X_AI_HEADER_NAME: &str = "x-zed-client-supports-x-ai";
 60
 61/// The maximum number of edit predictions that can be rejected per request.
 62pub const MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST: usize = 100;
 63
 64#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 65#[serde(rename_all = "snake_case")]
 66pub enum UsageLimit {
 67    Limited(i32),
 68    Unlimited,
 69}
 70
 71impl FromStr for UsageLimit {
 72    type Err = anyhow::Error;
 73
 74    fn from_str(value: &str) -> Result<Self, Self::Err> {
 75        match value {
 76            "unlimited" => Ok(Self::Unlimited),
 77            limit => limit
 78                .parse::<i32>()
 79                .map(Self::Limited)
 80                .context("failed to parse limit"),
 81        }
 82    }
 83}
 84
 85#[derive(Debug, Clone, Copy, PartialEq)]
 86pub enum Plan {
 87    V1(PlanV1),
 88    V2(PlanV2),
 89}
 90
 91impl Plan {
 92    pub fn is_v2(&self) -> bool {
 93        matches!(self, Self::V2(_))
 94    }
 95}
 96
 97#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
 98#[serde(rename_all = "snake_case")]
 99pub enum PlanV1 {
100    #[default]
101    #[serde(alias = "Free")]
102    ZedFree,
103    #[serde(alias = "ZedPro")]
104    ZedPro,
105    #[serde(alias = "ZedProTrial")]
106    ZedProTrial,
107}
108
109impl FromStr for PlanV1 {
110    type Err = anyhow::Error;
111
112    fn from_str(value: &str) -> Result<Self, Self::Err> {
113        match value {
114            "zed_free" => Ok(Self::ZedFree),
115            "zed_pro" => Ok(Self::ZedPro),
116            "zed_pro_trial" => Ok(Self::ZedProTrial),
117            plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
118        }
119    }
120}
121
122#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
123#[serde(rename_all = "snake_case")]
124pub enum PlanV2 {
125    #[default]
126    ZedFree,
127    ZedPro,
128    ZedProTrial,
129}
130
131impl FromStr for PlanV2 {
132    type Err = anyhow::Error;
133
134    fn from_str(value: &str) -> Result<Self, Self::Err> {
135        match value {
136            "zed_free" => Ok(Self::ZedFree),
137            "zed_pro" => Ok(Self::ZedPro),
138            "zed_pro_trial" => Ok(Self::ZedProTrial),
139            plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
140        }
141    }
142}
143
144#[derive(
145    Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
146)]
147#[serde(rename_all = "snake_case")]
148#[strum(serialize_all = "snake_case")]
149pub enum LanguageModelProvider {
150    Anthropic,
151    OpenAi,
152    Google,
153    XAi,
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct PredictEditsBody {
158    #[serde(skip_serializing_if = "Option::is_none", default)]
159    pub outline: Option<String>,
160    pub input_events: String,
161    pub input_excerpt: String,
162    #[serde(skip_serializing_if = "Option::is_none", default)]
163    pub speculated_output: Option<String>,
164    /// Whether the user provided consent for sampling this interaction.
165    #[serde(default, alias = "data_collection_permission")]
166    pub can_collect_data: bool,
167    #[serde(skip_serializing_if = "Option::is_none", default)]
168    pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
169    /// Info about the git repository state, only present when can_collect_data is true.
170    #[serde(skip_serializing_if = "Option::is_none", default)]
171    pub git_info: Option<PredictEditsGitInfo>,
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct PredictEditsGitInfo {
176    /// SHA of git HEAD commit at time of prediction.
177    #[serde(skip_serializing_if = "Option::is_none", default)]
178    pub head_sha: Option<String>,
179    /// URL of the remote called `origin`.
180    #[serde(skip_serializing_if = "Option::is_none", default)]
181    pub remote_origin_url: Option<String>,
182    /// URL of the remote called `upstream`.
183    #[serde(skip_serializing_if = "Option::is_none", default)]
184    pub remote_upstream_url: Option<String>,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct PredictEditsResponse {
189    pub request_id: String,
190    pub output_excerpt: String,
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct AcceptEditPredictionBody {
195    pub request_id: String,
196}
197
198#[derive(Debug, Clone, Deserialize)]
199pub struct RejectEditPredictionsBody {
200    pub rejections: Vec<EditPredictionRejection>,
201}
202
203#[derive(Debug, Clone, Serialize)]
204pub struct RejectEditPredictionsBodyRef<'a> {
205    pub rejections: &'a [EditPredictionRejection],
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
209pub struct EditPredictionRejection {
210    pub request_id: String,
211    #[serde(default)]
212    pub reason: EditPredictionRejectReason,
213    pub was_shown: bool,
214}
215
216#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
217pub enum EditPredictionRejectReason {
218    /// New requests were triggered before this one completed
219    Canceled,
220    /// No edits returned
221    Empty,
222    /// Edits returned, but none remained after interpolation
223    InterpolatedEmpty,
224    /// The new prediction was preferred over the current one
225    Replaced,
226    /// The current prediction was preferred over the new one
227    CurrentPreferred,
228    /// The current prediction was discarded
229    #[default]
230    Discarded,
231}
232
233#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
234#[serde(rename_all = "snake_case")]
235pub enum CompletionMode {
236    Normal,
237    Max,
238}
239
240#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
241#[serde(rename_all = "snake_case")]
242pub enum CompletionIntent {
243    UserPrompt,
244    ToolResults,
245    ThreadSummarization,
246    ThreadContextSummarization,
247    CreateFile,
248    EditFile,
249    InlineAssist,
250    TerminalInlineAssist,
251    GenerateGitCommitMessage,
252}
253
254#[derive(Debug, Serialize, Deserialize)]
255pub struct CompletionBody {
256    #[serde(skip_serializing_if = "Option::is_none", default)]
257    pub thread_id: Option<String>,
258    #[serde(skip_serializing_if = "Option::is_none", default)]
259    pub prompt_id: Option<String>,
260    #[serde(skip_serializing_if = "Option::is_none", default)]
261    pub intent: Option<CompletionIntent>,
262    #[serde(skip_serializing_if = "Option::is_none", default)]
263    pub mode: Option<CompletionMode>,
264    pub provider: LanguageModelProvider,
265    pub model: String,
266    pub provider_request: serde_json::Value,
267}
268
269#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
270#[serde(rename_all = "snake_case")]
271pub enum CompletionRequestStatus {
272    Queued {
273        position: usize,
274    },
275    Started,
276    Failed {
277        code: String,
278        message: String,
279        request_id: Uuid,
280        /// Retry duration in seconds.
281        retry_after: Option<f64>,
282    },
283    UsageUpdated {
284        amount: usize,
285        limit: UsageLimit,
286    },
287    ToolUseLimitReached,
288}
289
290#[derive(Serialize, Deserialize)]
291#[serde(rename_all = "snake_case")]
292pub enum CompletionEvent<T> {
293    Status(CompletionRequestStatus),
294    Event(T),
295}
296
297impl<T> CompletionEvent<T> {
298    pub fn into_status(self) -> Option<CompletionRequestStatus> {
299        match self {
300            Self::Status(status) => Some(status),
301            Self::Event(_) => None,
302        }
303    }
304
305    pub fn into_event(self) -> Option<T> {
306        match self {
307            Self::Event(event) => Some(event),
308            Self::Status(_) => None,
309        }
310    }
311}
312
313#[derive(Serialize, Deserialize)]
314pub struct WebSearchBody {
315    pub query: String,
316}
317
318#[derive(Debug, Serialize, Deserialize, Clone)]
319pub struct WebSearchResponse {
320    pub results: Vec<WebSearchResult>,
321}
322
323#[derive(Debug, Serialize, Deserialize, Clone)]
324pub struct WebSearchResult {
325    pub title: String,
326    pub url: String,
327    pub text: String,
328}
329
330#[derive(Serialize, Deserialize)]
331pub struct CountTokensBody {
332    pub provider: LanguageModelProvider,
333    pub model: String,
334    pub provider_request: serde_json::Value,
335}
336
337#[derive(Serialize, Deserialize)]
338pub struct CountTokensResponse {
339    pub tokens: usize,
340}
341
342#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
343pub struct LanguageModelId(pub Arc<str>);
344
345impl std::fmt::Display for LanguageModelId {
346    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
347        write!(f, "{}", self.0)
348    }
349}
350
351#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
352pub struct LanguageModel {
353    pub provider: LanguageModelProvider,
354    pub id: LanguageModelId,
355    pub display_name: String,
356    pub max_token_count: usize,
357    pub max_token_count_in_max_mode: Option<usize>,
358    pub max_output_tokens: usize,
359    pub supports_tools: bool,
360    pub supports_images: bool,
361    pub supports_thinking: bool,
362    pub supports_max_mode: bool,
363    // only used by OpenAI and xAI
364    #[serde(default)]
365    pub supports_parallel_tool_calls: bool,
366}
367
368#[derive(Debug, Serialize, Deserialize)]
369pub struct ListModelsResponse {
370    pub models: Vec<LanguageModel>,
371    pub default_model: Option<LanguageModelId>,
372    pub default_fast_model: Option<LanguageModelId>,
373    pub recommended_models: Vec<LanguageModelId>,
374}
375
376#[derive(Debug, Serialize, Deserialize)]
377pub struct GetSubscriptionResponse {
378    pub plan: PlanV1,
379    pub usage: Option<CurrentUsage>,
380}
381
382#[derive(Debug, PartialEq, Serialize, Deserialize)]
383pub struct CurrentUsage {
384    pub model_requests: UsageData,
385    pub edit_predictions: UsageData,
386}
387
388#[derive(Debug, PartialEq, Serialize, Deserialize)]
389pub struct UsageData {
390    pub used: u32,
391    pub limit: UsageLimit,
392}
393
394#[cfg(test)]
395mod tests {
396    use pretty_assertions::assert_eq;
397    use serde_json::json;
398
399    use super::*;
400
401    #[test]
402    fn test_plan_v1_deserialize_snake_case() {
403        let plan = serde_json::from_value::<PlanV1>(json!("zed_free")).unwrap();
404        assert_eq!(plan, PlanV1::ZedFree);
405
406        let plan = serde_json::from_value::<PlanV1>(json!("zed_pro")).unwrap();
407        assert_eq!(plan, PlanV1::ZedPro);
408
409        let plan = serde_json::from_value::<PlanV1>(json!("zed_pro_trial")).unwrap();
410        assert_eq!(plan, PlanV1::ZedProTrial);
411    }
412
413    #[test]
414    fn test_plan_v1_deserialize_aliases() {
415        let plan = serde_json::from_value::<PlanV1>(json!("Free")).unwrap();
416        assert_eq!(plan, PlanV1::ZedFree);
417
418        let plan = serde_json::from_value::<PlanV1>(json!("ZedPro")).unwrap();
419        assert_eq!(plan, PlanV1::ZedPro);
420
421        let plan = serde_json::from_value::<PlanV1>(json!("ZedProTrial")).unwrap();
422        assert_eq!(plan, PlanV1::ZedProTrial);
423    }
424
425    #[test]
426    fn test_plan_v2_deserialize_snake_case() {
427        let plan = serde_json::from_value::<PlanV2>(json!("zed_free")).unwrap();
428        assert_eq!(plan, PlanV2::ZedFree);
429
430        let plan = serde_json::from_value::<PlanV2>(json!("zed_pro")).unwrap();
431        assert_eq!(plan, PlanV2::ZedPro);
432
433        let plan = serde_json::from_value::<PlanV2>(json!("zed_pro_trial")).unwrap();
434        assert_eq!(plan, PlanV2::ZedProTrial);
435    }
436
437    #[test]
438    fn test_usage_limit_from_str() {
439        let limit = UsageLimit::from_str("unlimited").unwrap();
440        assert!(matches!(limit, UsageLimit::Unlimited));
441
442        let limit = UsageLimit::from_str(&0.to_string()).unwrap();
443        assert!(matches!(limit, UsageLimit::Limited(0)));
444
445        let limit = UsageLimit::from_str(&50.to_string()).unwrap();
446        assert!(matches!(limit, UsageLimit::Limited(50)));
447
448        for value in ["not_a_number", "50xyz"] {
449            let limit = UsageLimit::from_str(value);
450            assert!(limit.is_err());
451        }
452    }
453}