1#[cfg(feature = "predict-edits")]
2pub mod predict_edits_v3;
3
4use std::str::FromStr;
5use std::sync::Arc;
6
7use anyhow::Context as _;
8use serde::{Deserialize, Serialize};
9use strum::{Display, EnumIter, EnumString};
10use uuid::Uuid;
11
12/// The name of the header used to indicate which version of Zed the client is running.
13pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
14
15/// The name of the header used to indicate when a request failed due to an
16/// expired LLM token.
17///
18/// The client may use this as a signal to refresh the token.
19pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
20
21/// The name of the header used to indicate when a request failed due to an outdated LLM token.
22///
23/// A token is considered "outdated" when we can't parse the claims (e.g., after adding a new required claim).
24///
25/// This is distinct from [`EXPIRED_LLM_TOKEN_HEADER_NAME`] which indicates the token's time-based validity has passed.
26/// An outdated token means the token's structure is incompatible with the current server expectations.
27pub const OUTDATED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-outdated-token";
28
29/// The name of the header used to indicate the usage limit for edit predictions.
30pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
31
32/// The name of the header used to indicate the usage amount for edit predictions.
33pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
34
35pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
36
37/// The name of the header used to indicate the minimum required Zed version.
38///
39/// This can be used to force a Zed upgrade in order to continue communicating
40/// with the LLM service.
41pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
42
43/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
44pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
45 "x-zed-client-supports-status-messages";
46
47/// The name of the header used by the client to indicate to the server that it supports receiving a "stream_ended" request completion status.
48pub const CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME: &str =
49 "x-zed-client-supports-stream-ended-request-completion-status";
50
51/// The name of the header used by the server to indicate to the client that it supports sending status messages.
52pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
53 "x-zed-server-supports-status-messages";
54
55/// The name of the header used by the client to indicate that it supports receiving xAI models.
56pub const CLIENT_SUPPORTS_X_AI_HEADER_NAME: &str = "x-zed-client-supports-x-ai";
57
58/// The maximum number of edit predictions that can be rejected per request.
59pub const MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST: usize = 100;
60
61#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
62#[serde(rename_all = "snake_case")]
63pub enum UsageLimit {
64 Limited(i32),
65 Unlimited,
66}
67
68impl FromStr for UsageLimit {
69 type Err = anyhow::Error;
70
71 fn from_str(value: &str) -> Result<Self, Self::Err> {
72 match value {
73 "unlimited" => Ok(Self::Unlimited),
74 limit => limit
75 .parse::<i32>()
76 .map(Self::Limited)
77 .context("failed to parse limit"),
78 }
79 }
80}
81
82#[derive(
83 Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
84)]
85#[serde(rename_all = "snake_case")]
86#[strum(serialize_all = "snake_case")]
87pub enum LanguageModelProvider {
88 Anthropic,
89 OpenAi,
90 Google,
91 XAi,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct PredictEditsBody {
96 #[serde(skip_serializing_if = "Option::is_none", default)]
97 pub outline: Option<String>,
98 pub input_events: String,
99 pub input_excerpt: String,
100 #[serde(skip_serializing_if = "Option::is_none", default)]
101 pub speculated_output: Option<String>,
102 /// Whether the user provided consent for sampling this interaction.
103 #[serde(default, alias = "data_collection_permission")]
104 pub can_collect_data: bool,
105 #[serde(skip_serializing_if = "Option::is_none", default)]
106 pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
107 /// Info about the git repository state, only present when can_collect_data is true.
108 #[serde(skip_serializing_if = "Option::is_none", default)]
109 pub git_info: Option<PredictEditsGitInfo>,
110 /// The trigger for this request.
111 #[serde(default)]
112 pub trigger: PredictEditsRequestTrigger,
113}
114
115#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, strum::AsRefStr)]
116#[strum(serialize_all = "snake_case")]
117pub enum PredictEditsRequestTrigger {
118 Testing,
119 Diagnostics,
120 Cli,
121 #[default]
122 Other,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct PredictEditsGitInfo {
127 /// SHA of git HEAD commit at time of prediction.
128 #[serde(skip_serializing_if = "Option::is_none", default)]
129 pub head_sha: Option<String>,
130 /// URL of the remote called `origin`.
131 #[serde(skip_serializing_if = "Option::is_none", default)]
132 pub remote_origin_url: Option<String>,
133 /// URL of the remote called `upstream`.
134 #[serde(skip_serializing_if = "Option::is_none", default)]
135 pub remote_upstream_url: Option<String>,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct PredictEditsResponse {
140 pub request_id: String,
141 pub output_excerpt: String,
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct AcceptEditPredictionBody {
146 pub request_id: String,
147 #[serde(default, skip_serializing_if = "Option::is_none")]
148 pub model_version: Option<String>,
149 #[serde(default, skip_serializing_if = "Option::is_none")]
150 pub e2e_latency_ms: Option<u128>,
151}
152
153#[derive(Debug, Clone, Deserialize)]
154pub struct RejectEditPredictionsBody {
155 pub rejections: Vec<EditPredictionRejection>,
156}
157
158#[derive(Debug, Clone, Serialize)]
159pub struct RejectEditPredictionsBodyRef<'a> {
160 pub rejections: &'a [EditPredictionRejection],
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
164pub struct EditPredictionRejection {
165 pub request_id: String,
166 #[serde(default)]
167 pub reason: EditPredictionRejectReason,
168 pub was_shown: bool,
169 #[serde(default, skip_serializing_if = "Option::is_none")]
170 pub model_version: Option<String>,
171 #[serde(default, skip_serializing_if = "Option::is_none")]
172 pub e2e_latency_ms: Option<u128>,
173}
174
175#[derive(
176 Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, strum::AsRefStr,
177)]
178#[strum(serialize_all = "snake_case")]
179pub enum EditPredictionRejectReason {
180 /// New requests were triggered before this one completed
181 Canceled,
182 /// No edits returned
183 Empty,
184 /// Edits returned, but none remained after interpolation
185 InterpolatedEmpty,
186 /// The new prediction was preferred over the current one
187 Replaced,
188 /// The current prediction was preferred over the new one
189 CurrentPreferred,
190 /// The current prediction was discarded
191 #[default]
192 Discarded,
193 /// The current prediction was explicitly rejected by the user
194 Rejected,
195}
196
197#[derive(Debug, Serialize, Deserialize)]
198pub struct CompletionBody {
199 #[serde(skip_serializing_if = "Option::is_none", default)]
200 pub thread_id: Option<String>,
201 #[serde(skip_serializing_if = "Option::is_none", default)]
202 pub prompt_id: Option<String>,
203 pub provider: LanguageModelProvider,
204 pub model: String,
205 pub provider_request: serde_json::Value,
206}
207
208#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
209#[serde(rename_all = "snake_case")]
210pub enum CompletionRequestStatus {
211 Queued {
212 position: usize,
213 },
214 Started,
215 Failed {
216 code: String,
217 message: String,
218 request_id: Uuid,
219 /// Retry duration in seconds.
220 retry_after: Option<f64>,
221 },
222 /// The cloud sends a StreamEnded message when the stream from the LLM provider finishes.
223 StreamEnded,
224 #[serde(other)]
225 Unknown,
226}
227
228#[derive(Serialize, Deserialize)]
229#[serde(rename_all = "snake_case")]
230pub enum CompletionEvent<T> {
231 Status(CompletionRequestStatus),
232 Event(T),
233}
234
235impl<T> CompletionEvent<T> {
236 pub fn into_status(self) -> Option<CompletionRequestStatus> {
237 match self {
238 Self::Status(status) => Some(status),
239 Self::Event(_) => None,
240 }
241 }
242
243 pub fn into_event(self) -> Option<T> {
244 match self {
245 Self::Event(event) => Some(event),
246 Self::Status(_) => None,
247 }
248 }
249}
250
251#[derive(Serialize, Deserialize)]
252pub struct WebSearchBody {
253 pub query: String,
254}
255
256#[derive(Debug, Serialize, Deserialize, Clone)]
257pub struct WebSearchResponse {
258 pub results: Vec<WebSearchResult>,
259}
260
261#[derive(Debug, Serialize, Deserialize, Clone)]
262pub struct WebSearchResult {
263 pub title: String,
264 pub url: String,
265 pub text: String,
266}
267
268#[derive(Serialize, Deserialize)]
269pub struct CountTokensBody {
270 pub provider: LanguageModelProvider,
271 pub model: String,
272 pub provider_request: serde_json::Value,
273}
274
275#[derive(Serialize, Deserialize)]
276pub struct CountTokensResponse {
277 pub tokens: usize,
278}
279
280#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
281pub struct LanguageModelId(pub Arc<str>);
282
283impl std::fmt::Display for LanguageModelId {
284 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285 write!(f, "{}", self.0)
286 }
287}
288
289#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
290pub struct LanguageModel {
291 pub provider: LanguageModelProvider,
292 pub id: LanguageModelId,
293 pub display_name: String,
294 #[serde(default)]
295 pub is_latest: bool,
296 pub max_token_count: usize,
297 pub max_token_count_in_max_mode: Option<usize>,
298 pub max_output_tokens: usize,
299 pub supports_tools: bool,
300 pub supports_images: bool,
301 pub supports_thinking: bool,
302 #[serde(default)]
303 pub supports_fast_mode: bool,
304 pub supported_effort_levels: Vec<SupportedEffortLevel>,
305 #[serde(default)]
306 pub supports_streaming_tools: bool,
307 /// Only used by OpenAI and xAI.
308 #[serde(default)]
309 pub supports_parallel_tool_calls: bool,
310}
311
312#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
313pub struct SupportedEffortLevel {
314 pub name: Arc<str>,
315 pub value: Arc<str>,
316 #[serde(default, skip_serializing_if = "Option::is_none")]
317 pub is_default: Option<bool>,
318}
319
320#[derive(Debug, Serialize, Deserialize)]
321pub struct ListModelsResponse {
322 pub models: Vec<LanguageModel>,
323 pub default_model: Option<LanguageModelId>,
324 pub default_fast_model: Option<LanguageModelId>,
325 pub recommended_models: Vec<LanguageModelId>,
326}
327
328#[derive(Debug, PartialEq, Serialize, Deserialize)]
329pub struct CurrentUsage {
330 pub edit_predictions: UsageData,
331}
332
333#[derive(Debug, PartialEq, Serialize, Deserialize)]
334pub struct UsageData {
335 pub used: u32,
336 pub limit: UsageLimit,
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342
343 #[test]
344 fn test_usage_limit_from_str() {
345 let limit = UsageLimit::from_str("unlimited").unwrap();
346 assert!(matches!(limit, UsageLimit::Unlimited));
347
348 let limit = UsageLimit::from_str(&0.to_string()).unwrap();
349 assert!(matches!(limit, UsageLimit::Limited(0)));
350
351 let limit = UsageLimit::from_str(&50.to_string()).unwrap();
352 assert!(matches!(limit, UsageLimit::Limited(50)));
353
354 for value in ["not_a_number", "50xyz"] {
355 let limit = UsageLimit::from_str(value);
356 assert!(limit.is_err());
357 }
358 }
359}