1pub mod predict_edits_v3;
2
3use std::str::FromStr;
4use std::sync::Arc;
5
6use anyhow::Context as _;
7use serde::{Deserialize, Serialize};
8use strum::{Display, EnumIter, EnumString};
9use uuid::Uuid;
10
11/// The name of the header used to indicate which version of Zed the client is running.
12pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
13
14/// The name of the header used to indicate when a request failed due to an
15/// expired LLM token.
16///
17/// The client may use this as a signal to refresh the token.
18pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
19
20/// The name of the header used to indicate when a request failed due to an outdated LLM token.
21///
22/// A token is considered "outdated" when we can't parse the claims (e.g., after adding a new required claim).
23///
24/// This is distinct from [`EXPIRED_LLM_TOKEN_HEADER_NAME`] which indicates the token's time-based validity has passed.
25/// An outdated token means the token's structure is incompatible with the current server expectations.
26pub const OUTDATED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-outdated-token";
27
28/// The name of the header used to indicate the usage limit for edit predictions.
29pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
30
31/// The name of the header used to indicate the usage amount for edit predictions.
32pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
33
34pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
35
36/// The name of the header used to indicate the minimum required Zed version.
37///
38/// This can be used to force a Zed upgrade in order to continue communicating
39/// with the LLM service.
40pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
41
42/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
43pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
44 "x-zed-client-supports-status-messages";
45
46/// The name of the header used by the server to indicate to the client that it supports sending status messages.
47pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
48 "x-zed-server-supports-status-messages";
49
50/// The name of the header used by the client to indicate that it supports receiving xAI models.
51pub const CLIENT_SUPPORTS_X_AI_HEADER_NAME: &str = "x-zed-client-supports-x-ai";
52
53/// The maximum number of edit predictions that can be rejected per request.
54pub const MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST: usize = 100;
55
56#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
57#[serde(rename_all = "snake_case")]
58pub enum UsageLimit {
59 Limited(i32),
60 Unlimited,
61}
62
63impl FromStr for UsageLimit {
64 type Err = anyhow::Error;
65
66 fn from_str(value: &str) -> Result<Self, Self::Err> {
67 match value {
68 "unlimited" => Ok(Self::Unlimited),
69 limit => limit
70 .parse::<i32>()
71 .map(Self::Limited)
72 .context("failed to parse limit"),
73 }
74 }
75}
76
77#[derive(
78 Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
79)]
80#[serde(rename_all = "snake_case")]
81#[strum(serialize_all = "snake_case")]
82pub enum LanguageModelProvider {
83 Anthropic,
84 OpenAi,
85 Google,
86 XAi,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct PredictEditsBody {
91 #[serde(skip_serializing_if = "Option::is_none", default)]
92 pub outline: Option<String>,
93 pub input_events: String,
94 pub input_excerpt: String,
95 #[serde(skip_serializing_if = "Option::is_none", default)]
96 pub speculated_output: Option<String>,
97 /// Whether the user provided consent for sampling this interaction.
98 #[serde(default, alias = "data_collection_permission")]
99 pub can_collect_data: bool,
100 #[serde(skip_serializing_if = "Option::is_none", default)]
101 pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
102 /// Info about the git repository state, only present when can_collect_data is true.
103 #[serde(skip_serializing_if = "Option::is_none", default)]
104 pub git_info: Option<PredictEditsGitInfo>,
105 /// The trigger for this request.
106 #[serde(default)]
107 pub trigger: PredictEditsRequestTrigger,
108}
109
110#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)]
111pub enum PredictEditsRequestTrigger {
112 Testing,
113 Diagnostics,
114 Cli,
115 #[default]
116 Other,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct PredictEditsGitInfo {
121 /// SHA of git HEAD commit at time of prediction.
122 #[serde(skip_serializing_if = "Option::is_none", default)]
123 pub head_sha: Option<String>,
124 /// URL of the remote called `origin`.
125 #[serde(skip_serializing_if = "Option::is_none", default)]
126 pub remote_origin_url: Option<String>,
127 /// URL of the remote called `upstream`.
128 #[serde(skip_serializing_if = "Option::is_none", default)]
129 pub remote_upstream_url: Option<String>,
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct PredictEditsResponse {
134 pub request_id: String,
135 pub output_excerpt: String,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct AcceptEditPredictionBody {
140 pub request_id: String,
141}
142
143#[derive(Debug, Clone, Deserialize)]
144pub struct RejectEditPredictionsBody {
145 pub rejections: Vec<EditPredictionRejection>,
146}
147
148#[derive(Debug, Clone, Serialize)]
149pub struct RejectEditPredictionsBodyRef<'a> {
150 pub rejections: &'a [EditPredictionRejection],
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
154pub struct EditPredictionRejection {
155 pub request_id: String,
156 #[serde(default)]
157 pub reason: EditPredictionRejectReason,
158 pub was_shown: bool,
159}
160
161#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
162pub enum EditPredictionRejectReason {
163 /// New requests were triggered before this one completed
164 Canceled,
165 /// No edits returned
166 Empty,
167 /// Edits returned, but none remained after interpolation
168 InterpolatedEmpty,
169 /// The new prediction was preferred over the current one
170 Replaced,
171 /// The current prediction was preferred over the new one
172 CurrentPreferred,
173 /// The current prediction was discarded
174 #[default]
175 Discarded,
176}
177
178#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
179#[serde(rename_all = "snake_case")]
180pub enum CompletionIntent {
181 UserPrompt,
182 ToolResults,
183 ThreadSummarization,
184 ThreadContextSummarization,
185 CreateFile,
186 EditFile,
187 InlineAssist,
188 TerminalInlineAssist,
189 GenerateGitCommitMessage,
190}
191
192#[derive(Debug, Serialize, Deserialize)]
193pub struct CompletionBody {
194 #[serde(skip_serializing_if = "Option::is_none", default)]
195 pub thread_id: Option<String>,
196 #[serde(skip_serializing_if = "Option::is_none", default)]
197 pub prompt_id: Option<String>,
198 #[serde(skip_serializing_if = "Option::is_none", default)]
199 pub intent: Option<CompletionIntent>,
200 pub provider: LanguageModelProvider,
201 pub model: String,
202 pub provider_request: serde_json::Value,
203}
204
205#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
206#[serde(rename_all = "snake_case")]
207pub enum CompletionRequestStatus {
208 Queued {
209 position: usize,
210 },
211 Started,
212 Failed {
213 code: String,
214 message: String,
215 request_id: Uuid,
216 /// Retry duration in seconds.
217 retry_after: Option<f64>,
218 },
219 UsageUpdated {
220 amount: usize,
221 limit: UsageLimit,
222 },
223 ToolUseLimitReached,
224}
225
226#[derive(Serialize, Deserialize)]
227#[serde(rename_all = "snake_case")]
228pub enum CompletionEvent<T> {
229 Status(CompletionRequestStatus),
230 Event(T),
231}
232
233impl<T> CompletionEvent<T> {
234 pub fn into_status(self) -> Option<CompletionRequestStatus> {
235 match self {
236 Self::Status(status) => Some(status),
237 Self::Event(_) => None,
238 }
239 }
240
241 pub fn into_event(self) -> Option<T> {
242 match self {
243 Self::Event(event) => Some(event),
244 Self::Status(_) => None,
245 }
246 }
247}
248
249#[derive(Serialize, Deserialize)]
250pub struct WebSearchBody {
251 pub query: String,
252}
253
254#[derive(Debug, Serialize, Deserialize, Clone)]
255pub struct WebSearchResponse {
256 pub results: Vec<WebSearchResult>,
257}
258
259#[derive(Debug, Serialize, Deserialize, Clone)]
260pub struct WebSearchResult {
261 pub title: String,
262 pub url: String,
263 pub text: String,
264}
265
266#[derive(Serialize, Deserialize)]
267pub struct CountTokensBody {
268 pub provider: LanguageModelProvider,
269 pub model: String,
270 pub provider_request: serde_json::Value,
271}
272
273#[derive(Serialize, Deserialize)]
274pub struct CountTokensResponse {
275 pub tokens: usize,
276}
277
278#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
279pub struct LanguageModelId(pub Arc<str>);
280
281impl std::fmt::Display for LanguageModelId {
282 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283 write!(f, "{}", self.0)
284 }
285}
286
287#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
288pub struct LanguageModel {
289 pub provider: LanguageModelProvider,
290 pub id: LanguageModelId,
291 pub display_name: String,
292 pub max_token_count: usize,
293 pub max_token_count_in_max_mode: Option<usize>,
294 pub max_output_tokens: usize,
295 pub supports_tools: bool,
296 pub supports_images: bool,
297 pub supports_thinking: bool,
298 #[serde(default)]
299 pub supports_streaming_tools: bool,
300 /// Only used by OpenAI and xAI.
301 #[serde(default)]
302 pub supports_parallel_tool_calls: bool,
303}
304
305#[derive(Debug, Serialize, Deserialize)]
306pub struct ListModelsResponse {
307 pub models: Vec<LanguageModel>,
308 pub default_model: Option<LanguageModelId>,
309 pub default_fast_model: Option<LanguageModelId>,
310 pub recommended_models: Vec<LanguageModelId>,
311}
312
313#[derive(Debug, PartialEq, Serialize, Deserialize)]
314pub struct CurrentUsage {
315 pub edit_predictions: UsageData,
316}
317
318#[derive(Debug, PartialEq, Serialize, Deserialize)]
319pub struct UsageData {
320 pub used: u32,
321 pub limit: UsageLimit,
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327
328 #[test]
329 fn test_usage_limit_from_str() {
330 let limit = UsageLimit::from_str("unlimited").unwrap();
331 assert!(matches!(limit, UsageLimit::Unlimited));
332
333 let limit = UsageLimit::from_str(&0.to_string()).unwrap();
334 assert!(matches!(limit, UsageLimit::Limited(0)));
335
336 let limit = UsageLimit::from_str(&50.to_string()).unwrap();
337 assert!(matches!(limit, UsageLimit::Limited(50)));
338
339 for value in ["not_a_number", "50xyz"] {
340 let limit = UsageLimit::from_str(value);
341 assert!(limit.is_err());
342 }
343 }
344}