cloud_llm_client.rs

  1pub mod predict_edits_v3;
  2
  3use std::str::FromStr;
  4use std::sync::Arc;
  5
  6use anyhow::Context as _;
  7use serde::{Deserialize, Serialize};
  8use strum::{Display, EnumIter, EnumString};
  9use uuid::Uuid;
 10
 11/// The name of the header used to indicate which version of Zed the client is running.
 12pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
 13
 14/// The name of the header used to indicate when a request failed due to an
 15/// expired LLM token.
 16///
 17/// The client may use this as a signal to refresh the token.
 18pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
 19
 20/// The name of the header used to indicate the usage limit for edit predictions.
 21pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
 22
 23/// The name of the header used to indicate the usage amount for edit predictions.
 24pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
 25
 26pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
 27
 28/// The name of the header used to indicate the minimum required Zed version.
 29///
 30/// This can be used to force a Zed upgrade in order to continue communicating
 31/// with the LLM service.
 32pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
 33
 34/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
 35pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 36    "x-zed-client-supports-status-messages";
 37
 38/// The name of the header used by the server to indicate to the client that it supports sending status messages.
 39pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 40    "x-zed-server-supports-status-messages";
 41
 42/// The name of the header used by the client to indicate that it supports receiving xAI models.
 43pub const CLIENT_SUPPORTS_X_AI_HEADER_NAME: &str = "x-zed-client-supports-x-ai";
 44
 45/// The maximum number of edit predictions that can be rejected per request.
 46pub const MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST: usize = 100;
 47
 48#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 49#[serde(rename_all = "snake_case")]
 50pub enum UsageLimit {
 51    Limited(i32),
 52    Unlimited,
 53}
 54
 55impl FromStr for UsageLimit {
 56    type Err = anyhow::Error;
 57
 58    fn from_str(value: &str) -> Result<Self, Self::Err> {
 59        match value {
 60            "unlimited" => Ok(Self::Unlimited),
 61            limit => limit
 62                .parse::<i32>()
 63                .map(Self::Limited)
 64                .context("failed to parse limit"),
 65        }
 66    }
 67}
 68
 69#[derive(Debug, Clone, Copy, PartialEq)]
 70pub enum Plan {
 71    V2(PlanV2),
 72}
 73
 74impl Plan {
 75    pub fn is_v2(&self) -> bool {
 76        matches!(self, Self::V2(_))
 77    }
 78}
 79
 80#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
 81#[serde(rename_all = "snake_case")]
 82pub enum PlanV2 {
 83    #[default]
 84    ZedFree,
 85    ZedPro,
 86    ZedProTrial,
 87}
 88
 89impl FromStr for PlanV2 {
 90    type Err = anyhow::Error;
 91
 92    fn from_str(value: &str) -> Result<Self, Self::Err> {
 93        match value {
 94            "zed_free" => Ok(Self::ZedFree),
 95            "zed_pro" => Ok(Self::ZedPro),
 96            "zed_pro_trial" => Ok(Self::ZedProTrial),
 97            plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
 98        }
 99    }
100}
101
102#[derive(
103    Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
104)]
105#[serde(rename_all = "snake_case")]
106#[strum(serialize_all = "snake_case")]
107pub enum LanguageModelProvider {
108    Anthropic,
109    OpenAi,
110    Google,
111    XAi,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct PredictEditsBody {
116    #[serde(skip_serializing_if = "Option::is_none", default)]
117    pub outline: Option<String>,
118    pub input_events: String,
119    pub input_excerpt: String,
120    #[serde(skip_serializing_if = "Option::is_none", default)]
121    pub speculated_output: Option<String>,
122    /// Whether the user provided consent for sampling this interaction.
123    #[serde(default, alias = "data_collection_permission")]
124    pub can_collect_data: bool,
125    #[serde(skip_serializing_if = "Option::is_none", default)]
126    pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
127    /// Info about the git repository state, only present when can_collect_data is true.
128    #[serde(skip_serializing_if = "Option::is_none", default)]
129    pub git_info: Option<PredictEditsGitInfo>,
130    /// The trigger for this request.
131    #[serde(default)]
132    pub trigger: PredictEditsRequestTrigger,
133}
134
135#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)]
136pub enum PredictEditsRequestTrigger {
137    Testing,
138    Diagnostics,
139    Cli,
140    #[default]
141    Other,
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct PredictEditsGitInfo {
146    /// SHA of git HEAD commit at time of prediction.
147    #[serde(skip_serializing_if = "Option::is_none", default)]
148    pub head_sha: Option<String>,
149    /// URL of the remote called `origin`.
150    #[serde(skip_serializing_if = "Option::is_none", default)]
151    pub remote_origin_url: Option<String>,
152    /// URL of the remote called `upstream`.
153    #[serde(skip_serializing_if = "Option::is_none", default)]
154    pub remote_upstream_url: Option<String>,
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct PredictEditsResponse {
159    pub request_id: String,
160    pub output_excerpt: String,
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct AcceptEditPredictionBody {
165    pub request_id: String,
166}
167
168#[derive(Debug, Clone, Deserialize)]
169pub struct RejectEditPredictionsBody {
170    pub rejections: Vec<EditPredictionRejection>,
171}
172
173#[derive(Debug, Clone, Serialize)]
174pub struct RejectEditPredictionsBodyRef<'a> {
175    pub rejections: &'a [EditPredictionRejection],
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
179pub struct EditPredictionRejection {
180    pub request_id: String,
181    #[serde(default)]
182    pub reason: EditPredictionRejectReason,
183    pub was_shown: bool,
184}
185
186#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
187pub enum EditPredictionRejectReason {
188    /// New requests were triggered before this one completed
189    Canceled,
190    /// No edits returned
191    Empty,
192    /// Edits returned, but none remained after interpolation
193    InterpolatedEmpty,
194    /// The new prediction was preferred over the current one
195    Replaced,
196    /// The current prediction was preferred over the new one
197    CurrentPreferred,
198    /// The current prediction was discarded
199    #[default]
200    Discarded,
201}
202
203#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
204#[serde(rename_all = "snake_case")]
205pub enum CompletionIntent {
206    UserPrompt,
207    ToolResults,
208    ThreadSummarization,
209    ThreadContextSummarization,
210    CreateFile,
211    EditFile,
212    InlineAssist,
213    TerminalInlineAssist,
214    GenerateGitCommitMessage,
215}
216
217#[derive(Debug, Serialize, Deserialize)]
218pub struct CompletionBody {
219    #[serde(skip_serializing_if = "Option::is_none", default)]
220    pub thread_id: Option<String>,
221    #[serde(skip_serializing_if = "Option::is_none", default)]
222    pub prompt_id: Option<String>,
223    #[serde(skip_serializing_if = "Option::is_none", default)]
224    pub intent: Option<CompletionIntent>,
225    pub provider: LanguageModelProvider,
226    pub model: String,
227    pub provider_request: serde_json::Value,
228}
229
230#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
231#[serde(rename_all = "snake_case")]
232pub enum CompletionRequestStatus {
233    Queued {
234        position: usize,
235    },
236    Started,
237    Failed {
238        code: String,
239        message: String,
240        request_id: Uuid,
241        /// Retry duration in seconds.
242        retry_after: Option<f64>,
243    },
244    UsageUpdated {
245        amount: usize,
246        limit: UsageLimit,
247    },
248    ToolUseLimitReached,
249}
250
251#[derive(Serialize, Deserialize)]
252#[serde(rename_all = "snake_case")]
253pub enum CompletionEvent<T> {
254    Status(CompletionRequestStatus),
255    Event(T),
256}
257
258impl<T> CompletionEvent<T> {
259    pub fn into_status(self) -> Option<CompletionRequestStatus> {
260        match self {
261            Self::Status(status) => Some(status),
262            Self::Event(_) => None,
263        }
264    }
265
266    pub fn into_event(self) -> Option<T> {
267        match self {
268            Self::Event(event) => Some(event),
269            Self::Status(_) => None,
270        }
271    }
272}
273
274#[derive(Serialize, Deserialize)]
275pub struct WebSearchBody {
276    pub query: String,
277}
278
279#[derive(Debug, Serialize, Deserialize, Clone)]
280pub struct WebSearchResponse {
281    pub results: Vec<WebSearchResult>,
282}
283
284#[derive(Debug, Serialize, Deserialize, Clone)]
285pub struct WebSearchResult {
286    pub title: String,
287    pub url: String,
288    pub text: String,
289}
290
291#[derive(Serialize, Deserialize)]
292pub struct CountTokensBody {
293    pub provider: LanguageModelProvider,
294    pub model: String,
295    pub provider_request: serde_json::Value,
296}
297
298#[derive(Serialize, Deserialize)]
299pub struct CountTokensResponse {
300    pub tokens: usize,
301}
302
303#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
304pub struct LanguageModelId(pub Arc<str>);
305
306impl std::fmt::Display for LanguageModelId {
307    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308        write!(f, "{}", self.0)
309    }
310}
311
312#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
313pub struct LanguageModel {
314    pub provider: LanguageModelProvider,
315    pub id: LanguageModelId,
316    pub display_name: String,
317    pub max_token_count: usize,
318    pub max_token_count_in_max_mode: Option<usize>,
319    pub max_output_tokens: usize,
320    pub supports_tools: bool,
321    pub supports_images: bool,
322    pub supports_thinking: bool,
323    #[serde(default)]
324    pub supports_streaming_tools: bool,
325    /// Only used by OpenAI and xAI.
326    #[serde(default)]
327    pub supports_parallel_tool_calls: bool,
328}
329
330#[derive(Debug, Serialize, Deserialize)]
331pub struct ListModelsResponse {
332    pub models: Vec<LanguageModel>,
333    pub default_model: Option<LanguageModelId>,
334    pub default_fast_model: Option<LanguageModelId>,
335    pub recommended_models: Vec<LanguageModelId>,
336}
337
338#[derive(Debug, PartialEq, Serialize, Deserialize)]
339pub struct CurrentUsage {
340    pub edit_predictions: UsageData,
341}
342
343#[derive(Debug, PartialEq, Serialize, Deserialize)]
344pub struct UsageData {
345    pub used: u32,
346    pub limit: UsageLimit,
347}
348
349#[cfg(test)]
350mod tests {
351    use pretty_assertions::assert_eq;
352    use serde_json::json;
353
354    use super::*;
355
356    #[test]
357    fn test_plan_v2_deserialize_snake_case() {
358        let plan = serde_json::from_value::<PlanV2>(json!("zed_free")).unwrap();
359        assert_eq!(plan, PlanV2::ZedFree);
360
361        let plan = serde_json::from_value::<PlanV2>(json!("zed_pro")).unwrap();
362        assert_eq!(plan, PlanV2::ZedPro);
363
364        let plan = serde_json::from_value::<PlanV2>(json!("zed_pro_trial")).unwrap();
365        assert_eq!(plan, PlanV2::ZedProTrial);
366    }
367
368    #[test]
369    fn test_usage_limit_from_str() {
370        let limit = UsageLimit::from_str("unlimited").unwrap();
371        assert!(matches!(limit, UsageLimit::Unlimited));
372
373        let limit = UsageLimit::from_str(&0.to_string()).unwrap();
374        assert!(matches!(limit, UsageLimit::Limited(0)));
375
376        let limit = UsageLimit::from_str(&50.to_string()).unwrap();
377        assert!(matches!(limit, UsageLimit::Limited(50)));
378
379        for value in ["not_a_number", "50xyz"] {
380            let limit = UsageLimit::from_str(value);
381            assert!(limit.is_err());
382        }
383    }
384}