cloud_llm_client.rs

  1use std::str::FromStr;
  2use std::sync::Arc;
  3
  4use anyhow::Context as _;
  5use serde::{Deserialize, Serialize};
  6use strum::{Display, EnumIter, EnumString};
  7use uuid::Uuid;
  8
  9/// The name of the header used to indicate which version of Zed the client is running.
 10pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
 11
 12/// The name of the header used to indicate when a request failed due to an
 13/// expired LLM token.
 14///
 15/// The client may use this as a signal to refresh the token.
 16pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
 17
 18/// The name of the header used to indicate what plan the user is currently on.
 19pub const CURRENT_PLAN_HEADER_NAME: &str = "x-zed-plan";
 20
 21/// The name of the header used to indicate the usage limit for model requests.
 22pub const MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-model-requests-usage-limit";
 23
 24/// The name of the header used to indicate the usage amount for model requests.
 25pub const MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-model-requests-usage-amount";
 26
 27/// The name of the header used to indicate the usage limit for edit predictions.
 28pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
 29
 30/// The name of the header used to indicate the usage amount for edit predictions.
 31pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
 32
 33/// The name of the header used to indicate the resource for which the subscription limit has been reached.
 34pub const SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME: &str = "x-zed-subscription-limit-resource";
 35
 36pub const MODEL_REQUESTS_RESOURCE_HEADER_VALUE: &str = "model_requests";
 37pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
 38
 39/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached.
 40pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached";
 41
 42/// The name of the header used to indicate the minimum required Zed version.
 43///
 44/// This can be used to force a Zed upgrade in order to continue communicating
 45/// with the LLM service.
 46pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
 47
 48/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
 49pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 50    "x-zed-client-supports-status-messages";
 51
 52/// The name of the header used by the server to indicate to the client that it supports sending status messages.
 53pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 54    "x-zed-server-supports-status-messages";
 55
 56#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 57#[serde(rename_all = "snake_case")]
 58pub enum UsageLimit {
 59    Limited(i32),
 60    Unlimited,
 61}
 62
 63impl FromStr for UsageLimit {
 64    type Err = anyhow::Error;
 65
 66    fn from_str(value: &str) -> Result<Self, Self::Err> {
 67        match value {
 68            "unlimited" => Ok(Self::Unlimited),
 69            limit => limit
 70                .parse::<i32>()
 71                .map(Self::Limited)
 72                .context("failed to parse limit"),
 73        }
 74    }
 75}
 76
 77#[derive(Debug, Clone, Copy, PartialEq)]
 78pub enum Plan {
 79    V1(PlanV1),
 80    V2(PlanV2),
 81}
 82
 83impl Plan {
 84    pub fn is_v2(&self) -> bool {
 85        matches!(self, Self::V2(_))
 86    }
 87}
 88
 89#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
 90#[serde(rename_all = "snake_case")]
 91pub enum PlanV1 {
 92    #[default]
 93    #[serde(alias = "Free")]
 94    ZedFree,
 95    #[serde(alias = "ZedPro")]
 96    ZedPro,
 97    #[serde(alias = "ZedProTrial")]
 98    ZedProTrial,
 99}
