1pub mod predict_edits_v3;
2
3use std::str::FromStr;
4use std::sync::Arc;
5
6use anyhow::Context as _;
7use serde::{Deserialize, Serialize};
8use strum::{Display, EnumIter, EnumString};
9use uuid::Uuid;
10
11/// The name of the header used to indicate which version of Zed the client is running.
12pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
13
14/// The name of the header used to indicate when a request failed due to an
15/// expired LLM token.
16///
17/// The client may use this as a signal to refresh the token.
18pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
19
20/// The name of the header used to indicate what plan the user is currently on.
21pub const CURRENT_PLAN_HEADER_NAME: &str = "x-zed-plan";
22
23/// The name of the header used to indicate the usage limit for model requests.
24pub const MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-model-requests-usage-limit";
25
26/// The name of the header used to indicate the usage amount for model requests.
27pub const MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-model-requests-usage-amount";
28
29/// The name of the header used to indicate the usage limit for edit predictions.
30pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
31
32/// The name of the header used to indicate the usage amount for edit predictions.
33pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
34
35/// The name of the header used to indicate the resource for which the subscription limit has been reached.
36pub const SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME: &str = "x-zed-subscription-limit-resource";
37
38pub const MODEL_REQUESTS_RESOURCE_HEADER_VALUE: &str = "model_requests";
39pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
40
41/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached.
42pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached";
43
44/// The name of the header used to indicate the minimum required Zed version.
45///
46/// This can be used to force a Zed upgrade in order to continue communicating
47/// with the LLM service.
48pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
49
50/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
51pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
52 "x-zed-client-supports-status-messages";
53
54/// The name of the header used by the server to indicate to the client that it supports sending status messages.
55pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
56 "x-zed-server-supports-status-messages";
57
58#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
59#[serde(rename_all = "snake_case")]
60pub enum UsageLimit {
61 Limited(i32),
62 Unlimited,
63}
64
65impl FromStr for UsageLimit {
66 type Err = anyhow::Error;
67
68 fn from_str(value: &str) -> Result<Self, Self::Err> {
69 match value {
70 "unlimited" => Ok(Self::Unlimited),
71 limit => limit
72 .parse::<i32>()
73 .map(Self::Limited)
74 .context("failed to parse limit"),
75 }
76 }
77}
78
79#[derive(Debug, Clone, Copy, PartialEq)]
80pub enum Plan {
81 V1(PlanV1),
82 V2(PlanV2),
83}
84
85impl Plan {
86 pub fn is_v2(&self) -> bool {
87 matches!(self, Self::V2(_))
88 }
89}
90
91#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
92#[serde(rename_all = "snake_case")]
93pub enum PlanV1 {
94 #[default]
95 #[serde(alias = "Free")]
96 ZedFree,
97 #[serde(alias = "ZedPro")]
98 ZedPro,
99 #[serde(alias = "ZedProTrial")]
100 ZedProTrial,
101}
102
103impl FromStr for PlanV1 {
104 type Err = anyhow::Error;
105
106 fn from_str(value: &str) -> Result<Self, Self::Err> {
107 match value {
108 "zed_free" => Ok(Self::ZedFree),
109 "zed_pro" => Ok(Self::ZedPro),
110 "zed_pro_trial" => Ok(Self::ZedProTrial),
111 plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
112 }
113 }
114}
115
116#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
117#[serde(rename_all = "snake_case")]
118pub enum PlanV2 {
119 #[default]
120 ZedFree,
121 ZedPro,
122 ZedProTrial,
123}
124
125impl FromStr for PlanV2 {
126 type Err = anyhow::Error;
127
128 fn from_str(value: &str) -> Result<Self, Self::Err> {
129 match value {
130 "zed_free" => Ok(Self::ZedFree),
131 "zed_pro" => Ok(Self::ZedPro),
132 "zed_pro_trial" => Ok(Self::ZedProTrial),
133 plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
134 }
135 }
136}
137
138#[derive(
139 Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
140)]
141#[serde(rename_all = "snake_case")]
142#[strum(serialize_all = "snake_case")]
143pub enum LanguageModelProvider {
144 Anthropic,
145 OpenAi,
146 Google,
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct PredictEditsBody {
151 #[serde(skip_serializing_if = "Option::is_none", default)]
152 pub outline: Option<String>,
153 pub input_events: String,
154 pub input_excerpt: String,
155 #[serde(skip_serializing_if = "Option::is_none", default)]
156 pub speculated_output: Option<String>,
157 /// Whether the user provided consent for sampling this interaction.
158 #[serde(default, alias = "data_collection_permission")]
159 pub can_collect_data: bool,
160 #[serde(skip_serializing_if = "Option::is_none", default)]
161 pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
162 /// Info about the git repository state, only present when can_collect_data is true.
163 #[serde(skip_serializing_if = "Option::is_none", default)]
164 pub git_info: Option<PredictEditsGitInfo>,
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct PredictEditsGitInfo {
169 /// SHA of git HEAD commit at time of prediction.
170 #[serde(skip_serializing_if = "Option::is_none", default)]
171 pub head_sha: Option<String>,
172 /// URL of the remote called `origin`.
173 #[serde(skip_serializing_if = "Option::is_none", default)]
174 pub remote_origin_url: Option<String>,
175 /// URL of the remote called `upstream`.
176 #[serde(skip_serializing_if = "Option::is_none", default)]
177 pub remote_upstream_url: Option<String>,
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct PredictEditsResponse {
182 pub request_id: Uuid,
183 pub output_excerpt: String,
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
187pub struct AcceptEditPredictionBody {
188 pub request_id: Uuid,
189}
190
191#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
192#[serde(rename_all = "snake_case")]
193pub enum CompletionMode {
194 Normal,
195 Max,
196}
197
198#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
199#[serde(rename_all = "snake_case")]
200pub enum CompletionIntent {
201 UserPrompt,
202 ToolResults,
203 ThreadSummarization,
204 ThreadContextSummarization,
205 CreateFile,
206 EditFile,
207 InlineAssist,
208 TerminalInlineAssist,
209 GenerateGitCommitMessage,
210}
211
212#[derive(Debug, Serialize, Deserialize)]
213pub struct CompletionBody {
214 #[serde(skip_serializing_if = "Option::is_none", default)]
215 pub thread_id: Option<String>,
216 #[serde(skip_serializing_if = "Option::is_none", default)]
217 pub prompt_id: Option<String>,
218 #[serde(skip_serializing_if = "Option::is_none", default)]
219 pub intent: Option<CompletionIntent>,
220 #[serde(skip_serializing_if = "Option::is_none", default)]
221 pub mode: Option<CompletionMode>,
222 pub provider: LanguageModelProvider,
223 pub model: String,
224 pub provider_request: serde_json::Value,
225}
226
227#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
228#[serde(rename_all = "snake_case")]
229pub enum CompletionRequestStatus {
230 Queued {
231 position: usize,
232 },
233 Started,
234 Failed {
235 code: String,
236 message: String,
237 request_id: Uuid,
238 /// Retry duration in seconds.
239 retry_after: Option<f64>,
240 },
241 UsageUpdated {
242 amount: usize,
243 limit: UsageLimit,
244 },
245 ToolUseLimitReached,
246}
247
248#[derive(Serialize, Deserialize)]
249#[serde(rename_all = "snake_case")]
250pub enum CompletionEvent<T> {
251 Status(CompletionRequestStatus),
252 Event(T),
253}
254
255impl<T> CompletionEvent<T> {
256 pub fn into_status(self) -> Option<CompletionRequestStatus> {
257 match self {
258 Self::Status(status) => Some(status),
259 Self::Event(_) => None,
260 }
261 }
262
263 pub fn into_event(self) -> Option<T> {
264 match self {
265 Self::Event(event) => Some(event),
266 Self::Status(_) => None,
267 }
268 }
269}
270
271#[derive(Serialize, Deserialize)]
272pub struct WebSearchBody {
273 pub query: String,
274}
275
276#[derive(Debug, Serialize, Deserialize, Clone)]
277pub struct WebSearchResponse {
278 pub results: Vec<WebSearchResult>,
279}
280
281#[derive(Debug, Serialize, Deserialize, Clone)]
282pub struct WebSearchResult {
283 pub title: String,
284 pub url: String,
285 pub text: String,
286}
287
288#[derive(Serialize, Deserialize)]
289pub struct CountTokensBody {
290 pub provider: LanguageModelProvider,
291 pub model: String,
292 pub provider_request: serde_json::Value,
293}
294
295#[derive(Serialize, Deserialize)]
296pub struct CountTokensResponse {
297 pub tokens: usize,
298}
299
300#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
301pub struct LanguageModelId(pub Arc<str>);
302
303impl std::fmt::Display for LanguageModelId {
304 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305 write!(f, "{}", self.0)
306 }
307}
308
309#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
310pub struct LanguageModel {
311 pub provider: LanguageModelProvider,
312 pub id: LanguageModelId,
313 pub display_name: String,
314 pub max_token_count: usize,
315 pub max_token_count_in_max_mode: Option<usize>,
316 pub max_output_tokens: usize,
317 pub supports_tools: bool,
318 pub supports_images: bool,
319 pub supports_thinking: bool,
320 pub supports_max_mode: bool,
321}
322
323#[derive(Debug, Serialize, Deserialize)]
324pub struct ListModelsResponse {
325 pub models: Vec<LanguageModel>,
326 pub default_model: Option<LanguageModelId>,
327 pub default_fast_model: Option<LanguageModelId>,
328 pub recommended_models: Vec<LanguageModelId>,
329}
330
331#[derive(Debug, Serialize, Deserialize)]
332pub struct GetSubscriptionResponse {
333 pub plan: PlanV1,
334 pub usage: Option<CurrentUsage>,
335}
336
337#[derive(Debug, PartialEq, Serialize, Deserialize)]
338pub struct CurrentUsage {
339 pub model_requests: UsageData,
340 pub edit_predictions: UsageData,
341}
342
343#[derive(Debug, PartialEq, Serialize, Deserialize)]
344pub struct UsageData {
345 pub used: u32,
346 pub limit: UsageLimit,
347}
348
349#[cfg(test)]
350mod tests {
351 use pretty_assertions::assert_eq;
352 use serde_json::json;
353
354 use super::*;
355
356 #[test]
357 fn test_plan_v1_deserialize_snake_case() {
358 let plan = serde_json::from_value::<PlanV1>(json!("zed_free")).unwrap();
359 assert_eq!(plan, PlanV1::ZedFree);
360
361 let plan = serde_json::from_value::<PlanV1>(json!("zed_pro")).unwrap();
362 assert_eq!(plan, PlanV1::ZedPro);
363
364 let plan = serde_json::from_value::<PlanV1>(json!("zed_pro_trial")).unwrap();
365 assert_eq!(plan, PlanV1::ZedProTrial);
366 }
367
368 #[test]
369 fn test_plan_v1_deserialize_aliases() {
370 let plan = serde_json::from_value::<PlanV1>(json!("Free")).unwrap();
371 assert_eq!(plan, PlanV1::ZedFree);
372
373 let plan = serde_json::from_value::<PlanV1>(json!("ZedPro")).unwrap();
374 assert_eq!(plan, PlanV1::ZedPro);
375
376 let plan = serde_json::from_value::<PlanV1>(json!("ZedProTrial")).unwrap();
377 assert_eq!(plan, PlanV1::ZedProTrial);
378 }
379
380 #[test]
381 fn test_plan_v2_deserialize_snake_case() {
382 let plan = serde_json::from_value::<PlanV2>(json!("zed_free")).unwrap();
383 assert_eq!(plan, PlanV2::ZedFree);
384
385 let plan = serde_json::from_value::<PlanV2>(json!("zed_pro")).unwrap();
386 assert_eq!(plan, PlanV2::ZedPro);
387
388 let plan = serde_json::from_value::<PlanV2>(json!("zed_pro_trial")).unwrap();
389 assert_eq!(plan, PlanV2::ZedProTrial);
390 }
391
392 #[test]
393 fn test_usage_limit_from_str() {
394 let limit = UsageLimit::from_str("unlimited").unwrap();
395 assert!(matches!(limit, UsageLimit::Unlimited));
396
397 let limit = UsageLimit::from_str(&0.to_string()).unwrap();
398 assert!(matches!(limit, UsageLimit::Limited(0)));
399
400 let limit = UsageLimit::from_str(&50.to_string()).unwrap();
401 assert!(matches!(limit, UsageLimit::Limited(50)));
402
403 for value in ["not_a_number", "50xyz"] {
404 let limit = UsageLimit::from_str(value);
405 assert!(limit.is_err());
406 }
407 }
408}