1pub mod predict_edits_v3;
2
3use std::str::FromStr;
4use std::sync::Arc;
5
6use anyhow::Context as _;
7use serde::{Deserialize, Serialize};
8use strum::{Display, EnumIter, EnumString};
9use uuid::Uuid;
10
11/// The name of the header used to indicate which version of Zed the client is running.
12pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
13
14/// The name of the header used to indicate when a request failed due to an
15/// expired LLM token.
16///
17/// The client may use this as a signal to refresh the token.
18pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
19
20/// The name of the header used to indicate when a request failed due to an outdated LLM token.
21///
22/// A token is considered "outdated" when we can't parse the claims (e.g., after adding a new required claim).
23///
24/// This is distinct from [`EXPIRED_LLM_TOKEN_HEADER_NAME`] which indicates the token's time-based validity has passed.
25/// An outdated token means the token's structure is incompatible with the current server expectations.
26pub const OUTDATED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-outdated-token";
27
28/// The name of the header used to indicate the usage limit for edit predictions.
29pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
30
31/// The name of the header used to indicate the usage amount for edit predictions.
32pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
33
34pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
35
36/// The name of the header used to indicate the minimum required Zed version.
37///
38/// This can be used to force a Zed upgrade in order to continue communicating
39/// with the LLM service.
40pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
41
42/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
43pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
44 "x-zed-client-supports-status-messages";
45
46/// The name of the header used by the server to indicate to the client that it supports sending status messages.
47pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
48 "x-zed-server-supports-status-messages";
49
50/// The name of the header used by the client to indicate that it supports receiving xAI models.
51pub const CLIENT_SUPPORTS_X_AI_HEADER_NAME: &str = "x-zed-client-supports-x-ai";
52
53/// The maximum number of edit predictions that can be rejected per request.
54pub const MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST: usize = 100;
55
56#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
57#[serde(rename_all = "snake_case")]
58pub enum UsageLimit {
59 Limited(i32),
60 Unlimited,
61}
62
63impl FromStr for UsageLimit {
64 type Err = anyhow::Error;
65
66 fn from_str(value: &str) -> Result<Self, Self::Err> {
67 match value {
68 "unlimited" => Ok(Self::Unlimited),
69 limit => limit
70 .parse::<i32>()
71 .map(Self::Limited)
72 .context("failed to parse limit"),
73 }
74 }
75}
76
77#[derive(Debug, Clone, Copy, PartialEq)]
78pub enum Plan {
79 V2(PlanV2),
80}
81
82impl Plan {
83 pub fn is_v2(&self) -> bool {
84 matches!(self, Self::V2(_))
85 }
86}
87
88#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
89#[serde(rename_all = "snake_case")]
90pub enum PlanV2 {
91 #[default]
92 ZedFree,
93 ZedPro,
94 ZedProTrial,
95}
96
97impl FromStr for PlanV2 {
98 type Err = anyhow::Error;
99
100 fn from_str(value: &str) -> Result<Self, Self::Err> {
101 match value {
102 "zed_free" => Ok(Self::ZedFree),
103 "zed_pro" => Ok(Self::ZedPro),
104 "zed_pro_trial" => Ok(Self::ZedProTrial),
105 plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
106 }
107 }
108}
109
110#[derive(
111 Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
112)]
113#[serde(rename_all = "snake_case")]
114#[strum(serialize_all = "snake_case")]
115pub enum LanguageModelProvider {
116 Anthropic,
117 OpenAi,
118 Google,
119 XAi,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct PredictEditsBody {
124 #[serde(skip_serializing_if = "Option::is_none", default)]
125 pub outline: Option<String>,
126 pub input_events: String,
127 pub input_excerpt: String,
128 #[serde(skip_serializing_if = "Option::is_none", default)]
129 pub speculated_output: Option<String>,
130 /// Whether the user provided consent for sampling this interaction.
131 #[serde(default, alias = "data_collection_permission")]
132 pub can_collect_data: bool,
133 #[serde(skip_serializing_if = "Option::is_none", default)]
134 pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
135 /// Info about the git repository state, only present when can_collect_data is true.
136 #[serde(skip_serializing_if = "Option::is_none", default)]
137 pub git_info: Option<PredictEditsGitInfo>,
138 /// The trigger for this request.
139 #[serde(default)]
140 pub trigger: PredictEditsRequestTrigger,
141}
142
143#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)]
144pub enum PredictEditsRequestTrigger {
145 Testing,
146 Diagnostics,
147 Cli,
148 #[default]
149 Other,
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct PredictEditsGitInfo {
154 /// SHA of git HEAD commit at time of prediction.
155 #[serde(skip_serializing_if = "Option::is_none", default)]
156 pub head_sha: Option<String>,
157 /// URL of the remote called `origin`.
158 #[serde(skip_serializing_if = "Option::is_none", default)]
159 pub remote_origin_url: Option<String>,
160 /// URL of the remote called `upstream`.
161 #[serde(skip_serializing_if = "Option::is_none", default)]
162 pub remote_upstream_url: Option<String>,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct PredictEditsResponse {
167 pub request_id: String,
168 pub output_excerpt: String,
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct AcceptEditPredictionBody {
173 pub request_id: String,
174}
175
176#[derive(Debug, Clone, Deserialize)]
177pub struct RejectEditPredictionsBody {
178 pub rejections: Vec<EditPredictionRejection>,
179}
180
181#[derive(Debug, Clone, Serialize)]
182pub struct RejectEditPredictionsBodyRef<'a> {
183 pub rejections: &'a [EditPredictionRejection],
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
187pub struct EditPredictionRejection {
188 pub request_id: String,
189 #[serde(default)]
190 pub reason: EditPredictionRejectReason,
191 pub was_shown: bool,
192}
193
194#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
195pub enum EditPredictionRejectReason {
196 /// New requests were triggered before this one completed
197 Canceled,
198 /// No edits returned
199 Empty,
200 /// Edits returned, but none remained after interpolation
201 InterpolatedEmpty,
202 /// The new prediction was preferred over the current one
203 Replaced,
204 /// The current prediction was preferred over the new one
205 CurrentPreferred,
206 /// The current prediction was discarded
207 #[default]
208 Discarded,
209}
210
211#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
212#[serde(rename_all = "snake_case")]
213pub enum CompletionIntent {
214 UserPrompt,
215 ToolResults,
216 ThreadSummarization,
217 ThreadContextSummarization,
218 CreateFile,
219 EditFile,
220 InlineAssist,
221 TerminalInlineAssist,
222 GenerateGitCommitMessage,
223}
224
225#[derive(Debug, Serialize, Deserialize)]
226pub struct CompletionBody {
227 #[serde(skip_serializing_if = "Option::is_none", default)]
228 pub thread_id: Option<String>,
229 #[serde(skip_serializing_if = "Option::is_none", default)]
230 pub prompt_id: Option<String>,
231 #[serde(skip_serializing_if = "Option::is_none", default)]
232 pub intent: Option<CompletionIntent>,
233 pub provider: LanguageModelProvider,
234 pub model: String,
235 pub provider_request: serde_json::Value,
236}
237
238#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
239#[serde(rename_all = "snake_case")]
240pub enum CompletionRequestStatus {
241 Queued {
242 position: usize,
243 },
244 Started,
245 Failed {
246 code: String,
247 message: String,
248 request_id: Uuid,
249 /// Retry duration in seconds.
250 retry_after: Option<f64>,
251 },
252 UsageUpdated {
253 amount: usize,
254 limit: UsageLimit,
255 },
256 ToolUseLimitReached,
257}
258
259#[derive(Serialize, Deserialize)]
260#[serde(rename_all = "snake_case")]
261pub enum CompletionEvent<T> {
262 Status(CompletionRequestStatus),
263 Event(T),
264}
265
266impl<T> CompletionEvent<T> {
267 pub fn into_status(self) -> Option<CompletionRequestStatus> {
268 match self {
269 Self::Status(status) => Some(status),
270 Self::Event(_) => None,
271 }
272 }
273
274 pub fn into_event(self) -> Option<T> {
275 match self {
276 Self::Event(event) => Some(event),
277 Self::Status(_) => None,
278 }
279 }
280}
281
282#[derive(Serialize, Deserialize)]
283pub struct WebSearchBody {
284 pub query: String,
285}
286
287#[derive(Debug, Serialize, Deserialize, Clone)]
288pub struct WebSearchResponse {
289 pub results: Vec<WebSearchResult>,
290}
291
292#[derive(Debug, Serialize, Deserialize, Clone)]
293pub struct WebSearchResult {
294 pub title: String,
295 pub url: String,
296 pub text: String,
297}
298
299#[derive(Serialize, Deserialize)]
300pub struct CountTokensBody {
301 pub provider: LanguageModelProvider,
302 pub model: String,
303 pub provider_request: serde_json::Value,
304}
305
306#[derive(Serialize, Deserialize)]
307pub struct CountTokensResponse {
308 pub tokens: usize,
309}
310
311#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
312pub struct LanguageModelId(pub Arc<str>);
313
314impl std::fmt::Display for LanguageModelId {
315 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316 write!(f, "{}", self.0)
317 }
318}
319
320#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
321pub struct LanguageModel {
322 pub provider: LanguageModelProvider,
323 pub id: LanguageModelId,
324 pub display_name: String,
325 pub max_token_count: usize,
326 pub max_token_count_in_max_mode: Option<usize>,
327 pub max_output_tokens: usize,
328 pub supports_tools: bool,
329 pub supports_images: bool,
330 pub supports_thinking: bool,
331 #[serde(default)]
332 pub supports_streaming_tools: bool,
333 /// Only used by OpenAI and xAI.
334 #[serde(default)]
335 pub supports_parallel_tool_calls: bool,
336}
337
338#[derive(Debug, Serialize, Deserialize)]
339pub struct ListModelsResponse {
340 pub models: Vec<LanguageModel>,
341 pub default_model: Option<LanguageModelId>,
342 pub default_fast_model: Option<LanguageModelId>,
343 pub recommended_models: Vec<LanguageModelId>,
344}
345
346#[derive(Debug, PartialEq, Serialize, Deserialize)]
347pub struct CurrentUsage {
348 pub edit_predictions: UsageData,
349}
350
351#[derive(Debug, PartialEq, Serialize, Deserialize)]
352pub struct UsageData {
353 pub used: u32,
354 pub limit: UsageLimit,
355}
356
357#[cfg(test)]
358mod tests {
359 use pretty_assertions::assert_eq;
360 use serde_json::json;
361
362 use super::*;
363
364 #[test]
365 fn test_plan_v2_deserialize_snake_case() {
366 let plan = serde_json::from_value::<PlanV2>(json!("zed_free")).unwrap();
367 assert_eq!(plan, PlanV2::ZedFree);
368
369 let plan = serde_json::from_value::<PlanV2>(json!("zed_pro")).unwrap();
370 assert_eq!(plan, PlanV2::ZedPro);
371
372 let plan = serde_json::from_value::<PlanV2>(json!("zed_pro_trial")).unwrap();
373 assert_eq!(plan, PlanV2::ZedProTrial);
374 }
375
376 #[test]
377 fn test_usage_limit_from_str() {
378 let limit = UsageLimit::from_str("unlimited").unwrap();
379 assert!(matches!(limit, UsageLimit::Unlimited));
380
381 let limit = UsageLimit::from_str(&0.to_string()).unwrap();
382 assert!(matches!(limit, UsageLimit::Limited(0)));
383
384 let limit = UsageLimit::from_str(&50.to_string()).unwrap();
385 assert!(matches!(limit, UsageLimit::Limited(50)));
386
387 for value in ["not_a_number", "50xyz"] {
388 let limit = UsageLimit::from_str(value);
389 assert!(limit.is_err());
390 }
391 }
392}