1pub mod predict_edits_v3;
2
3use std::str::FromStr;
4use std::sync::Arc;
5
6use anyhow::Context as _;
7use serde::{Deserialize, Serialize};
8use strum::{Display, EnumIter, EnumString};
9use uuid::Uuid;
10
11/// The name of the header used to indicate which version of Zed the client is running.
12pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
13
14/// The name of the header used to indicate when a request failed due to an
15/// expired LLM token.
16///
17/// The client may use this as a signal to refresh the token.
18pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
19
20/// The name of the header used to indicate the usage limit for edit predictions.
21pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
22
23/// The name of the header used to indicate the usage amount for edit predictions.
24pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
25
26pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
27
28/// The name of the header used to indicate the minimum required Zed version.
29///
30/// This can be used to force a Zed upgrade in order to continue communicating
31/// with the LLM service.
32pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
33
34/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
35pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
36 "x-zed-client-supports-status-messages";
37
38/// The name of the header used by the server to indicate to the client that it supports sending status messages.
39pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
40 "x-zed-server-supports-status-messages";
41
42/// The name of the header used by the client to indicate that it supports receiving xAI models.
43pub const CLIENT_SUPPORTS_X_AI_HEADER_NAME: &str = "x-zed-client-supports-x-ai";
44
45/// The maximum number of edit predictions that can be rejected per request.
46pub const MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST: usize = 100;
47
48#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
49#[serde(rename_all = "snake_case")]
50pub enum UsageLimit {
51 Limited(i32),
52 Unlimited,
53}
54
55impl FromStr for UsageLimit {
56 type Err = anyhow::Error;
57
58 fn from_str(value: &str) -> Result<Self, Self::Err> {
59 match value {
60 "unlimited" => Ok(Self::Unlimited),
61 limit => limit
62 .parse::<i32>()
63 .map(Self::Limited)
64 .context("failed to parse limit"),
65 }
66 }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq)]
70pub enum Plan {
71 V2(PlanV2),
72}
73
74impl Plan {
75 pub fn is_v2(&self) -> bool {
76 matches!(self, Self::V2(_))
77 }
78}
79
80#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
81#[serde(rename_all = "snake_case")]
82pub enum PlanV2 {
83 #[default]
84 ZedFree,
85 ZedPro,
86 ZedProTrial,
87}
88
89impl FromStr for PlanV2 {
90 type Err = anyhow::Error;
91
92 fn from_str(value: &str) -> Result<Self, Self::Err> {
93 match value {
94 "zed_free" => Ok(Self::ZedFree),
95 "zed_pro" => Ok(Self::ZedPro),
96 "zed_pro_trial" => Ok(Self::ZedProTrial),
97 plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
98 }
99 }
100}
101
102#[derive(
103 Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
104)]
105#[serde(rename_all = "snake_case")]
106#[strum(serialize_all = "snake_case")]
107pub enum LanguageModelProvider {
108 Anthropic,
109 OpenAi,
110 Google,
111 XAi,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct PredictEditsBody {
116 #[serde(skip_serializing_if = "Option::is_none", default)]
117 pub outline: Option<String>,
118 pub input_events: String,
119 pub input_excerpt: String,
120 #[serde(skip_serializing_if = "Option::is_none", default)]
121 pub speculated_output: Option<String>,
122 /// Whether the user provided consent for sampling this interaction.
123 #[serde(default, alias = "data_collection_permission")]
124 pub can_collect_data: bool,
125 #[serde(skip_serializing_if = "Option::is_none", default)]
126 pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
127 /// Info about the git repository state, only present when can_collect_data is true.
128 #[serde(skip_serializing_if = "Option::is_none", default)]
129 pub git_info: Option<PredictEditsGitInfo>,
130 /// The trigger for this request.
131 #[serde(default)]
132 pub trigger: PredictEditsRequestTrigger,
133}
134
135#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)]
136pub enum PredictEditsRequestTrigger {
137 Diagnostics,
138 Cli,
139 #[default]
140 Other,
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct PredictEditsGitInfo {
145 /// SHA of git HEAD commit at time of prediction.
146 #[serde(skip_serializing_if = "Option::is_none", default)]
147 pub head_sha: Option<String>,
148 /// URL of the remote called `origin`.
149 #[serde(skip_serializing_if = "Option::is_none", default)]
150 pub remote_origin_url: Option<String>,
151 /// URL of the remote called `upstream`.
152 #[serde(skip_serializing_if = "Option::is_none", default)]
153 pub remote_upstream_url: Option<String>,
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct PredictEditsResponse {
158 pub request_id: String,
159 pub output_excerpt: String,
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct AcceptEditPredictionBody {
164 pub request_id: String,
165}
166
167#[derive(Debug, Clone, Deserialize)]
168pub struct RejectEditPredictionsBody {
169 pub rejections: Vec<EditPredictionRejection>,
170}
171
172#[derive(Debug, Clone, Serialize)]
173pub struct RejectEditPredictionsBodyRef<'a> {
174 pub rejections: &'a [EditPredictionRejection],
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
178pub struct EditPredictionRejection {
179 pub request_id: String,
180 #[serde(default)]
181 pub reason: EditPredictionRejectReason,
182 pub was_shown: bool,
183}
184
185#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
186pub enum EditPredictionRejectReason {
187 /// New requests were triggered before this one completed
188 Canceled,
189 /// No edits returned
190 Empty,
191 /// Edits returned, but none remained after interpolation
192 InterpolatedEmpty,
193 /// The new prediction was preferred over the current one
194 Replaced,
195 /// The current prediction was preferred over the new one
196 CurrentPreferred,
197 /// The current prediction was discarded
198 #[default]
199 Discarded,
200}
201
202#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
203#[serde(rename_all = "snake_case")]
204pub enum CompletionIntent {
205 UserPrompt,
206 ToolResults,
207 ThreadSummarization,
208 ThreadContextSummarization,
209 CreateFile,
210 EditFile,
211 InlineAssist,
212 TerminalInlineAssist,
213 GenerateGitCommitMessage,
214}
215
216#[derive(Debug, Serialize, Deserialize)]
217pub struct CompletionBody {
218 #[serde(skip_serializing_if = "Option::is_none", default)]
219 pub thread_id: Option<String>,
220 #[serde(skip_serializing_if = "Option::is_none", default)]
221 pub prompt_id: Option<String>,
222 #[serde(skip_serializing_if = "Option::is_none", default)]
223 pub intent: Option<CompletionIntent>,
224 pub provider: LanguageModelProvider,
225 pub model: String,
226 pub provider_request: serde_json::Value,
227}
228
229#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
230#[serde(rename_all = "snake_case")]
231pub enum CompletionRequestStatus {
232 Queued {
233 position: usize,
234 },
235 Started,
236 Failed {
237 code: String,
238 message: String,
239 request_id: Uuid,
240 /// Retry duration in seconds.
241 retry_after: Option<f64>,
242 },
243 UsageUpdated {
244 amount: usize,
245 limit: UsageLimit,
246 },
247 ToolUseLimitReached,
248}
249
250#[derive(Serialize, Deserialize)]
251#[serde(rename_all = "snake_case")]
252pub enum CompletionEvent<T> {
253 Status(CompletionRequestStatus),
254 Event(T),
255}
256
257impl<T> CompletionEvent<T> {
258 pub fn into_status(self) -> Option<CompletionRequestStatus> {
259 match self {
260 Self::Status(status) => Some(status),
261 Self::Event(_) => None,
262 }
263 }
264
265 pub fn into_event(self) -> Option<T> {
266 match self {
267 Self::Event(event) => Some(event),
268 Self::Status(_) => None,
269 }
270 }
271}
272
273#[derive(Serialize, Deserialize)]
274pub struct WebSearchBody {
275 pub query: String,
276}
277
278#[derive(Debug, Serialize, Deserialize, Clone)]
279pub struct WebSearchResponse {
280 pub results: Vec<WebSearchResult>,
281}
282
283#[derive(Debug, Serialize, Deserialize, Clone)]
284pub struct WebSearchResult {
285 pub title: String,
286 pub url: String,
287 pub text: String,
288}
289
290#[derive(Serialize, Deserialize)]
291pub struct CountTokensBody {
292 pub provider: LanguageModelProvider,
293 pub model: String,
294 pub provider_request: serde_json::Value,
295}
296
297#[derive(Serialize, Deserialize)]
298pub struct CountTokensResponse {
299 pub tokens: usize,
300}
301
302#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
303pub struct LanguageModelId(pub Arc<str>);
304
305impl std::fmt::Display for LanguageModelId {
306 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307 write!(f, "{}", self.0)
308 }
309}
310
311#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
312pub struct LanguageModel {
313 pub provider: LanguageModelProvider,
314 pub id: LanguageModelId,
315 pub display_name: String,
316 pub max_token_count: usize,
317 pub max_token_count_in_max_mode: Option<usize>,
318 pub max_output_tokens: usize,
319 pub supports_tools: bool,
320 pub supports_images: bool,
321 pub supports_thinking: bool,
322 #[serde(default)]
323 pub supports_streaming_tools: bool,
324 /// Only used by OpenAI and xAI.
325 #[serde(default)]
326 pub supports_parallel_tool_calls: bool,
327}
328
329#[derive(Debug, Serialize, Deserialize)]
330pub struct ListModelsResponse {
331 pub models: Vec<LanguageModel>,
332 pub default_model: Option<LanguageModelId>,
333 pub default_fast_model: Option<LanguageModelId>,
334 pub recommended_models: Vec<LanguageModelId>,
335}
336
337#[derive(Debug, PartialEq, Serialize, Deserialize)]
338pub struct CurrentUsage {
339 pub edit_predictions: UsageData,
340}
341
342#[derive(Debug, PartialEq, Serialize, Deserialize)]
343pub struct UsageData {
344 pub used: u32,
345 pub limit: UsageLimit,
346}
347
348#[cfg(test)]
349mod tests {
350 use pretty_assertions::assert_eq;
351 use serde_json::json;
352
353 use super::*;
354
355 #[test]
356 fn test_plan_v2_deserialize_snake_case() {
357 let plan = serde_json::from_value::<PlanV2>(json!("zed_free")).unwrap();
358 assert_eq!(plan, PlanV2::ZedFree);
359
360 let plan = serde_json::from_value::<PlanV2>(json!("zed_pro")).unwrap();
361 assert_eq!(plan, PlanV2::ZedPro);
362
363 let plan = serde_json::from_value::<PlanV2>(json!("zed_pro_trial")).unwrap();
364 assert_eq!(plan, PlanV2::ZedProTrial);
365 }
366
367 #[test]
368 fn test_usage_limit_from_str() {
369 let limit = UsageLimit::from_str("unlimited").unwrap();
370 assert!(matches!(limit, UsageLimit::Unlimited));
371
372 let limit = UsageLimit::from_str(&0.to_string()).unwrap();
373 assert!(matches!(limit, UsageLimit::Limited(0)));
374
375 let limit = UsageLimit::from_str(&50.to_string()).unwrap();
376 assert!(matches!(limit, UsageLimit::Limited(50)));
377
378 for value in ["not_a_number", "50xyz"] {
379 let limit = UsageLimit::from_str(value);
380 assert!(limit.is_err());
381 }
382 }
383}