1pub mod predict_edits_v3;
2
3use std::str::FromStr;
4use std::sync::Arc;
5
6use anyhow::Context as _;
7use serde::{Deserialize, Serialize};
8use strum::{Display, EnumIter, EnumString};
9use uuid::Uuid;
10
11/// The name of the header used to indicate which version of Zed the client is running.
12pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
13
14/// The name of the header used to indicate when a request failed due to an
15/// expired LLM token.
16///
17/// The client may use this as a signal to refresh the token.
18pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
19
20/// The name of the header used to indicate when a request failed due to an outdated LLM token.
21///
22/// A token is considered "outdated" when we can't parse the claims (e.g., after adding a new required claim).
23///
24/// This is distinct from [`EXPIRED_LLM_TOKEN_HEADER_NAME`] which indicates the token's time-based validity has passed.
25/// An outdated token means the token's structure is incompatible with the current server expectations.
26pub const OUTDATED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-outdated-token";
27
28/// The name of the header used to indicate the usage limit for edit predictions.
29pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
30
31/// The name of the header used to indicate the usage amount for edit predictions.
32pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
33
34pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
35
36/// The name of the header used to indicate the minimum required Zed version.
37///
38/// This can be used to force a Zed upgrade in order to continue communicating
39/// with the LLM service.
40pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
41
42/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
43pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
44 "x-zed-client-supports-status-messages";
45
46/// The name of the header used by the client to indicate to the server that it supports receiving a "stream_ended" request completion status.
47pub const CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME: &str =
48 "x-zed-client-supports-stream-ended-request-completion-status";
49
50/// The name of the header used by the server to indicate to the client that it supports sending status messages.
51pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
52 "x-zed-server-supports-status-messages";
53
54/// The name of the header used by the client to indicate that it supports receiving xAI models.
55pub const CLIENT_SUPPORTS_X_AI_HEADER_NAME: &str = "x-zed-client-supports-x-ai";
56
57/// The maximum number of edit predictions that can be rejected per request.
58pub const MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST: usize = 100;
59
60#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
61#[serde(rename_all = "snake_case")]
62pub enum UsageLimit {
63 Limited(i32),
64 Unlimited,
65}
66
67impl FromStr for UsageLimit {
68 type Err = anyhow::Error;
69
70 fn from_str(value: &str) -> Result<Self, Self::Err> {
71 match value {
72 "unlimited" => Ok(Self::Unlimited),
73 limit => limit
74 .parse::<i32>()
75 .map(Self::Limited)
76 .context("failed to parse limit"),
77 }
78 }
79}
80
81#[derive(
82 Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
83)]
84#[serde(rename_all = "snake_case")]
85#[strum(serialize_all = "snake_case")]
86pub enum LanguageModelProvider {
87 Anthropic,
88 OpenAi,
89 Google,
90 XAi,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct PredictEditsBody {
95 #[serde(skip_serializing_if = "Option::is_none", default)]
96 pub outline: Option<String>,
97 pub input_events: String,
98 pub input_excerpt: String,
99 #[serde(skip_serializing_if = "Option::is_none", default)]
100 pub speculated_output: Option<String>,
101 /// Whether the user provided consent for sampling this interaction.
102 #[serde(default, alias = "data_collection_permission")]
103 pub can_collect_data: bool,
104 #[serde(skip_serializing_if = "Option::is_none", default)]
105 pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
106 /// Info about the git repository state, only present when can_collect_data is true.
107 #[serde(skip_serializing_if = "Option::is_none", default)]
108 pub git_info: Option<PredictEditsGitInfo>,
109 /// The trigger for this request.
110 #[serde(default)]
111 pub trigger: PredictEditsRequestTrigger,
112}
113
114#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, strum::AsRefStr)]
115#[strum(serialize_all = "snake_case")]
116pub enum PredictEditsRequestTrigger {
117 Testing,
118 Diagnostics,
119 Cli,
120 #[default]
121 Other,
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct PredictEditsGitInfo {
126 /// SHA of git HEAD commit at time of prediction.
127 #[serde(skip_serializing_if = "Option::is_none", default)]
128 pub head_sha: Option<String>,
129 /// URL of the remote called `origin`.
130 #[serde(skip_serializing_if = "Option::is_none", default)]
131 pub remote_origin_url: Option<String>,
132 /// URL of the remote called `upstream`.
133 #[serde(skip_serializing_if = "Option::is_none", default)]
134 pub remote_upstream_url: Option<String>,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct PredictEditsResponse {
139 pub request_id: String,
140 pub output_excerpt: String,
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct AcceptEditPredictionBody {
145 pub request_id: String,
146 #[serde(default, skip_serializing_if = "Option::is_none")]
147 pub model_version: Option<String>,
148 #[serde(default, skip_serializing_if = "Option::is_none")]
149 pub e2e_latency_ms: Option<u128>,
150}
151
152#[derive(Debug, Clone, Deserialize)]
153pub struct RejectEditPredictionsBody {
154 pub rejections: Vec<EditPredictionRejection>,
155}
156
157#[derive(Debug, Clone, Serialize)]
158pub struct RejectEditPredictionsBodyRef<'a> {
159 pub rejections: &'a [EditPredictionRejection],
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
163pub struct EditPredictionRejection {
164 pub request_id: String,
165 #[serde(default)]
166 pub reason: EditPredictionRejectReason,
167 pub was_shown: bool,
168 #[serde(default, skip_serializing_if = "Option::is_none")]
169 pub model_version: Option<String>,
170 #[serde(default, skip_serializing_if = "Option::is_none")]
171 pub e2e_latency_ms: Option<u128>,
172}
173
174#[derive(
175 Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, strum::AsRefStr,
176)]
177#[strum(serialize_all = "snake_case")]
178pub enum EditPredictionRejectReason {
179 /// New requests were triggered before this one completed
180 Canceled,
181 /// No edits returned
182 Empty,
183 /// Edits returned, but none remained after interpolation
184 InterpolatedEmpty,
185 /// The new prediction was preferred over the current one
186 Replaced,
187 /// The current prediction was preferred over the new one
188 CurrentPreferred,
189 /// The current prediction was discarded
190 #[default]
191 Discarded,
192 /// The current prediction was explicitly rejected by the user
193 Rejected,
194}
195
196#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
197#[serde(rename_all = "snake_case")]
198pub enum CompletionIntent {
199 UserPrompt,
200 Subagent,
201 ToolResults,
202 ThreadSummarization,
203 ThreadContextSummarization,
204 CreateFile,
205 EditFile,
206 InlineAssist,
207 TerminalInlineAssist,
208 GenerateGitCommitMessage,
209}
210
211#[derive(Debug, Serialize, Deserialize)]
212pub struct CompletionBody {
213 #[serde(skip_serializing_if = "Option::is_none", default)]
214 pub thread_id: Option<String>,
215 #[serde(skip_serializing_if = "Option::is_none", default)]
216 pub prompt_id: Option<String>,
217 #[serde(skip_serializing_if = "Option::is_none", default)]
218 pub intent: Option<CompletionIntent>,
219 pub provider: LanguageModelProvider,
220 pub model: String,
221 pub provider_request: serde_json::Value,
222}
223
224#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
225#[serde(rename_all = "snake_case")]
226pub enum CompletionRequestStatus {
227 Queued {
228 position: usize,
229 },
230 Started,
231 Failed {
232 code: String,
233 message: String,
234 request_id: Uuid,
235 /// Retry duration in seconds.
236 retry_after: Option<f64>,
237 },
238 /// The cloud sends a StreamEnded message when the stream from the LLM provider finishes.
239 StreamEnded,
240 #[serde(other)]
241 Unknown,
242}
243
244#[derive(Serialize, Deserialize)]
245#[serde(rename_all = "snake_case")]
246pub enum CompletionEvent<T> {
247 Status(CompletionRequestStatus),
248 Event(T),
249}
250
251impl<T> CompletionEvent<T> {
252 pub fn into_status(self) -> Option<CompletionRequestStatus> {
253 match self {
254 Self::Status(status) => Some(status),
255 Self::Event(_) => None,
256 }
257 }
258
259 pub fn into_event(self) -> Option<T> {
260 match self {
261 Self::Event(event) => Some(event),
262 Self::Status(_) => None,
263 }
264 }
265}
266
267#[derive(Serialize, Deserialize)]
268pub struct WebSearchBody {
269 pub query: String,
270}
271
272#[derive(Debug, Serialize, Deserialize, Clone)]
273pub struct WebSearchResponse {
274 pub results: Vec<WebSearchResult>,
275}
276
277#[derive(Debug, Serialize, Deserialize, Clone)]
278pub struct WebSearchResult {
279 pub title: String,
280 pub url: String,
281 pub text: String,
282}
283
284#[derive(Serialize, Deserialize)]
285pub struct CountTokensBody {
286 pub provider: LanguageModelProvider,
287 pub model: String,
288 pub provider_request: serde_json::Value,
289}
290
291#[derive(Serialize, Deserialize)]
292pub struct CountTokensResponse {
293 pub tokens: usize,
294}
295
296#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
297pub struct LanguageModelId(pub Arc<str>);
298
299impl std::fmt::Display for LanguageModelId {
300 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301 write!(f, "{}", self.0)
302 }
303}
304
305#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
306pub struct LanguageModel {
307 pub provider: LanguageModelProvider,
308 pub id: LanguageModelId,
309 pub display_name: String,
310 #[serde(default)]
311 pub is_latest: bool,
312 pub max_token_count: usize,
313 pub max_token_count_in_max_mode: Option<usize>,
314 pub max_output_tokens: usize,
315 pub supports_tools: bool,
316 pub supports_images: bool,
317 pub supports_thinking: bool,
318 #[serde(default)]
319 pub supports_fast_mode: bool,
320 pub supported_effort_levels: Vec<SupportedEffortLevel>,
321 #[serde(default)]
322 pub supports_streaming_tools: bool,
323 /// Only used by OpenAI and xAI.
324 #[serde(default)]
325 pub supports_parallel_tool_calls: bool,
326}
327
328#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
329pub struct SupportedEffortLevel {
330 pub name: Arc<str>,
331 pub value: Arc<str>,
332 #[serde(default, skip_serializing_if = "Option::is_none")]
333 pub is_default: Option<bool>,
334}
335
336#[derive(Debug, Serialize, Deserialize)]
337pub struct ListModelsResponse {
338 pub models: Vec<LanguageModel>,
339 pub default_model: Option<LanguageModelId>,
340 pub default_fast_model: Option<LanguageModelId>,
341 pub recommended_models: Vec<LanguageModelId>,
342}
343
344#[derive(Debug, PartialEq, Serialize, Deserialize)]
345pub struct CurrentUsage {
346 pub edit_predictions: UsageData,
347}
348
349#[derive(Debug, PartialEq, Serialize, Deserialize)]
350pub struct UsageData {
351 pub used: u32,
352 pub limit: UsageLimit,
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 #[test]
360 fn test_usage_limit_from_str() {
361 let limit = UsageLimit::from_str("unlimited").unwrap();
362 assert!(matches!(limit, UsageLimit::Unlimited));
363
364 let limit = UsageLimit::from_str(&0.to_string()).unwrap();
365 assert!(matches!(limit, UsageLimit::Limited(0)));
366
367 let limit = UsageLimit::from_str(&50.to_string()).unwrap();
368 assert!(matches!(limit, UsageLimit::Limited(50)));
369
370 for value in ["not_a_number", "50xyz"] {
371 let limit = UsageLimit::from_str(value);
372 assert!(limit.is_err());
373 }
374 }
375}