cloud_llm_client.rs

  1use std::str::FromStr;
  2use std::sync::Arc;
  3
  4use anyhow::Context as _;
  5use serde::{Deserialize, Serialize};
  6use strum::{Display, EnumIter, EnumString};
  7use uuid::Uuid;
  8
  9/// The name of the header used to indicate which version of Zed the client is running.
 10pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
 11
 12/// The name of the header used to indicate when a request failed due to an
 13/// expired LLM token.
 14///
 15/// The client may use this as a signal to refresh the token.
 16pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
 17
 18/// The name of the header used to indicate what plan the user is currently on.
 19pub const CURRENT_PLAN_HEADER_NAME: &str = "x-zed-plan";
 20
 21/// The name of the header used to indicate the usage limit for model requests.
 22pub const MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-model-requests-usage-limit";
 23
 24/// The name of the header used to indicate the usage amount for model requests.
 25pub const MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-model-requests-usage-amount";
 26
 27/// The name of the header used to indicate the usage limit for edit predictions.
 28pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
 29
 30/// The name of the header used to indicate the usage amount for edit predictions.
 31pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
 32
 33/// The name of the header used to indicate the resource for which the subscription limit has been reached.
 34pub const SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME: &str = "x-zed-subscription-limit-resource";
 35
 36pub const MODEL_REQUESTS_RESOURCE_HEADER_VALUE: &str = "model_requests";
 37pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
 38
 39/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached.
 40pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached";
 41
 42/// The name of the header used to indicate the the minimum required Zed version.
 43///
 44/// This can be used to force a Zed upgrade in order to continue communicating
 45/// with the LLM service.
 46pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
 47
 48/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
 49pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 50    "x-zed-client-supports-status-messages";
 51
 52/// The name of the header used by the server to indicate to the client that it supports sending status messages.
 53pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 54    "x-zed-server-supports-status-messages";
 55
 56#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 57#[serde(rename_all = "snake_case")]
 58pub enum UsageLimit {
 59    Limited(i32),
 60    Unlimited,
 61}
 62
 63impl FromStr for UsageLimit {
 64    type Err = anyhow::Error;
 65
 66    fn from_str(value: &str) -> Result<Self, Self::Err> {
 67        match value {
 68            "unlimited" => Ok(Self::Unlimited),
 69            limit => limit
 70                .parse::<i32>()
 71                .map(Self::Limited)
 72                .context("failed to parse limit"),
 73        }
 74    }
 75}
 76
 77#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
 78#[serde(rename_all = "snake_case")]
 79pub enum Plan {
 80    #[default]
 81    #[serde(alias = "Free")]
 82    ZedFree,
 83    ZedFreeV2,
 84    #[serde(alias = "ZedPro")]
 85    ZedPro,
 86    ZedProV2,
 87    #[serde(alias = "ZedProTrial")]
 88    ZedProTrial,
 89    ZedProTrialV2,
 90}
 91
 92impl Plan {
 93    pub fn is_v2(&self) -> bool {
 94        matches!(self, Plan::ZedFreeV2 | Plan::ZedProV2 | Plan::ZedProTrialV2)
 95    }
 96}
 97
 98impl FromStr for Plan {
 99    type Err = anyhow::Error;
100
101    fn from_str(value: &str) -> Result<Self, Self::Err> {
102        match value {
103            "zed_free" => Ok(Plan::ZedFree),
104            "zed_free_v2" => Ok(Plan::ZedFreeV2),
105            "zed_pro" => Ok(Plan::ZedPro),
106            "zed_pro_v2" => Ok(Plan::ZedProV2),
107            "zed_pro_trial" => Ok(Plan::ZedProTrial),
108            "zed_pro_trial_v2" => Ok(Plan::ZedProTrialV2),
109            plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
110        }
111    }
112}
113
114#[derive(
115    Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
116)]
117#[serde(rename_all = "snake_case")]
118#[strum(serialize_all = "snake_case")]
119pub enum LanguageModelProvider {
120    Anthropic,
121    OpenAi,
122    Google,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct PredictEditsBody {
127    #[serde(skip_serializing_if = "Option::is_none", default)]
128    pub outline: Option<String>,
129    pub input_events: String,
130    pub input_excerpt: String,
131    #[serde(skip_serializing_if = "Option::is_none", default)]
132    pub speculated_output: Option<String>,
133    /// Whether the user provided consent for sampling this interaction.
134    #[serde(default, alias = "data_collection_permission")]
135    pub can_collect_data: bool,
136    #[serde(skip_serializing_if = "Option::is_none", default)]
137    pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
138    /// Info about the git repository state, only present when can_collect_data is true.
139    #[serde(skip_serializing_if = "Option::is_none", default)]
140    pub git_info: Option<PredictEditsGitInfo>,
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct PredictEditsGitInfo {
145    /// SHA of git HEAD commit at time of prediction.
146    #[serde(skip_serializing_if = "Option::is_none", default)]
147    pub head_sha: Option<String>,
148    /// URL of the remote called `origin`.
149    #[serde(skip_serializing_if = "Option::is_none", default)]
150    pub remote_origin_url: Option<String>,
151    /// URL of the remote called `upstream`.
152    #[serde(skip_serializing_if = "Option::is_none", default)]
153    pub remote_upstream_url: Option<String>,
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct PredictEditsResponse {
158    pub request_id: Uuid,
159    pub output_excerpt: String,
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct AcceptEditPredictionBody {
164    pub request_id: Uuid,
165}
166
167#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
168#[serde(rename_all = "snake_case")]
169pub enum CompletionMode {
170    Normal,
171    Max,
172}
173
174#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
175#[serde(rename_all = "snake_case")]
176pub enum CompletionIntent {
177    UserPrompt,
178    ToolResults,
179    ThreadSummarization,
180    ThreadContextSummarization,
181    CreateFile,
182    EditFile,
183    InlineAssist,
184    TerminalInlineAssist,
185    GenerateGitCommitMessage,
186}
187
188#[derive(Debug, Serialize, Deserialize)]
189pub struct CompletionBody {
190    #[serde(skip_serializing_if = "Option::is_none", default)]
191    pub thread_id: Option<String>,
192    #[serde(skip_serializing_if = "Option::is_none", default)]
193    pub prompt_id: Option<String>,
194    #[serde(skip_serializing_if = "Option::is_none", default)]
195    pub intent: Option<CompletionIntent>,
196    #[serde(skip_serializing_if = "Option::is_none", default)]
197    pub mode: Option<CompletionMode>,
198    pub provider: LanguageModelProvider,
199    pub model: String,
200    pub provider_request: serde_json::Value,
201}
202
203#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
204#[serde(rename_all = "snake_case")]
205pub enum CompletionRequestStatus {
206    Queued {
207        position: usize,
208    },
209    Started,
210    Failed {
211        code: String,
212        message: String,
213        request_id: Uuid,
214        /// Retry duration in seconds.
215        retry_after: Option<f64>,
216    },
217    UsageUpdated {
218        amount: usize,
219        limit: UsageLimit,
220    },
221    ToolUseLimitReached,
222}
223
224#[derive(Serialize, Deserialize)]
225#[serde(rename_all = "snake_case")]
226pub enum CompletionEvent<T> {
227    Status(CompletionRequestStatus),
228    Event(T),
229}
230
231impl<T> CompletionEvent<T> {
232    pub fn into_status(self) -> Option<CompletionRequestStatus> {
233        match self {
234            Self::Status(status) => Some(status),
235            Self::Event(_) => None,
236        }
237    }
238
239    pub fn into_event(self) -> Option<T> {
240        match self {
241            Self::Event(event) => Some(event),
242            Self::Status(_) => None,
243        }
244    }
245}
246
247#[derive(Serialize, Deserialize)]
248pub struct WebSearchBody {
249    pub query: String,
250}
251
252#[derive(Debug, Serialize, Deserialize, Clone)]
253pub struct WebSearchResponse {
254    pub results: Vec<WebSearchResult>,
255}
256
257#[derive(Debug, Serialize, Deserialize, Clone)]
258pub struct WebSearchResult {
259    pub title: String,
260    pub url: String,
261    pub text: String,
262}
263
264#[derive(Serialize, Deserialize)]
265pub struct CountTokensBody {
266    pub provider: LanguageModelProvider,
267    pub model: String,
268    pub provider_request: serde_json::Value,
269}
270
271#[derive(Serialize, Deserialize)]
272pub struct CountTokensResponse {
273    pub tokens: usize,
274}
275
276#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
277pub struct LanguageModelId(pub Arc<str>);
278
279impl std::fmt::Display for LanguageModelId {
280    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281        write!(f, "{}", self.0)
282    }
283}
284
285#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
286pub struct LanguageModel {
287    pub provider: LanguageModelProvider,
288    pub id: LanguageModelId,
289    pub display_name: String,
290    pub max_token_count: usize,
291    pub max_token_count_in_max_mode: Option<usize>,
292    pub max_output_tokens: usize,
293    pub supports_tools: bool,
294    pub supports_images: bool,
295    pub supports_thinking: bool,
296    pub supports_max_mode: bool,
297}
298
299#[derive(Debug, Serialize, Deserialize)]
300pub struct ListModelsResponse {
301    pub models: Vec<LanguageModel>,
302    pub default_model: LanguageModelId,
303    pub default_fast_model: LanguageModelId,
304    pub recommended_models: Vec<LanguageModelId>,
305}
306
307#[derive(Debug, Serialize, Deserialize)]
308pub struct GetSubscriptionResponse {
309    pub plan: Plan,
310    pub usage: Option<CurrentUsage>,
311}
312
313#[derive(Debug, PartialEq, Serialize, Deserialize)]
314pub struct CurrentUsage {
315    pub model_requests: UsageData,
316    pub edit_predictions: UsageData,
317}
318
319#[derive(Debug, PartialEq, Serialize, Deserialize)]
320pub struct UsageData {
321    pub used: u32,
322    pub limit: UsageLimit,
323}
324
325#[cfg(test)]
326mod tests {
327    use pretty_assertions::assert_eq;
328    use serde_json::json;
329
330    use super::*;
331
332    #[test]
333    fn test_plan_deserialize_snake_case() {
334        let plan = serde_json::from_value::<Plan>(json!("zed_free")).unwrap();
335        assert_eq!(plan, Plan::ZedFree);
336
337        let plan = serde_json::from_value::<Plan>(json!("zed_pro")).unwrap();
338        assert_eq!(plan, Plan::ZedPro);
339
340        let plan = serde_json::from_value::<Plan>(json!("zed_pro_trial")).unwrap();
341        assert_eq!(plan, Plan::ZedProTrial);
342
343        let plan = serde_json::from_value::<Plan>(json!("zed_pro_v2")).unwrap();
344        assert_eq!(plan, Plan::ZedProV2);
345
346        let plan = serde_json::from_value::<Plan>(json!("zed_pro_trial_v2")).unwrap();
347        assert_eq!(plan, Plan::ZedProTrialV2);
348    }
349
350    #[test]
351    fn test_plan_deserialize_aliases() {
352        let plan = serde_json::from_value::<Plan>(json!("Free")).unwrap();
353        assert_eq!(plan, Plan::ZedFree);
354
355        let plan = serde_json::from_value::<Plan>(json!("ZedPro")).unwrap();
356        assert_eq!(plan, Plan::ZedPro);
357
358        let plan = serde_json::from_value::<Plan>(json!("ZedProTrial")).unwrap();
359        assert_eq!(plan, Plan::ZedProTrial);
360    }
361
362    #[test]
363    fn test_usage_limit_from_str() {
364        let limit = UsageLimit::from_str("unlimited").unwrap();
365        assert!(matches!(limit, UsageLimit::Unlimited));
366
367        let limit = UsageLimit::from_str(&0.to_string()).unwrap();
368        assert!(matches!(limit, UsageLimit::Limited(0)));
369
370        let limit = UsageLimit::from_str(&50.to_string()).unwrap();
371        assert!(matches!(limit, UsageLimit::Limited(50)));
372
373        for value in ["not_a_number", "50xyz"] {
374            let limit = UsageLimit::from_str(value);
375            assert!(limit.is_err());
376        }
377    }
378}