cloud_llm_client.rs

  1#[cfg(feature = "predict-edits")]
  2pub mod predict_edits_v3;
  3
  4use std::str::FromStr;
  5use std::sync::Arc;
  6
  7use anyhow::Context as _;
  8use serde::{Deserialize, Serialize};
  9use strum::{Display, EnumIter, EnumString};
 10use uuid::Uuid;
 11
 12/// The name of the header used to indicate which version of Zed the client is running.
 13pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
 14
 15/// The name of the header used to indicate when a request failed due to an
 16/// expired LLM token.
 17///
 18/// The client may use this as a signal to refresh the token.
 19pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
 20
 21/// The name of the header used to indicate when a request failed due to an outdated LLM token.
 22///
 23/// A token is considered "outdated" when we can't parse the claims (e.g., after adding a new required claim).
 24///
 25/// This is distinct from [`EXPIRED_LLM_TOKEN_HEADER_NAME`] which indicates the token's time-based validity has passed.
 26/// An outdated token means the token's structure is incompatible with the current server expectations.
 27pub const OUTDATED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-outdated-token";
 28
 29/// The name of the header used to indicate the usage limit for edit predictions.
 30pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
 31
 32/// The name of the header used to indicate the usage amount for edit predictions.
 33pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
 34
 35pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
 36
 37/// The name of the header used to indicate the minimum required Zed version.
 38///
 39/// This can be used to force a Zed upgrade in order to continue communicating
 40/// with the LLM service.
 41pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
 42
 43/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
 44pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 45    "x-zed-client-supports-status-messages";
 46
 47/// The name of the header used by the client to indicate to the server that it supports receiving a "stream_ended" request completion status.
 48pub const CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME: &str =
 49    "x-zed-client-supports-stream-ended-request-completion-status";
 50
 51/// The name of the header used by the server to indicate to the client that it supports sending status messages.
 52pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 53    "x-zed-server-supports-status-messages";
 54
 55/// The name of the header used by the client to indicate that it supports receiving xAI models.
 56pub const CLIENT_SUPPORTS_X_AI_HEADER_NAME: &str = "x-zed-client-supports-x-ai";
 57
 58/// The maximum number of edit predictions that can be rejected per request.
 59pub const MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST: usize = 100;
 60
 61#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 62#[serde(rename_all = "snake_case")]
 63pub enum UsageLimit {
 64    Limited(i32),
 65    Unlimited,
 66}
 67
 68impl FromStr for UsageLimit {
 69    type Err = anyhow::Error;
 70
 71    fn from_str(value: &str) -> Result<Self, Self::Err> {
 72        match value {
 73            "unlimited" => Ok(Self::Unlimited),
 74            limit => limit
 75                .parse::<i32>()
 76                .map(Self::Limited)
 77                .context("failed to parse limit"),
 78        }
 79    }
 80}
 81
 82#[derive(
 83    Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
 84)]
 85#[serde(rename_all = "snake_case")]
 86#[strum(serialize_all = "snake_case")]
 87pub enum LanguageModelProvider {
 88    Anthropic,
 89    OpenAi,
 90    Google,
 91    XAi,
 92}
 93
 94#[derive(Debug, Clone, Serialize, Deserialize)]
 95pub struct PredictEditsBody {
 96    #[serde(skip_serializing_if = "Option::is_none", default)]
 97    pub outline: Option<String>,
 98    pub input_events: String,
 99    pub input_excerpt: String,
100    #[serde(skip_serializing_if = "Option::is_none", default)]
101    pub speculated_output: Option<String>,
102    /// Whether the user provided consent for sampling this interaction.
103    #[serde(default, alias = "data_collection_permission")]
104    pub can_collect_data: bool,
105    #[serde(skip_serializing_if = "Option::is_none", default)]
106    pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
107    /// Info about the git repository state, only present when can_collect_data is true.
108    #[serde(skip_serializing_if = "Option::is_none", default)]
109    pub git_info: Option<PredictEditsGitInfo>,
110    /// The trigger for this request.
111    #[serde(default)]
112    pub trigger: PredictEditsRequestTrigger,
113}
114
115#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, strum::AsRefStr)]
116#[strum(serialize_all = "snake_case")]
117pub enum PredictEditsRequestTrigger {
118    Testing,
119    Diagnostics,
120    Cli,
121    #[default]
122    Other,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct PredictEditsGitInfo {
127    /// SHA of git HEAD commit at time of prediction.
128    #[serde(skip_serializing_if = "Option::is_none", default)]
129    pub head_sha: Option<String>,
130    /// URL of the remote called `origin`.
131    #[serde(skip_serializing_if = "Option::is_none", default)]
132    pub remote_origin_url: Option<String>,
133    /// URL of the remote called `upstream`.
134    #[serde(skip_serializing_if = "Option::is_none", default)]
135    pub remote_upstream_url: Option<String>,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct PredictEditsResponse {
140    pub request_id: String,
141    pub output_excerpt: String,
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct AcceptEditPredictionBody {
146    pub request_id: String,
147    #[serde(default, skip_serializing_if = "Option::is_none")]
148    pub model_version: Option<String>,
149    #[serde(default, skip_serializing_if = "Option::is_none")]
150    pub e2e_latency_ms: Option<u128>,
151}
152
153#[derive(Debug, Clone, Deserialize)]
154pub struct RejectEditPredictionsBody {
155    pub rejections: Vec<EditPredictionRejection>,
156}
157
158#[derive(Debug, Clone, Serialize)]
159pub struct RejectEditPredictionsBodyRef<'a> {
160    pub rejections: &'a [EditPredictionRejection],
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
164pub struct EditPredictionRejection {
165    pub request_id: String,
166    #[serde(default)]
167    pub reason: EditPredictionRejectReason,
168    pub was_shown: bool,
169    #[serde(default, skip_serializing_if = "Option::is_none")]
170    pub model_version: Option<String>,
171    #[serde(default, skip_serializing_if = "Option::is_none")]
172    pub e2e_latency_ms: Option<u128>,
173}
174
175#[derive(
176    Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, strum::AsRefStr,
177)]
178#[strum(serialize_all = "snake_case")]
179pub enum EditPredictionRejectReason {
180    /// New requests were triggered before this one completed
181    Canceled,
182    /// No edits returned
183    Empty,
184    /// Edits returned, but none remained after interpolation
185    InterpolatedEmpty,
186    /// The new prediction was preferred over the current one
187    Replaced,
188    /// The current prediction was preferred over the new one
189    CurrentPreferred,
190    /// The current prediction was discarded
191    #[default]
192    Discarded,
193    /// The current prediction was explicitly rejected by the user
194    Rejected,
195}
196
197#[derive(Debug, Serialize, Deserialize)]
198pub struct CompletionBody {
199    #[serde(skip_serializing_if = "Option::is_none", default)]
200    pub thread_id: Option<String>,
201    #[serde(skip_serializing_if = "Option::is_none", default)]
202    pub prompt_id: Option<String>,
203    pub provider: LanguageModelProvider,
204    pub model: String,
205    pub provider_request: serde_json::Value,
206}
207
208#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
209#[serde(rename_all = "snake_case")]
210pub enum CompletionRequestStatus {
211    Queued {
212        position: usize,
213    },
214    Started,
215    Failed {
216        code: String,
217        message: String,
218        request_id: Uuid,
219        /// Retry duration in seconds.
220        retry_after: Option<f64>,
221    },
222    /// The cloud sends a StreamEnded message when the stream from the LLM provider finishes.
223    StreamEnded,
224    #[serde(other)]
225    Unknown,
226}
227
228#[derive(Serialize, Deserialize)]
229#[serde(rename_all = "snake_case")]
230pub enum CompletionEvent<T> {
231    Status(CompletionRequestStatus),
232    Event(T),
233}
234
235impl<T> CompletionEvent<T> {
236    pub fn into_status(self) -> Option<CompletionRequestStatus> {
237        match self {
238            Self::Status(status) => Some(status),
239            Self::Event(_) => None,
240        }
241    }
242
243    pub fn into_event(self) -> Option<T> {
244        match self {
245            Self::Event(event) => Some(event),
246            Self::Status(_) => None,
247        }
248    }
249}
250
251#[derive(Serialize, Deserialize)]
252pub struct WebSearchBody {
253    pub query: String,
254}
255
256#[derive(Debug, Serialize, Deserialize, Clone)]
257pub struct WebSearchResponse {
258    pub results: Vec<WebSearchResult>,
259}
260
261#[derive(Debug, Serialize, Deserialize, Clone)]
262pub struct WebSearchResult {
263    pub title: String,
264    pub url: String,
265    pub text: String,
266}
267
268#[derive(Serialize, Deserialize)]
269pub struct CountTokensBody {
270    pub provider: LanguageModelProvider,
271    pub model: String,
272    pub provider_request: serde_json::Value,
273}
274
275#[derive(Serialize, Deserialize)]
276pub struct CountTokensResponse {
277    pub tokens: usize,
278}
279
280#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
281pub struct LanguageModelId(pub Arc<str>);
282
283impl std::fmt::Display for LanguageModelId {
284    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285        write!(f, "{}", self.0)
286    }
287}
288
289#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
290pub struct LanguageModel {
291    pub provider: LanguageModelProvider,
292    pub id: LanguageModelId,
293    pub display_name: String,
294    #[serde(default)]
295    pub is_latest: bool,
296    pub max_token_count: usize,
297    pub max_token_count_in_max_mode: Option<usize>,
298    pub max_output_tokens: usize,
299    pub supports_tools: bool,
300    pub supports_images: bool,
301    pub supports_thinking: bool,
302    #[serde(default)]
303    pub supports_fast_mode: bool,
304    pub supported_effort_levels: Vec<SupportedEffortLevel>,
305    #[serde(default)]
306    pub supports_streaming_tools: bool,
307    /// Only used by OpenAI and xAI.
308    #[serde(default)]
309    pub supports_parallel_tool_calls: bool,
310}
311
312#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
313pub struct SupportedEffortLevel {
314    pub name: Arc<str>,
315    pub value: Arc<str>,
316    #[serde(default, skip_serializing_if = "Option::is_none")]
317    pub is_default: Option<bool>,
318}
319
320#[derive(Debug, Serialize, Deserialize)]
321pub struct ListModelsResponse {
322    pub models: Vec<LanguageModel>,
323    pub default_model: Option<LanguageModelId>,
324    pub default_fast_model: Option<LanguageModelId>,
325    pub recommended_models: Vec<LanguageModelId>,
326}
327
328#[derive(Debug, PartialEq, Serialize, Deserialize)]
329pub struct CurrentUsage {
330    pub edit_predictions: UsageData,
331}
332
333#[derive(Debug, PartialEq, Serialize, Deserialize)]
334pub struct UsageData {
335    pub used: u32,
336    pub limit: UsageLimit,
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342
343    #[test]
344    fn test_usage_limit_from_str() {
345        let limit = UsageLimit::from_str("unlimited").unwrap();
346        assert!(matches!(limit, UsageLimit::Unlimited));
347
348        let limit = UsageLimit::from_str(&0.to_string()).unwrap();
349        assert!(matches!(limit, UsageLimit::Limited(0)));
350
351        let limit = UsageLimit::from_str(&50.to_string()).unwrap();
352        assert!(matches!(limit, UsageLimit::Limited(50)));
353
354        for value in ["not_a_number", "50xyz"] {
355            let limit = UsageLimit::from_str(value);
356            assert!(limit.is_err());
357        }
358    }
359}