cloud_llm_client.rs

  1use std::str::FromStr;
  2use std::sync::Arc;
  3
  4use anyhow::Context as _;
  5use serde::{Deserialize, Deserializer, Serialize, Serializer};
  6use serde_json::value::RawValue;
  7use std::marker::PhantomData;
  8use strum::{Display, EnumIter, EnumString};
  9use uuid::Uuid;
 10
 11/// The name of the header used to indicate which version of Zed the client is running.
 12pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
 13
 14/// The name of the header used to indicate when a request failed due to an
 15/// expired LLM token.
 16///
 17/// The client may use this as a signal to refresh the token.
 18pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
 19
 20/// The name of the header used to indicate what plan the user is currently on.
 21pub const CURRENT_PLAN_HEADER_NAME: &str = "x-zed-plan";
 22
 23/// The name of the header used to indicate the usage limit for model requests.
 24pub const MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-model-requests-usage-limit";
 25
 26/// The name of the header used to indicate the usage amount for model requests.
 27pub const MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-model-requests-usage-amount";
 28
 29/// The name of the header used to indicate the usage limit for edit predictions.
 30pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
 31
 32/// The name of the header used to indicate the usage amount for edit predictions.
 33pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
 34
 35/// The name of the header used to indicate the resource for which the subscription limit has been reached.
 36pub const SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME: &str = "x-zed-subscription-limit-resource";
 37
 38pub const MODEL_REQUESTS_RESOURCE_HEADER_VALUE: &str = "model_requests";
 39pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
 40
 41/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached.
 42pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached";
 43
 44/// The name of the header used to indicate the the minimum required Zed version.
 45///
 46/// This can be used to force a Zed upgrade in order to continue communicating
 47/// with the LLM service.
 48pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
 49
 50/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
 51pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 52    "x-zed-client-supports-status-messages";
 53
 54/// The name of the header used by the server to indicate to the client that it supports sending status messages.
 55pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 56    "x-zed-server-supports-status-messages";
 57
 58#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 59#[serde(rename_all = "snake_case")]
 60pub enum UsageLimit {
 61    Limited(i32),
 62    Unlimited,
 63}
 64
 65impl FromStr for UsageLimit {
 66    type Err = anyhow::Error;
 67
 68    fn from_str(value: &str) -> Result<Self, Self::Err> {
 69        match value {
 70            "unlimited" => Ok(Self::Unlimited),
 71            limit => limit
 72                .parse::<i32>()
 73                .map(Self::Limited)
 74                .context("failed to parse limit"),
 75        }
 76    }
 77}
 78
 79#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
 80#[serde(rename_all = "snake_case")]
 81pub enum Plan {
 82    #[default]
 83    #[serde(alias = "Free")]
 84    ZedFree,
 85    #[serde(alias = "ZedPro")]
 86    ZedPro,
 87    #[serde(alias = "ZedProTrial")]
 88    ZedProTrial,
 89}
 90
 91impl Plan {
 92    pub fn as_str(&self) -> &'static str {
 93        match self {
 94            Plan::ZedFree => "zed_free",
 95            Plan::ZedPro => "zed_pro",
 96            Plan::ZedProTrial => "zed_pro_trial",
 97        }
 98    }
 99
