1pub mod predict_edits_v3;
2
3use std::str::FromStr;
4use std::sync::Arc;
5
6use anyhow::Context as _;
7use serde::{Deserialize, Serialize};
8use strum::{Display, EnumIter, EnumString};
9use uuid::Uuid;
10
11/// The name of the header used to indicate which version of Zed the client is running.
12pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
13
14/// The name of the header used to indicate when a request failed due to an
15/// expired LLM token.
16///
17/// The client may use this as a signal to refresh the token.
18pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
19
20/// The name of the header used to indicate what plan the user is currently on.
21pub const CURRENT_PLAN_HEADER_NAME: &str = "x-zed-plan";
22
23/// The name of the header used to indicate the usage limit for model requests.
24pub const MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-model-requests-usage-limit";
25
26/// The name of the header used to indicate the usage amount for model requests.
27pub const MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-model-requests-usage-amount";
28
29/// The name of the header used to indicate the usage limit for edit predictions.
30pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
31
32/// The name of the header used to indicate the usage amount for edit predictions.
33pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
34
35/// The name of the header used to indicate the resource for which the subscription limit has been reached.
36pub const SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME: &str = "x-zed-subscription-limit-resource";
37
38pub const MODEL_REQUESTS_RESOURCE_HEADER_VALUE: &str = "model_requests";
39pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
40
41/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached.
42pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached";
43
44/// The name of the header used to indicate the minimum required Zed version.
45///
46/// This can be used to force a Zed upgrade in order to continue communicating
47/// with the LLM service.
48pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
49
50/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
51pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
52 "x-zed-client-supports-status-messages";
53
54/// The name of the header used by the server to indicate to the client that it supports sending status messages.
55pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
56 "x-zed-server-supports-status-messages";
57
58/// The name of the header used by the client to indicate that it supports receiving xAI models.
59pub const CLIENT_SUPPORTS_X_AI_HEADER_NAME: &str = "x-zed-client-supports-x-ai";
60
61#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
62#[serde(rename_all = "snake_case")]
63pub enum UsageLimit {
64 Limited(i32),
65 Unlimited,
66}
67
68impl FromStr for UsageLimit {
69 type Err = anyhow::Error;
70
71 fn from_str(value: &str) -> Result<Self, Self::Err> {
72 match value {
73 "unlimited" => Ok(Self::Unlimited),
74 limit => limit
75 .parse::<i32>()
76 .map(Self::Limited)
77 .context("failed to parse limit"),
78 }
79 }
80}
81
82#[derive(Debug, Clone, Copy, PartialEq)]
83pub enum Plan {
84 V1(PlanV1),
85 V2(PlanV2),
86}
87
88impl Plan {
89 pub const fn is_v2(&self) -> bool {
90 matches!(self, Self::V2(_))
91 }
92}
93
94#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
95#[serde(rename_all = "snake_case")]
96pub enum PlanV1 {
97 #[default]
98 #[serde(alias = "Free")]
99 ZedFree,
100 #[serde(alias = "ZedPro")]
101 ZedPro,
102 #[serde(alias = "ZedProTrial")]
103 ZedProTrial,
104}
105
106impl FromStr for PlanV1 {
107 type Err = anyhow::Error;
108
109 fn from_str(value: &str) -> Result<Self, Self::Err> {
110 match value {
111 "zed_free" => Ok(Self::ZedFree),
112 "zed_pro" => Ok(Self::ZedPro),
113 "zed_pro_trial" => Ok(Self::ZedProTrial),
114 plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
115 }
116 }
117}
118
119#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
120#[serde(rename_all = "snake_case")]
121pub enum PlanV2 {
122 #[default]
123 ZedFree,
124 ZedPro,
125 ZedProTrial,
126}
127
128impl FromStr for PlanV2 {
129 type Err = anyhow::Error;
130
131 fn from_str(value: &str) -> Result<Self, Self::Err> {
132 match value {
133 "zed_free" => Ok(Self::ZedFree),
134 "zed_pro" => Ok(Self::ZedPro),
135 "zed_pro_trial" => Ok(Self::ZedProTrial),
136 plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
137 }
138 }
139}
140
141#[derive(
142 Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
143)]
144#[serde(rename_all = "snake_case")]
145#[strum(serialize_all = "snake_case")]
146pub enum LanguageModelProvider {
147 Anthropic,
148 OpenAi,
149 Google,
150 XAi,
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct PredictEditsBody {
155 #[serde(skip_serializing_if = "Option::is_none", default)]
156 pub outline: Option<String>,
157 pub input_events: String,
158 pub input_excerpt: String,
159 #[serde(skip_serializing_if = "Option::is_none", default)]
160 pub speculated_output: Option<String>,
161 /// Whether the user provided consent for sampling this interaction.
162 #[serde(default, alias = "data_collection_permission")]
163 pub can_collect_data: bool,
164 #[serde(skip_serializing_if = "Option::is_none", default)]
165 pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
166 /// Info about the git repository state, only present when can_collect_data is true.
167 #[serde(skip_serializing_if = "Option::is_none", default)]
168 pub git_info: Option<PredictEditsGitInfo>,
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct PredictEditsGitInfo {
173 /// SHA of git HEAD commit at time of prediction.
174 #[serde(skip_serializing_if = "Option::is_none", default)]
175 pub head_sha: Option<String>,
176 /// URL of the remote called `origin`.
177 #[serde(skip_serializing_if = "Option::is_none", default)]
178 pub remote_origin_url: Option<String>,
179 /// URL of the remote called `upstream`.
180 #[serde(skip_serializing_if = "Option::is_none", default)]
181 pub remote_upstream_url: Option<String>,
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct PredictEditsResponse {
186 pub request_id: Uuid,
187 pub output_excerpt: String,
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct AcceptEditPredictionBody {
192 pub request_id: Uuid,
193}
194
195#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
196#[serde(rename_all = "snake_case")]
197pub enum CompletionMode {
198 Normal,
199 Max,
200}
201
202#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
203#[serde(rename_all = "snake_case")]
204pub enum CompletionIntent {
205 UserPrompt,
206 ToolResults,
207 ThreadSummarization,
208 ThreadContextSummarization,
209 CreateFile,
210 EditFile,
211 InlineAssist,
212 TerminalInlineAssist,
213 GenerateGitCommitMessage,
214}
215
216#[derive(Debug, Serialize, Deserialize)]
217pub struct CompletionBody {
218 #[serde(skip_serializing_if = "Option::is_none", default)]
219 pub thread_id: Option<String>,
220 #[serde(skip_serializing_if = "Option::is_none", default)]
221 pub prompt_id: Option<String>,
222 #[serde(skip_serializing_if = "Option::is_none", default)]
223 pub intent: Option<CompletionIntent>,
224 #[serde(skip_serializing_if = "Option::is_none", default)]
225 pub mode: Option<CompletionMode>,
226 pub provider: LanguageModelProvider,
227 pub model: String,
228 pub provider_request: serde_json::Value,
229}
230
231#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
232#[serde(rename_all = "snake_case")]
233pub enum CompletionRequestStatus {
234 Queued {
235 position: usize,
236 },
237 Started,
238 Failed {
239 code: String,
240 message: String,
241 request_id: Uuid,
242 /// Retry duration in seconds.
243 retry_after: Option<f64>,
244 },
245 UsageUpdated {
246 amount: usize,
247 limit: UsageLimit,
248 },
249 ToolUseLimitReached,
250}
251
252#[derive(Serialize, Deserialize)]
253#[serde(rename_all = "snake_case")]
254pub enum CompletionEvent<T> {
255 Status(CompletionRequestStatus),
256 Event(T),
257}
258
259impl<T> CompletionEvent<T> {
260 pub fn into_status(self) -> Option<CompletionRequestStatus> {
261 match self {
262 Self::Status(status) => Some(status),
263 Self::Event(_) => None,
264 }
265 }
266
267 pub fn into_event(self) -> Option<T> {
268 match self {
269 Self::Event(event) => Some(event),
270 Self::Status(_) => None,
271 }
272 }
273}
274
275#[derive(Serialize, Deserialize)]
276pub struct WebSearchBody {
277 pub query: String,
278}
279
280#[derive(Debug, Serialize, Deserialize, Clone)]
281pub struct WebSearchResponse {
282 pub results: Vec<WebSearchResult>,
283}
284
285#[derive(Debug, Serialize, Deserialize, Clone)]
286pub struct WebSearchResult {
287 pub title: String,
288 pub url: String,
289 pub text: String,
290}
291
292#[derive(Serialize, Deserialize)]
293pub struct CountTokensBody {
294 pub provider: LanguageModelProvider,
295 pub model: String,
296 pub provider_request: serde_json::Value,
297}
298
299#[derive(Serialize, Deserialize)]
300pub struct CountTokensResponse {
301 pub tokens: usize,
302}
303
304#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
305pub struct LanguageModelId(pub Arc<str>);
306
307impl std::fmt::Display for LanguageModelId {
308 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309 write!(f, "{}", self.0)
310 }
311}
312
313#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
314pub struct LanguageModel {
315 pub provider: LanguageModelProvider,
316 pub id: LanguageModelId,
317 pub display_name: String,
318 pub max_token_count: usize,
319 pub max_token_count_in_max_mode: Option<usize>,
320 pub max_output_tokens: usize,
321 pub supports_tools: bool,
322 pub supports_images: bool,
323 pub supports_thinking: bool,
324 pub supports_max_mode: bool,
325 // only used by OpenAI and xAI
326 #[serde(default)]
327 pub supports_parallel_tool_calls: bool,
328}
329
330#[derive(Debug, Serialize, Deserialize)]
331pub struct ListModelsResponse {
332 pub models: Vec<LanguageModel>,
333 pub default_model: Option<LanguageModelId>,
334 pub default_fast_model: Option<LanguageModelId>,
335 pub recommended_models: Vec<LanguageModelId>,
336}
337
338#[derive(Debug, Serialize, Deserialize)]
339pub struct GetSubscriptionResponse {
340 pub plan: PlanV1,
341 pub usage: Option<CurrentUsage>,
342}
343
344#[derive(Debug, PartialEq, Serialize, Deserialize)]
345pub struct CurrentUsage {
346 pub model_requests: UsageData,
347 pub edit_predictions: UsageData,
348}
349
350#[derive(Debug, PartialEq, Serialize, Deserialize)]
351pub struct UsageData {
352 pub used: u32,
353 pub limit: UsageLimit,
354}
355
356#[cfg(test)]
357mod tests {
358 use pretty_assertions::assert_eq;
359 use serde_json::json;
360
361 use super::*;
362
363 #[test]
364 fn test_plan_v1_deserialize_snake_case() {
365 let plan = serde_json::from_value::<PlanV1>(json!("zed_free")).unwrap();
366 assert_eq!(plan, PlanV1::ZedFree);
367
368 let plan = serde_json::from_value::<PlanV1>(json!("zed_pro")).unwrap();
369 assert_eq!(plan, PlanV1::ZedPro);
370
371 let plan = serde_json::from_value::<PlanV1>(json!("zed_pro_trial")).unwrap();
372 assert_eq!(plan, PlanV1::ZedProTrial);
373 }
374
375 #[test]
376 fn test_plan_v1_deserialize_aliases() {
377 let plan = serde_json::from_value::<PlanV1>(json!("Free")).unwrap();
378 assert_eq!(plan, PlanV1::ZedFree);
379
380 let plan = serde_json::from_value::<PlanV1>(json!("ZedPro")).unwrap();
381 assert_eq!(plan, PlanV1::ZedPro);
382
383 let plan = serde_json::from_value::<PlanV1>(json!("ZedProTrial")).unwrap();
384 assert_eq!(plan, PlanV1::ZedProTrial);
385 }
386
387 #[test]
388 fn test_plan_v2_deserialize_snake_case() {
389 let plan = serde_json::from_value::<PlanV2>(json!("zed_free")).unwrap();
390 assert_eq!(plan, PlanV2::ZedFree);
391
392 let plan = serde_json::from_value::<PlanV2>(json!("zed_pro")).unwrap();
393 assert_eq!(plan, PlanV2::ZedPro);
394
395 let plan = serde_json::from_value::<PlanV2>(json!("zed_pro_trial")).unwrap();
396 assert_eq!(plan, PlanV2::ZedProTrial);
397 }
398
399 #[test]
400 fn test_usage_limit_from_str() {
401 let limit = UsageLimit::from_str("unlimited").unwrap();
402 assert!(matches!(limit, UsageLimit::Unlimited));
403
404 let limit = UsageLimit::from_str(&0.to_string()).unwrap();
405 assert!(matches!(limit, UsageLimit::Limited(0)));
406
407 let limit = UsageLimit::from_str(&50.to_string()).unwrap();
408 assert!(matches!(limit, UsageLimit::Limited(50)));
409
410 for value in ["not_a_number", "50xyz"] {
411 let limit = UsageLimit::from_str(value);
412 assert!(limit.is_err());
413 }
414 }
415}