cloud_llm_client.rs

  1pub mod predict_edits_v3;
  2
  3use std::str::FromStr;
  4use std::sync::Arc;
  5
  6use anyhow::Context as _;
  7use serde::{Deserialize, Serialize};
  8use strum::{Display, EnumIter, EnumString};
  9use uuid::Uuid;
 10
 11/// The name of the header used to indicate which version of Zed the client is running.
 12pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
 13
 14/// The name of the header used to indicate when a request failed due to an
 15/// expired LLM token.
 16///
 17/// The client may use this as a signal to refresh the token.
 18pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
 19
 20/// The name of the header used to indicate when a request failed due to an outdated LLM token.
 21///
 22/// A token is considered "outdated" when we can't parse the claims (e.g., after adding a new required claim).
 23///
 24/// This is distinct from [`EXPIRED_LLM_TOKEN_HEADER_NAME`] which indicates the token's time-based validity has passed.
 25/// An outdated token means the token's structure is incompatible with the current server expectations.
 26pub const OUTDATED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-outdated-token";
 27
 28/// The name of the header used to indicate the usage limit for edit predictions.
 29pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
 30
 31/// The name of the header used to indicate the usage amount for edit predictions.
 32pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
 33
 34pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
 35
 36/// The name of the header used to indicate the minimum required Zed version.
 37///
 38/// This can be used to force a Zed upgrade in order to continue communicating
 39/// with the LLM service.
 40pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
 41
 42/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
 43pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 44    "x-zed-client-supports-status-messages";
 45
 46/// The name of the header used by the server to indicate to the client that it supports sending status messages.
 47pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 48    "x-zed-server-supports-status-messages";
 49
 50/// The name of the header used by the client to indicate that it supports receiving xAI models.
 51pub const CLIENT_SUPPORTS_X_AI_HEADER_NAME: &str = "x-zed-client-supports-x-ai";
 52
 53/// The maximum number of edit predictions that can be rejected per request.
 54pub const MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST: usize = 100;
 55
 56#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 57#[serde(rename_all = "snake_case")]
 58pub enum UsageLimit {
 59    Limited(i32),
 60    Unlimited,
 61}
 62
 63impl FromStr for UsageLimit {
 64    type Err = anyhow::Error;
 65
 66    fn from_str(value: &str) -> Result<Self, Self::Err> {
 67        match value {
 68            "unlimited" => Ok(Self::Unlimited),
 69            limit => limit
 70                .parse::<i32>()
 71                .map(Self::Limited)
 72                .context("failed to parse limit"),
 73        }
 74    }
 75}
 76
 77#[derive(
 78    Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
 79)]
 80#[serde(rename_all = "snake_case")]
 81#[strum(serialize_all = "snake_case")]
 82pub enum LanguageModelProvider {
 83    Anthropic,
 84    OpenAi,
 85    Google,
 86    XAi,
 87}
 88
 89#[derive(Debug, Clone, Serialize, Deserialize)]
 90pub struct PredictEditsBody {
 91    #[serde(skip_serializing_if = "Option::is_none", default)]
 92    pub outline: Option<String>,
 93    pub input_events: String,
 94    pub input_excerpt: String,
 95    #[serde(skip_serializing_if = "Option::is_none", default)]
 96    pub speculated_output: Option<String>,
 97    /// Whether the user provided consent for sampling this interaction.
 98    #[serde(default, alias = "data_collection_permission")]
 99    pub can_collect_data: bool,
100    #[serde(skip_serializing_if = "Option::is_none", default)]
101    pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
102    /// Info about the git repository state, only present when can_collect_data is true.
103    #[serde(skip_serializing_if = "Option::is_none", default)]
104    pub git_info: Option<PredictEditsGitInfo>,
105    /// The trigger for this request.
106    #[serde(default)]
107    pub trigger: PredictEditsRequestTrigger,
108}
109
110#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)]
111pub enum PredictEditsRequestTrigger {
112    Testing,
113    Diagnostics,
114    Cli,
115    #[default]
116    Other,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct PredictEditsGitInfo {
121    /// SHA of git HEAD commit at time of prediction.
122    #[serde(skip_serializing_if = "Option::is_none", default)]
123    pub head_sha: Option<String>,
124    /// URL of the remote called `origin`.
125    #[serde(skip_serializing_if = "Option::is_none", default)]
126    pub remote_origin_url: Option<String>,
127    /// URL of the remote called `upstream`.
128    #[serde(skip_serializing_if = "Option::is_none", default)]
129    pub remote_upstream_url: Option<String>,
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct PredictEditsResponse {
134    pub request_id: String,
135    pub output_excerpt: String,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct AcceptEditPredictionBody {
140    pub request_id: String,
141}
142
143#[derive(Debug, Clone, Deserialize)]
144pub struct RejectEditPredictionsBody {
145    pub rejections: Vec<EditPredictionRejection>,
146}
147
148#[derive(Debug, Clone, Serialize)]
149pub struct RejectEditPredictionsBodyRef<'a> {
150    pub rejections: &'a [EditPredictionRejection],
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
154pub struct EditPredictionRejection {
155    pub request_id: String,
156    #[serde(default)]
157    pub reason: EditPredictionRejectReason,
158    pub was_shown: bool,
159}
160
161#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
162pub enum EditPredictionRejectReason {
163    /// New requests were triggered before this one completed
164    Canceled,
165    /// No edits returned
166    Empty,
167    /// Edits returned, but none remained after interpolation
168    InterpolatedEmpty,
169    /// The new prediction was preferred over the current one
170    Replaced,
171    /// The current prediction was preferred over the new one
172    CurrentPreferred,
173    /// The current prediction was discarded
174    #[default]
175    Discarded,
176    /// The current prediction was explicitly rejected by the user
177    Rejected,
178}
179
180#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
181#[serde(rename_all = "snake_case")]
182pub enum CompletionIntent {
183    UserPrompt,
184    ToolResults,
185    ThreadSummarization,
186    ThreadContextSummarization,
187    CreateFile,
188    EditFile,
189    InlineAssist,
190    TerminalInlineAssist,
191    GenerateGitCommitMessage,
192}
193
194#[derive(Debug, Serialize, Deserialize)]
195pub struct CompletionBody {
196    #[serde(skip_serializing_if = "Option::is_none", default)]
197    pub thread_id: Option<String>,
198    #[serde(skip_serializing_if = "Option::is_none", default)]
199    pub prompt_id: Option<String>,
200    #[serde(skip_serializing_if = "Option::is_none", default)]
201    pub intent: Option<CompletionIntent>,
202    pub provider: LanguageModelProvider,
203    pub model: String,
204    pub provider_request: serde_json::Value,
205}
206
207#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
208#[serde(rename_all = "snake_case")]
209pub enum CompletionRequestStatus {
210    Queued {
211        position: usize,
212    },
213    Started,
214    Failed {
215        code: String,
216        message: String,
217        request_id: Uuid,
218        /// Retry duration in seconds.
219        retry_after: Option<f64>,
220    },
221    UsageUpdated {
222        amount: usize,
223        limit: UsageLimit,
224    },
225    ToolUseLimitReached,
226}
227
228#[derive(Serialize, Deserialize)]
229#[serde(rename_all = "snake_case")]
230pub enum CompletionEvent<T> {
231    Status(CompletionRequestStatus),
232    Event(T),
233}
234
235impl<T> CompletionEvent<T> {
236    pub fn into_status(self) -> Option<CompletionRequestStatus> {
237        match self {
238            Self::Status(status) => Some(status),
239            Self::Event(_) => None,
240        }
241    }
242
243    pub fn into_event(self) -> Option<T> {
244        match self {
245            Self::Event(event) => Some(event),
246            Self::Status(_) => None,
247        }
248    }
249}
250
251#[derive(Serialize, Deserialize)]
252pub struct WebSearchBody {
253    pub query: String,
254}
255
256#[derive(Debug, Serialize, Deserialize, Clone)]
257pub struct WebSearchResponse {
258    pub results: Vec<WebSearchResult>,
259}
260
261#[derive(Debug, Serialize, Deserialize, Clone)]
262pub struct WebSearchResult {
263    pub title: String,
264    pub url: String,
265    pub text: String,
266}
267
268#[derive(Serialize, Deserialize)]
269pub struct CountTokensBody {
270    pub provider: LanguageModelProvider,
271    pub model: String,
272    pub provider_request: serde_json::Value,
273}
274
275#[derive(Serialize, Deserialize)]
276pub struct CountTokensResponse {
277    pub tokens: usize,
278}
279
280#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
281pub struct LanguageModelId(pub Arc<str>);
282
283impl std::fmt::Display for LanguageModelId {
284    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285        write!(f, "{}", self.0)
286    }
287}
288
289#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
290pub struct LanguageModel {
291    pub provider: LanguageModelProvider,
292    pub id: LanguageModelId,
293    pub display_name: String,
294    #[serde(default)]
295    pub is_latest: bool,
296    pub max_token_count: usize,
297    pub max_token_count_in_max_mode: Option<usize>,
298    pub max_output_tokens: usize,
299    pub supports_tools: bool,
300    pub supports_images: bool,
301    pub supports_thinking: bool,
302    pub supported_effort_levels: Vec<SupportedEffortLevel>,
303    #[serde(default)]
304    pub supports_streaming_tools: bool,
305    /// Only used by OpenAI and xAI.
306    #[serde(default)]
307    pub supports_parallel_tool_calls: bool,
308}
309
310#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
311pub struct SupportedEffortLevel {
312    pub name: Arc<str>,
313    pub value: Arc<str>,
314    #[serde(default, skip_serializing_if = "Option::is_none")]
315    pub is_default: Option<bool>,
316}
317
318#[derive(Debug, Serialize, Deserialize)]
319pub struct ListModelsResponse {
320    pub models: Vec<LanguageModel>,
321    pub default_model: Option<LanguageModelId>,
322    pub default_fast_model: Option<LanguageModelId>,
323    pub recommended_models: Vec<LanguageModelId>,
324}
325
326#[derive(Debug, PartialEq, Serialize, Deserialize)]
327pub struct CurrentUsage {
328    pub edit_predictions: UsageData,
329}
330
331#[derive(Debug, PartialEq, Serialize, Deserialize)]
332pub struct UsageData {
333    pub used: u32,
334    pub limit: UsageLimit,
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340
341    #[test]
342    fn test_usage_limit_from_str() {
343        let limit = UsageLimit::from_str("unlimited").unwrap();
344        assert!(matches!(limit, UsageLimit::Unlimited));
345
346        let limit = UsageLimit::from_str(&0.to_string()).unwrap();
347        assert!(matches!(limit, UsageLimit::Limited(0)));
348
349        let limit = UsageLimit::from_str(&50.to_string()).unwrap();
350        assert!(matches!(limit, UsageLimit::Limited(50)));
351
352        for value in ["not_a_number", "50xyz"] {
353            let limit = UsageLimit::from_str(value);
354            assert!(limit.is_err());
355        }
356    }
357}