1pub mod predict_edits_v3;
  2
  3use std::str::FromStr;
  4use std::sync::Arc;
  5
  6use anyhow::Context as _;
  7use serde::{Deserialize, Serialize};
  8use strum::{Display, EnumIter, EnumString};
  9use uuid::Uuid;
 10
 11/// The name of the header used to indicate which version of Zed the client is running.
 12pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
 13
 14/// The name of the header used to indicate when a request failed due to an
 15/// expired LLM token.
 16///
 17/// The client may use this as a signal to refresh the token.
 18pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
 19
 20/// The name of the header used to indicate what plan the user is currently on.
 21pub const CURRENT_PLAN_HEADER_NAME: &str = "x-zed-plan";
 22
 23/// The name of the header used to indicate the usage limit for model requests.
 24pub const MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-model-requests-usage-limit";
 25
 26/// The name of the header used to indicate the usage amount for model requests.
 27pub const MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-model-requests-usage-amount";
 28
 29/// The name of the header used to indicate the usage limit for edit predictions.
 30pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
 31
 32/// The name of the header used to indicate the usage amount for edit predictions.
 33pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
 34
 35/// The name of the header used to indicate the resource for which the subscription limit has been reached.
 36pub const SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME: &str = "x-zed-subscription-limit-resource";
 37
 38pub const MODEL_REQUESTS_RESOURCE_HEADER_VALUE: &str = "model_requests";
 39pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
 40
 41/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached.
 42pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached";
 43
 44/// The name of the header used to indicate the minimum required Zed version.
 45///
 46/// This can be used to force a Zed upgrade in order to continue communicating
 47/// with the LLM service.
 48pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
 49
 50/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
 51pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 52    "x-zed-client-supports-status-messages";
 53
 54/// The name of the header used by the server to indicate to the client that it supports sending status messages.
 55pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 56    "x-zed-server-supports-status-messages";
 57
 58/// The name of the header used by the client to indicate that it supports receiving xAI models.
 59pub const CLIENT_SUPPORTS_X_AI_HEADER_NAME: &str = "x-zed-client-supports-x-ai";
 60
 61#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 62#[serde(rename_all = "snake_case")]
 63pub enum UsageLimit {
 64    Limited(i32),
 65    Unlimited,
 66}
 67
 68impl FromStr for UsageLimit {
 69    type Err = anyhow::Error;
 70
 71    fn from_str(value: &str) -> Result<Self, Self::Err> {
 72        match value {
 73            "unlimited" => Ok(Self::Unlimited),
 74            limit => limit
 75                .parse::<i32>()
 76                .map(Self::Limited)
 77                .context("failed to parse limit"),
 78        }
 79    }
 80}
 81
 82#[derive(Debug, Clone, Copy, PartialEq)]
 83pub enum Plan {
 84    V1(PlanV1),
 85    V2(PlanV2),
 86}
 87
 88impl Plan {
 89    pub fn is_v2(&self) -> bool {
 90        matches!(self, Self::V2(_))
 91    }
 92}
 93
 94#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
 95#[serde(rename_all = "snake_case")]
 96pub enum PlanV1 {
 97    #[default]
 98    #[serde(alias = "Free")]
 99    ZedFree,
100    #[serde(alias = "ZedPro")]
101    ZedPro,
102    #[serde(alias = "ZedProTrial")]
103    ZedProTrial,
104}
105
106impl FromStr for PlanV1 {
107    type Err = anyhow::Error;
108
109    fn from_str(value: &str) -> Result<Self, Self::Err> {
110        match value {
111            "zed_free" => Ok(Self::ZedFree),
112            "zed_pro" => Ok(Self::ZedPro),
113            "zed_pro_trial" => Ok(Self::ZedProTrial),
114            plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
115        }
116    }
117}
118
119#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
120#[serde(rename_all = "snake_case")]
121pub enum PlanV2 {
122    #[default]
123    ZedFree,
124    ZedPro,
125    ZedProTrial,
126}
127
128impl FromStr for PlanV2 {
129    type Err = anyhow::Error;
130
131    fn from_str(value: &str) -> Result<Self, Self::Err> {
132        match value {
133            "zed_free" => Ok(Self::ZedFree),
134            "zed_pro" => Ok(Self::ZedPro),
135            "zed_pro_trial" => Ok(Self::ZedProTrial),
136            plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
137        }
138    }
139}
140
141#[derive(
142    Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
143)]
144#[serde(rename_all = "snake_case")]
145#[strum(serialize_all = "snake_case")]
146pub enum LanguageModelProvider {
147    Anthropic,
148    OpenAi,
149    Google,
150    XAi,
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct PredictEditsBody {
155    #[serde(skip_serializing_if = "Option::is_none", default)]
156    pub outline: Option<String>,
157    pub input_events: String,
158    pub input_excerpt: String,
159    #[serde(skip_serializing_if = "Option::is_none", default)]
160    pub speculated_output: Option<String>,
161    /// Whether the user provided consent for sampling this interaction.
162    #[serde(default, alias = "data_collection_permission")]
163    pub can_collect_data: bool,
164    #[serde(skip_serializing_if = "Option::is_none", default)]
165    pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
166    /// Info about the git repository state, only present when can_collect_data is true.
167    #[serde(skip_serializing_if = "Option::is_none", default)]
168    pub git_info: Option<PredictEditsGitInfo>,
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct PredictEditsGitInfo {
173    /// SHA of git HEAD commit at time of prediction.
174    #[serde(skip_serializing_if = "Option::is_none", default)]
175    pub head_sha: Option<String>,
176    /// URL of the remote called `origin`.
177    #[serde(skip_serializing_if = "Option::is_none", default)]
178    pub remote_origin_url: Option<String>,
179    /// URL of the remote called `upstream`.
180    #[serde(skip_serializing_if = "Option::is_none", default)]
181    pub remote_upstream_url: Option<String>,
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct PredictEditsResponse {
186    pub request_id: Uuid,
187    pub output_excerpt: String,
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct AcceptEditPredictionBody {
192    pub request_id: Uuid,
193}
194
195#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
196#[serde(rename_all = "snake_case")]
197pub enum CompletionMode {
198    Normal,
199    Max,
200}
201
202#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
203#[serde(rename_all = "snake_case")]
204pub enum CompletionIntent {
205    UserPrompt,
206    ToolResults,
207    ThreadSummarization,
208    ThreadContextSummarization,
209    CreateFile,
210    EditFile,
211    InlineAssist,
212    TerminalInlineAssist,
213    GenerateGitCommitMessage,
214}
215
216#[derive(Debug, Serialize, Deserialize)]
217pub struct CompletionBody {
218    #[serde(skip_serializing_if = "Option::is_none", default)]
219    pub thread_id: Option<String>,
220    #[serde(skip_serializing_if = "Option::is_none", default)]
221    pub prompt_id: Option<String>,
222    #[serde(skip_serializing_if = "Option::is_none", default)]
223    pub intent: Option<CompletionIntent>,
224    #[serde(skip_serializing_if = "Option::is_none", default)]
225    pub mode: Option<CompletionMode>,
226    pub provider: LanguageModelProvider,
227    pub model: String,
228    pub provider_request: serde_json::Value,
229}
230
231#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
232#[serde(rename_all = "snake_case")]
233pub enum CompletionRequestStatus {
234    Queued {
235        position: usize,
236    },
237    Started,
238    Failed {
239        code: String,
240        message: String,
241        request_id: Uuid,
242        /// Retry duration in seconds.
243        retry_after: Option<f64>,
244    },
245    UsageUpdated {
246        amount: usize,
247        limit: UsageLimit,
248    },
249    ToolUseLimitReached,
250}
251
252#[derive(Serialize, Deserialize)]
253#[serde(rename_all = "snake_case")]
254pub enum CompletionEvent<T> {
255    Status(CompletionRequestStatus),
256    Event(T),
257}
258
259impl<T> CompletionEvent<T> {
260    pub fn into_status(self) -> Option<CompletionRequestStatus> {
261        match self {
262            Self::Status(status) => Some(status),
263            Self::Event(_) => None,
264        }
265    }
266
267    pub fn into_event(self) -> Option<T> {
268        match self {
269            Self::Event(event) => Some(event),
270            Self::Status(_) => None,
271        }
272    }
273}
274
275#[derive(Serialize, Deserialize)]
276pub struct WebSearchBody {
277    pub query: String,
278}
279
280#[derive(Debug, Serialize, Deserialize, Clone)]
281pub struct WebSearchResponse {
282    pub results: Vec<WebSearchResult>,
283}
284
285#[derive(Debug, Serialize, Deserialize, Clone)]
286pub struct WebSearchResult {
287    pub title: String,
288    pub url: String,
289    pub text: String,
290}
291
292#[derive(Serialize, Deserialize)]
293pub struct CountTokensBody {
294    pub provider: LanguageModelProvider,
295    pub model: String,
296    pub provider_request: serde_json::Value,
297}
298
299#[derive(Serialize, Deserialize)]
300pub struct CountTokensResponse {
301    pub tokens: usize,
302}
303
304#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
305pub struct LanguageModelId(pub Arc<str>);
306
307impl std::fmt::Display for LanguageModelId {
308    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309        write!(f, "{}", self.0)
310    }
311}
312
313#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
314pub struct LanguageModel {
315    pub provider: LanguageModelProvider,
316    pub id: LanguageModelId,
317    pub display_name: String,
318    pub max_token_count: usize,
319    pub max_token_count_in_max_mode: Option<usize>,
320    pub max_output_tokens: usize,
321    pub supports_tools: bool,
322    pub supports_images: bool,
323    pub supports_thinking: bool,
324    pub supports_max_mode: bool,
325    // only used by OpenAI and xAI
326    #[serde(default)]
327    pub supports_parallel_tool_calls: bool,
328}
329
330#[derive(Debug, Serialize, Deserialize)]
331pub struct ListModelsResponse {
332    pub models: Vec<LanguageModel>,
333    pub default_model: Option<LanguageModelId>,
334    pub default_fast_model: Option<LanguageModelId>,
335    pub recommended_models: Vec<LanguageModelId>,
336}
337
338#[derive(Debug, Serialize, Deserialize)]
339pub struct GetSubscriptionResponse {
340    pub plan: PlanV1,
341    pub usage: Option<CurrentUsage>,
342}
343
344#[derive(Debug, PartialEq, Serialize, Deserialize)]
345pub struct CurrentUsage {
346    pub model_requests: UsageData,
347    pub edit_predictions: UsageData,
348}
349
350#[derive(Debug, PartialEq, Serialize, Deserialize)]
351pub struct UsageData {
352    pub used: u32,
353    pub limit: UsageLimit,
354}
355
356#[cfg(test)]
357mod tests {
358    use pretty_assertions::assert_eq;
359    use serde_json::json;
360
361    use super::*;
362
363    #[test]
364    fn test_plan_v1_deserialize_snake_case() {
365        let plan = serde_json::from_value::<PlanV1>(json!("zed_free")).unwrap();
366        assert_eq!(plan, PlanV1::ZedFree);
367
368        let plan = serde_json::from_value::<PlanV1>(json!("zed_pro")).unwrap();
369        assert_eq!(plan, PlanV1::ZedPro);
370
371        let plan = serde_json::from_value::<PlanV1>(json!("zed_pro_trial")).unwrap();
372        assert_eq!(plan, PlanV1::ZedProTrial);
373    }
374
375    #[test]
376    fn test_plan_v1_deserialize_aliases() {
377        let plan = serde_json::from_value::<PlanV1>(json!("Free")).unwrap();
378        assert_eq!(plan, PlanV1::ZedFree);
379
380        let plan = serde_json::from_value::<PlanV1>(json!("ZedPro")).unwrap();
381        assert_eq!(plan, PlanV1::ZedPro);
382
383        let plan = serde_json::from_value::<PlanV1>(json!("ZedProTrial")).unwrap();
384        assert_eq!(plan, PlanV1::ZedProTrial);
385    }
386
387    #[test]
388    fn test_plan_v2_deserialize_snake_case() {
389        let plan = serde_json::from_value::<PlanV2>(json!("zed_free")).unwrap();
390        assert_eq!(plan, PlanV2::ZedFree);
391
392        let plan = serde_json::from_value::<PlanV2>(json!("zed_pro")).unwrap();
393        assert_eq!(plan, PlanV2::ZedPro);
394
395        let plan = serde_json::from_value::<PlanV2>(json!("zed_pro_trial")).unwrap();
396        assert_eq!(plan, PlanV2::ZedProTrial);
397    }
398
399    #[test]
400    fn test_usage_limit_from_str() {
401        let limit = UsageLimit::from_str("unlimited").unwrap();
402        assert!(matches!(limit, UsageLimit::Unlimited));
403
404        let limit = UsageLimit::from_str(&0.to_string()).unwrap();
405        assert!(matches!(limit, UsageLimit::Limited(0)));
406
407        let limit = UsageLimit::from_str(&50.to_string()).unwrap();
408        assert!(matches!(limit, UsageLimit::Limited(50)));
409
410        for value in ["not_a_number", "50xyz"] {
411            let limit = UsageLimit::from_str(value);
412            assert!(limit.is_err());
413        }
414    }
415}