cloud_llm_client.rs

  1pub mod predict_edits_v3;
  2
  3use std::str::FromStr;
  4use std::sync::Arc;
  5
  6use anyhow::Context as _;
  7use serde::{Deserialize, Serialize};
  8use strum::{Display, EnumIter, EnumString};
  9use uuid::Uuid;
 10
 11/// The name of the header used to indicate which version of Zed the client is running.
 12pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
 13
 14/// The name of the header used to indicate when a request failed due to an
 15/// expired LLM token.
 16///
 17/// The client may use this as a signal to refresh the token.
 18pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
 19
 20/// The name of the header used to indicate when a request failed due to an outdated LLM token.
 21///
 22/// A token is considered "outdated" when we can't parse the claims (e.g., after adding a new required claim).
 23///
 24/// This is distinct from [`EXPIRED_LLM_TOKEN_HEADER_NAME`] which indicates the token's time-based validity has passed.
 25/// An outdated token means the token's structure is incompatible with the current server expectations.
 26pub const OUTDATED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-outdated-token";
 27
 28/// The name of the header used to indicate the usage limit for edit predictions.
 29pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
 30
 31/// The name of the header used to indicate the usage amount for edit predictions.
 32pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
 33
 34pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
 35
 36/// The name of the header used to indicate the minimum required Zed version.
 37///
 38/// This can be used to force a Zed upgrade in order to continue communicating
 39/// with the LLM service.
 40pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
 41
 42/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
 43pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 44    "x-zed-client-supports-status-messages";
 45
 46/// The name of the header used by the client to indicate to the server that it supports receiving a "stream_ended" request completion status.
 47pub const CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME: &str =
 48    "x-zed-client-supports-stream-ended-request-completion-status";
 49
 50/// The name of the header used by the server to indicate to the client that it supports sending status messages.
 51pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 52    "x-zed-server-supports-status-messages";
 53
 54/// The name of the header used by the client to indicate that it supports receiving xAI models.
 55pub const CLIENT_SUPPORTS_X_AI_HEADER_NAME: &str = "x-zed-client-supports-x-ai";
 56
 57/// The maximum number of edit predictions that can be rejected per request.
 58pub const MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST: usize = 100;
 59
 60#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 61#[serde(rename_all = "snake_case")]
 62pub enum UsageLimit {
 63    Limited(i32),
 64    Unlimited,
 65}
 66
 67impl FromStr for UsageLimit {
 68    type Err = anyhow::Error;
 69
 70    fn from_str(value: &str) -> Result<Self, Self::Err> {
 71        match value {
 72            "unlimited" => Ok(Self::Unlimited),
 73            limit => limit
 74                .parse::<i32>()
 75                .map(Self::Limited)
 76                .context("failed to parse limit"),
 77        }
 78    }
 79}
 80
 81#[derive(
 82    Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
 83)]
 84#[serde(rename_all = "snake_case")]
 85#[strum(serialize_all = "snake_case")]
 86pub enum LanguageModelProvider {
 87    Anthropic,
 88    OpenAi,
 89    Google,
 90    XAi,
 91}
 92
 93#[derive(Debug, Clone, Serialize, Deserialize)]
 94pub struct PredictEditsBody {
 95    #[serde(skip_serializing_if = "Option::is_none", default)]
 96    pub outline: Option<String>,
 97    pub input_events: String,
 98    pub input_excerpt: String,
 99    #[serde(skip_serializing_if = "Option::is_none", default)]
100    pub speculated_output: Option<String>,
101    /// Whether the user provided consent for sampling this interaction.
102    #[serde(default, alias = "data_collection_permission")]
103    pub can_collect_data: bool,
104    #[serde(skip_serializing_if = "Option::is_none", default)]
105    pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
106    /// Info about the git repository state, only present when can_collect_data is true.
107    #[serde(skip_serializing_if = "Option::is_none", default)]
108    pub git_info: Option<PredictEditsGitInfo>,
109    /// The trigger for this request.
110    #[serde(default)]
111    pub trigger: PredictEditsRequestTrigger,
112}
113
114#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, strum::AsRefStr)]
115#[strum(serialize_all = "snake_case")]
116pub enum PredictEditsRequestTrigger {
117    Testing,
118    Diagnostics,
119    Cli,
120    #[default]
121    Other,
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct PredictEditsGitInfo {
126    /// SHA of git HEAD commit at time of prediction.
127    #[serde(skip_serializing_if = "Option::is_none", default)]
128    pub head_sha: Option<String>,
129    /// URL of the remote called `origin`.
130    #[serde(skip_serializing_if = "Option::is_none", default)]
131    pub remote_origin_url: Option<String>,
132    /// URL of the remote called `upstream`.
133    #[serde(skip_serializing_if = "Option::is_none", default)]
134    pub remote_upstream_url: Option<String>,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct PredictEditsResponse {
139    pub request_id: String,
140    pub output_excerpt: String,
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct AcceptEditPredictionBody {
145    pub request_id: String,
146    #[serde(default, skip_serializing_if = "Option::is_none")]
147    pub model_version: Option<String>,
148    #[serde(default, skip_serializing_if = "Option::is_none")]
149    pub e2e_latency_ms: Option<u128>,
150}
151
152#[derive(Debug, Clone, Deserialize)]
153pub struct RejectEditPredictionsBody {
154    pub rejections: Vec<EditPredictionRejection>,
155}
156
157#[derive(Debug, Clone, Serialize)]
158pub struct RejectEditPredictionsBodyRef<'a> {
159    pub rejections: &'a [EditPredictionRejection],
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
163pub struct EditPredictionRejection {
164    pub request_id: String,
165    #[serde(default)]
166    pub reason: EditPredictionRejectReason,
167    pub was_shown: bool,
168    #[serde(default, skip_serializing_if = "Option::is_none")]
169    pub model_version: Option<String>,
170    #[serde(default, skip_serializing_if = "Option::is_none")]
171    pub e2e_latency_ms: Option<u128>,
172}
173
174#[derive(
175    Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, strum::AsRefStr,
176)]
177#[strum(serialize_all = "snake_case")]
178pub enum EditPredictionRejectReason {
179    /// New requests were triggered before this one completed
180    Canceled,
181    /// No edits returned
182    Empty,
183    /// Edits returned, but none remained after interpolation
184    InterpolatedEmpty,
185    /// The new prediction was preferred over the current one
186    Replaced,
187    /// The current prediction was preferred over the new one
188    CurrentPreferred,
189    /// The current prediction was discarded
190    #[default]
191    Discarded,
192    /// The current prediction was explicitly rejected by the user
193    Rejected,
194}
195
196#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
197#[serde(rename_all = "snake_case")]
198pub enum CompletionIntent {
199    UserPrompt,
200    ToolResults,
201    ThreadSummarization,
202    ThreadContextSummarization,
203    CreateFile,
204    EditFile,
205    InlineAssist,
206    TerminalInlineAssist,
207    GenerateGitCommitMessage,
208}
209
210#[derive(Debug, Serialize, Deserialize)]
211pub struct CompletionBody {
212    #[serde(skip_serializing_if = "Option::is_none", default)]
213    pub thread_id: Option<String>,
214    #[serde(skip_serializing_if = "Option::is_none", default)]
215    pub prompt_id: Option<String>,
216    #[serde(skip_serializing_if = "Option::is_none", default)]
217    pub intent: Option<CompletionIntent>,
218    pub provider: LanguageModelProvider,
219    pub model: String,
220    pub provider_request: serde_json::Value,
221}
222
223#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
224#[serde(rename_all = "snake_case")]
225pub enum CompletionRequestStatus {
226    Queued {
227        position: usize,
228    },
229    Started,
230    Failed {
231        code: String,
232        message: String,
233        request_id: Uuid,
234        /// Retry duration in seconds.
235        retry_after: Option<f64>,
236    },
237    /// The cloud sends a StreamEnded message when the stream from the LLM provider finishes.
238    StreamEnded,
239    #[serde(other)]
240    Unknown,
241}
242
243#[derive(Serialize, Deserialize)]
244#[serde(rename_all = "snake_case")]
245pub enum CompletionEvent<T> {
246    Status(CompletionRequestStatus),
247    Event(T),
248}
249
250impl<T> CompletionEvent<T> {
251    pub fn into_status(self) -> Option<CompletionRequestStatus> {
252        match self {
253            Self::Status(status) => Some(status),
254            Self::Event(_) => None,
255        }
256    }
257
258    pub fn into_event(self) -> Option<T> {
259        match self {
260            Self::Event(event) => Some(event),
261            Self::Status(_) => None,
262        }
263    }
264}
265
266#[derive(Serialize, Deserialize)]
267pub struct WebSearchBody {
268    pub query: String,
269}
270
271#[derive(Debug, Serialize, Deserialize, Clone)]
272pub struct WebSearchResponse {
273    pub results: Vec<WebSearchResult>,
274}
275
276#[derive(Debug, Serialize, Deserialize, Clone)]
277pub struct WebSearchResult {
278    pub title: String,
279    pub url: String,
280    pub text: String,
281}
282
283#[derive(Serialize, Deserialize)]
284pub struct CountTokensBody {
285    pub provider: LanguageModelProvider,
286    pub model: String,
287    pub provider_request: serde_json::Value,
288}
289
290#[derive(Serialize, Deserialize)]
291pub struct CountTokensResponse {
292    pub tokens: usize,
293}
294
295#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
296pub struct LanguageModelId(pub Arc<str>);
297
298impl std::fmt::Display for LanguageModelId {
299    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300        write!(f, "{}", self.0)
301    }
302}
303
304#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
305pub struct LanguageModel {
306    pub provider: LanguageModelProvider,
307    pub id: LanguageModelId,
308    pub display_name: String,
309    #[serde(default)]
310    pub is_latest: bool,
311    pub max_token_count: usize,
312    pub max_token_count_in_max_mode: Option<usize>,
313    pub max_output_tokens: usize,
314    pub supports_tools: bool,
315    pub supports_images: bool,
316    pub supports_thinking: bool,
317    #[serde(default)]
318    pub supports_fast_mode: bool,
319    pub supported_effort_levels: Vec<SupportedEffortLevel>,
320    #[serde(default)]
321    pub supports_streaming_tools: bool,
322    /// Only used by OpenAI and xAI.
323    #[serde(default)]
324    pub supports_parallel_tool_calls: bool,
325}
326
327#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
328pub struct SupportedEffortLevel {
329    pub name: Arc<str>,
330    pub value: Arc<str>,
331    #[serde(default, skip_serializing_if = "Option::is_none")]
332    pub is_default: Option<bool>,
333}
334
335#[derive(Debug, Serialize, Deserialize)]
336pub struct ListModelsResponse {
337    pub models: Vec<LanguageModel>,
338    pub default_model: Option<LanguageModelId>,
339    pub default_fast_model: Option<LanguageModelId>,
340    pub recommended_models: Vec<LanguageModelId>,
341}
342
343#[derive(Debug, PartialEq, Serialize, Deserialize)]
344pub struct CurrentUsage {
345    pub edit_predictions: UsageData,
346}
347
348#[derive(Debug, PartialEq, Serialize, Deserialize)]
349pub struct UsageData {
350    pub used: u32,
351    pub limit: UsageLimit,
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    #[test]
359    fn test_usage_limit_from_str() {
360        let limit = UsageLimit::from_str("unlimited").unwrap();
361        assert!(matches!(limit, UsageLimit::Unlimited));
362
363        let limit = UsageLimit::from_str(&0.to_string()).unwrap();
364        assert!(matches!(limit, UsageLimit::Limited(0)));
365
366        let limit = UsageLimit::from_str(&50.to_string()).unwrap();
367        assert!(matches!(limit, UsageLimit::Limited(50)));
368
369        for value in ["not_a_number", "50xyz"] {
370            let limit = UsageLimit::from_str(value);
371            assert!(limit.is_err());
372        }
373    }
374}