cloud_llm_client.rs

  1pub mod predict_edits_v3;
  2pub mod udiff;
  3
  4use std::str::FromStr;
  5use std::sync::Arc;
  6
  7use anyhow::Context as _;
  8use serde::{Deserialize, Serialize};
  9use strum::{Display, EnumIter, EnumString};
 10use uuid::Uuid;
 11
 12/// The name of the header used to indicate which version of Zed the client is running.
 13pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
 14
 15/// The name of the header used to indicate when a request failed due to an
 16/// expired LLM token.
 17///
 18/// The client may use this as a signal to refresh the token.
 19pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
 20
 21/// The name of the header used to indicate what plan the user is currently on.
 22pub const CURRENT_PLAN_HEADER_NAME: &str = "x-zed-plan";
 23
 24/// The name of the header used to indicate the usage limit for model requests.
 25pub const MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-model-requests-usage-limit";
 26
 27/// The name of the header used to indicate the usage amount for model requests.
 28pub const MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-model-requests-usage-amount";
 29
 30/// The name of the header used to indicate the usage limit for edit predictions.
 31pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
 32
 33/// The name of the header used to indicate the usage amount for edit predictions.
 34pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
 35
 36/// The name of the header used to indicate the resource for which the subscription limit has been reached.
 37pub const SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME: &str = "x-zed-subscription-limit-resource";
 38
 39pub const MODEL_REQUESTS_RESOURCE_HEADER_VALUE: &str = "model_requests";
 40pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
 41
 42/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached.
 43pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached";
 44
 45/// The name of the header used to indicate the minimum required Zed version.
 46///
 47/// This can be used to force a Zed upgrade in order to continue communicating
 48/// with the LLM service.
 49pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
 50
 51/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
 52pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 53    "x-zed-client-supports-status-messages";
 54
 55/// The name of the header used by the server to indicate to the client that it supports sending status messages.
 56pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 57    "x-zed-server-supports-status-messages";
 58
 59/// The name of the header used by the client to indicate that it supports receiving xAI models.
 60pub const CLIENT_SUPPORTS_X_AI_HEADER_NAME: &str = "x-zed-client-supports-x-ai";
 61
 62#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 63#[serde(rename_all = "snake_case")]
 64pub enum UsageLimit {
 65    Limited(i32),
 66    Unlimited,
 67}
 68
 69impl FromStr for UsageLimit {
 70    type Err = anyhow::Error;
 71
 72    fn from_str(value: &str) -> Result<Self, Self::Err> {
 73        match value {
 74            "unlimited" => Ok(Self::Unlimited),
 75            limit => limit
 76                .parse::<i32>()
 77                .map(Self::Limited)
 78                .context("failed to parse limit"),
 79        }
 80    }
 81}
 82
 83#[derive(Debug, Clone, Copy, PartialEq)]
 84pub enum Plan {
 85    V1(PlanV1),
 86    V2(PlanV2),
 87}
 88
 89impl Plan {
 90    pub fn is_v2(&self) -> bool {
 91        matches!(self, Self::V2(_))
 92    }
 93}
 94
 95#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
 96#[serde(rename_all = "snake_case")]
 97pub enum PlanV1 {
 98    #[default]
 99    #[serde(alias = "Free")]
100    ZedFree,
101    #[serde(alias = "ZedPro")]
102    ZedPro,
103    #[serde(alias = "ZedProTrial")]
104    ZedProTrial,
105}
106
107impl FromStr for PlanV1 {
108    type Err = anyhow::Error;
109
110    fn from_str(value: &str) -> Result<Self, Self::Err> {
111        match value {
112            "zed_free" => Ok(Self::ZedFree),
113            "zed_pro" => Ok(Self::ZedPro),
114            "zed_pro_trial" => Ok(Self::ZedProTrial),
115            plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
116        }
117    }
118}
119
120#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
121#[serde(rename_all = "snake_case")]
122pub enum PlanV2 {
123    #[default]
124    ZedFree,
125    ZedPro,
126    ZedProTrial,
127}
128
129impl FromStr for PlanV2 {
130    type Err = anyhow::Error;
131
132    fn from_str(value: &str) -> Result<Self, Self::Err> {
133        match value {
134            "zed_free" => Ok(Self::ZedFree),
135            "zed_pro" => Ok(Self::ZedPro),
136            "zed_pro_trial" => Ok(Self::ZedProTrial),
137            plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
138        }
139    }
140}
141
142#[derive(
143    Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
144)]
145#[serde(rename_all = "snake_case")]
146#[strum(serialize_all = "snake_case")]
147pub enum LanguageModelProvider {
148    Anthropic,
149    OpenAi,
150    Google,
151    XAi,
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct PredictEditsBody {
156    #[serde(skip_serializing_if = "Option::is_none", default)]
157    pub outline: Option<String>,
158    pub input_events: String,
159    pub input_excerpt: String,
160    #[serde(skip_serializing_if = "Option::is_none", default)]
161    pub speculated_output: Option<String>,
162    /// Whether the user provided consent for sampling this interaction.
163    #[serde(default, alias = "data_collection_permission")]
164    pub can_collect_data: bool,
165    #[serde(skip_serializing_if = "Option::is_none", default)]
166    pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
167    /// Info about the git repository state, only present when can_collect_data is true.
168    #[serde(skip_serializing_if = "Option::is_none", default)]
169    pub git_info: Option<PredictEditsGitInfo>,
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct PredictEditsGitInfo {
174    /// SHA of git HEAD commit at time of prediction.
175    #[serde(skip_serializing_if = "Option::is_none", default)]
176    pub head_sha: Option<String>,
177    /// URL of the remote called `origin`.
178    #[serde(skip_serializing_if = "Option::is_none", default)]
179    pub remote_origin_url: Option<String>,
180    /// URL of the remote called `upstream`.
181    #[serde(skip_serializing_if = "Option::is_none", default)]
182    pub remote_upstream_url: Option<String>,
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct PredictEditsResponse {
187    pub request_id: Uuid,
188    pub output_excerpt: String,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct AcceptEditPredictionBody {
193    pub request_id: Uuid,
194}
195
196#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
197#[serde(rename_all = "snake_case")]
198pub enum CompletionMode {
199    Normal,
200    Max,
201}
202
203#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
204#[serde(rename_all = "snake_case")]
205pub enum CompletionIntent {
206    UserPrompt,
207    ToolResults,
208    ThreadSummarization,
209    ThreadContextSummarization,
210    CreateFile,
211    EditFile,
212    InlineAssist,
213    TerminalInlineAssist,
214    GenerateGitCommitMessage,
215}
216
217#[derive(Debug, Serialize, Deserialize)]
218pub struct CompletionBody {
219    #[serde(skip_serializing_if = "Option::is_none", default)]
220    pub thread_id: Option<String>,
221    #[serde(skip_serializing_if = "Option::is_none", default)]
222    pub prompt_id: Option<String>,
223    #[serde(skip_serializing_if = "Option::is_none", default)]
224    pub intent: Option<CompletionIntent>,
225    #[serde(skip_serializing_if = "Option::is_none", default)]
226    pub mode: Option<CompletionMode>,
227    pub provider: LanguageModelProvider,
228    pub model: String,
229    pub provider_request: serde_json::Value,
230}
231
232#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
233#[serde(rename_all = "snake_case")]
234pub enum CompletionRequestStatus {
235    Queued {
236        position: usize,
237    },
238    Started,
239    Failed {
240        code: String,
241        message: String,
242        request_id: Uuid,
243        /// Retry duration in seconds.
244        retry_after: Option<f64>,
245    },
246    UsageUpdated {
247        amount: usize,
248        limit: UsageLimit,
249    },
250    ToolUseLimitReached,
251}
252
253#[derive(Serialize, Deserialize)]
254#[serde(rename_all = "snake_case")]
255pub enum CompletionEvent<T> {
256    Status(CompletionRequestStatus),
257    Event(T),
258}
259
260impl<T> CompletionEvent<T> {
261    pub fn into_status(self) -> Option<CompletionRequestStatus> {
262        match self {
263            Self::Status(status) => Some(status),
264            Self::Event(_) => None,
265        }
266    }
267
268    pub fn into_event(self) -> Option<T> {
269        match self {
270            Self::Event(event) => Some(event),
271            Self::Status(_) => None,
272        }
273    }
274}
275
276#[derive(Serialize, Deserialize)]
277pub struct WebSearchBody {
278    pub query: String,
279}
280
281#[derive(Debug, Serialize, Deserialize, Clone)]
282pub struct WebSearchResponse {
283    pub results: Vec<WebSearchResult>,
284}
285
286#[derive(Debug, Serialize, Deserialize, Clone)]
287pub struct WebSearchResult {
288    pub title: String,
289    pub url: String,
290    pub text: String,
291}
292
293#[derive(Serialize, Deserialize)]
294pub struct CountTokensBody {
295    pub provider: LanguageModelProvider,
296    pub model: String,
297    pub provider_request: serde_json::Value,
298}
299
300#[derive(Serialize, Deserialize)]
301pub struct CountTokensResponse {
302    pub tokens: usize,
303}
304
305#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
306pub struct LanguageModelId(pub Arc<str>);
307
308impl std::fmt::Display for LanguageModelId {
309    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310        write!(f, "{}", self.0)
311    }
312}
313
314#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
315pub struct LanguageModel {
316    pub provider: LanguageModelProvider,
317    pub id: LanguageModelId,
318    pub display_name: String,
319    pub max_token_count: usize,
320    pub max_token_count_in_max_mode: Option<usize>,
321    pub max_output_tokens: usize,
322    pub supports_tools: bool,
323    pub supports_images: bool,
324    pub supports_thinking: bool,
325    pub supports_max_mode: bool,
326    // only used by OpenAI and xAI
327    #[serde(default)]
328    pub supports_parallel_tool_calls: bool,
329}
330
331#[derive(Debug, Serialize, Deserialize)]
332pub struct ListModelsResponse {
333    pub models: Vec<LanguageModel>,
334    pub default_model: Option<LanguageModelId>,
335    pub default_fast_model: Option<LanguageModelId>,
336    pub recommended_models: Vec<LanguageModelId>,
337}
338
339#[derive(Debug, Serialize, Deserialize)]
340pub struct GetSubscriptionResponse {
341    pub plan: PlanV1,
342    pub usage: Option<CurrentUsage>,
343}
344
345#[derive(Debug, PartialEq, Serialize, Deserialize)]
346pub struct CurrentUsage {
347    pub model_requests: UsageData,
348    pub edit_predictions: UsageData,
349}
350
351#[derive(Debug, PartialEq, Serialize, Deserialize)]
352pub struct UsageData {
353    pub used: u32,
354    pub limit: UsageLimit,
355}
356
357#[cfg(test)]
358mod tests {
359    use pretty_assertions::assert_eq;
360    use serde_json::json;
361
362    use super::*;
363
364    #[test]
365    fn test_plan_v1_deserialize_snake_case() {
366        let plan = serde_json::from_value::<PlanV1>(json!("zed_free")).unwrap();
367        assert_eq!(plan, PlanV1::ZedFree);
368
369        let plan = serde_json::from_value::<PlanV1>(json!("zed_pro")).unwrap();
370        assert_eq!(plan, PlanV1::ZedPro);
371
372        let plan = serde_json::from_value::<PlanV1>(json!("zed_pro_trial")).unwrap();
373        assert_eq!(plan, PlanV1::ZedProTrial);
374    }
375
376    #[test]
377    fn test_plan_v1_deserialize_aliases() {
378        let plan = serde_json::from_value::<PlanV1>(json!("Free")).unwrap();
379        assert_eq!(plan, PlanV1::ZedFree);
380
381        let plan = serde_json::from_value::<PlanV1>(json!("ZedPro")).unwrap();
382        assert_eq!(plan, PlanV1::ZedPro);
383
384        let plan = serde_json::from_value::<PlanV1>(json!("ZedProTrial")).unwrap();
385        assert_eq!(plan, PlanV1::ZedProTrial);
386    }
387
388    #[test]
389    fn test_plan_v2_deserialize_snake_case() {
390        let plan = serde_json::from_value::<PlanV2>(json!("zed_free")).unwrap();
391        assert_eq!(plan, PlanV2::ZedFree);
392
393        let plan = serde_json::from_value::<PlanV2>(json!("zed_pro")).unwrap();
394        assert_eq!(plan, PlanV2::ZedPro);
395
396        let plan = serde_json::from_value::<PlanV2>(json!("zed_pro_trial")).unwrap();
397        assert_eq!(plan, PlanV2::ZedProTrial);
398    }
399
400    #[test]
401    fn test_usage_limit_from_str() {
402        let limit = UsageLimit::from_str("unlimited").unwrap();
403        assert!(matches!(limit, UsageLimit::Unlimited));
404
405        let limit = UsageLimit::from_str(&0.to_string()).unwrap();
406        assert!(matches!(limit, UsageLimit::Limited(0)));
407
408        let limit = UsageLimit::from_str(&50.to_string()).unwrap();
409        assert!(matches!(limit, UsageLimit::Limited(50)));
410
411        for value in ["not_a_number", "50xyz"] {
412            let limit = UsageLimit::from_str(value);
413            assert!(limit.is_err());
414        }
415    }
416}