1use std::str::FromStr;
2use std::sync::Arc;
3
4use anyhow::Context as _;
5use serde::{Deserialize, Serialize};
6use strum::{Display, EnumIter, EnumString};
7use uuid::Uuid;
8
9/// The name of the header used to indicate which version of Zed the client is running.
10pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
11
12/// The name of the header used to indicate when a request failed due to an
13/// expired LLM token.
14///
15/// The client may use this as a signal to refresh the token.
16pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
17
18/// The name of the header used to indicate what plan the user is currently on.
19pub const CURRENT_PLAN_HEADER_NAME: &str = "x-zed-plan";
20
21/// The name of the header used to indicate the usage limit for model requests.
22pub const MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-model-requests-usage-limit";
23
24/// The name of the header used to indicate the usage amount for model requests.
25pub const MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-model-requests-usage-amount";
26
27/// The name of the header used to indicate the usage limit for edit predictions.
28pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
29
30/// The name of the header used to indicate the usage amount for edit predictions.
31pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
32
33/// The name of the header used to indicate the resource for which the subscription limit has been reached.
34pub const SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME: &str = "x-zed-subscription-limit-resource";
35
36pub const MODEL_REQUESTS_RESOURCE_HEADER_VALUE: &str = "model_requests";
37pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
38
39/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached.
40pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached";
41
42/// The name of the header used to indicate the the minimum required Zed version.
43///
44/// This can be used to force a Zed upgrade in order to continue communicating
45/// with the LLM service.
46pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
47
48/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
49pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
50 "x-zed-client-supports-status-messages";
51
52/// The name of the header used by the server to indicate to the client that it supports sending status messages.
53pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
54 "x-zed-server-supports-status-messages";
55
56#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
57#[serde(rename_all = "snake_case")]
58pub enum UsageLimit {
59 Limited(i32),
60 Unlimited,
61}
62
63impl FromStr for UsageLimit {
64 type Err = anyhow::Error;
65
66 fn from_str(value: &str) -> Result<Self, Self::Err> {
67 match value {
68 "unlimited" => Ok(Self::Unlimited),
69 limit => limit
70 .parse::<i32>()
71 .map(Self::Limited)
72 .context("failed to parse limit"),
73 }
74 }
75}
76
77#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
78#[serde(rename_all = "snake_case")]
79pub enum Plan {
80 #[default]
81 #[serde(alias = "Free")]
82 ZedFree,
83 #[serde(alias = "ZedPro")]
84 ZedPro,
85 #[serde(alias = "ZedProTrial")]
86 ZedProTrial,
87}
88
89impl Plan {
90 pub fn as_str(&self) -> &'static str {
91 match self {
92 Plan::ZedFree => "zed_free",
93 Plan::ZedPro => "zed_pro",
94 Plan::ZedProTrial => "zed_pro_trial",
95 }
96 }
97
98 pub fn model_requests_limit(&self) -> UsageLimit {
99 match self {
100 Plan::ZedPro => UsageLimit::Limited(500),
101 Plan::ZedProTrial => UsageLimit::Limited(150),
102 Plan::ZedFree => UsageLimit::Limited(50),
103 }
104 }
105
106 pub fn edit_predictions_limit(&self) -> UsageLimit {
107 match self {
108 Plan::ZedPro => UsageLimit::Unlimited,
109 Plan::ZedProTrial => UsageLimit::Unlimited,
110 Plan::ZedFree => UsageLimit::Limited(2_000),
111 }
112 }
113}
114
115impl FromStr for Plan {
116 type Err = anyhow::Error;
117
118 fn from_str(value: &str) -> Result<Self, Self::Err> {
119 match value {
120 "zed_free" => Ok(Plan::ZedFree),
121 "zed_pro" => Ok(Plan::ZedPro),
122 "zed_pro_trial" => Ok(Plan::ZedProTrial),
123 plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
124 }
125 }
126}
127
128#[derive(
129 Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
130)]
131#[serde(rename_all = "snake_case")]
132#[strum(serialize_all = "snake_case")]
133pub enum LanguageModelProvider {
134 Anthropic,
135 OpenAi,
136 Google,
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct PredictEditsBody {
141 pub input_events: String,
142 pub input_excerpt: String,
143 /// Whether the user provided consent for sampling this interaction.
144 #[serde(default, alias = "data_collection_permission")]
145 pub can_collect_data: bool,
146 /// Note that this is no longer sent, in favor of `PredictEditsAdditionalContext`.
147 #[serde(skip_serializing_if = "Option::is_none", default)]
148 pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
149 /// Info about the git repository state, only present when can_collect_data is true. Note that
150 /// this is no longer sent, in favor of `PredictEditsAdditionalContext`.
151 #[serde(skip_serializing_if = "Option::is_none", default)]
152 pub git_info: Option<PredictEditsGitInfo>,
153}
154
155/// Additional context only stored when can_collect_data is true for the corresponding edit
156/// predictions request.
157#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct PredictEditsAdditionalContext {
159 /// Path to the file in the repository that contains the input excerpt.
160 pub input_path: String,
161 /// Cursor position within the file that contains the input excerpt.
162 pub cursor_point: Point,
163 /// Cursor offset in bytes within the file that contains the input excerpt.
164 pub cursor_offset: usize,
165 #[serde(flatten)]
166 pub git_info: PredictEditsGitInfo,
167 /// Diagnostic near the cursor position.
168 #[serde(skip_serializing_if = "Vec::is_empty", default)]
169 pub diagnostic_groups: Vec<(String, Box<serde_json::value::RawValue>)>,
170 /// True if the diagnostics were truncated.
171 pub diagnostic_groups_truncated: bool,
172 /// Recently active files that may be within this repository.
173 #[serde(skip_serializing_if = "Vec::is_empty", default)]
174 pub recent_files: Vec<PredictEditsRecentFile>,
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct PredictEditsGitInfo {
179 /// SHA of git HEAD commit at time of prediction.
180 #[serde(skip_serializing_if = "Option::is_none", default)]
181 pub head_sha: Option<String>,
182 /// URL of the remote called `origin`.
183 #[serde(skip_serializing_if = "Option::is_none", default)]
184 pub remote_origin_url: Option<String>,
185 /// URL of the remote called `upstream`.
186 #[serde(skip_serializing_if = "Option::is_none", default)]
187 pub remote_upstream_url: Option<String>,
188}
189
190/// A zero-indexed point in a text buffer consisting of a row and column.
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct Point {
193 pub row: u32,
194 pub column: u32,
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct PredictEditsRecentFile {
199 /// Path to a file within the repository.
200 pub path: String,
201 /// Most recent cursor position with the file.
202 pub cursor_point: Point,
203 /// Milliseconds between the editor for this file being active and the request time.
204 pub active_to_now_ms: u32,
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct PredictEditsResponse {
209 pub request_id: Uuid,
210 pub output_excerpt: String,
211}
212
213#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct AcceptEditPredictionBody {
215 pub request_id: Uuid,
216}
217
218#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
219#[serde(rename_all = "snake_case")]
220pub enum CompletionMode {
221 Normal,
222 Max,
223}
224
225#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
226#[serde(rename_all = "snake_case")]
227pub enum CompletionIntent {
228 UserPrompt,
229 ToolResults,
230 ThreadSummarization,
231 ThreadContextSummarization,
232 CreateFile,
233 EditFile,
234 InlineAssist,
235 TerminalInlineAssist,
236 GenerateGitCommitMessage,
237}
238
239#[derive(Debug, Serialize, Deserialize)]
240pub struct CompletionBody {
241 #[serde(skip_serializing_if = "Option::is_none", default)]
242 pub thread_id: Option<String>,
243 #[serde(skip_serializing_if = "Option::is_none", default)]
244 pub prompt_id: Option<String>,
245 #[serde(skip_serializing_if = "Option::is_none", default)]
246 pub intent: Option<CompletionIntent>,
247 #[serde(skip_serializing_if = "Option::is_none", default)]
248 pub mode: Option<CompletionMode>,
249 pub provider: LanguageModelProvider,
250 pub model: String,
251 pub provider_request: serde_json::Value,
252}
253
254#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
255#[serde(rename_all = "snake_case")]
256pub enum CompletionRequestStatus {
257 Queued {
258 position: usize,
259 },
260 Started,
261 Failed {
262 code: String,
263 message: String,
264 request_id: Uuid,
265 /// Retry duration in seconds.
266 retry_after: Option<f64>,
267 },
268 UsageUpdated {
269 amount: usize,
270 limit: UsageLimit,
271 },
272 ToolUseLimitReached,
273}
274
275#[derive(Serialize, Deserialize)]
276#[serde(rename_all = "snake_case")]
277pub enum CompletionEvent<T> {
278 Status(CompletionRequestStatus),
279 Event(T),
280}
281
282impl<T> CompletionEvent<T> {
283 pub fn into_status(self) -> Option<CompletionRequestStatus> {
284 match self {
285 Self::Status(status) => Some(status),
286 Self::Event(_) => None,
287 }
288 }
289
290 pub fn into_event(self) -> Option<T> {
291 match self {
292 Self::Event(event) => Some(event),
293 Self::Status(_) => None,
294 }
295 }
296}
297
298#[derive(Serialize, Deserialize)]
299pub struct WebSearchBody {
300 pub query: String,
301}
302
303#[derive(Debug, Serialize, Deserialize, Clone)]
304pub struct WebSearchResponse {
305 pub results: Vec<WebSearchResult>,
306}
307
308#[derive(Debug, Serialize, Deserialize, Clone)]
309pub struct WebSearchResult {
310 pub title: String,
311 pub url: String,
312 pub text: String,
313}
314
315#[derive(Serialize, Deserialize)]
316pub struct CountTokensBody {
317 pub provider: LanguageModelProvider,
318 pub model: String,
319 pub provider_request: serde_json::Value,
320}
321
322#[derive(Serialize, Deserialize)]
323pub struct CountTokensResponse {
324 pub tokens: usize,
325}
326
327#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
328pub struct LanguageModelId(pub Arc<str>);
329
330impl std::fmt::Display for LanguageModelId {
331 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332 write!(f, "{}", self.0)
333 }
334}
335
336#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
337pub struct LanguageModel {
338 pub provider: LanguageModelProvider,
339 pub id: LanguageModelId,
340 pub display_name: String,
341 pub max_token_count: usize,
342 pub max_token_count_in_max_mode: Option<usize>,
343 pub max_output_tokens: usize,
344 pub supports_tools: bool,
345 pub supports_images: bool,
346 pub supports_thinking: bool,
347 pub supports_max_mode: bool,
348}
349
350#[derive(Debug, Serialize, Deserialize)]
351pub struct ListModelsResponse {
352 pub models: Vec<LanguageModel>,
353 pub default_model: LanguageModelId,
354 pub default_fast_model: LanguageModelId,
355 pub recommended_models: Vec<LanguageModelId>,
356}
357
358#[derive(Debug, Serialize, Deserialize)]
359pub struct GetSubscriptionResponse {
360 pub plan: Plan,
361 pub usage: Option<CurrentUsage>,
362}
363
364#[derive(Debug, PartialEq, Serialize, Deserialize)]
365pub struct CurrentUsage {
366 pub model_requests: UsageData,
367 pub edit_predictions: UsageData,
368}
369
370#[derive(Debug, PartialEq, Serialize, Deserialize)]
371pub struct UsageData {
372 pub used: u32,
373 pub limit: UsageLimit,
374}
375
376#[cfg(test)]
377mod tests {
378 use pretty_assertions::assert_eq;
379 use serde_json::json;
380
381 use super::*;
382
383 #[test]
384 fn test_plan_deserialize_snake_case() {
385 let plan = serde_json::from_value::<Plan>(json!("zed_free")).unwrap();
386 assert_eq!(plan, Plan::ZedFree);
387
388 let plan = serde_json::from_value::<Plan>(json!("zed_pro")).unwrap();
389 assert_eq!(plan, Plan::ZedPro);
390
391 let plan = serde_json::from_value::<Plan>(json!("zed_pro_trial")).unwrap();
392 assert_eq!(plan, Plan::ZedProTrial);
393 }
394
395 #[test]
396 fn test_plan_deserialize_aliases() {
397 let plan = serde_json::from_value::<Plan>(json!("Free")).unwrap();
398 assert_eq!(plan, Plan::ZedFree);
399
400 let plan = serde_json::from_value::<Plan>(json!("ZedPro")).unwrap();
401 assert_eq!(plan, Plan::ZedPro);
402
403 let plan = serde_json::from_value::<Plan>(json!("ZedProTrial")).unwrap();
404 assert_eq!(plan, Plan::ZedProTrial);
405 }
406
407 #[test]
408 fn test_usage_limit_from_str() {
409 let limit = UsageLimit::from_str("unlimited").unwrap();
410 assert!(matches!(limit, UsageLimit::Unlimited));
411
412 let limit = UsageLimit::from_str(&0.to_string()).unwrap();
413 assert!(matches!(limit, UsageLimit::Limited(0)));
414
415 let limit = UsageLimit::from_str(&50.to_string()).unwrap();
416 assert!(matches!(limit, UsageLimit::Limited(50)));
417
418 for value in ["not_a_number", "50xyz"] {
419 let limit = UsageLimit::from_str(value);
420 assert!(limit.is_err());
421 }
422 }
423}