1pub mod predict_edits_v3;
2
3use std::str::FromStr;
4use std::sync::Arc;
5
6use anyhow::Context as _;
7use serde::{Deserialize, Serialize};
8use strum::{Display, EnumIter, EnumString};
9use uuid::Uuid;
10
11/// The name of the header used to indicate which version of Zed the client is running.
12pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
13
14/// The name of the header used to indicate when a request failed due to an
15/// expired LLM token.
16///
17/// The client may use this as a signal to refresh the token.
18pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
19
20/// The name of the header used to indicate what plan the user is currently on.
21pub const CURRENT_PLAN_HEADER_NAME: &str = "x-zed-plan";
22
23/// The name of the header used to indicate the usage limit for model requests.
24pub const MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-model-requests-usage-limit";
25
26/// The name of the header used to indicate the usage amount for model requests.
27pub const MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-model-requests-usage-amount";
28
29/// The name of the header used to indicate the usage limit for edit predictions.
30pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
31
32/// The name of the header used to indicate the usage amount for edit predictions.
33pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
34
35/// The name of the header used to indicate the resource for which the subscription limit has been reached.
36pub const SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME: &str = "x-zed-subscription-limit-resource";
37
38pub const MODEL_REQUESTS_RESOURCE_HEADER_VALUE: &str = "model_requests";
39pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
40
41/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached.
42pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached";
43
44/// The name of the header used to indicate the minimum required Zed version.
45///
46/// This can be used to force a Zed upgrade in order to continue communicating
47/// with the LLM service.
48pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
49
50/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
51pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
52 "x-zed-client-supports-status-messages";
53
54/// The name of the header used by the server to indicate to the client that it supports sending status messages.
55pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
56 "x-zed-server-supports-status-messages";
57
58/// The name of the header used by the client to indicate that it supports receiving xAI models.
59pub const CLIENT_SUPPORTS_X_AI_HEADER_NAME: &str = "x-zed-client-supports-x-ai";
60
61/// The maximum number of edit predictions that can be rejected per request.
62pub const MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST: usize = 100;
63
64#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
65#[serde(rename_all = "snake_case")]
66pub enum UsageLimit {
67 Limited(i32),
68 Unlimited,
69}
70
71impl FromStr for UsageLimit {
72 type Err = anyhow::Error;
73
74 fn from_str(value: &str) -> Result<Self, Self::Err> {
75 match value {
76 "unlimited" => Ok(Self::Unlimited),
77 limit => limit
78 .parse::<i32>()
79 .map(Self::Limited)
80 .context("failed to parse limit"),
81 }
82 }
83}
84
85#[derive(Debug, Clone, Copy, PartialEq)]
86pub enum Plan {
87 V1(PlanV1),
88 V2(PlanV2),
89}
90
91impl Plan {
92 pub fn is_v2(&self) -> bool {
93 matches!(self, Self::V2(_))
94 }
95}
96
97#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
98#[serde(rename_all = "snake_case")]
99pub enum PlanV1 {
100 #[default]
101 #[serde(alias = "Free")]
102 ZedFree,
103 #[serde(alias = "ZedPro")]
104 ZedPro,
105 #[serde(alias = "ZedProTrial")]
106 ZedProTrial,
107}
108
109impl FromStr for PlanV1 {
110 type Err = anyhow::Error;
111
112 fn from_str(value: &str) -> Result<Self, Self::Err> {
113 match value {
114 "zed_free" => Ok(Self::ZedFree),
115 "zed_pro" => Ok(Self::ZedPro),
116 "zed_pro_trial" => Ok(Self::ZedProTrial),
117 plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
118 }
119 }
120}
121
122#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
123#[serde(rename_all = "snake_case")]
124pub enum PlanV2 {
125 #[default]
126 ZedFree,
127 ZedPro,
128 ZedProTrial,
129}
130
131impl FromStr for PlanV2 {
132 type Err = anyhow::Error;
133
134 fn from_str(value: &str) -> Result<Self, Self::Err> {
135 match value {
136 "zed_free" => Ok(Self::ZedFree),
137 "zed_pro" => Ok(Self::ZedPro),
138 "zed_pro_trial" => Ok(Self::ZedProTrial),
139 plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
140 }
141 }
142}
143
144#[derive(
145 Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
146)]
147#[serde(rename_all = "snake_case")]
148#[strum(serialize_all = "snake_case")]
149pub enum LanguageModelProvider {
150 Anthropic,
151 OpenAi,
152 Google,
153 XAi,
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct PredictEditsBody {
158 #[serde(skip_serializing_if = "Option::is_none", default)]
159 pub outline: Option<String>,
160 pub input_events: String,
161 pub input_excerpt: String,
162 #[serde(skip_serializing_if = "Option::is_none", default)]
163 pub speculated_output: Option<String>,
164 /// Whether the user provided consent for sampling this interaction.
165 #[serde(default, alias = "data_collection_permission")]
166 pub can_collect_data: bool,
167 #[serde(skip_serializing_if = "Option::is_none", default)]
168 pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
169 /// Info about the git repository state, only present when can_collect_data is true.
170 #[serde(skip_serializing_if = "Option::is_none", default)]
171 pub git_info: Option<PredictEditsGitInfo>,
172 /// The trigger for this request.
173 #[serde(default)]
174 pub trigger: PredictEditsRequestTrigger,
175}
176
177#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)]
178pub enum PredictEditsRequestTrigger {
179 Testing,
180 Diagnostics,
181 Cli,
182 #[default]
183 Other,
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
187pub struct PredictEditsGitInfo {
188 /// SHA of git HEAD commit at time of prediction.
189 #[serde(skip_serializing_if = "Option::is_none", default)]
190 pub head_sha: Option<String>,
191 /// URL of the remote called `origin`.
192 #[serde(skip_serializing_if = "Option::is_none", default)]
193 pub remote_origin_url: Option<String>,
194 /// URL of the remote called `upstream`.
195 #[serde(skip_serializing_if = "Option::is_none", default)]
196 pub remote_upstream_url: Option<String>,
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize)]
200pub struct PredictEditsResponse {
201 pub request_id: String,
202 pub output_excerpt: String,
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct AcceptEditPredictionBody {
207 pub request_id: String,
208}
209
210#[derive(Debug, Clone, Deserialize)]
211pub struct RejectEditPredictionsBody {
212 pub rejections: Vec<EditPredictionRejection>,
213}
214
215#[derive(Debug, Clone, Serialize)]
216pub struct RejectEditPredictionsBodyRef<'a> {
217 pub rejections: &'a [EditPredictionRejection],
218}
219
220#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
221pub struct EditPredictionRejection {
222 pub request_id: String,
223 #[serde(default)]
224 pub reason: EditPredictionRejectReason,
225 pub was_shown: bool,
226}
227
228#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
229pub enum EditPredictionRejectReason {
230 /// New requests were triggered before this one completed
231 Canceled,
232 /// No edits returned
233 Empty,
234 /// Edits returned, but none remained after interpolation
235 InterpolatedEmpty,
236 /// The new prediction was preferred over the current one
237 Replaced,
238 /// The current prediction was preferred over the new one
239 CurrentPreferred,
240 /// The current prediction was discarded
241 #[default]
242 Discarded,
243}
244
245#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
246#[serde(rename_all = "snake_case")]
247pub enum CompletionMode {
248 Normal,
249 Max,
250}
251
252#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
253#[serde(rename_all = "snake_case")]
254pub enum CompletionIntent {
255 UserPrompt,
256 ToolResults,
257 ThreadSummarization,
258 ThreadContextSummarization,
259 CreateFile,
260 EditFile,
261 InlineAssist,
262 TerminalInlineAssist,
263 GenerateGitCommitMessage,
264}
265
266#[derive(Debug, Serialize, Deserialize)]
267pub struct CompletionBody {
268 #[serde(skip_serializing_if = "Option::is_none", default)]
269 pub thread_id: Option<String>,
270 #[serde(skip_serializing_if = "Option::is_none", default)]
271 pub prompt_id: Option<String>,
272 #[serde(skip_serializing_if = "Option::is_none", default)]
273 pub intent: Option<CompletionIntent>,
274 #[serde(skip_serializing_if = "Option::is_none", default)]
275 pub mode: Option<CompletionMode>,
276 pub provider: LanguageModelProvider,
277 pub model: String,
278 pub provider_request: serde_json::Value,
279}
280
281#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
282#[serde(rename_all = "snake_case")]
283pub enum CompletionRequestStatus {
284 Queued {
285 position: usize,
286 },
287 Started,
288 Failed {
289 code: String,
290 message: String,
291 request_id: Uuid,
292 /// Retry duration in seconds.
293 retry_after: Option<f64>,
294 },
295 UsageUpdated {
296 amount: usize,
297 limit: UsageLimit,
298 },
299 ToolUseLimitReached,
300}
301
302#[derive(Serialize, Deserialize)]
303#[serde(rename_all = "snake_case")]
304pub enum CompletionEvent<T> {
305 Status(CompletionRequestStatus),
306 Event(T),
307}
308
309impl<T> CompletionEvent<T> {
310 pub fn into_status(self) -> Option<CompletionRequestStatus> {
311 match self {
312 Self::Status(status) => Some(status),
313 Self::Event(_) => None,
314 }
315 }
316
317 pub fn into_event(self) -> Option<T> {
318 match self {
319 Self::Event(event) => Some(event),
320 Self::Status(_) => None,
321 }
322 }
323}
324
325#[derive(Serialize, Deserialize)]
326pub struct WebSearchBody {
327 pub query: String,
328}
329
330#[derive(Debug, Serialize, Deserialize, Clone)]
331pub struct WebSearchResponse {
332 pub results: Vec<WebSearchResult>,
333}
334
335#[derive(Debug, Serialize, Deserialize, Clone)]
336pub struct WebSearchResult {
337 pub title: String,
338 pub url: String,
339 pub text: String,
340}
341
342#[derive(Serialize, Deserialize)]
343pub struct CountTokensBody {
344 pub provider: LanguageModelProvider,
345 pub model: String,
346 pub provider_request: serde_json::Value,
347}
348
349#[derive(Serialize, Deserialize)]
350pub struct CountTokensResponse {
351 pub tokens: usize,
352}
353
354#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
355pub struct LanguageModelId(pub Arc<str>);
356
357impl std::fmt::Display for LanguageModelId {
358 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
359 write!(f, "{}", self.0)
360 }
361}
362
363#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
364pub struct LanguageModel {
365 pub provider: LanguageModelProvider,
366 pub id: LanguageModelId,
367 pub display_name: String,
368 pub max_token_count: usize,
369 pub max_token_count_in_max_mode: Option<usize>,
370 pub max_output_tokens: usize,
371 pub supports_tools: bool,
372 pub supports_images: bool,
373 pub supports_thinking: bool,
374 pub supports_max_mode: bool,
375 #[serde(default)]
376 pub supports_streaming_tools: bool,
377 // only used by OpenAI and xAI
378 #[serde(default)]
379 pub supports_parallel_tool_calls: bool,
380}
381
382#[derive(Debug, Serialize, Deserialize)]
383pub struct ListModelsResponse {
384 pub models: Vec<LanguageModel>,
385 pub default_model: Option<LanguageModelId>,
386 pub default_fast_model: Option<LanguageModelId>,
387 pub recommended_models: Vec<LanguageModelId>,
388}
389
390#[derive(Debug, Serialize, Deserialize)]
391pub struct GetSubscriptionResponse {
392 pub plan: PlanV1,
393 pub usage: Option<CurrentUsage>,
394}
395
396#[derive(Debug, PartialEq, Serialize, Deserialize)]
397pub struct CurrentUsage {
398 pub model_requests: UsageData,
399 pub edit_predictions: UsageData,
400}
401
402#[derive(Debug, PartialEq, Serialize, Deserialize)]
403pub struct UsageData {
404 pub used: u32,
405 pub limit: UsageLimit,
406}
407
408#[cfg(test)]
409mod tests {
410 use pretty_assertions::assert_eq;
411 use serde_json::json;
412
413 use super::*;
414
415 #[test]
416 fn test_plan_v1_deserialize_snake_case() {
417 let plan = serde_json::from_value::<PlanV1>(json!("zed_free")).unwrap();
418 assert_eq!(plan, PlanV1::ZedFree);
419
420 let plan = serde_json::from_value::<PlanV1>(json!("zed_pro")).unwrap();
421 assert_eq!(plan, PlanV1::ZedPro);
422
423 let plan = serde_json::from_value::<PlanV1>(json!("zed_pro_trial")).unwrap();
424 assert_eq!(plan, PlanV1::ZedProTrial);
425 }
426
427 #[test]
428 fn test_plan_v1_deserialize_aliases() {
429 let plan = serde_json::from_value::<PlanV1>(json!("Free")).unwrap();
430 assert_eq!(plan, PlanV1::ZedFree);
431
432 let plan = serde_json::from_value::<PlanV1>(json!("ZedPro")).unwrap();
433 assert_eq!(plan, PlanV1::ZedPro);
434
435 let plan = serde_json::from_value::<PlanV1>(json!("ZedProTrial")).unwrap();
436 assert_eq!(plan, PlanV1::ZedProTrial);
437 }
438
439 #[test]
440 fn test_plan_v2_deserialize_snake_case() {
441 let plan = serde_json::from_value::<PlanV2>(json!("zed_free")).unwrap();
442 assert_eq!(plan, PlanV2::ZedFree);
443
444 let plan = serde_json::from_value::<PlanV2>(json!("zed_pro")).unwrap();
445 assert_eq!(plan, PlanV2::ZedPro);
446
447 let plan = serde_json::from_value::<PlanV2>(json!("zed_pro_trial")).unwrap();
448 assert_eq!(plan, PlanV2::ZedProTrial);
449 }
450
451 #[test]
452 fn test_usage_limit_from_str() {
453 let limit = UsageLimit::from_str("unlimited").unwrap();
454 assert!(matches!(limit, UsageLimit::Unlimited));
455
456 let limit = UsageLimit::from_str(&0.to_string()).unwrap();
457 assert!(matches!(limit, UsageLimit::Limited(0)));
458
459 let limit = UsageLimit::from_str(&50.to_string()).unwrap();
460 assert!(matches!(limit, UsageLimit::Limited(50)));
461
462 for value in ["not_a_number", "50xyz"] {
463 let limit = UsageLimit::from_str(value);
464 assert!(limit.is_err());
465 }
466 }
467}