100    pub fn model_requests_limit(&self) -> UsageLimit {
101        match self {
102            Plan::ZedPro => UsageLimit::Limited(500),
103            Plan::ZedProTrial => UsageLimit::Limited(150),
104            Plan::ZedFree => UsageLimit::Limited(50),
105        }
106    }
107
108    pub fn edit_predictions_limit(&self) -> UsageLimit {
109        match self {
110            Plan::ZedPro => UsageLimit::Unlimited,
111            Plan::ZedProTrial => UsageLimit::Unlimited,
112            Plan::ZedFree => UsageLimit::Limited(2_000),
113        }
114    }
115}
116
117impl FromStr for Plan {
118    type Err = anyhow::Error;
119
120    fn from_str(value: &str) -> Result<Self, Self::Err> {
121        match value {
122            "zed_free" => Ok(Plan::ZedFree),
123            "zed_pro" => Ok(Plan::ZedPro),
124            "zed_pro_trial" => Ok(Plan::ZedProTrial),
125            plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
126        }
127    }
128}
129
130#[derive(
131    Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
132)]
133#[serde(rename_all = "snake_case")]
134#[strum(serialize_all = "snake_case")]
135pub enum LanguageModelProvider {
136    Anthropic,
137    OpenAi,
138    Google,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct PredictEditsBody {
143    pub input_events: String,
144    pub input_excerpt: String,
145    /// Whether the user provided consent for sampling this interaction.
146    #[serde(default, alias = "data_collection_permission")]
147    pub can_collect_data: bool,
148    /// Note that this is no longer sent, in favor of `PredictEditsAdditionalContext`.
149    #[serde(skip_serializing_if = "Option::is_none", default)]
150    pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
151    /// Info about the git repository state, only present when can_collect_data is true. Note that
152    /// this is no longer sent, in favor of `PredictEditsAdditionalContext`.
153    #[serde(skip_serializing_if = "Option::is_none", default)]
154    pub git_info: Option<PredictEditsGitInfo>,
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct PredictEditsGitInfo {
159    /// full_path to the repo (worktree name + relative path to repo)
160    pub worktree_path: Option<String>,
161    /// SHA of git HEAD commit at time of prediction.
162    #[serde(skip_serializing_if = "Option::is_none", default)]
163    pub head_sha: Option<String>,
164    /// URL of the remote called `origin`.
165    #[serde(skip_serializing_if = "Option::is_none", default)]
166    pub remote_origin_url: Option<String>,
167    /// URL of the remote called `upstream`.
168    #[serde(skip_serializing_if = "Option::is_none", default)]
169    pub remote_upstream_url: Option<String>,
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct PredictEditsResponse {
174    pub request_id: Uuid,
175    pub output_excerpt: String,
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct AcceptEditPredictionBody {
180    pub request_id: Uuid,
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
184pub struct PredictEditsTrainingData {
185    pub request_id: Uuid,
186    /// When true, `request_id` is an ID that corresponds to an edit prediction.
187    pub has_prediction: bool,
188    /// State that `events` is based on. Initially this is `GitHead` and subsequent uploads will
189    /// then be based on the previous upload.
190    pub diff_base: PredictEditsDiffBase,
191    /// Fine-grained edit events atop `diff_base`.
192    #[serde(skip_serializing_if = "Vec::is_empty", default)]
193    pub events: Vec<SerializedJson<PredictEditsEvent>>,
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
197#[serde(rename_all = "snake_case")]
198pub enum PredictEditsDiffBase {
199    GitHead { git_info: PredictEditsGitInfo },
200    PreviousUpload { request_id: Uuid },
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct PredictEditsEvent {
205    pub entry_id: usize,
206    #[serde(skip_serializing_if = "Option::is_none", default)]
207    pub path: Option<String>,
208    pub timestamp_ms: u64,
209    pub data: PredictEditsEventData,
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
213#[serde(rename_all = "snake_case")]
214pub enum PredictEditsEventData {
215    MoveCursor {
216        offset: usize,
217        #[serde(skip_serializing_if = "Vec::is_empty", default)]
218        diagnostic_groups: Vec<(String, Box<RawValue>)>,
219        #[serde(skip_serializing_if = "is_default", default)]
220        diagnostic_groups_truncated: bool,
221    },
222    Create {
223        content: String,
224    },
225    Delete,
226    Edit {
227        unified_diff: String,
228    },
229    MarkDiffTooLarge,
230}
231
232#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
233#[serde(rename_all = "snake_case")]
234pub enum CompletionMode {
235    Normal,
236    Max,
237}
238
239#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
240#[serde(rename_all = "snake_case")]
241pub enum CompletionIntent {
242    UserPrompt,
243    ToolResults,
244    ThreadSummarization,
245    ThreadContextSummarization,
246    CreateFile,
247    EditFile,
248    InlineAssist,
249    TerminalInlineAssist,
250    GenerateGitCommitMessage,
251}
252
253#[derive(Debug, Serialize, Deserialize)]
254pub struct CompletionBody {
255    #[serde(skip_serializing_if = "Option::is_none", default)]
256    pub thread_id: Option<String>,
257    #[serde(skip_serializing_if = "Option::is_none", default)]
258    pub prompt_id: Option<String>,
259    #[serde(skip_serializing_if = "Option::is_none", default)]
260    pub intent: Option<CompletionIntent>,
261    #[serde(skip_serializing_if = "Option::is_none", default)]
262    pub mode: Option<CompletionMode>,
263    pub provider: LanguageModelProvider,
264    pub model: String,
265    pub provider_request: serde_json::Value,
266}
267
268#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
269#[serde(rename_all = "snake_case")]
270pub enum CompletionRequestStatus {
271    Queued {
272        position: usize,
273    },
274    Started,
275    Failed {
276        code: String,
277        message: String,
278        request_id: Uuid,
279        /// Retry duration in seconds.
280        retry_after: Option<f64>,
281    },
282    UsageUpdated {
283        amount: usize,
284        limit: UsageLimit,
285    },
286    ToolUseLimitReached,
287}
288
289#[derive(Serialize, Deserialize)]
290#[serde(rename_all = "snake_case")]
291pub enum CompletionEvent<T> {
292    Status(CompletionRequestStatus),
293    Event(T),
294}
295
296impl<T> CompletionEvent<T> {
297    pub fn into_status(self) -> Option<CompletionRequestStatus> {
298        match self {
299            Self::Status(status) => Some(status),
300            Self::Event(_) => None,
301        }
302    }
303
304    pub fn into_event(self) -> Option<T> {
305        match self {
306            Self::Event(event) => Some(event),
307            Self::Status(_) => None,
308        }
309    }
310}
311
312#[derive(Serialize, Deserialize)]
313pub struct WebSearchBody {
314    pub query: String,
315}
316
317#[derive(Debug, Serialize, Deserialize, Clone)]
318pub struct WebSearchResponse {
319    pub results: Vec<WebSearchResult>,
320}
321
322#[derive(Debug, Serialize, Deserialize, Clone)]
323pub struct WebSearchResult {
324    pub title: String,
325    pub url: String,
326    pub text: String,
327}
328
329#[derive(Serialize, Deserialize)]
330pub struct CountTokensBody {
331    pub provider: LanguageModelProvider,
332    pub model: String,
333    pub provider_request: serde_json::Value,
334}
335
336#[derive(Serialize, Deserialize)]
337pub struct CountTokensResponse {
338    pub tokens: usize,
339}
340
341#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
342pub struct LanguageModelId(pub Arc<str>);
343
344impl std::fmt::Display for LanguageModelId {
345    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
346        write!(f, "{}", self.0)
347    }
348}
349
350#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
351pub struct LanguageModel {
352    pub provider: LanguageModelProvider,
353    pub id: LanguageModelId,
354    pub display_name: String,
355    pub max_token_count: usize,
356    pub max_token_count_in_max_mode: Option<usize>,
357    pub max_output_tokens: usize,
358    pub supports_tools: bool,
359    pub supports_images: bool,
360    pub supports_thinking: bool,
361    pub supports_max_mode: bool,
362}
363
364#[derive(Debug, Serialize, Deserialize)]
365pub struct ListModelsResponse {
366    pub models: Vec<LanguageModel>,
367    pub default_model: LanguageModelId,
368    pub default_fast_model: LanguageModelId,
369    pub recommended_models: Vec<LanguageModelId>,
370}
371
372#[derive(Debug, Serialize, Deserialize)]
373pub struct GetSubscriptionResponse {
374    pub plan: Plan,
375    pub usage: Option<CurrentUsage>,
376}
377
378#[derive(Debug, PartialEq, Serialize, Deserialize)]
379pub struct CurrentUsage {
380    pub model_requests: UsageData,
381    pub edit_predictions: UsageData,
382}
383
384#[derive(Debug, PartialEq, Serialize, Deserialize)]
385pub struct UsageData {
386    pub used: u32,
387    pub limit: UsageLimit,
388}
389
390#[derive(Debug, Clone)]
391pub struct SerializedJson<T> {
392    raw: Box<RawValue>,
393    _phantom: PhantomData<T>,
394}
395
396impl<T> SerializedJson<T>
397where
398    T: Serialize + for<'de> Deserialize<'de>,
399{
400    pub fn new(value: &T) -> Result<Self, serde_json::Error> {
401        Ok(SerializedJson {
402            raw: serde_json::value::to_raw_value(value)?,
403            _phantom: PhantomData,
404        })
405    }
406
407    pub fn deserialize(&self) -> Result<T, serde_json::Error> {
408        serde_json::from_str(self.raw.get())
409    }
410
411    pub fn as_raw(&self) -> &RawValue {
412        &self.raw
413    }
414
415    pub fn into_raw(self) -> Box<RawValue> {
416        self.raw
417    }
418}
419
420impl<T> Serialize for SerializedJson<T> {
421    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
422    where
423        S: Serializer,
424    {
425        self.raw.serialize(serializer)
426    }
427}
428
429impl<'de, T> Deserialize<'de> for SerializedJson<T> {
430    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
431    where
432        D: Deserializer<'de>,
433    {
434        let raw = Box::<RawValue>::deserialize(deserializer)?;
435        Ok(SerializedJson {
436            raw,
437            _phantom: PhantomData,
438        })
439    }
440}
441
442fn is_default<T: Default + PartialEq>(value: &T) -> bool {
443    *value == T::default()
444}
445
446#[cfg(test)]
447mod tests {
448    use pretty_assertions::assert_eq;
449    use serde_json::json;
450
451    use super::*;
452
453    #[test]
454    fn test_plan_deserialize_snake_case() {
455        let plan = serde_json::from_value::<Plan>(json!("zed_free")).unwrap();
456        assert_eq!(plan, Plan::ZedFree);
457
458        let plan = serde_json::from_value::<Plan>(json!("zed_pro")).unwrap();
459        assert_eq!(plan, Plan::ZedPro);
460
461        let plan = serde_json::from_value::<Plan>(json!("zed_pro_trial")).unwrap();
462        assert_eq!(plan, Plan::ZedProTrial);
463    }
464
465    #[test]
466    fn test_plan_deserialize_aliases() {
467        let plan = serde_json::from_value::<Plan>(json!("Free")).unwrap();
468        assert_eq!(plan, Plan::ZedFree);
469
470        let plan = serde_json::from_value::<Plan>(json!("ZedPro")).unwrap();
471        assert_eq!(plan, Plan::ZedPro);
472
473        let plan = serde_json::from_value::<Plan>(json!("ZedProTrial")).unwrap();
474        assert_eq!(plan, Plan::ZedProTrial);
475    }
476
477    #[test]
478    fn test_usage_limit_from_str() {
479        let limit = UsageLimit::from_str("unlimited").unwrap();
480        assert!(matches!(limit, UsageLimit::Unlimited));
481
482        let limit = UsageLimit::from_str(&0.to_string()).unwrap();
483        assert!(matches!(limit, UsageLimit::Limited(0)));
484
485        let limit = UsageLimit::from_str(&50.to_string()).unwrap();
486        assert!(matches!(limit, UsageLimit::Limited(50)));
487
488        for value in ["not_a_number", "50xyz"] {
489            let limit = UsageLimit::from_str(value);
490            assert!(limit.is_err());
491        }
492    }
493}