cloud_llm_client.rs

  1#[cfg(feature = "predict-edits")]
  2pub mod predict_edits_v3;
  3
  4use std::str::FromStr;
  5use std::sync::Arc;
  6
  7use anyhow::Context as _;
  8use serde::{Deserialize, Serialize};
  9use strum::{Display, EnumIter, EnumString};
 10use uuid::Uuid;
 11
 12/// The name of the header used to indicate which version of Zed the client is running.
 13pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
 14
 15/// The name of the header used to indicate which edit prediction experiment should be used.
 16pub const PREFERRED_EXPERIMENT_HEADER_NAME: &str = "x-zed-preferred-experiment";
 17
 18/// The name of the header used to indicate when a request failed due to an
 19/// expired LLM token.
 20///
 21/// The client may use this as a signal to refresh the token.
 22pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
 23
 24/// The name of the header used to indicate when a request failed due to an outdated LLM token.
 25///
 26/// A token is considered "outdated" when we can't parse the claims (e.g., after adding a new required claim).
 27///
 28/// This is distinct from [`EXPIRED_LLM_TOKEN_HEADER_NAME`] which indicates the token's time-based validity has passed.
 29/// An outdated token means the token's structure is incompatible with the current server expectations.
 30pub const OUTDATED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-outdated-token";
 31
 32/// The name of the header used to indicate the usage limit for edit predictions.
 33pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
 34
 35/// The name of the header used to indicate the usage amount for edit predictions.
 36pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
 37
 38pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
 39
 40/// The name of the header used to indicate the minimum required Zed version.
 41///
 42/// This can be used to force a Zed upgrade in order to continue communicating
 43/// with the LLM service.
 44pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
 45
 46/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
 47pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 48    "x-zed-client-supports-status-messages";
 49
 50/// The name of the header used by the client to indicate to the server that it supports receiving a "stream_ended" request completion status.
 51pub const CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME: &str =
 52    "x-zed-client-supports-stream-ended-request-completion-status";
 53
 54/// The name of the header used by the server to indicate to the client that it supports sending status messages.
 55pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 56    "x-zed-server-supports-status-messages";
 57
 58/// The name of the header used by the client to indicate that it supports receiving xAI models.
 59pub const CLIENT_SUPPORTS_X_AI_HEADER_NAME: &str = "x-zed-client-supports-x-ai";
 60
 61/// The maximum number of edit predictions that can be rejected per request.
 62pub const MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST: usize = 100;
 63
 64#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 65#[serde(rename_all = "snake_case")]
 66pub enum UsageLimit {
 67    Limited(i32),
 68    Unlimited,
 69}
 70
 71impl FromStr for UsageLimit {
 72    type Err = anyhow::Error;
 73
 74    fn from_str(value: &str) -> Result<Self, Self::Err> {
 75        match value {
 76            "unlimited" => Ok(Self::Unlimited),
 77            limit => limit
 78                .parse::<i32>()
 79                .map(Self::Limited)
 80                .context("failed to parse limit"),
 81        }
 82    }
 83}
 84
 85#[derive(
 86    Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
 87)]
 88#[serde(rename_all = "snake_case")]
 89#[strum(serialize_all = "snake_case")]
 90pub enum LanguageModelProvider {
 91    Anthropic,
 92    OpenAi,
 93    Google,
 94    XAi,
 95}
 96
 97#[derive(Debug, Clone, Serialize, Deserialize)]
 98pub struct PredictEditsBody {
 99    #[serde(skip_serializing_if = "Option::is_none", default)]
100    pub outline: Option<String>,
101    pub input_events: String,
102    pub input_excerpt: String,
103    #[serde(skip_serializing_if = "Option::is_none", default)]
104    pub speculated_output: Option<String>,
105    /// Whether the user provided consent for sampling this interaction.
106    #[serde(default, alias = "data_collection_permission")]
107    pub can_collect_data: bool,
108    #[serde(skip_serializing_if = "Option::is_none", default)]
109    pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
110    /// Info about the git repository state, only present when can_collect_data is true.
111    #[serde(skip_serializing_if = "Option::is_none", default)]
112    pub git_info: Option<PredictEditsGitInfo>,
113    /// The trigger for this request.
114    #[serde(default)]
115    pub trigger: PredictEditsRequestTrigger,
116}
117
118#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, strum::AsRefStr)]
119#[strum(serialize_all = "snake_case")]
120pub enum PredictEditsRequestTrigger {
121    Testing,
122    Diagnostics,
123    Cli,
124    #[default]
125    Other,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct PredictEditsGitInfo {
130    /// SHA of git HEAD commit at time of prediction.
131    #[serde(skip_serializing_if = "Option::is_none", default)]
132    pub head_sha: Option<String>,
133    /// URL of the remote called `origin`.
134    #[serde(skip_serializing_if = "Option::is_none", default)]
135    pub remote_origin_url: Option<String>,
136    /// URL of the remote called `upstream`.
137    #[serde(skip_serializing_if = "Option::is_none", default)]
138    pub remote_upstream_url: Option<String>,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct PredictEditsResponse {
143    pub request_id: String,
144    pub output_excerpt: String,
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct AcceptEditPredictionBody {
149    pub request_id: String,
150    #[serde(default, skip_serializing_if = "Option::is_none")]
151    pub model_version: Option<String>,
152    #[serde(default, skip_serializing_if = "Option::is_none")]
153    pub e2e_latency_ms: Option<u128>,
154}
155
156#[derive(Debug, Clone, Deserialize)]
157pub struct RejectEditPredictionsBody {
158    pub rejections: Vec<EditPredictionRejection>,
159}
160
161#[derive(Debug, Clone, Serialize)]
162pub struct RejectEditPredictionsBodyRef<'a> {
163    pub rejections: &'a [EditPredictionRejection],
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
167pub struct EditPredictionRejection {
168    pub request_id: String,
169    #[serde(default)]
170    pub reason: EditPredictionRejectReason,
171    pub was_shown: bool,
172    #[serde(default, skip_serializing_if = "Option::is_none")]
173    pub model_version: Option<String>,
174    #[serde(default, skip_serializing_if = "Option::is_none")]
175    pub e2e_latency_ms: Option<u128>,
176}
177
178#[derive(
179    Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, strum::AsRefStr,
180)]
181#[strum(serialize_all = "snake_case")]
182pub enum EditPredictionRejectReason {
183    /// New requests were triggered before this one completed
184    Canceled,
185    /// No edits returned
186    Empty,
187    /// Edits returned, but none remained after interpolation
188    InterpolatedEmpty,
189    /// The new prediction was preferred over the current one
190    Replaced,
191    /// The current prediction was preferred over the new one
192    CurrentPreferred,
193    /// The current prediction was discarded
194    #[default]
195    Discarded,
196    /// The current prediction was explicitly rejected by the user
197    Rejected,
198}
199
200#[derive(Debug, Serialize, Deserialize)]
201pub struct CompletionBody {
202    #[serde(skip_serializing_if = "Option::is_none", default)]
203    pub thread_id: Option<String>,
204    #[serde(skip_serializing_if = "Option::is_none", default)]
205    pub prompt_id: Option<String>,
206    pub provider: LanguageModelProvider,
207    pub model: String,
208    pub provider_request: serde_json::Value,
209}
210
211#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
212#[serde(rename_all = "snake_case")]
213pub enum CompletionRequestStatus {
214    Queued {
215        position: usize,
216    },
217    Started,
218    Failed {
219        code: String,
220        message: String,
221        request_id: Uuid,
222        /// Retry duration in seconds.
223        retry_after: Option<f64>,
224    },
225    /// The cloud sends a StreamEnded message when the stream from the LLM provider finishes.
226    StreamEnded,
227    #[serde(other)]
228    Unknown,
229}
230
231#[derive(Serialize, Deserialize)]
232#[serde(rename_all = "snake_case")]
233pub enum CompletionEvent<T> {
234    Status(CompletionRequestStatus),
235    Event(T),
236}
237
238impl<T> CompletionEvent<T> {
239    pub fn into_status(self) -> Option<CompletionRequestStatus> {
240        match self {
241            Self::Status(status) => Some(status),
242            Self::Event(_) => None,
243        }
244    }
245
246    pub fn into_event(self) -> Option<T> {
247        match self {
248            Self::Event(event) => Some(event),
249            Self::Status(_) => None,
250        }
251    }
252}
253
254#[derive(Serialize, Deserialize)]
255pub struct WebSearchBody {
256    pub query: String,
257}
258
259#[derive(Debug, Serialize, Deserialize, Clone)]
260pub struct WebSearchResponse {
261    pub results: Vec<WebSearchResult>,
262}
263
264#[derive(Debug, Serialize, Deserialize, Clone)]
265pub struct WebSearchResult {
266    pub title: String,
267    pub url: String,
268    pub text: String,
269}
270
271#[derive(Serialize, Deserialize)]
272pub struct CountTokensBody {
273    pub provider: LanguageModelProvider,
274    pub model: String,
275    pub provider_request: serde_json::Value,
276}
277
278#[derive(Serialize, Deserialize)]
279pub struct CountTokensResponse {
280    pub tokens: usize,
281}
282
283#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
284pub struct LanguageModelId(pub Arc<str>);
285
286impl std::fmt::Display for LanguageModelId {
287    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288        write!(f, "{}", self.0)
289    }
290}
291
292#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
293pub struct LanguageModel {
294    pub provider: LanguageModelProvider,
295    pub id: LanguageModelId,
296    pub display_name: String,
297    #[serde(default)]
298    pub is_latest: bool,
299    pub max_token_count: usize,
300    pub max_token_count_in_max_mode: Option<usize>,
301    pub max_output_tokens: usize,
302    pub supports_tools: bool,
303    pub supports_images: bool,
304    pub supports_thinking: bool,
305    #[serde(default)]
306    pub supports_fast_mode: bool,
307    pub supported_effort_levels: Vec<SupportedEffortLevel>,
308    #[serde(default)]
309    pub supports_streaming_tools: bool,
310    /// Only used by OpenAI and xAI.
311    #[serde(default)]
312    pub supports_parallel_tool_calls: bool,
313}
314
315#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
316pub struct SupportedEffortLevel {
317    pub name: Arc<str>,
318    pub value: Arc<str>,
319    #[serde(default, skip_serializing_if = "Option::is_none")]
320    pub is_default: Option<bool>,
321}
322
323#[derive(Debug, Serialize, Deserialize)]
324pub struct ListModelsResponse {
325    pub models: Vec<LanguageModel>,
326    pub default_model: Option<LanguageModelId>,
327    pub default_fast_model: Option<LanguageModelId>,
328    pub recommended_models: Vec<LanguageModelId>,
329}
330
331#[derive(Debug, PartialEq, Serialize, Deserialize)]
332pub struct CurrentUsage {
333    pub edit_predictions: UsageData,
334}
335
336#[derive(Debug, PartialEq, Serialize, Deserialize)]
337pub struct UsageData {
338    pub used: u32,
339    pub limit: UsageLimit,
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345
346    #[test]
347    fn test_usage_limit_from_str() {
348        let limit = UsageLimit::from_str("unlimited").unwrap();
349        assert!(matches!(limit, UsageLimit::Unlimited));
350
351        let limit = UsageLimit::from_str(&0.to_string()).unwrap();
352        assert!(matches!(limit, UsageLimit::Limited(0)));
353
354        let limit = UsageLimit::from_str(&50.to_string()).unwrap();
355        assert!(matches!(limit, UsageLimit::Limited(50)));
356
357        for value in ["not_a_number", "50xyz"] {
358            let limit = UsageLimit::from_str(value);
359            assert!(limit.is_err());
360        }
361    }
362}