1use super::open_ai::count_open_ai_tokens;
2use crate::{
3 settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
4 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
5 LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
6};
7use anyhow::{anyhow, bail, Result};
8use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
9use collections::BTreeMap;
10use feature_flags::{FeatureFlagAppExt, LanguageModels};
11use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
12use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
13use http_client::{AsyncBody, HttpClient, Method, Response};
14use schemars::JsonSchema;
15use serde::{Deserialize, Serialize};
16use serde_json::value::RawValue;
17use settings::{Settings, SettingsStore};
18use smol::{
19 io::BufReader,
20 lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
21};
22use std::{future, sync::Arc};
23use strum::IntoEnumIterator;
24use ui::prelude::*;
25
26use crate::{LanguageModelAvailability, LanguageModelProvider};
27
28use super::anthropic::count_anthropic_tokens;
29
30pub const PROVIDER_ID: &str = "zed.dev";
31pub const PROVIDER_NAME: &str = "Zed";
32
33#[derive(Default, Clone, Debug, PartialEq)]
34pub struct ZedDotDevSettings {
35 pub available_models: Vec<AvailableModel>,
36}
37
38#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
39#[serde(rename_all = "lowercase")]
40pub enum AvailableProvider {
41 Anthropic,
42 OpenAi,
43 Google,
44}
45
46#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
47pub struct AvailableModel {
48 provider: AvailableProvider,
49 name: String,
50 max_tokens: usize,
51 tool_override: Option<String>,
52}
53
54pub struct CloudLanguageModelProvider {
55 client: Arc<Client>,
56 llm_api_token: LlmApiToken,
57 state: gpui::Model<State>,
58 _maintain_client_status: Task<()>,
59}
60
61pub struct State {
62 client: Arc<Client>,
63 user_store: Model<UserStore>,
64 status: client::Status,
65 _subscription: Subscription,
66}
67
68impl State {
69 fn is_signed_out(&self) -> bool {
70 self.status.is_signed_out()
71 }
72
73 fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
74 let client = self.client.clone();
75 cx.spawn(move |this, mut cx| async move {
76 client.authenticate_and_connect(true, &cx).await?;
77 this.update(&mut cx, |_, cx| cx.notify())
78 })
79 }
80}
81
82impl CloudLanguageModelProvider {
83 pub fn new(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) -> Self {
84 let mut status_rx = client.status();
85 let status = *status_rx.borrow();
86
87 let state = cx.new_model(|cx| State {
88 client: client.clone(),
89 user_store,
90 status,
91 _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
92 cx.notify();
93 }),
94 });
95
96 let state_ref = state.downgrade();
97 let maintain_client_status = cx.spawn(|mut cx| async move {
98 while let Some(status) = status_rx.next().await {
99 if let Some(this) = state_ref.upgrade() {
100 _ = this.update(&mut cx, |this, cx| {
101 if this.status != status {
102 this.status = status;
103 cx.notify();
104 }
105 });
106 } else {
107 break;
108 }
109 }
110 });
111
112 Self {
113 client,
114 state,
115 llm_api_token: LlmApiToken::default(),
116 _maintain_client_status: maintain_client_status,
117 }
118 }
119}
120
121impl LanguageModelProviderState for CloudLanguageModelProvider {
122 type ObservableEntity = State;
123
124 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
125 Some(self.state.clone())
126 }
127}
128
129impl LanguageModelProvider for CloudLanguageModelProvider {
130 fn id(&self) -> LanguageModelProviderId {
131 LanguageModelProviderId(PROVIDER_ID.into())
132 }
133
134 fn name(&self) -> LanguageModelProviderName {
135 LanguageModelProviderName(PROVIDER_NAME.into())
136 }
137
138 fn icon(&self) -> IconName {
139 IconName::AiZed
140 }
141
142 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
143 let mut models = BTreeMap::default();
144
145 let is_user = !cx.has_flag::<LanguageModels>();
146 if is_user {
147 models.insert(
148 anthropic::Model::Claude3_5Sonnet.id().to_string(),
149 CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
150 );
151 } else {
152 for model in anthropic::Model::iter() {
153 if !matches!(model, anthropic::Model::Custom { .. }) {
154 models.insert(model.id().to_string(), CloudModel::Anthropic(model));
155 }
156 }
157 for model in open_ai::Model::iter() {
158 if !matches!(model, open_ai::Model::Custom { .. }) {
159 models.insert(model.id().to_string(), CloudModel::OpenAi(model));
160 }
161 }
162 for model in google_ai::Model::iter() {
163 if !matches!(model, google_ai::Model::Custom { .. }) {
164 models.insert(model.id().to_string(), CloudModel::Google(model));
165 }
166 }
167 for model in ZedModel::iter() {
168 models.insert(model.id().to_string(), CloudModel::Zed(model));
169 }
170
171 // Override with available models from settings
172 for model in &AllLanguageModelSettings::get_global(cx)
173 .zed_dot_dev
174 .available_models
175 {
176 let model = match model.provider {
177 AvailableProvider::Anthropic => {
178 CloudModel::Anthropic(anthropic::Model::Custom {
179 name: model.name.clone(),
180 max_tokens: model.max_tokens,
181 tool_override: model.tool_override.clone(),
182 })
183 }
184 AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
185 name: model.name.clone(),
186 max_tokens: model.max_tokens,
187 }),
188 AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
189 name: model.name.clone(),
190 max_tokens: model.max_tokens,
191 }),
192 };
193 models.insert(model.id().to_string(), model.clone());
194 }
195 }
196
197 models
198 .into_values()
199 .map(|model| {
200 Arc::new(CloudLanguageModel {
201 id: LanguageModelId::from(model.id().to_string()),
202 model,
203 llm_api_token: self.llm_api_token.clone(),
204 client: self.client.clone(),
205 request_limiter: RateLimiter::new(4),
206 }) as Arc<dyn LanguageModel>
207 })
208 .collect()
209 }
210
211 fn is_authenticated(&self, cx: &AppContext) -> bool {
212 !self.state.read(cx).is_signed_out()
213 }
214
215 fn authenticate(&self, _cx: &mut AppContext) -> Task<Result<()>> {
216 Task::ready(Ok(()))
217 }
218
219 fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
220 cx.new_view(|_cx| ConfigurationView {
221 state: self.state.clone(),
222 })
223 .into()
224 }
225
226 fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
227 Task::ready(Ok(()))
228 }
229}
230
231pub struct CloudLanguageModel {
232 id: LanguageModelId,
233 model: CloudModel,
234 llm_api_token: LlmApiToken,
235 client: Arc<Client>,
236 request_limiter: RateLimiter,
237}
238
239#[derive(Clone, Default)]
240struct LlmApiToken(Arc<RwLock<Option<String>>>);
241
242impl CloudLanguageModel {
243 async fn perform_llm_completion(
244 client: Arc<Client>,
245 llm_api_token: LlmApiToken,
246 body: PerformCompletionParams,
247 ) -> Result<Response<AsyncBody>> {
248 let http_client = &client.http_client();
249
250 let mut token = llm_api_token.acquire(&client).await?;
251 let mut did_retry = false;
252
253 let response = loop {
254 let request = http_client::Request::builder()
255 .method(Method::POST)
256 .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
257 .header("Content-Type", "application/json")
258 .header("Authorization", format!("Bearer {token}"))
259 .body(serde_json::to_string(&body)?.into())?;
260 let response = http_client.send(request).await?;
261 if response.status().is_success() {
262 break response;
263 } else if !did_retry
264 && response
265 .headers()
266 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
267 .is_some()
268 {
269 did_retry = true;
270 token = llm_api_token.refresh(&client).await?;
271 } else {
272 break Err(anyhow!(
273 "cloud language model completion failed with status {}",
274 response.status()
275 ))?;
276 }
277 };
278
279 Ok(response)
280 }
281}
282
283impl LanguageModel for CloudLanguageModel {
284 fn id(&self) -> LanguageModelId {
285 self.id.clone()
286 }
287
288 fn name(&self) -> LanguageModelName {
289 LanguageModelName::from(self.model.display_name().to_string())
290 }
291
292 fn provider_id(&self) -> LanguageModelProviderId {
293 LanguageModelProviderId(PROVIDER_ID.into())
294 }
295
296 fn provider_name(&self) -> LanguageModelProviderName {
297 LanguageModelProviderName(PROVIDER_NAME.into())
298 }
299
300 fn telemetry_id(&self) -> String {
301 format!("zed.dev/{}", self.model.id())
302 }
303
304 fn availability(&self) -> LanguageModelAvailability {
305 self.model.availability()
306 }
307
308 fn max_token_count(&self) -> usize {
309 self.model.max_token_count()
310 }
311
312 fn count_tokens(
313 &self,
314 request: LanguageModelRequest,
315 cx: &AppContext,
316 ) -> BoxFuture<'static, Result<usize>> {
317 match self.model.clone() {
318 CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
319 CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
320 CloudModel::Google(model) => {
321 let client = self.client.clone();
322 let request = request.into_google(model.id().into());
323 let request = google_ai::CountTokensRequest {
324 contents: request.contents,
325 };
326 async move {
327 let request = serde_json::to_string(&request)?;
328 let response = client
329 .request(proto::CountLanguageModelTokens {
330 provider: proto::LanguageModelProvider::Google as i32,
331 request,
332 })
333 .await?;
334 Ok(response.token_count as usize)
335 }
336 .boxed()
337 }
338 CloudModel::Zed(_) => {
339 count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
340 }
341 }
342 }
343
344 fn stream_completion(
345 &self,
346 request: LanguageModelRequest,
347 _cx: &AsyncAppContext,
348 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
349 match &self.model {
350 CloudModel::Anthropic(model) => {
351 let request = request.into_anthropic(model.id().into());
352 let client = self.client.clone();
353 let llm_api_token = self.llm_api_token.clone();
354 let future = self.request_limiter.stream(async move {
355 let response = Self::perform_llm_completion(
356 client.clone(),
357 llm_api_token,
358 PerformCompletionParams {
359 provider: client::LanguageModelProvider::Anthropic,
360 model: request.model.clone(),
361 provider_request: RawValue::from_string(serde_json::to_string(
362 &request,
363 )?)?,
364 },
365 )
366 .await?;
367 let body = BufReader::new(response.into_body());
368 let stream = futures::stream::try_unfold(body, move |mut body| async move {
369 let mut buffer = String::new();
370 match body.read_line(&mut buffer).await {
371 Ok(0) => Ok(None),
372 Ok(_) => {
373 let event: anthropic::Event = serde_json::from_str(&buffer)?;
374 Ok(Some((event, body)))
375 }
376 Err(e) => Err(e.into()),
377 }
378 });
379
380 Ok(anthropic::extract_text_from_events(stream))
381 });
382 async move { Ok(future.await?.boxed()) }.boxed()
383 }
384 CloudModel::OpenAi(model) => {
385 let client = self.client.clone();
386 let request = request.into_open_ai(model.id().into());
387 let llm_api_token = self.llm_api_token.clone();
388 let future = self.request_limiter.stream(async move {
389 let response = Self::perform_llm_completion(
390 client.clone(),
391 llm_api_token,
392 PerformCompletionParams {
393 provider: client::LanguageModelProvider::OpenAi,
394 model: request.model.clone(),
395 provider_request: RawValue::from_string(serde_json::to_string(
396 &request,
397 )?)?,
398 },
399 )
400 .await?;
401 let body = BufReader::new(response.into_body());
402 let stream = futures::stream::try_unfold(body, move |mut body| async move {
403 let mut buffer = String::new();
404 match body.read_line(&mut buffer).await {
405 Ok(0) => Ok(None),
406 Ok(_) => {
407 let event: open_ai::ResponseStreamEvent =
408 serde_json::from_str(&buffer)?;
409 Ok(Some((event, body)))
410 }
411 Err(e) => Err(e.into()),
412 }
413 });
414
415 Ok(open_ai::extract_text_from_events(stream))
416 });
417 async move { Ok(future.await?.boxed()) }.boxed()
418 }
419 CloudModel::Google(model) => {
420 let client = self.client.clone();
421 let request = request.into_google(model.id().into());
422 let llm_api_token = self.llm_api_token.clone();
423 let future = self.request_limiter.stream(async move {
424 let response = Self::perform_llm_completion(
425 client.clone(),
426 llm_api_token,
427 PerformCompletionParams {
428 provider: client::LanguageModelProvider::Google,
429 model: request.model.clone(),
430 provider_request: RawValue::from_string(serde_json::to_string(
431 &request,
432 )?)?,
433 },
434 )
435 .await?;
436 let body = BufReader::new(response.into_body());
437 let stream = futures::stream::try_unfold(body, move |mut body| async move {
438 let mut buffer = String::new();
439 match body.read_line(&mut buffer).await {
440 Ok(0) => Ok(None),
441 Ok(_) => {
442 let event: google_ai::GenerateContentResponse =
443 serde_json::from_str(&buffer)?;
444 Ok(Some((event, body)))
445 }
446 Err(e) => Err(e.into()),
447 }
448 });
449
450 Ok(google_ai::extract_text_from_events(stream))
451 });
452 async move { Ok(future.await?.boxed()) }.boxed()
453 }
454 CloudModel::Zed(model) => {
455 let client = self.client.clone();
456 let mut request = request.into_open_ai(model.id().into());
457 request.max_tokens = Some(4000);
458 let llm_api_token = self.llm_api_token.clone();
459 let future = self.request_limiter.stream(async move {
460 let response = Self::perform_llm_completion(
461 client.clone(),
462 llm_api_token,
463 PerformCompletionParams {
464 provider: client::LanguageModelProvider::Zed,
465 model: request.model.clone(),
466 provider_request: RawValue::from_string(serde_json::to_string(
467 &request,
468 )?)?,
469 },
470 )
471 .await?;
472 let body = BufReader::new(response.into_body());
473 let stream = futures::stream::try_unfold(body, move |mut body| async move {
474 let mut buffer = String::new();
475 match body.read_line(&mut buffer).await {
476 Ok(0) => Ok(None),
477 Ok(_) => {
478 let event: open_ai::ResponseStreamEvent =
479 serde_json::from_str(&buffer)?;
480 Ok(Some((event, body)))
481 }
482 Err(e) => Err(e.into()),
483 }
484 });
485
486 Ok(open_ai::extract_text_from_events(stream))
487 });
488 async move { Ok(future.await?.boxed()) }.boxed()
489 }
490 }
491 }
492
493 fn use_any_tool(
494 &self,
495 request: LanguageModelRequest,
496 tool_name: String,
497 tool_description: String,
498 input_schema: serde_json::Value,
499 _cx: &AsyncAppContext,
500 ) -> BoxFuture<'static, Result<serde_json::Value>> {
501 match &self.model {
502 CloudModel::Anthropic(model) => {
503 let client = self.client.clone();
504 let mut request = request.into_anthropic(model.tool_model_id().into());
505 request.tool_choice = Some(anthropic::ToolChoice::Tool {
506 name: tool_name.clone(),
507 });
508 request.tools = vec![anthropic::Tool {
509 name: tool_name.clone(),
510 description: tool_description,
511 input_schema,
512 }];
513
514 let llm_api_token = self.llm_api_token.clone();
515 self.request_limiter
516 .run(async move {
517 let response = Self::perform_llm_completion(
518 client.clone(),
519 llm_api_token,
520 PerformCompletionParams {
521 provider: client::LanguageModelProvider::Anthropic,
522 model: request.model.clone(),
523 provider_request: RawValue::from_string(serde_json::to_string(
524 &request,
525 )?)?,
526 },
527 )
528 .await?;
529
530 let mut tool_use_index = None;
531 let mut tool_input = String::new();
532 let mut body = BufReader::new(response.into_body());
533 let mut line = String::new();
534 while body.read_line(&mut line).await? > 0 {
535 let event: anthropic::Event = serde_json::from_str(&line)?;
536 line.clear();
537
538 match event {
539 anthropic::Event::ContentBlockStart {
540 content_block,
541 index,
542 } => {
543 if let anthropic::Content::ToolUse { name, .. } = content_block
544 {
545 if name == tool_name {
546 tool_use_index = Some(index);
547 }
548 }
549 }
550 anthropic::Event::ContentBlockDelta { index, delta } => match delta
551 {
552 anthropic::ContentDelta::TextDelta { .. } => {}
553 anthropic::ContentDelta::InputJsonDelta { partial_json } => {
554 if Some(index) == tool_use_index {
555 tool_input.push_str(&partial_json);
556 }
557 }
558 },
559 anthropic::Event::ContentBlockStop { index } => {
560 if Some(index) == tool_use_index {
561 return Ok(serde_json::from_str(&tool_input)?);
562 }
563 }
564 _ => {}
565 }
566 }
567
568 if tool_use_index.is_some() {
569 Err(anyhow!("tool content incomplete"))
570 } else {
571 Err(anyhow!("tool not used"))
572 }
573 })
574 .boxed()
575 }
576 CloudModel::OpenAi(model) => {
577 let mut request = request.into_open_ai(model.id().into());
578 let client = self.client.clone();
579 let mut function = open_ai::FunctionDefinition {
580 name: tool_name.clone(),
581 description: None,
582 parameters: None,
583 };
584 let func = open_ai::ToolDefinition::Function {
585 function: function.clone(),
586 };
587 request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
588 // Fill in description and params separately, as they're not needed for tool_choice field.
589 function.description = Some(tool_description);
590 function.parameters = Some(input_schema);
591 request.tools = vec![open_ai::ToolDefinition::Function { function }];
592
593 let llm_api_token = self.llm_api_token.clone();
594 self.request_limiter
595 .run(async move {
596 let response = Self::perform_llm_completion(
597 client.clone(),
598 llm_api_token,
599 PerformCompletionParams {
600 provider: client::LanguageModelProvider::OpenAi,
601 model: request.model.clone(),
602 provider_request: RawValue::from_string(serde_json::to_string(
603 &request,
604 )?)?,
605 },
606 )
607 .await?;
608
609 let mut body = BufReader::new(response.into_body());
610 let mut line = String::new();
611 let mut load_state = None;
612
613 while body.read_line(&mut line).await? > 0 {
614 let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
615 line.clear();
616
617 for choice in part.choices {
618 let Some(tool_calls) = choice.delta.tool_calls else {
619 continue;
620 };
621
622 for call in tool_calls {
623 if let Some(func) = call.function {
624 if func.name.as_deref() == Some(tool_name.as_str()) {
625 load_state = Some((String::default(), call.index));
626 }
627 if let Some((arguments, (output, index))) =
628 func.arguments.zip(load_state.as_mut())
629 {
630 if call.index == *index {
631 output.push_str(&arguments);
632 }
633 }
634 }
635 }
636 }
637 }
638
639 if let Some((arguments, _)) = load_state {
640 return Ok(serde_json::from_str(&arguments)?);
641 } else {
642 bail!("tool not used");
643 }
644 })
645 .boxed()
646 }
647 CloudModel::Google(_) => {
648 future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
649 }
650 CloudModel::Zed(model) => {
651 // All Zed models are OpenAI-based at the time of writing.
652 let mut request = request.into_open_ai(model.id().into());
653 let client = self.client.clone();
654 let mut function = open_ai::FunctionDefinition {
655 name: tool_name.clone(),
656 description: None,
657 parameters: None,
658 };
659 let func = open_ai::ToolDefinition::Function {
660 function: function.clone(),
661 };
662 request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
663 // Fill in description and params separately, as they're not needed for tool_choice field.
664 function.description = Some(tool_description);
665 function.parameters = Some(input_schema);
666 request.tools = vec![open_ai::ToolDefinition::Function { function }];
667
668 let llm_api_token = self.llm_api_token.clone();
669 self.request_limiter
670 .run(async move {
671 let response = Self::perform_llm_completion(
672 client.clone(),
673 llm_api_token,
674 PerformCompletionParams {
675 provider: client::LanguageModelProvider::Zed,
676 model: request.model.clone(),
677 provider_request: RawValue::from_string(serde_json::to_string(
678 &request,
679 )?)?,
680 },
681 )
682 .await?;
683
684 let mut body = BufReader::new(response.into_body());
685 let mut line = String::new();
686 let mut load_state = None;
687
688 while body.read_line(&mut line).await? > 0 {
689 let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
690 line.clear();
691
692 for choice in part.choices {
693 let Some(tool_calls) = choice.delta.tool_calls else {
694 continue;
695 };
696
697 for call in tool_calls {
698 if let Some(func) = call.function {
699 if func.name.as_deref() == Some(tool_name.as_str()) {
700 load_state = Some((String::default(), call.index));
701 }
702 if let Some((arguments, (output, index))) =
703 func.arguments.zip(load_state.as_mut())
704 {
705 if call.index == *index {
706 output.push_str(&arguments);
707 }
708 }
709 }
710 }
711 }
712 }
713 if let Some((arguments, _)) = load_state {
714 return Ok(serde_json::from_str(&arguments)?);
715 } else {
716 bail!("tool not used");
717 }
718 })
719 .boxed()
720 }
721 }
722 }
723}
724
725impl LlmApiToken {
726 async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
727 let lock = self.0.upgradable_read().await;
728 if let Some(token) = lock.as_ref() {
729 Ok(token.to_string())
730 } else {
731 Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, &client).await
732 }
733 }
734
735 async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
736 Self::fetch(self.0.write().await, &client).await
737 }
738
739 async fn fetch<'a>(
740 mut lock: RwLockWriteGuard<'a, Option<String>>,
741 client: &Arc<Client>,
742 ) -> Result<String> {
743 let response = client.request(proto::GetLlmToken {}).await?;
744 *lock = Some(response.token.clone());
745 Ok(response.token.clone())
746 }
747}
748
749struct ConfigurationView {
750 state: gpui::Model<State>,
751}
752
753impl ConfigurationView {
754 fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
755 self.state.update(cx, |state, cx| {
756 state.authenticate(cx).detach_and_log_err(cx);
757 });
758 cx.notify();
759 }
760}
761
762impl Render for ConfigurationView {
763 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
764 const ZED_AI_URL: &str = "https://zed.dev/ai";
765 const ACCOUNT_SETTINGS_URL: &str = "https://zed.dev/account";
766
767 let is_connected = !self.state.read(cx).is_signed_out();
768 let plan = self.state.read(cx).user_store.read(cx).current_plan();
769
770 let is_pro = plan == Some(proto::Plan::ZedPro);
771
772 if is_connected {
773 v_flex()
774 .gap_3()
775 .max_w_4_5()
776 .child(Label::new(
777 if is_pro {
778 "You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."
779 } else {
780 "You have basic access to models from Anthropic, OpenAI, Google and more through the Zed AI Free plan."
781 }))
782 .child(
783 if is_pro {
784 h_flex().child(
785 Button::new("manage_settings", "Manage Subscription")
786 .style(ButtonStyle::Filled)
787 .on_click(cx.listener(|_, _, cx| {
788 cx.open_url(ACCOUNT_SETTINGS_URL)
789 })))
790 } else {
791 h_flex()
792 .gap_2()
793 .child(
794 Button::new("learn_more", "Learn more")
795 .style(ButtonStyle::Subtle)
796 .on_click(cx.listener(|_, _, cx| {
797 cx.open_url(ZED_AI_URL)
798 })))
799 .child(
800 Button::new("upgrade", "Upgrade")
801 .style(ButtonStyle::Subtle)
802 .color(Color::Accent)
803 .on_click(cx.listener(|_, _, cx| {
804 cx.open_url(ACCOUNT_SETTINGS_URL)
805 })))
806 },
807 )
808 } else {
809 v_flex()
810 .gap_6()
811 .child(Label::new("Use the zed.dev to access language models."))
812 .child(
813 v_flex()
814 .gap_2()
815 .child(
816 Button::new("sign_in", "Sign in")
817 .icon_color(Color::Muted)
818 .icon(IconName::Github)
819 .icon_position(IconPosition::Start)
820 .style(ButtonStyle::Filled)
821 .full_width()
822 .on_click(cx.listener(move |this, _, cx| this.authenticate(cx))),
823 )
824 .child(
825 div().flex().w_full().items_center().child(
826 Label::new("Sign in to enable collaboration.")
827 .color(Color::Muted)
828 .size(LabelSize::Small),
829 ),
830 ),
831 )
832 }
833 }
834}