1pub mod predict_edits_v3;
2
3use std::str::FromStr;
4use std::sync::Arc;
5
6use anyhow::Context as _;
7use serde::{Deserialize, Serialize};
8use strum::{Display, EnumIter, EnumString};
9use uuid::Uuid;
10
11/// The name of the header used to indicate which version of Zed the client is running.
12pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
13
14/// The name of the header used to indicate when a request failed due to an
15/// expired LLM token.
16///
17/// The client may use this as a signal to refresh the token.
18pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
19
20/// The name of the header used to indicate when a request failed due to an outdated LLM token.
21///
22/// A token is considered "outdated" when we can't parse the claims (e.g., after adding a new required claim).
23///
24/// This is distinct from [`EXPIRED_LLM_TOKEN_HEADER_NAME`] which indicates the token's time-based validity has passed.
25/// An outdated token means the token's structure is incompatible with the current server expectations.
26pub const OUTDATED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-outdated-token";
27
28/// The name of the header used to indicate the usage limit for edit predictions.
29pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
30
31/// The name of the header used to indicate the usage amount for edit predictions.
32pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
33
34pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
35
36/// The name of the header used to indicate the minimum required Zed version.
37///
38/// This can be used to force a Zed upgrade in order to continue communicating
39/// with the LLM service.
40pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
41
42/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
43pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
44 "x-zed-client-supports-status-messages";
45
46/// The name of the header used by the server to indicate to the client that it supports sending status messages.
47pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
48 "x-zed-server-supports-status-messages";
49
50/// The name of the header used by the client to indicate that it supports receiving xAI models.
51pub const CLIENT_SUPPORTS_X_AI_HEADER_NAME: &str = "x-zed-client-supports-x-ai";
52
53/// The maximum number of edit predictions that can be rejected per request.
54pub const MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST: usize = 100;
55
56#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
57#[serde(rename_all = "snake_case")]
58pub enum UsageLimit {
59 Limited(i32),
60 Unlimited,
61}
62
63impl FromStr for UsageLimit {
64 type Err = anyhow::Error;
65
66 fn from_str(value: &str) -> Result<Self, Self::Err> {
67 match value {
68 "unlimited" => Ok(Self::Unlimited),
69 limit => limit
70 .parse::<i32>()
71 .map(Self::Limited)
72 .context("failed to parse limit"),
73 }
74 }
75}
76
77#[derive(Debug, Clone, Copy, PartialEq)]
78pub enum Plan {
79 V2(PlanV2),
80}
81
82#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
83#[serde(rename_all = "snake_case")]
84pub enum PlanV2 {
85 #[default]
86 ZedFree,
87 ZedPro,
88 ZedProTrial,
89}
90
91#[derive(
92 Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
93)]
94#[serde(rename_all = "snake_case")]
95#[strum(serialize_all = "snake_case")]
96pub enum LanguageModelProvider {
97 Anthropic,
98 OpenAi,
99 Google,
100 XAi,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct PredictEditsBody {
105 #[serde(skip_serializing_if = "Option::is_none", default)]
106 pub outline: Option<String>,
107 pub input_events: String,
108 pub input_excerpt: String,
109 #[serde(skip_serializing_if = "Option::is_none", default)]
110 pub speculated_output: Option<String>,
111 /// Whether the user provided consent for sampling this interaction.
112 #[serde(default, alias = "data_collection_permission")]
113 pub can_collect_data: bool,
114 #[serde(skip_serializing_if = "Option::is_none", default)]
115 pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
116 /// Info about the git repository state, only present when can_collect_data is true.
117 #[serde(skip_serializing_if = "Option::is_none", default)]
118 pub git_info: Option<PredictEditsGitInfo>,
119 /// The trigger for this request.
120 #[serde(default)]
121 pub trigger: PredictEditsRequestTrigger,
122}
123
124#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)]
125pub enum PredictEditsRequestTrigger {
126 Testing,
127 Diagnostics,
128 Cli,
129 #[default]
130 Other,
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct PredictEditsGitInfo {
135 /// SHA of git HEAD commit at time of prediction.
136 #[serde(skip_serializing_if = "Option::is_none", default)]
137 pub head_sha: Option<String>,
138 /// URL of the remote called `origin`.
139 #[serde(skip_serializing_if = "Option::is_none", default)]
140 pub remote_origin_url: Option<String>,
141 /// URL of the remote called `upstream`.
142 #[serde(skip_serializing_if = "Option::is_none", default)]
143 pub remote_upstream_url: Option<String>,
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct PredictEditsResponse {
148 pub request_id: String,
149 pub output_excerpt: String,
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct AcceptEditPredictionBody {
154 pub request_id: String,
155}
156
157#[derive(Debug, Clone, Deserialize)]
158pub struct RejectEditPredictionsBody {
159 pub rejections: Vec<EditPredictionRejection>,
160}
161
162#[derive(Debug, Clone, Serialize)]
163pub struct RejectEditPredictionsBodyRef<'a> {
164 pub rejections: &'a [EditPredictionRejection],
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
168pub struct EditPredictionRejection {
169 pub request_id: String,
170 #[serde(default)]
171 pub reason: EditPredictionRejectReason,
172 pub was_shown: bool,
173}
174
175#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
176pub enum EditPredictionRejectReason {
177 /// New requests were triggered before this one completed
178 Canceled,
179 /// No edits returned
180 Empty,
181 /// Edits returned, but none remained after interpolation
182 InterpolatedEmpty,
183 /// The new prediction was preferred over the current one
184 Replaced,
185 /// The current prediction was preferred over the new one
186 CurrentPreferred,
187 /// The current prediction was discarded
188 #[default]
189 Discarded,
190}
191
192#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
193#[serde(rename_all = "snake_case")]
194pub enum CompletionIntent {
195 UserPrompt,
196 ToolResults,
197 ThreadSummarization,
198 ThreadContextSummarization,
199 CreateFile,
200 EditFile,
201 InlineAssist,
202 TerminalInlineAssist,
203 GenerateGitCommitMessage,
204}
205
206#[derive(Debug, Serialize, Deserialize)]
207pub struct CompletionBody {
208 #[serde(skip_serializing_if = "Option::is_none", default)]
209 pub thread_id: Option<String>,
210 #[serde(skip_serializing_if = "Option::is_none", default)]
211 pub prompt_id: Option<String>,
212 #[serde(skip_serializing_if = "Option::is_none", default)]
213 pub intent: Option<CompletionIntent>,
214 pub provider: LanguageModelProvider,
215 pub model: String,
216 pub provider_request: serde_json::Value,
217}
218
219#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
220#[serde(rename_all = "snake_case")]
221pub enum CompletionRequestStatus {
222 Queued {
223 position: usize,
224 },
225 Started,
226 Failed {
227 code: String,
228 message: String,
229 request_id: Uuid,
230 /// Retry duration in seconds.
231 retry_after: Option<f64>,
232 },
233 UsageUpdated {
234 amount: usize,
235 limit: UsageLimit,
236 },
237 ToolUseLimitReached,
238}
239
240#[derive(Serialize, Deserialize)]
241#[serde(rename_all = "snake_case")]
242pub enum CompletionEvent<T> {
243 Status(CompletionRequestStatus),
244 Event(T),
245}
246
247impl<T> CompletionEvent<T> {
248 pub fn into_status(self) -> Option<CompletionRequestStatus> {
249 match self {
250 Self::Status(status) => Some(status),
251 Self::Event(_) => None,
252 }
253 }
254
255 pub fn into_event(self) -> Option<T> {
256 match self {
257 Self::Event(event) => Some(event),
258 Self::Status(_) => None,
259 }
260 }
261}
262
263#[derive(Serialize, Deserialize)]
264pub struct WebSearchBody {
265 pub query: String,
266}
267
268#[derive(Debug, Serialize, Deserialize, Clone)]
269pub struct WebSearchResponse {
270 pub results: Vec<WebSearchResult>,
271}
272
273#[derive(Debug, Serialize, Deserialize, Clone)]
274pub struct WebSearchResult {
275 pub title: String,
276 pub url: String,
277 pub text: String,
278}
279
280#[derive(Serialize, Deserialize)]
281pub struct CountTokensBody {
282 pub provider: LanguageModelProvider,
283 pub model: String,
284 pub provider_request: serde_json::Value,
285}
286
287#[derive(Serialize, Deserialize)]
288pub struct CountTokensResponse {
289 pub tokens: usize,
290}
291
292#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
293pub struct LanguageModelId(pub Arc<str>);
294
295impl std::fmt::Display for LanguageModelId {
296 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297 write!(f, "{}", self.0)
298 }
299}
300
301#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
302pub struct LanguageModel {
303 pub provider: LanguageModelProvider,
304 pub id: LanguageModelId,
305 pub display_name: String,
306 pub max_token_count: usize,
307 pub max_token_count_in_max_mode: Option<usize>,
308 pub max_output_tokens: usize,
309 pub supports_tools: bool,
310 pub supports_images: bool,
311 pub supports_thinking: bool,
312 #[serde(default)]
313 pub supports_streaming_tools: bool,
314 /// Only used by OpenAI and xAI.
315 #[serde(default)]
316 pub supports_parallel_tool_calls: bool,
317}
318
319#[derive(Debug, Serialize, Deserialize)]
320pub struct ListModelsResponse {
321 pub models: Vec<LanguageModel>,
322 pub default_model: Option<LanguageModelId>,
323 pub default_fast_model: Option<LanguageModelId>,
324 pub recommended_models: Vec<LanguageModelId>,
325}
326
327#[derive(Debug, PartialEq, Serialize, Deserialize)]
328pub struct CurrentUsage {
329 pub edit_predictions: UsageData,
330}
331
332#[derive(Debug, PartialEq, Serialize, Deserialize)]
333pub struct UsageData {
334 pub used: u32,
335 pub limit: UsageLimit,
336}
337
338#[cfg(test)]
339mod tests {
340 use pretty_assertions::assert_eq;
341 use serde_json::json;
342
343 use super::*;
344
345 #[test]
346 fn test_plan_v2_deserialize_snake_case() {
347 let plan = serde_json::from_value::<PlanV2>(json!("zed_free")).unwrap();
348 assert_eq!(plan, PlanV2::ZedFree);
349
350 let plan = serde_json::from_value::<PlanV2>(json!("zed_pro")).unwrap();
351 assert_eq!(plan, PlanV2::ZedPro);
352
353 let plan = serde_json::from_value::<PlanV2>(json!("zed_pro_trial")).unwrap();
354 assert_eq!(plan, PlanV2::ZedProTrial);
355 }
356
357 #[test]
358 fn test_usage_limit_from_str() {
359 let limit = UsageLimit::from_str("unlimited").unwrap();
360 assert!(matches!(limit, UsageLimit::Unlimited));
361
362 let limit = UsageLimit::from_str(&0.to_string()).unwrap();
363 assert!(matches!(limit, UsageLimit::Limited(0)));
364
365 let limit = UsageLimit::from_str(&50.to_string()).unwrap();
366 assert!(matches!(limit, UsageLimit::Limited(50)));
367
368 for value in ["not_a_number", "50xyz"] {
369 let limit = UsageLimit::from_str(value);
370 assert!(limit.is_err());
371 }
372 }
373}