1use std::str::FromStr;
2use std::sync::Arc;
3
4use anyhow::Context as _;
5use serde::{Deserialize, Serialize};
6use strum::{Display, EnumIter, EnumString};
7use uuid::Uuid;
8
9/// The name of the header used to indicate which version of Zed the client is running.
10pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
11
12/// The name of the header used to indicate when a request failed due to an
13/// expired LLM token.
14///
15/// The client may use this as a signal to refresh the token.
16pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
17
18/// The name of the header used to indicate what plan the user is currently on.
19pub const CURRENT_PLAN_HEADER_NAME: &str = "x-zed-plan";
20
21/// The name of the header used to indicate the usage limit for model requests.
22pub const MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-model-requests-usage-limit";
23
24/// The name of the header used to indicate the usage amount for model requests.
25pub const MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-model-requests-usage-amount";
26
27/// The name of the header used to indicate the usage limit for edit predictions.
28pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
29
30/// The name of the header used to indicate the usage amount for edit predictions.
31pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
32
33/// The name of the header used to indicate the resource for which the subscription limit has been reached.
34pub const SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME: &str = "x-zed-subscription-limit-resource";
35
36pub const MODEL_REQUESTS_RESOURCE_HEADER_VALUE: &str = "model_requests";
37pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
38
39/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached.
40pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached";
41
42/// The name of the header used to indicate the the minimum required Zed version.
43///
44/// This can be used to force a Zed upgrade in order to continue communicating
45/// with the LLM service.
46pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
47
48/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
49pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
50 "x-zed-client-supports-status-messages";
51
52/// The name of the header used by the server to indicate to the client that it supports sending status messages.
53pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
54 "x-zed-server-supports-status-messages";
55
56#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
57#[serde(rename_all = "snake_case")]
58pub enum UsageLimit {
59 Limited(i32),
60 Unlimited,
61}
62
63impl FromStr for UsageLimit {
64 type Err = anyhow::Error;
65
66 fn from_str(value: &str) -> Result<Self, Self::Err> {
67 match value {
68 "unlimited" => Ok(Self::Unlimited),
69 limit => limit
70 .parse::<i32>()
71 .map(Self::Limited)
72 .context("failed to parse limit"),
73 }
74 }
75}
76
77#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
78#[serde(rename_all = "snake_case")]
79pub enum Plan {
80 #[default]
81 #[serde(alias = "Free")]
82 ZedFree,
83 #[serde(alias = "ZedPro")]
84 ZedPro,
85 #[serde(alias = "ZedProTrial")]
86 ZedProTrial,
87}
88
89impl FromStr for Plan {
90 type Err = anyhow::Error;
91
92 fn from_str(value: &str) -> Result<Self, Self::Err> {
93 match value {
94 "zed_free" => Ok(Plan::ZedFree),
95 "zed_pro" => Ok(Plan::ZedPro),
96 "zed_pro_trial" => Ok(Plan::ZedProTrial),
97 plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
98 }
99 }
100}
101
102#[derive(
103 Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
104)]
105#[serde(rename_all = "snake_case")]
106#[strum(serialize_all = "snake_case")]
107pub enum LanguageModelProvider {
108 Anthropic,
109 OpenAi,
110 Google,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct PredictEditsBody {
115 #[serde(skip_serializing_if = "Option::is_none", default)]
116 pub outline: Option<String>,
117 pub input_events: String,
118 pub input_excerpt: String,
119 #[serde(skip_serializing_if = "Option::is_none", default)]
120 pub speculated_output: Option<String>,
121 /// Whether the user provided consent for sampling this interaction.
122 #[serde(default, alias = "data_collection_permission")]
123 pub can_collect_data: bool,
124 #[serde(skip_serializing_if = "Option::is_none", default)]
125 pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
126 /// Info about the git repository state, only present when can_collect_data is true.
127 #[serde(skip_serializing_if = "Option::is_none", default)]
128 pub git_info: Option<PredictEditsGitInfo>,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct PredictEditsGitInfo {
133 /// SHA of git HEAD commit at time of prediction.
134 #[serde(skip_serializing_if = "Option::is_none", default)]
135 pub head_sha: Option<String>,
136 /// URL of the remote called `origin`.
137 #[serde(skip_serializing_if = "Option::is_none", default)]
138 pub remote_origin_url: Option<String>,
139 /// URL of the remote called `upstream`.
140 #[serde(skip_serializing_if = "Option::is_none", default)]
141 pub remote_upstream_url: Option<String>,
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct PredictEditsResponse {
146 pub request_id: Uuid,
147 pub output_excerpt: String,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct AcceptEditPredictionBody {
152 pub request_id: Uuid,
153}
154
155#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
156#[serde(rename_all = "snake_case")]
157pub enum CompletionMode {
158 Normal,
159 Max,
160}
161
162#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
163#[serde(rename_all = "snake_case")]
164pub enum CompletionIntent {
165 UserPrompt,
166 ToolResults,
167 ThreadSummarization,
168 ThreadContextSummarization,
169 CreateFile,
170 EditFile,
171 InlineAssist,
172 TerminalInlineAssist,
173 GenerateGitCommitMessage,
174}
175
176#[derive(Debug, Serialize, Deserialize)]
177pub struct CompletionBody {
178 #[serde(skip_serializing_if = "Option::is_none", default)]
179 pub thread_id: Option<String>,
180 #[serde(skip_serializing_if = "Option::is_none", default)]
181 pub prompt_id: Option<String>,
182 #[serde(skip_serializing_if = "Option::is_none", default)]
183 pub intent: Option<CompletionIntent>,
184 #[serde(skip_serializing_if = "Option::is_none", default)]
185 pub mode: Option<CompletionMode>,
186 pub provider: LanguageModelProvider,
187 pub model: String,
188 pub provider_request: serde_json::Value,
189}
190
191#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
192#[serde(rename_all = "snake_case")]
193pub enum CompletionRequestStatus {
194 Queued {
195 position: usize,
196 },
197 Started,
198 Failed {
199 code: String,
200 message: String,
201 request_id: Uuid,
202 /// Retry duration in seconds.
203 retry_after: Option<f64>,
204 },
205 UsageUpdated {
206 amount: usize,
207 limit: UsageLimit,
208 },
209 ToolUseLimitReached,
210}
211
212#[derive(Serialize, Deserialize)]
213#[serde(rename_all = "snake_case")]
214pub enum CompletionEvent<T> {
215 Status(CompletionRequestStatus),
216 Event(T),
217}
218
219impl<T> CompletionEvent<T> {
220 pub fn into_status(self) -> Option<CompletionRequestStatus> {
221 match self {
222 Self::Status(status) => Some(status),
223 Self::Event(_) => None,
224 }
225 }
226
227 pub fn into_event(self) -> Option<T> {
228 match self {
229 Self::Event(event) => Some(event),
230 Self::Status(_) => None,
231 }
232 }
233}
234
235#[derive(Serialize, Deserialize)]
236pub struct WebSearchBody {
237 pub query: String,
238}
239
240#[derive(Debug, Serialize, Deserialize, Clone)]
241pub struct WebSearchResponse {
242 pub results: Vec<WebSearchResult>,
243}
244
245#[derive(Debug, Serialize, Deserialize, Clone)]
246pub struct WebSearchResult {
247 pub title: String,
248 pub url: String,
249 pub text: String,
250}
251
252#[derive(Serialize, Deserialize)]
253pub struct CountTokensBody {
254 pub provider: LanguageModelProvider,
255 pub model: String,
256 pub provider_request: serde_json::Value,
257}
258
259#[derive(Serialize, Deserialize)]
260pub struct CountTokensResponse {
261 pub tokens: usize,
262}
263
264#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
265pub struct LanguageModelId(pub Arc<str>);
266
267impl std::fmt::Display for LanguageModelId {
268 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
269 write!(f, "{}", self.0)
270 }
271}
272
273#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
274pub struct LanguageModel {
275 pub provider: LanguageModelProvider,
276 pub id: LanguageModelId,
277 pub display_name: String,
278 pub max_token_count: usize,
279 pub max_token_count_in_max_mode: Option<usize>,
280 pub max_output_tokens: usize,
281 pub supports_tools: bool,
282 pub supports_images: bool,
283 pub supports_thinking: bool,
284 pub supports_max_mode: bool,
285}
286
287#[derive(Debug, Serialize, Deserialize)]
288pub struct ListModelsResponse {
289 pub models: Vec<LanguageModel>,
290 pub default_model: LanguageModelId,
291 pub default_fast_model: LanguageModelId,
292 pub recommended_models: Vec<LanguageModelId>,
293}
294
295#[derive(Debug, Serialize, Deserialize)]
296pub struct GetSubscriptionResponse {
297 pub plan: Plan,
298 pub usage: Option<CurrentUsage>,
299}
300
301#[derive(Debug, PartialEq, Serialize, Deserialize)]
302pub struct CurrentUsage {
303 pub model_requests: UsageData,
304 pub edit_predictions: UsageData,
305}
306
307#[derive(Debug, PartialEq, Serialize, Deserialize)]
308pub struct UsageData {
309 pub used: u32,
310 pub limit: UsageLimit,
311}
312
313#[cfg(test)]
314mod tests {
315 use pretty_assertions::assert_eq;
316 use serde_json::json;
317
318 use super::*;
319
320 #[test]
321 fn test_plan_deserialize_snake_case() {
322 let plan = serde_json::from_value::<Plan>(json!("zed_free")).unwrap();
323 assert_eq!(plan, Plan::ZedFree);
324
325 let plan = serde_json::from_value::<Plan>(json!("zed_pro")).unwrap();
326 assert_eq!(plan, Plan::ZedPro);
327
328 let plan = serde_json::from_value::<Plan>(json!("zed_pro_trial")).unwrap();
329 assert_eq!(plan, Plan::ZedProTrial);
330 }
331
332 #[test]
333 fn test_plan_deserialize_aliases() {
334 let plan = serde_json::from_value::<Plan>(json!("Free")).unwrap();
335 assert_eq!(plan, Plan::ZedFree);
336
337 let plan = serde_json::from_value::<Plan>(json!("ZedPro")).unwrap();
338 assert_eq!(plan, Plan::ZedPro);
339
340 let plan = serde_json::from_value::<Plan>(json!("ZedProTrial")).unwrap();
341 assert_eq!(plan, Plan::ZedProTrial);
342 }
343
344 #[test]
345 fn test_usage_limit_from_str() {
346 let limit = UsageLimit::from_str("unlimited").unwrap();
347 assert!(matches!(limit, UsageLimit::Unlimited));
348
349 let limit = UsageLimit::from_str(&0.to_string()).unwrap();
350 assert!(matches!(limit, UsageLimit::Limited(0)));
351
352 let limit = UsageLimit::from_str(&50.to_string()).unwrap();
353 assert!(matches!(limit, UsageLimit::Limited(50)));
354
355 for value in ["not_a_number", "50xyz"] {
356 let limit = UsageLimit::from_str(value);
357 assert!(limit.is_err());
358 }
359 }
360}