100
101impl FromStr for PlanV1 {
102    type Err = anyhow::Error;
103
104    fn from_str(value: &str) -> Result<Self, Self::Err> {
105        match value {
106            "zed_free" => Ok(Self::ZedFree),
107            "zed_pro" => Ok(Self::ZedPro),
108            "zed_pro_trial" => Ok(Self::ZedProTrial),
109            plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
110        }
111    }
112}
113
114#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
115#[serde(rename_all = "snake_case")]
116pub enum PlanV2 {
117    #[default]
118    ZedFree,
119    ZedPro,
120    ZedProTrial,
121}
122
123impl FromStr for PlanV2 {
124    type Err = anyhow::Error;
125
126    fn from_str(value: &str) -> Result<Self, Self::Err> {
127        match value {
128            "zed_free" => Ok(Self::ZedFree),
129            "zed_pro" => Ok(Self::ZedPro),
130            "zed_pro_trial" => Ok(Self::ZedProTrial),
131            plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
132        }
133    }
134}
135
136#[derive(
137    Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
138)]
139#[serde(rename_all = "snake_case")]
140#[strum(serialize_all = "snake_case")]
141pub enum LanguageModelProvider {
142    Anthropic,
143    OpenAi,
144    Google,
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct PredictEditsBody {
149    #[serde(skip_serializing_if = "Option::is_none", default)]
150    pub outline: Option<String>,
151    pub input_events: String,
152    pub input_excerpt: String,
153    #[serde(skip_serializing_if = "Option::is_none", default)]
154    pub speculated_output: Option<String>,
155    /// Whether the user provided consent for sampling this interaction.
156    #[serde(default, alias = "data_collection_permission")]
157    pub can_collect_data: bool,
158    #[serde(skip_serializing_if = "Option::is_none", default)]
159    pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
160    /// Info about the git repository state, only present when can_collect_data is true.
161    #[serde(skip_serializing_if = "Option::is_none", default)]
162    pub git_info: Option<PredictEditsGitInfo>,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct PredictEditsGitInfo {
167    /// SHA of git HEAD commit at time of prediction.
168    #[serde(skip_serializing_if = "Option::is_none", default)]
169    pub head_sha: Option<String>,
170    /// URL of the remote called `origin`.
171    #[serde(skip_serializing_if = "Option::is_none", default)]
172    pub remote_origin_url: Option<String>,
173    /// URL of the remote called `upstream`.
174    #[serde(skip_serializing_if = "Option::is_none", default)]
175    pub remote_upstream_url: Option<String>,
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct PredictEditsResponse {
180    pub request_id: Uuid,
181    pub output_excerpt: String,
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct AcceptEditPredictionBody {
186    pub request_id: Uuid,
187}
188
189#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
190#[serde(rename_all = "snake_case")]
191pub enum CompletionMode {
192    Normal,
193    Max,
194}
195
196#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
197#[serde(rename_all = "snake_case")]
198pub enum CompletionIntent {
199    UserPrompt,
200    ToolResults,
201    ThreadSummarization,
202    ThreadContextSummarization,
203    CreateFile,
204    EditFile,
205    InlineAssist,
206    TerminalInlineAssist,
207    GenerateGitCommitMessage,
208}
209
210#[derive(Debug, Serialize, Deserialize)]
211pub struct CompletionBody {
212    #[serde(skip_serializing_if = "Option::is_none", default)]
213    pub thread_id: Option<String>,
214    #[serde(skip_serializing_if = "Option::is_none", default)]
215    pub prompt_id: Option<String>,
216    #[serde(skip_serializing_if = "Option::is_none", default)]
217    pub intent: Option<CompletionIntent>,
218    #[serde(skip_serializing_if = "Option::is_none", default)]
219    pub mode: Option<CompletionMode>,
220    pub provider: LanguageModelProvider,
221    pub model: String,
222    pub provider_request: serde_json::Value,
223}
224
225#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
226#[serde(rename_all = "snake_case")]
227pub enum CompletionRequestStatus {
228    Queued {
229        position: usize,
230    },
231    Started,
232    Failed {
233        code: String,
234        message: String,
235        request_id: Uuid,
236        /// Retry duration in seconds.
237        retry_after: Option<f64>,
238    },
239    UsageUpdated {
240        amount: usize,
241        limit: UsageLimit,
242    },
243    ToolUseLimitReached,
244}
245
246#[derive(Serialize, Deserialize)]
247#[serde(rename_all = "snake_case")]
248pub enum CompletionEvent<T> {
249    Status(CompletionRequestStatus),
250    Event(T),
251}
252
253impl<T> CompletionEvent<T> {
254    pub fn into_status(self) -> Option<CompletionRequestStatus> {
255        match self {
256            Self::Status(status) => Some(status),
257            Self::Event(_) => None,
258        }
259    }
260
261    pub fn into_event(self) -> Option<T> {
262        match self {
263            Self::Event(event) => Some(event),
264            Self::Status(_) => None,
265        }
266    }
267}
268
269#[derive(Serialize, Deserialize)]
270pub struct WebSearchBody {
271    pub query: String,
272}
273
274#[derive(Debug, Serialize, Deserialize, Clone)]
275pub struct WebSearchResponse {
276    pub results: Vec<WebSearchResult>,
277}
278
279#[derive(Debug, Serialize, Deserialize, Clone)]
280pub struct WebSearchResult {
281    pub title: String,
282    pub url: String,
283    pub text: String,
284}
285
286#[derive(Serialize, Deserialize)]
287pub struct CountTokensBody {
288    pub provider: LanguageModelProvider,
289    pub model: String,
290    pub provider_request: serde_json::Value,
291}
292
293#[derive(Serialize, Deserialize)]
294pub struct CountTokensResponse {
295    pub tokens: usize,
296}
297
298#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
299pub struct LanguageModelId(pub Arc<str>);
300
301impl std::fmt::Display for LanguageModelId {
302    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303        write!(f, "{}", self.0)
304    }
305}
306
307#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
308pub struct LanguageModel {
309    pub provider: LanguageModelProvider,
310    pub id: LanguageModelId,
311    pub display_name: String,
312    pub max_token_count: usize,
313    pub max_token_count_in_max_mode: Option<usize>,
314    pub max_output_tokens: usize,
315    pub supports_tools: bool,
316    pub supports_images: bool,
317    pub supports_thinking: bool,
318    pub supports_max_mode: bool,
319}
320
321#[derive(Debug, Serialize, Deserialize)]
322pub struct ListModelsResponse {
323    pub models: Vec<LanguageModel>,
324    pub default_model: Option<LanguageModelId>,
325    pub default_fast_model: Option<LanguageModelId>,
326    pub recommended_models: Vec<LanguageModelId>,
327}
328
329#[derive(Debug, Serialize, Deserialize)]
330pub struct GetSubscriptionResponse {
331    pub plan: PlanV1,
332    pub usage: Option<CurrentUsage>,
333}
334
335#[derive(Debug, PartialEq, Serialize, Deserialize)]
336pub struct CurrentUsage {
337    pub model_requests: UsageData,
338    pub edit_predictions: UsageData,
339}
340
341#[derive(Debug, PartialEq, Serialize, Deserialize)]
342pub struct UsageData {
343    pub used: u32,
344    pub limit: UsageLimit,
345}
346
347#[cfg(test)]
348mod tests {
349    use pretty_assertions::assert_eq;
350    use serde_json::json;
351
352    use super::*;
353
354    #[test]
355    fn test_plan_v1_deserialize_snake_case() {
356        let plan = serde_json::from_value::<PlanV1>(json!("zed_free")).unwrap();
357        assert_eq!(plan, PlanV1::ZedFree);
358
359        let plan = serde_json::from_value::<PlanV1>(json!("zed_pro")).unwrap();
360        assert_eq!(plan, PlanV1::ZedPro);
361
362        let plan = serde_json::from_value::<PlanV1>(json!("zed_pro_trial")).unwrap();
363        assert_eq!(plan, PlanV1::ZedProTrial);
364    }
365
366    #[test]
367    fn test_plan_v1_deserialize_aliases() {
368        let plan = serde_json::from_value::<PlanV1>(json!("Free")).unwrap();
369        assert_eq!(plan, PlanV1::ZedFree);
370
371        let plan = serde_json::from_value::<PlanV1>(json!("ZedPro")).unwrap();
372        assert_eq!(plan, PlanV1::ZedPro);
373
374        let plan = serde_json::from_value::<PlanV1>(json!("ZedProTrial")).unwrap();
375        assert_eq!(plan, PlanV1::ZedProTrial);
376    }
377
378    #[test]
379    fn test_plan_v2_deserialize_snake_case() {
380        let plan = serde_json::from_value::<PlanV2>(json!("zed_free")).unwrap();
381        assert_eq!(plan, PlanV2::ZedFree);
382
383        let plan = serde_json::from_value::<PlanV2>(json!("zed_pro")).unwrap();
384        assert_eq!(plan, PlanV2::ZedPro);
385
386        let plan = serde_json::from_value::<PlanV2>(json!("zed_pro_trial")).unwrap();
387        assert_eq!(plan, PlanV2::ZedProTrial);
388    }
389
390    #[test]
391    fn test_usage_limit_from_str() {
392        let limit = UsageLimit::from_str("unlimited").unwrap();
393        assert!(matches!(limit, UsageLimit::Unlimited));
394
395        let limit = UsageLimit::from_str(&0.to_string()).unwrap();
396        assert!(matches!(limit, UsageLimit::Limited(0)));
397
398        let limit = UsageLimit::from_str(&50.to_string()).unwrap();
399        assert!(matches!(limit, UsageLimit::Limited(50)));
400
401        for value in ["not_a_number", "50xyz"] {
402            let limit = UsageLimit::from_str(value);
403            assert!(limit.is_err());
404        }
405    }
406}