1use std::str::FromStr;
2use std::sync::Arc;
3
4use anyhow::Context as _;
5use serde::{Deserialize, Serialize};
6use strum::{Display, EnumIter, EnumString};
7use uuid::Uuid;
8
9/// The name of the header used to indicate which version of Zed the client is running.
10pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
11
12/// The name of the header used to indicate when a request failed due to an
13/// expired LLM token.
14///
15/// The client may use this as a signal to refresh the token.
16pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
17
18/// The name of the header used to indicate what plan the user is currently on.
19pub const CURRENT_PLAN_HEADER_NAME: &str = "x-zed-plan";
20
21/// The name of the header used to indicate the usage limit for model requests.
22pub const MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-model-requests-usage-limit";
23
24/// The name of the header used to indicate the usage amount for model requests.
25pub const MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-model-requests-usage-amount";
26
27/// The name of the header used to indicate the usage limit for edit predictions.
28pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
29
30/// The name of the header used to indicate the usage amount for edit predictions.
31pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
32
33/// The name of the header used to indicate the resource for which the subscription limit has been reached.
34pub const SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME: &str = "x-zed-subscription-limit-resource";
35
36pub const MODEL_REQUESTS_RESOURCE_HEADER_VALUE: &str = "model_requests";
37pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
38
39/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached.
40pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached";
41
42/// The name of the header used to indicate the the minimum required Zed version.
43///
44/// This can be used to force a Zed upgrade in order to continue communicating
45/// with the LLM service.
46pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
47
48/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
49pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
50 "x-zed-client-supports-status-messages";
51
52/// The name of the header used by the server to indicate to the client that it supports sending status messages.
53pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
54 "x-zed-server-supports-status-messages";
55
56#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
57#[serde(rename_all = "snake_case")]
58pub enum UsageLimit {
59 Limited(i32),
60 Unlimited,
61}
62
63impl FromStr for UsageLimit {
64 type Err = anyhow::Error;
65
66 fn from_str(value: &str) -> Result<Self, Self::Err> {
67 match value {
68 "unlimited" => Ok(Self::Unlimited),
69 limit => limit
70 .parse::<i32>()
71 .map(Self::Limited)
72 .context("failed to parse limit"),
73 }
74 }
75}
76
77#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
78#[serde(rename_all = "snake_case")]
79pub enum Plan {
80 #[default]
81 #[serde(alias = "Free")]
82 ZedFree,
83 #[serde(alias = "ZedPro")]
84 ZedPro,
85 ZedProV2,
86 #[serde(alias = "ZedProTrial")]
87 ZedProTrial,
88 ZedProTrialV2,
89}
90
91impl FromStr for Plan {
92 type Err = anyhow::Error;
93
94 fn from_str(value: &str) -> Result<Self, Self::Err> {
95 match value {
96 "zed_free" => Ok(Plan::ZedFree),
97 "zed_pro" => Ok(Plan::ZedPro),
98 "zed_pro_trial" => Ok(Plan::ZedProTrial),
99 plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
100 }
101 }
102}
103
104#[derive(
105 Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
106)]
107#[serde(rename_all = "snake_case")]
108#[strum(serialize_all = "snake_case")]
109pub enum LanguageModelProvider {
110 Anthropic,
111 OpenAi,
112 Google,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct PredictEditsBody {
117 #[serde(skip_serializing_if = "Option::is_none", default)]
118 pub outline: Option<String>,
119 pub input_events: String,
120 pub input_excerpt: String,
121 #[serde(skip_serializing_if = "Option::is_none", default)]
122 pub speculated_output: Option<String>,
123 /// Whether the user provided consent for sampling this interaction.
124 #[serde(default, alias = "data_collection_permission")]
125 pub can_collect_data: bool,
126 #[serde(skip_serializing_if = "Option::is_none", default)]
127 pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
128 /// Info about the git repository state, only present when can_collect_data is true.
129 #[serde(skip_serializing_if = "Option::is_none", default)]
130 pub git_info: Option<PredictEditsGitInfo>,
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct PredictEditsGitInfo {
135 /// SHA of git HEAD commit at time of prediction.
136 #[serde(skip_serializing_if = "Option::is_none", default)]
137 pub head_sha: Option<String>,
138 /// URL of the remote called `origin`.
139 #[serde(skip_serializing_if = "Option::is_none", default)]
140 pub remote_origin_url: Option<String>,
141 /// URL of the remote called `upstream`.
142 #[serde(skip_serializing_if = "Option::is_none", default)]
143 pub remote_upstream_url: Option<String>,
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct PredictEditsResponse {
148 pub request_id: Uuid,
149 pub output_excerpt: String,
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct AcceptEditPredictionBody {
154 pub request_id: Uuid,
155}
156
157#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
158#[serde(rename_all = "snake_case")]
159pub enum CompletionMode {
160 Normal,
161 Max,
162}
163
164#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
165#[serde(rename_all = "snake_case")]
166pub enum CompletionIntent {
167 UserPrompt,
168 ToolResults,
169 ThreadSummarization,
170 ThreadContextSummarization,
171 CreateFile,
172 EditFile,
173 InlineAssist,
174 TerminalInlineAssist,
175 GenerateGitCommitMessage,
176}
177
178#[derive(Debug, Serialize, Deserialize)]
179pub struct CompletionBody {
180 #[serde(skip_serializing_if = "Option::is_none", default)]
181 pub thread_id: Option<String>,
182 #[serde(skip_serializing_if = "Option::is_none", default)]
183 pub prompt_id: Option<String>,
184 #[serde(skip_serializing_if = "Option::is_none", default)]
185 pub intent: Option<CompletionIntent>,
186 #[serde(skip_serializing_if = "Option::is_none", default)]
187 pub mode: Option<CompletionMode>,
188 pub provider: LanguageModelProvider,
189 pub model: String,
190 pub provider_request: serde_json::Value,
191}
192
193#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
194#[serde(rename_all = "snake_case")]
195pub enum CompletionRequestStatus {
196 Queued {
197 position: usize,
198 },
199 Started,
200 Failed {
201 code: String,
202 message: String,
203 request_id: Uuid,
204 /// Retry duration in seconds.
205 retry_after: Option<f64>,
206 },
207 UsageUpdated {
208 amount: usize,
209 limit: UsageLimit,
210 },
211 ToolUseLimitReached,
212}
213
214#[derive(Serialize, Deserialize)]
215#[serde(rename_all = "snake_case")]
216pub enum CompletionEvent<T> {
217 Status(CompletionRequestStatus),
218 Event(T),
219}
220
221impl<T> CompletionEvent<T> {
222 pub fn into_status(self) -> Option<CompletionRequestStatus> {
223 match self {
224 Self::Status(status) => Some(status),
225 Self::Event(_) => None,
226 }
227 }
228
229 pub fn into_event(self) -> Option<T> {
230 match self {
231 Self::Event(event) => Some(event),
232 Self::Status(_) => None,
233 }
234 }
235}
236
237#[derive(Serialize, Deserialize)]
238pub struct WebSearchBody {
239 pub query: String,
240}
241
242#[derive(Debug, Serialize, Deserialize, Clone)]
243pub struct WebSearchResponse {
244 pub results: Vec<WebSearchResult>,
245}
246
247#[derive(Debug, Serialize, Deserialize, Clone)]
248pub struct WebSearchResult {
249 pub title: String,
250 pub url: String,
251 pub text: String,
252}
253
254#[derive(Serialize, Deserialize)]
255pub struct CountTokensBody {
256 pub provider: LanguageModelProvider,
257 pub model: String,
258 pub provider_request: serde_json::Value,
259}
260
261#[derive(Serialize, Deserialize)]
262pub struct CountTokensResponse {
263 pub tokens: usize,
264}
265
266#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
267pub struct LanguageModelId(pub Arc<str>);
268
269impl std::fmt::Display for LanguageModelId {
270 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
271 write!(f, "{}", self.0)
272 }
273}
274
275#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
276pub struct LanguageModel {
277 pub provider: LanguageModelProvider,
278 pub id: LanguageModelId,
279 pub display_name: String,
280 pub max_token_count: usize,
281 pub max_token_count_in_max_mode: Option<usize>,
282 pub max_output_tokens: usize,
283 pub supports_tools: bool,
284 pub supports_images: bool,
285 pub supports_thinking: bool,
286 pub supports_max_mode: bool,
287}
288
289#[derive(Debug, Serialize, Deserialize)]
290pub struct ListModelsResponse {
291 pub models: Vec<LanguageModel>,
292 pub default_model: LanguageModelId,
293 pub default_fast_model: LanguageModelId,
294 pub recommended_models: Vec<LanguageModelId>,
295}
296
297#[derive(Debug, Serialize, Deserialize)]
298pub struct GetSubscriptionResponse {
299 pub plan: Plan,
300 pub usage: Option<CurrentUsage>,
301}
302
303#[derive(Debug, PartialEq, Serialize, Deserialize)]
304pub struct CurrentUsage {
305 pub model_requests: UsageData,
306 pub edit_predictions: UsageData,
307}
308
309#[derive(Debug, PartialEq, Serialize, Deserialize)]
310pub struct UsageData {
311 pub used: u32,
312 pub limit: UsageLimit,
313}
314
315#[cfg(test)]
316mod tests {
317 use pretty_assertions::assert_eq;
318 use serde_json::json;
319
320 use super::*;
321
322 #[test]
323 fn test_plan_deserialize_snake_case() {
324 let plan = serde_json::from_value::<Plan>(json!("zed_free")).unwrap();
325 assert_eq!(plan, Plan::ZedFree);
326
327 let plan = serde_json::from_value::<Plan>(json!("zed_pro")).unwrap();
328 assert_eq!(plan, Plan::ZedPro);
329
330 let plan = serde_json::from_value::<Plan>(json!("zed_pro_trial")).unwrap();
331 assert_eq!(plan, Plan::ZedProTrial);
332
333 let plan = serde_json::from_value::<Plan>(json!("zed_pro_v2")).unwrap();
334 assert_eq!(plan, Plan::ZedProV2);
335
336 let plan = serde_json::from_value::<Plan>(json!("zed_pro_trial_v2")).unwrap();
337 assert_eq!(plan, Plan::ZedProTrialV2);
338 }
339
340 #[test]
341 fn test_plan_deserialize_aliases() {
342 let plan = serde_json::from_value::<Plan>(json!("Free")).unwrap();
343 assert_eq!(plan, Plan::ZedFree);
344
345 let plan = serde_json::from_value::<Plan>(json!("ZedPro")).unwrap();
346 assert_eq!(plan, Plan::ZedPro);
347
348 let plan = serde_json::from_value::<Plan>(json!("ZedProTrial")).unwrap();
349 assert_eq!(plan, Plan::ZedProTrial);
350 }
351
352 #[test]
353 fn test_usage_limit_from_str() {
354 let limit = UsageLimit::from_str("unlimited").unwrap();
355 assert!(matches!(limit, UsageLimit::Unlimited));
356
357 let limit = UsageLimit::from_str(&0.to_string()).unwrap();
358 assert!(matches!(limit, UsageLimit::Limited(0)));
359
360 let limit = UsageLimit::from_str(&50.to_string()).unwrap();
361 assert!(matches!(limit, UsageLimit::Limited(50)));
362
363 for value in ["not_a_number", "50xyz"] {
364 let limit = UsageLimit::from_str(value);
365 assert!(limit.is_err());
366 }
367 }
368}