cloud_llm_client.rs

  1pub mod predict_edits_v3;
  2
  3use std::str::FromStr;
  4use std::sync::Arc;
  5
  6use anyhow::Context as _;
  7use serde::{Deserialize, Serialize};
  8use strum::{Display, EnumIter, EnumString};
  9use uuid::Uuid;
 10
 11/// The name of the header used to indicate which version of Zed the client is running.
 12pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
 13
 14/// The name of the header used to indicate when a request failed due to an
 15/// expired LLM token.
 16///
 17/// The client may use this as a signal to refresh the token.
 18pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
 19
 20/// The name of the header used to indicate when a request failed due to an outdated LLM token.
 21///
 22/// A token is considered "outdated" when we can't parse the claims (e.g., after adding a new required claim).
 23///
 24/// This is distinct from [`EXPIRED_LLM_TOKEN_HEADER_NAME`] which indicates the token's time-based validity has passed.
 25/// An outdated token means the token's structure is incompatible with the current server expectations.
 26pub const OUTDATED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-outdated-token";
 27
 28/// The name of the header used to indicate the usage limit for edit predictions.
 29pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
 30
 31/// The name of the header used to indicate the usage amount for edit predictions.
 32pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
 33
 34pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
 35
 36/// The name of the header used to indicate the minimum required Zed version.
 37///
 38/// This can be used to force a Zed upgrade in order to continue communicating
 39/// with the LLM service.
 40pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
 41
 42/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
 43pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 44    "x-zed-client-supports-status-messages";
 45
 46/// The name of the header used by the server to indicate to the client that it supports sending status messages.
 47pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 48    "x-zed-server-supports-status-messages";
 49
 50/// The name of the header used by the client to indicate that it supports receiving xAI models.
 51pub const CLIENT_SUPPORTS_X_AI_HEADER_NAME: &str = "x-zed-client-supports-x-ai";
 52
 53/// The maximum number of edit predictions that can be rejected per request.
 54pub const MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST: usize = 100;
 55
 56#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 57#[serde(rename_all = "snake_case")]
 58pub enum UsageLimit {
 59    Limited(i32),
 60    Unlimited,
 61}
 62
 63impl FromStr for UsageLimit {
 64    type Err = anyhow::Error;
 65
 66    fn from_str(value: &str) -> Result<Self, Self::Err> {
 67        match value {
 68            "unlimited" => Ok(Self::Unlimited),
 69            limit => limit
 70                .parse::<i32>()
 71                .map(Self::Limited)
 72                .context("failed to parse limit"),
 73        }
 74    }
 75}
 76
 77#[derive(Debug, Clone, Copy, PartialEq)]
 78pub enum Plan {
 79    V2(PlanV2),
 80}
 81
 82impl Plan {
 83    pub fn is_v2(&self) -> bool {
 84        matches!(self, Self::V2(_))
 85    }
 86}
 87
 88#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
 89#[serde(rename_all = "snake_case")]
 90pub enum PlanV2 {
 91    #[default]
 92    ZedFree,
 93    ZedPro,
 94    ZedProTrial,
 95}
 96
 97impl FromStr for PlanV2 {
 98    type Err = anyhow::Error;
 99
100    fn from_str(value: &str) -> Result<Self, Self::Err> {
101        match value {
102            "zed_free" => Ok(Self::ZedFree),
103            "zed_pro" => Ok(Self::ZedPro),
104            "zed_pro_trial" => Ok(Self::ZedProTrial),
105            plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
106        }
107    }
108}
109
110#[derive(
111    Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
112)]
113#[serde(rename_all = "snake_case")]
114#[strum(serialize_all = "snake_case")]
115pub enum LanguageModelProvider {
116    Anthropic,
117    OpenAi,
118    Google,
119    XAi,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct PredictEditsBody {
124    #[serde(skip_serializing_if = "Option::is_none", default)]
125    pub outline: Option<String>,
126    pub input_events: String,
127    pub input_excerpt: String,
128    #[serde(skip_serializing_if = "Option::is_none", default)]
129    pub speculated_output: Option<String>,
130    /// Whether the user provided consent for sampling this interaction.
131    #[serde(default, alias = "data_collection_permission")]
132    pub can_collect_data: bool,
133    #[serde(skip_serializing_if = "Option::is_none", default)]
134    pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
135    /// Info about the git repository state, only present when can_collect_data is true.
136    #[serde(skip_serializing_if = "Option::is_none", default)]
137    pub git_info: Option<PredictEditsGitInfo>,
138    /// The trigger for this request.
139    #[serde(default)]
140    pub trigger: PredictEditsRequestTrigger,
141}
142
143#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)]
144pub enum PredictEditsRequestTrigger {
145    Diagnostics,
146    Cli,
147    #[default]
148    Other,
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct PredictEditsGitInfo {
153    /// SHA of git HEAD commit at time of prediction.
154    #[serde(skip_serializing_if = "Option::is_none", default)]
155    pub head_sha: Option<String>,
156    /// URL of the remote called `origin`.
157    #[serde(skip_serializing_if = "Option::is_none", default)]
158    pub remote_origin_url: Option<String>,
159    /// URL of the remote called `upstream`.
160    #[serde(skip_serializing_if = "Option::is_none", default)]
161    pub remote_upstream_url: Option<String>,
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct PredictEditsResponse {
166    pub request_id: String,
167    pub output_excerpt: String,
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct AcceptEditPredictionBody {
172    pub request_id: String,
173}
174
175#[derive(Debug, Clone, Deserialize)]
176pub struct RejectEditPredictionsBody {
177    pub rejections: Vec<EditPredictionRejection>,
178}
179
180#[derive(Debug, Clone, Serialize)]
181pub struct RejectEditPredictionsBodyRef<'a> {
182    pub rejections: &'a [EditPredictionRejection],
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
186pub struct EditPredictionRejection {
187    pub request_id: String,
188    #[serde(default)]
189    pub reason: EditPredictionRejectReason,
190    pub was_shown: bool,
191}
192
193#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
194pub enum EditPredictionRejectReason {
195    /// New requests were triggered before this one completed
196    Canceled,
197    /// No edits returned
198    Empty,
199    /// Edits returned, but none remained after interpolation
200    InterpolatedEmpty,
201    /// The new prediction was preferred over the current one
202    Replaced,
203    /// The current prediction was preferred over the new one
204    CurrentPreferred,
205    /// The current prediction was discarded
206    #[default]
207    Discarded,
208}
209
210#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
211#[serde(rename_all = "snake_case")]
212pub enum CompletionIntent {
213    UserPrompt,
214    ToolResults,
215    ThreadSummarization,
216    ThreadContextSummarization,
217    CreateFile,
218    EditFile,
219    InlineAssist,
220    TerminalInlineAssist,
221    GenerateGitCommitMessage,
222}
223
224#[derive(Debug, Serialize, Deserialize)]
225pub struct CompletionBody {
226    #[serde(skip_serializing_if = "Option::is_none", default)]
227    pub thread_id: Option<String>,
228    #[serde(skip_serializing_if = "Option::is_none", default)]
229    pub prompt_id: Option<String>,
230    #[serde(skip_serializing_if = "Option::is_none", default)]
231    pub intent: Option<CompletionIntent>,
232    pub provider: LanguageModelProvider,
233    pub model: String,
234    pub provider_request: serde_json::Value,
235}
236
237#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
238#[serde(rename_all = "snake_case")]
239pub enum CompletionRequestStatus {
240    Queued {
241        position: usize,
242    },
243    Started,
244    Failed {
245        code: String,
246        message: String,
247        request_id: Uuid,
248        /// Retry duration in seconds.
249        retry_after: Option<f64>,
250    },
251    UsageUpdated {
252        amount: usize,
253        limit: UsageLimit,
254    },
255    ToolUseLimitReached,
256}
257
258#[derive(Serialize, Deserialize)]
259#[serde(rename_all = "snake_case")]
260pub enum CompletionEvent<T> {
261    Status(CompletionRequestStatus),
262    Event(T),
263}
264
265impl<T> CompletionEvent<T> {
266    pub fn into_status(self) -> Option<CompletionRequestStatus> {
267        match self {
268            Self::Status(status) => Some(status),
269            Self::Event(_) => None,
270        }
271    }
272
273    pub fn into_event(self) -> Option<T> {
274        match self {
275            Self::Event(event) => Some(event),
276            Self::Status(_) => None,
277        }
278    }
279}
280
281#[derive(Serialize, Deserialize)]
282pub struct WebSearchBody {
283    pub query: String,
284}
285
286#[derive(Debug, Serialize, Deserialize, Clone)]
287pub struct WebSearchResponse {
288    pub results: Vec<WebSearchResult>,
289}
290
291#[derive(Debug, Serialize, Deserialize, Clone)]
292pub struct WebSearchResult {
293    pub title: String,
294    pub url: String,
295    pub text: String,
296}
297
298#[derive(Serialize, Deserialize)]
299pub struct CountTokensBody {
300    pub provider: LanguageModelProvider,
301    pub model: String,
302    pub provider_request: serde_json::Value,
303}
304
305#[derive(Serialize, Deserialize)]
306pub struct CountTokensResponse {
307    pub tokens: usize,
308}
309
310#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
311pub struct LanguageModelId(pub Arc<str>);
312
313impl std::fmt::Display for LanguageModelId {
314    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315        write!(f, "{}", self.0)
316    }
317}
318
319#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
320pub struct LanguageModel {
321    pub provider: LanguageModelProvider,
322    pub id: LanguageModelId,
323    pub display_name: String,
324    pub max_token_count: usize,
325    pub max_token_count_in_max_mode: Option<usize>,
326    pub max_output_tokens: usize,
327    pub supports_tools: bool,
328    pub supports_images: bool,
329    pub supports_thinking: bool,
330    #[serde(default)]
331    pub supports_streaming_tools: bool,
332    /// Only used by OpenAI and xAI.
333    #[serde(default)]
334    pub supports_parallel_tool_calls: bool,
335}
336
337#[derive(Debug, Serialize, Deserialize)]
338pub struct ListModelsResponse {
339    pub models: Vec<LanguageModel>,
340    pub default_model: Option<LanguageModelId>,
341    pub default_fast_model: Option<LanguageModelId>,
342    pub recommended_models: Vec<LanguageModelId>,
343}
344
345#[derive(Debug, PartialEq, Serialize, Deserialize)]
346pub struct CurrentUsage {
347    pub edit_predictions: UsageData,
348}
349
350#[derive(Debug, PartialEq, Serialize, Deserialize)]
351pub struct UsageData {
352    pub used: u32,
353    pub limit: UsageLimit,
354}
355
356#[cfg(test)]
357mod tests {
358    use pretty_assertions::assert_eq;
359    use serde_json::json;
360
361    use super::*;
362
363    #[test]
364    fn test_plan_v2_deserialize_snake_case() {
365        let plan = serde_json::from_value::<PlanV2>(json!("zed_free")).unwrap();
366        assert_eq!(plan, PlanV2::ZedFree);
367
368        let plan = serde_json::from_value::<PlanV2>(json!("zed_pro")).unwrap();
369        assert_eq!(plan, PlanV2::ZedPro);
370
371        let plan = serde_json::from_value::<PlanV2>(json!("zed_pro_trial")).unwrap();
372        assert_eq!(plan, PlanV2::ZedProTrial);
373    }
374
375    #[test]
376    fn test_usage_limit_from_str() {
377        let limit = UsageLimit::from_str("unlimited").unwrap();
378        assert!(matches!(limit, UsageLimit::Unlimited));
379
380        let limit = UsageLimit::from_str(&0.to_string()).unwrap();
381        assert!(matches!(limit, UsageLimit::Limited(0)));
382
383        let limit = UsageLimit::from_str(&50.to_string()).unwrap();
384        assert!(matches!(limit, UsageLimit::Limited(50)));
385
386        for value in ["not_a_number", "50xyz"] {
387            let limit = UsageLimit::from_str(value);
388            assert!(limit.is_err());
389        }
390    }
391}