cloud_llm_client.rs

  1pub mod predict_edits_v3;
  2
  3use std::str::FromStr;
  4use std::sync::Arc;
  5
  6use anyhow::Context as _;
  7use serde::{Deserialize, Serialize};
  8use strum::{Display, EnumIter, EnumString};
  9use uuid::Uuid;
 10
 11/// The name of the header used to indicate which version of Zed the client is running.
 12pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
 13
 14/// The name of the header used to indicate when a request failed due to an
 15/// expired LLM token.
 16///
 17/// The client may use this as a signal to refresh the token.
 18pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
 19
 20/// The name of the header used to indicate what plan the user is currently on.
 21pub const CURRENT_PLAN_HEADER_NAME: &str = "x-zed-plan";
 22
 23/// The name of the header used to indicate the usage limit for model requests.
 24pub const MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-model-requests-usage-limit";
 25
 26/// The name of the header used to indicate the usage amount for model requests.
 27pub const MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-model-requests-usage-amount";
 28
 29/// The name of the header used to indicate the usage limit for edit predictions.
 30pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
 31
 32/// The name of the header used to indicate the usage amount for edit predictions.
 33pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
 34
 35/// The name of the header used to indicate the resource for which the subscription limit has been reached.
 36pub const SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME: &str = "x-zed-subscription-limit-resource";
 37
 38pub const MODEL_REQUESTS_RESOURCE_HEADER_VALUE: &str = "model_requests";
 39pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
 40
 41/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached.
 42pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached";
 43
 44/// The name of the header used to indicate the minimum required Zed version.
 45///
 46/// This can be used to force a Zed upgrade in order to continue communicating
 47/// with the LLM service.
 48pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
 49
 50/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
 51pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 52    "x-zed-client-supports-status-messages";
 53
 54/// The name of the header used by the server to indicate to the client that it supports sending status messages.
 55pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 56    "x-zed-server-supports-status-messages";
 57
 58#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 59#[serde(rename_all = "snake_case")]
 60pub enum UsageLimit {
 61    Limited(i32),
 62    Unlimited,
 63}
 64
 65impl FromStr for UsageLimit {
 66    type Err = anyhow::Error;
 67
 68    fn from_str(value: &str) -> Result<Self, Self::Err> {
 69        match value {
 70            "unlimited" => Ok(Self::Unlimited),
 71            limit => limit
 72                .parse::<i32>()
 73                .map(Self::Limited)
 74                .context("failed to parse limit"),
 75        }
 76    }
 77}
 78
 79#[derive(Debug, Clone, Copy, PartialEq)]
 80pub enum Plan {
 81    V1(PlanV1),
 82    V2(PlanV2),
 83}
 84
 85impl Plan {
 86    pub fn is_v2(&self) -> bool {
 87        matches!(self, Self::V2(_))
 88    }
 89}
 90
 91#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
 92#[serde(rename_all = "snake_case")]
 93pub enum PlanV1 {
 94    #[default]
 95    #[serde(alias = "Free")]
 96    ZedFree,
 97    #[serde(alias = "ZedPro")]
 98    ZedPro,
 99    #[serde(alias = "ZedProTrial")]
100    ZedProTrial,
101}
102
103impl FromStr for PlanV1 {
104    type Err = anyhow::Error;
105
106    fn from_str(value: &str) -> Result<Self, Self::Err> {
107        match value {
108            "zed_free" => Ok(Self::ZedFree),
109            "zed_pro" => Ok(Self::ZedPro),
110            "zed_pro_trial" => Ok(Self::ZedProTrial),
111            plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
112        }
113    }
114}
115
116#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
117#[serde(rename_all = "snake_case")]
118pub enum PlanV2 {
119    #[default]
120    ZedFree,
121    ZedPro,
122    ZedProTrial,
123}
124
125impl FromStr for PlanV2 {
126    type Err = anyhow::Error;
127
128    fn from_str(value: &str) -> Result<Self, Self::Err> {
129        match value {
130            "zed_free" => Ok(Self::ZedFree),
131            "zed_pro" => Ok(Self::ZedPro),
132            "zed_pro_trial" => Ok(Self::ZedProTrial),
133            plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
134        }
135    }
136}
137
138#[derive(
139    Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
140)]
141#[serde(rename_all = "snake_case")]
142#[strum(serialize_all = "snake_case")]
143pub enum LanguageModelProvider {
144    Anthropic,
145    OpenAi,
146    Google,
147    XAi,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct PredictEditsBody {
152    #[serde(skip_serializing_if = "Option::is_none", default)]
153    pub outline: Option<String>,
154    pub input_events: String,
155    pub input_excerpt: String,
156    #[serde(skip_serializing_if = "Option::is_none", default)]
157    pub speculated_output: Option<String>,
158    /// Whether the user provided consent for sampling this interaction.
159    #[serde(default, alias = "data_collection_permission")]
160    pub can_collect_data: bool,
161    #[serde(skip_serializing_if = "Option::is_none", default)]
162    pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
163    /// Info about the git repository state, only present when can_collect_data is true.
164    #[serde(skip_serializing_if = "Option::is_none", default)]
165    pub git_info: Option<PredictEditsGitInfo>,
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct PredictEditsGitInfo {
170    /// SHA of git HEAD commit at time of prediction.
171    #[serde(skip_serializing_if = "Option::is_none", default)]
172    pub head_sha: Option<String>,
173    /// URL of the remote called `origin`.
174    #[serde(skip_serializing_if = "Option::is_none", default)]
175    pub remote_origin_url: Option<String>,
176    /// URL of the remote called `upstream`.
177    #[serde(skip_serializing_if = "Option::is_none", default)]
178    pub remote_upstream_url: Option<String>,
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct PredictEditsResponse {
183    pub request_id: Uuid,
184    pub output_excerpt: String,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct AcceptEditPredictionBody {
189    pub request_id: Uuid,
190}
191
192#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
193#[serde(rename_all = "snake_case")]
194pub enum CompletionMode {
195    Normal,
196    Max,
197}
198
199#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
200#[serde(rename_all = "snake_case")]
201pub enum CompletionIntent {
202    UserPrompt,
203    ToolResults,
204    ThreadSummarization,
205    ThreadContextSummarization,
206    CreateFile,
207    EditFile,
208    InlineAssist,
209    TerminalInlineAssist,
210    GenerateGitCommitMessage,
211}
212
213#[derive(Debug, Serialize, Deserialize)]
214pub struct CompletionBody {
215    #[serde(skip_serializing_if = "Option::is_none", default)]
216    pub thread_id: Option<String>,
217    #[serde(skip_serializing_if = "Option::is_none", default)]
218    pub prompt_id: Option<String>,
219    #[serde(skip_serializing_if = "Option::is_none", default)]
220    pub intent: Option<CompletionIntent>,
221    #[serde(skip_serializing_if = "Option::is_none", default)]
222    pub mode: Option<CompletionMode>,
223    pub provider: LanguageModelProvider,
224    pub model: String,
225    pub provider_request: serde_json::Value,
226}
227
228#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
229#[serde(rename_all = "snake_case")]
230pub enum CompletionRequestStatus {
231    Queued {
232        position: usize,
233    },
234    Started,
235    Failed {
236        code: String,
237        message: String,
238        request_id: Uuid,
239        /// Retry duration in seconds.
240        retry_after: Option<f64>,
241    },
242    UsageUpdated {
243        amount: usize,
244        limit: UsageLimit,
245    },
246    ToolUseLimitReached,
247}
248
249#[derive(Serialize, Deserialize)]
250#[serde(rename_all = "snake_case")]
251pub enum CompletionEvent<T> {
252    Status(CompletionRequestStatus),
253    Event(T),
254}
255
256impl<T> CompletionEvent<T> {
257    pub fn into_status(self) -> Option<CompletionRequestStatus> {
258        match self {
259            Self::Status(status) => Some(status),
260            Self::Event(_) => None,
261        }
262    }
263
264    pub fn into_event(self) -> Option<T> {
265        match self {
266            Self::Event(event) => Some(event),
267            Self::Status(_) => None,
268        }
269    }
270}
271
272#[derive(Serialize, Deserialize)]
273pub struct WebSearchBody {
274    pub query: String,
275}
276
277#[derive(Debug, Serialize, Deserialize, Clone)]
278pub struct WebSearchResponse {
279    pub results: Vec<WebSearchResult>,
280}
281
282#[derive(Debug, Serialize, Deserialize, Clone)]
283pub struct WebSearchResult {
284    pub title: String,
285    pub url: String,
286    pub text: String,
287}
288
289#[derive(Serialize, Deserialize)]
290pub struct CountTokensBody {
291    pub provider: LanguageModelProvider,
292    pub model: String,
293    pub provider_request: serde_json::Value,
294}
295
296#[derive(Serialize, Deserialize)]
297pub struct CountTokensResponse {
298    pub tokens: usize,
299}
300
301#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
302pub struct LanguageModelId(pub Arc<str>);
303
304impl std::fmt::Display for LanguageModelId {
305    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306        write!(f, "{}", self.0)
307    }
308}
309
310#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
311pub struct LanguageModel {
312    pub provider: LanguageModelProvider,
313    pub id: LanguageModelId,
314    pub display_name: String,
315    pub max_token_count: usize,
316    pub max_token_count_in_max_mode: Option<usize>,
317    pub max_output_tokens: usize,
318    pub supports_tools: bool,
319    pub supports_images: bool,
320    pub supports_thinking: bool,
321    pub supports_max_mode: bool,
322}
323
324#[derive(Debug, Serialize, Deserialize)]
325pub struct ListModelsResponse {
326    pub models: Vec<LanguageModel>,
327    pub default_model: Option<LanguageModelId>,
328    pub default_fast_model: Option<LanguageModelId>,
329    pub recommended_models: Vec<LanguageModelId>,
330}
331
332#[derive(Debug, Serialize, Deserialize)]
333pub struct GetSubscriptionResponse {
334    pub plan: PlanV1,
335    pub usage: Option<CurrentUsage>,
336}
337
338#[derive(Debug, PartialEq, Serialize, Deserialize)]
339pub struct CurrentUsage {
340    pub model_requests: UsageData,
341    pub edit_predictions: UsageData,
342}
343
344#[derive(Debug, PartialEq, Serialize, Deserialize)]
345pub struct UsageData {
346    pub used: u32,
347    pub limit: UsageLimit,
348}
349
350#[cfg(test)]
351mod tests {
352    use pretty_assertions::assert_eq;
353    use serde_json::json;
354
355    use super::*;
356
357    #[test]
358    fn test_plan_v1_deserialize_snake_case() {
359        let plan = serde_json::from_value::<PlanV1>(json!("zed_free")).unwrap();
360        assert_eq!(plan, PlanV1::ZedFree);
361
362        let plan = serde_json::from_value::<PlanV1>(json!("zed_pro")).unwrap();
363        assert_eq!(plan, PlanV1::ZedPro);
364
365        let plan = serde_json::from_value::<PlanV1>(json!("zed_pro_trial")).unwrap();
366        assert_eq!(plan, PlanV1::ZedProTrial);
367    }
368
369    #[test]
370    fn test_plan_v1_deserialize_aliases() {
371        let plan = serde_json::from_value::<PlanV1>(json!("Free")).unwrap();
372        assert_eq!(plan, PlanV1::ZedFree);
373
374        let plan = serde_json::from_value::<PlanV1>(json!("ZedPro")).unwrap();
375        assert_eq!(plan, PlanV1::ZedPro);
376
377        let plan = serde_json::from_value::<PlanV1>(json!("ZedProTrial")).unwrap();
378        assert_eq!(plan, PlanV1::ZedProTrial);
379    }
380
381    #[test]
382    fn test_plan_v2_deserialize_snake_case() {
383        let plan = serde_json::from_value::<PlanV2>(json!("zed_free")).unwrap();
384        assert_eq!(plan, PlanV2::ZedFree);
385
386        let plan = serde_json::from_value::<PlanV2>(json!("zed_pro")).unwrap();
387        assert_eq!(plan, PlanV2::ZedPro);
388
389        let plan = serde_json::from_value::<PlanV2>(json!("zed_pro_trial")).unwrap();
390        assert_eq!(plan, PlanV2::ZedProTrial);
391    }
392
393    #[test]
394    fn test_usage_limit_from_str() {
395        let limit = UsageLimit::from_str("unlimited").unwrap();
396        assert!(matches!(limit, UsageLimit::Unlimited));
397
398        let limit = UsageLimit::from_str(&0.to_string()).unwrap();
399        assert!(matches!(limit, UsageLimit::Limited(0)));
400
401        let limit = UsageLimit::from_str(&50.to_string()).unwrap();
402        assert!(matches!(limit, UsageLimit::Limited(50)));
403
404        for value in ["not_a_number", "50xyz"] {
405            let limit = UsageLimit::from_str(value);
406            assert!(limit.is_err());
407        }
408    }
409}