1use anthropic::{AnthropicModelMode, parse_prompt_too_long};
2use anyhow::{Result, anyhow};
3use client::{Client, UserStore, zed_urls};
4use collections::BTreeMap;
5use feature_flags::{FeatureFlagAppExt, LlmClosedBetaFeatureFlag, ZedProFeatureFlag};
6use futures::{
7 AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
8};
9use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
10use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
11use language_model::{
12 AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
13 LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
14 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
15 LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat,
16 ModelRequestLimitReachedError, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID,
17};
18use language_model::{
19 LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
20 MaxMonthlySpendReachedError, PaymentRequiredError, RefreshLlmTokenListener,
21};
22use proto::Plan;
23use schemars::JsonSchema;
24use serde::{Deserialize, Serialize, de::DeserializeOwned};
25use settings::{Settings, SettingsStore};
26use smol::Timer;
27use smol::io::{AsyncReadExt, BufReader};
28use std::pin::Pin;
29use std::str::FromStr as _;
30use std::{
31 sync::{Arc, LazyLock},
32 time::Duration,
33};
34use strum::IntoEnumIterator;
35use thiserror::Error;
36use ui::{TintColor, prelude::*};
37use zed_llm_client::{
38 CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
39 CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
40 MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
41 SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
42 TOOL_USE_LIMIT_REACHED_HEADER_NAME,
43};
44
45use crate::AllLanguageModelSettings;
46use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
47use crate::provider::google::{GoogleEventMapper, into_google};
48use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
49
50pub const PROVIDER_NAME: &str = "Zed";
51
52const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
53 option_env!("ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON");
54
55fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
56 static ADDITIONAL_MODELS: LazyLock<Vec<AvailableModel>> = LazyLock::new(|| {
57 ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON
58 .map(|json| serde_json::from_str(json).unwrap())
59 .unwrap_or_default()
60 });
61 ADDITIONAL_MODELS.as_slice()
62}
63
64#[derive(Default, Clone, Debug, PartialEq)]
65pub struct ZedDotDevSettings {
66 pub available_models: Vec<AvailableModel>,
67}
68
69#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
70#[serde(rename_all = "lowercase")]
71pub enum AvailableProvider {
72 Anthropic,
73 OpenAi,
74 Google,
75}
76
77#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
78pub struct AvailableModel {
79 /// The provider of the language model.
80 pub provider: AvailableProvider,
81 /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
82 pub name: String,
83 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
84 pub display_name: Option<String>,
85 /// The size of the context window, indicating the maximum number of tokens the model can process.
86 pub max_tokens: usize,
87 /// The maximum number of output tokens allowed by the model.
88 pub max_output_tokens: Option<u32>,
89 /// The maximum number of completion tokens allowed by the model (o1-* only)
90 pub max_completion_tokens: Option<u32>,
91 /// Override this model with a different Anthropic model for tool calls.
92 pub tool_override: Option<String>,
93 /// Indicates whether this custom model supports caching.
94 pub cache_configuration: Option<LanguageModelCacheConfiguration>,
95 /// The default temperature to use for this model.
96 pub default_temperature: Option<f32>,
97 /// Any extra beta headers to provide when using the model.
98 #[serde(default)]
99 pub extra_beta_headers: Vec<String>,
100 /// The model's mode (e.g. thinking)
101 pub mode: Option<ModelMode>,
102}
103
104#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
105#[serde(tag = "type", rename_all = "lowercase")]
106pub enum ModelMode {
107 #[default]
108 Default,
109 Thinking {
110 /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
111 budget_tokens: Option<u32>,
112 },
113}
114
115impl From<ModelMode> for AnthropicModelMode {
116 fn from(value: ModelMode) -> Self {
117 match value {
118 ModelMode::Default => AnthropicModelMode::Default,
119 ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
120 }
121 }
122}
123
124pub struct CloudLanguageModelProvider {
125 client: Arc<Client>,
126 state: gpui::Entity<State>,
127 _maintain_client_status: Task<()>,
128}
129
130pub struct State {
131 client: Arc<Client>,
132 llm_api_token: LlmApiToken,
133 user_store: Entity<UserStore>,
134 status: client::Status,
135 accept_terms: Option<Task<Result<()>>>,
136 _settings_subscription: Subscription,
137 _llm_token_subscription: Subscription,
138}
139
140impl State {
141 fn new(
142 client: Arc<Client>,
143 user_store: Entity<UserStore>,
144 status: client::Status,
145 cx: &mut Context<Self>,
146 ) -> Self {
147 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
148
149 Self {
150 client: client.clone(),
151 llm_api_token: LlmApiToken::default(),
152 user_store,
153 status,
154 accept_terms: None,
155 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
156 cx.notify();
157 }),
158 _llm_token_subscription: cx.subscribe(
159 &refresh_llm_token_listener,
160 |this, _listener, _event, cx| {
161 let client = this.client.clone();
162 let llm_api_token = this.llm_api_token.clone();
163 cx.spawn(async move |_this, _cx| {
164 llm_api_token.refresh(&client).await?;
165 anyhow::Ok(())
166 })
167 .detach_and_log_err(cx);
168 },
169 ),
170 }
171 }
172
173 fn is_signed_out(&self) -> bool {
174 self.status.is_signed_out()
175 }
176
177 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
178 let client = self.client.clone();
179 cx.spawn(async move |this, cx| {
180 client.authenticate_and_connect(true, &cx).await?;
181 this.update(cx, |_, cx| cx.notify())
182 })
183 }
184
185 fn has_accepted_terms_of_service(&self, cx: &App) -> bool {
186 self.user_store
187 .read(cx)
188 .current_user_has_accepted_terms()
189 .unwrap_or(false)
190 }
191
192 fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
193 let user_store = self.user_store.clone();
194 self.accept_terms = Some(cx.spawn(async move |this, cx| {
195 let _ = user_store
196 .update(cx, |store, cx| store.accept_terms_of_service(cx))?
197 .await;
198 this.update(cx, |this, cx| {
199 this.accept_terms = None;
200 cx.notify()
201 })
202 }));
203 }
204}
205
206impl CloudLanguageModelProvider {
207 pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
208 let mut status_rx = client.status();
209 let status = *status_rx.borrow();
210
211 let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
212
213 let state_ref = state.downgrade();
214 let maintain_client_status = cx.spawn(async move |cx| {
215 while let Some(status) = status_rx.next().await {
216 if let Some(this) = state_ref.upgrade() {
217 _ = this.update(cx, |this, cx| {
218 if this.status != status {
219 this.status = status;
220 cx.notify();
221 }
222 });
223 } else {
224 break;
225 }
226 }
227 });
228
229 Self {
230 client,
231 state: state.clone(),
232 _maintain_client_status: maintain_client_status,
233 }
234 }
235
236 fn create_language_model(
237 &self,
238 model: CloudModel,
239 llm_api_token: LlmApiToken,
240 ) -> Arc<dyn LanguageModel> {
241 Arc::new(CloudLanguageModel {
242 id: LanguageModelId::from(model.id().to_string()),
243 model,
244 llm_api_token: llm_api_token.clone(),
245 client: self.client.clone(),
246 request_limiter: RateLimiter::new(4),
247 })
248 }
249}
250
251impl LanguageModelProviderState for CloudLanguageModelProvider {
252 type ObservableEntity = State;
253
254 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
255 Some(self.state.clone())
256 }
257}
258
259impl LanguageModelProvider for CloudLanguageModelProvider {
260 fn id(&self) -> LanguageModelProviderId {
261 LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
262 }
263
264 fn name(&self) -> LanguageModelProviderName {
265 LanguageModelProviderName(PROVIDER_NAME.into())
266 }
267
268 fn icon(&self) -> IconName {
269 IconName::AiZed
270 }
271
272 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
273 let llm_api_token = self.state.read(cx).llm_api_token.clone();
274 let model = CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet);
275 Some(self.create_language_model(model, llm_api_token))
276 }
277
278 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
279 let llm_api_token = self.state.read(cx).llm_api_token.clone();
280 let model = CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet);
281 Some(self.create_language_model(model, llm_api_token))
282 }
283
284 fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
285 let llm_api_token = self.state.read(cx).llm_api_token.clone();
286 [
287 CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
288 CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
289 ]
290 .into_iter()
291 .map(|model| self.create_language_model(model, llm_api_token.clone()))
292 .collect()
293 }
294
295 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
296 let mut models = BTreeMap::default();
297
298 if cx.is_staff() {
299 for model in anthropic::Model::iter() {
300 if !matches!(model, anthropic::Model::Custom { .. }) {
301 models.insert(model.id().to_string(), CloudModel::Anthropic(model));
302 }
303 }
304 for model in open_ai::Model::iter() {
305 if !matches!(model, open_ai::Model::Custom { .. }) {
306 models.insert(model.id().to_string(), CloudModel::OpenAi(model));
307 }
308 }
309 for model in google_ai::Model::iter() {
310 if !matches!(model, google_ai::Model::Custom { .. }) {
311 models.insert(model.id().to_string(), CloudModel::Google(model));
312 }
313 }
314 } else {
315 models.insert(
316 anthropic::Model::Claude3_5Sonnet.id().to_string(),
317 CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
318 );
319 models.insert(
320 anthropic::Model::Claude3_7Sonnet.id().to_string(),
321 CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
322 );
323 models.insert(
324 anthropic::Model::Claude3_7SonnetThinking.id().to_string(),
325 CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
326 );
327 }
328
329 let llm_closed_beta_models = if cx.has_flag::<LlmClosedBetaFeatureFlag>() {
330 zed_cloud_provider_additional_models()
331 } else {
332 &[]
333 };
334
335 // Override with available models from settings
336 for model in AllLanguageModelSettings::get_global(cx)
337 .zed_dot_dev
338 .available_models
339 .iter()
340 .chain(llm_closed_beta_models)
341 .cloned()
342 {
343 let model = match model.provider {
344 AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
345 name: model.name.clone(),
346 display_name: model.display_name.clone(),
347 max_tokens: model.max_tokens,
348 tool_override: model.tool_override.clone(),
349 cache_configuration: model.cache_configuration.as_ref().map(|config| {
350 anthropic::AnthropicModelCacheConfiguration {
351 max_cache_anchors: config.max_cache_anchors,
352 should_speculate: config.should_speculate,
353 min_total_token: config.min_total_token,
354 }
355 }),
356 default_temperature: model.default_temperature,
357 max_output_tokens: model.max_output_tokens,
358 extra_beta_headers: model.extra_beta_headers.clone(),
359 mode: model.mode.unwrap_or_default().into(),
360 }),
361 AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
362 name: model.name.clone(),
363 display_name: model.display_name.clone(),
364 max_tokens: model.max_tokens,
365 max_output_tokens: model.max_output_tokens,
366 max_completion_tokens: model.max_completion_tokens,
367 }),
368 AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
369 name: model.name.clone(),
370 display_name: model.display_name.clone(),
371 max_tokens: model.max_tokens,
372 }),
373 };
374 models.insert(model.id().to_string(), model.clone());
375 }
376
377 let llm_api_token = self.state.read(cx).llm_api_token.clone();
378 models
379 .into_values()
380 .map(|model| self.create_language_model(model, llm_api_token.clone()))
381 .collect()
382 }
383
384 fn is_authenticated(&self, cx: &App) -> bool {
385 !self.state.read(cx).is_signed_out()
386 }
387
388 fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
389 Task::ready(Ok(()))
390 }
391
392 fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
393 cx.new(|_| ConfigurationView {
394 state: self.state.clone(),
395 })
396 .into()
397 }
398
399 fn must_accept_terms(&self, cx: &App) -> bool {
400 !self.state.read(cx).has_accepted_terms_of_service(cx)
401 }
402
403 fn render_accept_terms(
404 &self,
405 view: LanguageModelProviderTosView,
406 cx: &mut App,
407 ) -> Option<AnyElement> {
408 render_accept_terms(self.state.clone(), view, cx)
409 }
410
411 fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
412 Task::ready(Ok(()))
413 }
414}
415
416fn render_accept_terms(
417 state: Entity<State>,
418 view_kind: LanguageModelProviderTosView,
419 cx: &mut App,
420) -> Option<AnyElement> {
421 if state.read(cx).has_accepted_terms_of_service(cx) {
422 return None;
423 }
424
425 let accept_terms_disabled = state.read(cx).accept_terms.is_some();
426
427 let thread_fresh_start = matches!(view_kind, LanguageModelProviderTosView::ThreadFreshStart);
428 let thread_empty_state = matches!(view_kind, LanguageModelProviderTosView::ThreadtEmptyState);
429
430 let terms_button = Button::new("terms_of_service", "Terms of Service")
431 .style(ButtonStyle::Subtle)
432 .icon(IconName::ArrowUpRight)
433 .icon_color(Color::Muted)
434 .icon_size(IconSize::XSmall)
435 .when(thread_empty_state, |this| this.label_size(LabelSize::Small))
436 .on_click(move |_, _window, cx| cx.open_url("https://zed.dev/terms-of-service"));
437
438 let button_container = h_flex().child(
439 Button::new("accept_terms", "I accept the Terms of Service")
440 .when(!thread_empty_state, |this| {
441 this.full_width()
442 .style(ButtonStyle::Tinted(TintColor::Accent))
443 .icon(IconName::Check)
444 .icon_position(IconPosition::Start)
445 .icon_size(IconSize::Small)
446 })
447 .when(thread_empty_state, |this| {
448 this.style(ButtonStyle::Tinted(TintColor::Warning))
449 .label_size(LabelSize::Small)
450 })
451 .disabled(accept_terms_disabled)
452 .on_click({
453 let state = state.downgrade();
454 move |_, _window, cx| {
455 state
456 .update(cx, |state, cx| state.accept_terms_of_service(cx))
457 .ok();
458 }
459 }),
460 );
461
462 let form = if thread_empty_state {
463 h_flex()
464 .w_full()
465 .flex_wrap()
466 .justify_between()
467 .child(
468 h_flex()
469 .child(
470 Label::new("To start using Zed AI, please read and accept the")
471 .size(LabelSize::Small),
472 )
473 .child(terms_button),
474 )
475 .child(button_container)
476 } else {
477 v_flex()
478 .w_full()
479 .gap_2()
480 .child(
481 h_flex()
482 .flex_wrap()
483 .when(thread_fresh_start, |this| this.justify_center())
484 .child(Label::new(
485 "To start using Zed AI, please read and accept the",
486 ))
487 .child(terms_button),
488 )
489 .child({
490 match view_kind {
491 LanguageModelProviderTosView::PromptEditorPopup => {
492 button_container.w_full().justify_end()
493 }
494 LanguageModelProviderTosView::Configuration => {
495 button_container.w_full().justify_start()
496 }
497 LanguageModelProviderTosView::ThreadFreshStart => {
498 button_container.w_full().justify_center()
499 }
500 LanguageModelProviderTosView::ThreadtEmptyState => div().w_0(),
501 }
502 })
503 };
504
505 Some(form.into_any())
506}
507
508pub struct CloudLanguageModel {
509 id: LanguageModelId,
510 model: CloudModel,
511 llm_api_token: LlmApiToken,
512 client: Arc<Client>,
513 request_limiter: RateLimiter,
514}
515
516struct PerformLlmCompletionResponse {
517 response: Response<AsyncBody>,
518 usage: Option<RequestUsage>,
519 tool_use_limit_reached: bool,
520 includes_status_messages: bool,
521}
522
523impl CloudLanguageModel {
524 const MAX_RETRIES: usize = 3;
525
526 async fn perform_llm_completion(
527 client: Arc<Client>,
528 llm_api_token: LlmApiToken,
529 body: CompletionBody,
530 ) -> Result<PerformLlmCompletionResponse> {
531 let http_client = &client.http_client();
532
533 let mut token = llm_api_token.acquire(&client).await?;
534 let mut retries_remaining = Self::MAX_RETRIES;
535 let mut retry_delay = Duration::from_secs(1);
536
537 loop {
538 let request_builder = http_client::Request::builder().method(Method::POST);
539 let request_builder = if let Ok(completions_url) = std::env::var("ZED_COMPLETIONS_URL")
540 {
541 request_builder.uri(completions_url)
542 } else {
543 request_builder.uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
544 };
545 let request = request_builder
546 .header("Content-Type", "application/json")
547 .header("Authorization", format!("Bearer {token}"))
548 .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
549 .body(serde_json::to_string(&body)?.into())?;
550 let mut response = http_client.send(request).await?;
551 let status = response.status();
552 if status.is_success() {
553 let includes_status_messages = response
554 .headers()
555 .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
556 .is_some();
557
558 let tool_use_limit_reached = response
559 .headers()
560 .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
561 .is_some();
562
563 let usage = if includes_status_messages {
564 None
565 } else {
566 RequestUsage::from_headers(response.headers()).ok()
567 };
568
569 return Ok(PerformLlmCompletionResponse {
570 response,
571 usage,
572 includes_status_messages,
573 tool_use_limit_reached,
574 });
575 } else if response
576 .headers()
577 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
578 .is_some()
579 {
580 retries_remaining -= 1;
581 token = llm_api_token.refresh(&client).await?;
582 } else if status == StatusCode::FORBIDDEN
583 && response
584 .headers()
585 .get(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME)
586 .is_some()
587 {
588 return Err(anyhow!(MaxMonthlySpendReachedError));
589 } else if status == StatusCode::FORBIDDEN
590 && response
591 .headers()
592 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
593 .is_some()
594 {
595 if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
596 .headers()
597 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
598 .and_then(|resource| resource.to_str().ok())
599 {
600 if let Some(plan) = response
601 .headers()
602 .get(CURRENT_PLAN_HEADER_NAME)
603 .and_then(|plan| plan.to_str().ok())
604 .and_then(|plan| zed_llm_client::Plan::from_str(plan).ok())
605 {
606 let plan = match plan {
607 zed_llm_client::Plan::Free => Plan::Free,
608 zed_llm_client::Plan::ZedPro => Plan::ZedPro,
609 zed_llm_client::Plan::ZedProTrial => Plan::ZedProTrial,
610 };
611 return Err(anyhow!(ModelRequestLimitReachedError { plan }));
612 }
613 }
614
615 return Err(anyhow!("Forbidden"));
616 } else if status.as_u16() >= 500 && status.as_u16() < 600 {
617 // If we encounter an error in the 500 range, retry after a delay.
618 // We've seen at least these in the wild from API providers:
619 // * 500 Internal Server Error
620 // * 502 Bad Gateway
621 // * 529 Service Overloaded
622
623 if retries_remaining == 0 {
624 let mut body = String::new();
625 response.body_mut().read_to_string(&mut body).await?;
626 return Err(anyhow!(
627 "cloud language model completion failed after {} retries with status {status}: {body}",
628 Self::MAX_RETRIES
629 ));
630 }
631
632 Timer::after(retry_delay).await;
633
634 retries_remaining -= 1;
635 retry_delay *= 2; // If it fails again, wait longer.
636 } else if status == StatusCode::PAYMENT_REQUIRED {
637 return Err(anyhow!(PaymentRequiredError));
638 } else {
639 let mut body = String::new();
640 response.body_mut().read_to_string(&mut body).await?;
641 return Err(anyhow!(ApiError { status, body }));
642 }
643 }
644 }
645}
646
647#[derive(Debug, Error)]
648#[error("cloud language model request failed with status {status}: {body}")]
649struct ApiError {
650 status: StatusCode,
651 body: String,
652}
653
654impl LanguageModel for CloudLanguageModel {
655 fn id(&self) -> LanguageModelId {
656 self.id.clone()
657 }
658
659 fn name(&self) -> LanguageModelName {
660 LanguageModelName::from(self.model.display_name().to_string())
661 }
662
663 fn provider_id(&self) -> LanguageModelProviderId {
664 LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
665 }
666
667 fn provider_name(&self) -> LanguageModelProviderName {
668 LanguageModelProviderName(PROVIDER_NAME.into())
669 }
670
671 fn supports_tools(&self) -> bool {
672 match self.model {
673 CloudModel::Anthropic(_) => true,
674 CloudModel::Google(_) => true,
675 CloudModel::OpenAi(_) => true,
676 }
677 }
678
679 fn telemetry_id(&self) -> String {
680 format!("zed.dev/{}", self.model.id())
681 }
682
683 fn availability(&self) -> LanguageModelAvailability {
684 self.model.availability()
685 }
686
687 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
688 self.model.tool_input_format()
689 }
690
691 fn max_token_count(&self) -> usize {
692 self.model.max_token_count()
693 }
694
695 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
696 match &self.model {
697 CloudModel::Anthropic(model) => {
698 model
699 .cache_configuration()
700 .map(|cache| LanguageModelCacheConfiguration {
701 max_cache_anchors: cache.max_cache_anchors,
702 should_speculate: cache.should_speculate,
703 min_total_token: cache.min_total_token,
704 })
705 }
706 CloudModel::OpenAi(_) | CloudModel::Google(_) => None,
707 }
708 }
709
710 fn count_tokens(
711 &self,
712 request: LanguageModelRequest,
713 cx: &App,
714 ) -> BoxFuture<'static, Result<usize>> {
715 match self.model.clone() {
716 CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
717 CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
718 CloudModel::Google(model) => {
719 let client = self.client.clone();
720 let llm_api_token = self.llm_api_token.clone();
721 let request = into_google(request, model.id().into());
722 async move {
723 let http_client = &client.http_client();
724 let token = llm_api_token.acquire(&client).await?;
725
726 let request_builder = http_client::Request::builder().method(Method::POST);
727 let request_builder =
728 if let Ok(completions_url) = std::env::var("ZED_COUNT_TOKENS_URL") {
729 request_builder.uri(completions_url)
730 } else {
731 request_builder.uri(
732 http_client
733 .build_zed_llm_url("/count_tokens", &[])?
734 .as_ref(),
735 )
736 };
737 let request_body = CountTokensBody {
738 provider: zed_llm_client::LanguageModelProvider::Google,
739 model: model.id().into(),
740 provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
741 contents: request.contents,
742 })?,
743 };
744 let request = request_builder
745 .header("Content-Type", "application/json")
746 .header("Authorization", format!("Bearer {token}"))
747 .body(serde_json::to_string(&request_body)?.into())?;
748 let mut response = http_client.send(request).await?;
749 let status = response.status();
750 let mut response_body = String::new();
751 response
752 .body_mut()
753 .read_to_string(&mut response_body)
754 .await?;
755
756 if status.is_success() {
757 let response_body: CountTokensResponse =
758 serde_json::from_str(&response_body)?;
759
760 Ok(response_body.tokens)
761 } else {
762 Err(anyhow!(ApiError {
763 status,
764 body: response_body
765 }))
766 }
767 }
768 .boxed()
769 }
770 }
771 }
772
773 fn stream_completion(
774 &self,
775 request: LanguageModelRequest,
776 _cx: &AsyncApp,
777 ) -> BoxFuture<
778 'static,
779 Result<
780 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
781 >,
782 > {
783 let thread_id = request.thread_id.clone();
784 let prompt_id = request.prompt_id.clone();
785 let mode = request.mode;
786 match &self.model {
787 CloudModel::Anthropic(model) => {
788 let request = into_anthropic(
789 request,
790 model.request_id().into(),
791 model.default_temperature(),
792 model.max_output_tokens(),
793 model.mode(),
794 );
795 let client = self.client.clone();
796 let llm_api_token = self.llm_api_token.clone();
797 let future = self.request_limiter.stream(async move {
798 let PerformLlmCompletionResponse {
799 response,
800 usage,
801 includes_status_messages,
802 tool_use_limit_reached,
803 } = Self::perform_llm_completion(
804 client.clone(),
805 llm_api_token,
806 CompletionBody {
807 thread_id,
808 prompt_id,
809 mode,
810 provider: zed_llm_client::LanguageModelProvider::Anthropic,
811 model: request.model.clone(),
812 provider_request: serde_json::to_value(&request)?,
813 },
814 )
815 .await
816 .map_err(|err| match err.downcast::<ApiError>() {
817 Ok(api_err) => {
818 if api_err.status == StatusCode::BAD_REQUEST {
819 if let Some(tokens) = parse_prompt_too_long(&api_err.body) {
820 return anyhow!(
821 LanguageModelKnownError::ContextWindowLimitExceeded {
822 tokens
823 }
824 );
825 }
826 }
827 anyhow!(api_err)
828 }
829 Err(err) => anyhow!(err),
830 })?;
831
832 let mut mapper = AnthropicEventMapper::new();
833 Ok(map_cloud_completion_events(
834 Box::pin(
835 response_lines(response, includes_status_messages)
836 .chain(usage_updated_event(usage))
837 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
838 ),
839 move |event| mapper.map_event(event),
840 ))
841 });
842 async move { Ok(future.await?.boxed()) }.boxed()
843 }
844 CloudModel::OpenAi(model) => {
845 let client = self.client.clone();
846 let request = into_open_ai(request, model, model.max_output_tokens());
847 let llm_api_token = self.llm_api_token.clone();
848 let future = self.request_limiter.stream(async move {
849 let PerformLlmCompletionResponse {
850 response,
851 usage,
852 includes_status_messages,
853 tool_use_limit_reached,
854 } = Self::perform_llm_completion(
855 client.clone(),
856 llm_api_token,
857 CompletionBody {
858 thread_id,
859 prompt_id,
860 mode,
861 provider: zed_llm_client::LanguageModelProvider::OpenAi,
862 model: request.model.clone(),
863 provider_request: serde_json::to_value(&request)?,
864 },
865 )
866 .await?;
867
868 let mut mapper = OpenAiEventMapper::new();
869 Ok(map_cloud_completion_events(
870 Box::pin(
871 response_lines(response, includes_status_messages)
872 .chain(usage_updated_event(usage))
873 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
874 ),
875 move |event| mapper.map_event(event),
876 ))
877 });
878 async move { Ok(future.await?.boxed()) }.boxed()
879 }
880 CloudModel::Google(model) => {
881 let client = self.client.clone();
882 let request = into_google(request, model.id().into());
883 let llm_api_token = self.llm_api_token.clone();
884 let future = self.request_limiter.stream(async move {
885 let PerformLlmCompletionResponse {
886 response,
887 usage,
888 includes_status_messages,
889 tool_use_limit_reached,
890 } = Self::perform_llm_completion(
891 client.clone(),
892 llm_api_token,
893 CompletionBody {
894 thread_id,
895 prompt_id,
896 mode,
897 provider: zed_llm_client::LanguageModelProvider::Google,
898 model: request.model.clone(),
899 provider_request: serde_json::to_value(&request)?,
900 },
901 )
902 .await?;
903
904 let mut mapper = GoogleEventMapper::new();
905 Ok(map_cloud_completion_events(
906 Box::pin(
907 response_lines(response, includes_status_messages)
908 .chain(usage_updated_event(usage))
909 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
910 ),
911 move |event| mapper.map_event(event),
912 ))
913 });
914 async move { Ok(future.await?.boxed()) }.boxed()
915 }
916 }
917 }
918}
919
920#[derive(Serialize, Deserialize)]
921#[serde(rename_all = "snake_case")]
922pub enum CloudCompletionEvent<T> {
923 Status(CompletionRequestStatus),
924 Event(T),
925}
926
927fn map_cloud_completion_events<T, F>(
928 stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>,
929 mut map_callback: F,
930) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
931where
932 T: DeserializeOwned + 'static,
933 F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
934 + Send
935 + 'static,
936{
937 stream
938 .flat_map(move |event| {
939 futures::stream::iter(match event {
940 Err(error) => {
941 vec![Err(LanguageModelCompletionError::Other(error))]
942 }
943 Ok(CloudCompletionEvent::Status(event)) => {
944 vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
945 }
946 Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
947 })
948 })
949 .boxed()
950}
951
952fn usage_updated_event<T>(
953 usage: Option<RequestUsage>,
954) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
955 futures::stream::iter(usage.map(|usage| {
956 Ok(CloudCompletionEvent::Status(
957 CompletionRequestStatus::UsageUpdated {
958 amount: usage.amount as usize,
959 limit: usage.limit,
960 },
961 ))
962 }))
963}
964
965fn tool_use_limit_reached_event<T>(
966 tool_use_limit_reached: bool,
967) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
968 futures::stream::iter(tool_use_limit_reached.then(|| {
969 Ok(CloudCompletionEvent::Status(
970 CompletionRequestStatus::ToolUseLimitReached,
971 ))
972 }))
973}
974
975fn response_lines<T: DeserializeOwned>(
976 response: Response<AsyncBody>,
977 includes_status_messages: bool,
978) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
979 futures::stream::try_unfold(
980 (String::new(), BufReader::new(response.into_body())),
981 move |(mut line, mut body)| async move {
982 match body.read_line(&mut line).await {
983 Ok(0) => Ok(None),
984 Ok(_) => {
985 let event = if includes_status_messages {
986 serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
987 } else {
988 CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
989 };
990
991 line.clear();
992 Ok(Some((event, (line, body))))
993 }
994 Err(e) => Err(e.into()),
995 }
996 },
997 )
998}
999
1000struct ConfigurationView {
1001 state: gpui::Entity<State>,
1002}
1003
1004impl ConfigurationView {
1005 fn authenticate(&mut self, cx: &mut Context<Self>) {
1006 self.state.update(cx, |state, cx| {
1007 state.authenticate(cx).detach_and_log_err(cx);
1008 });
1009 cx.notify();
1010 }
1011}
1012
1013impl Render for ConfigurationView {
1014 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1015 const ZED_AI_URL: &str = "https://zed.dev/ai";
1016
1017 let is_connected = !self.state.read(cx).is_signed_out();
1018 let plan = self.state.read(cx).user_store.read(cx).current_plan();
1019 let has_accepted_terms = self.state.read(cx).has_accepted_terms_of_service(cx);
1020
1021 let is_pro = plan == Some(proto::Plan::ZedPro);
1022 let subscription_text = Label::new(if is_pro {
1023 "You have full access to Zed's hosted LLMs, which include models from Anthropic, OpenAI, and Google. They come with faster speeds and higher limits through Zed Pro."
1024 } else {
1025 "You have basic access to models from Anthropic through the Zed AI Free plan."
1026 });
1027 let manage_subscription_button = if is_pro {
1028 Some(
1029 h_flex().child(
1030 Button::new("manage_settings", "Manage Subscription")
1031 .style(ButtonStyle::Tinted(TintColor::Accent))
1032 .on_click(
1033 cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1034 ),
1035 ),
1036 )
1037 } else if cx.has_flag::<ZedProFeatureFlag>() {
1038 Some(
1039 h_flex()
1040 .gap_2()
1041 .child(
1042 Button::new("learn_more", "Learn more")
1043 .style(ButtonStyle::Subtle)
1044 .on_click(cx.listener(|_, _, _, cx| cx.open_url(ZED_AI_URL))),
1045 )
1046 .child(
1047 Button::new("upgrade", "Upgrade")
1048 .style(ButtonStyle::Subtle)
1049 .color(Color::Accent)
1050 .on_click(
1051 cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1052 ),
1053 ),
1054 )
1055 } else {
1056 None
1057 };
1058
1059 if is_connected {
1060 v_flex()
1061 .gap_3()
1062 .w_full()
1063 .children(render_accept_terms(
1064 self.state.clone(),
1065 LanguageModelProviderTosView::Configuration,
1066 cx,
1067 ))
1068 .when(has_accepted_terms, |this| {
1069 this.child(subscription_text)
1070 .children(manage_subscription_button)
1071 })
1072 } else {
1073 v_flex()
1074 .gap_2()
1075 .child(Label::new("Use Zed AI to access hosted language models."))
1076 .child(
1077 Button::new("sign_in", "Sign In")
1078 .icon_color(Color::Muted)
1079 .icon(IconName::Github)
1080 .icon_position(IconPosition::Start)
1081 .on_click(cx.listener(move |this, _, _, cx| this.authenticate(cx))),
1082 )
1083 }
1084 }
1085}