cloud.rs

  1use super::open_ai::count_open_ai_tokens;
  2use crate::{
  3    settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
  4    LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
  5    LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
  6};
  7use anyhow::{anyhow, bail, Result};
  8use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
  9use collections::BTreeMap;
 10use feature_flags::{FeatureFlagAppExt, LanguageModels};
 11use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
 12use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
 13use http_client::{AsyncBody, HttpClient, Method, Response};
 14use schemars::JsonSchema;
 15use serde::{Deserialize, Serialize};
 16use serde_json::value::RawValue;
 17use settings::{Settings, SettingsStore};
 18use smol::{
 19    io::BufReader,
 20    lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
 21};
 22use std::{future, sync::Arc};
 23use strum::IntoEnumIterator;
 24use ui::prelude::*;
 25
 26use crate::{LanguageModelAvailability, LanguageModelProvider};
 27
 28use super::anthropic::count_anthropic_tokens;
 29
 30pub const PROVIDER_ID: &str = "zed.dev";
 31pub const PROVIDER_NAME: &str = "Zed";
 32
 33#[derive(Default, Clone, Debug, PartialEq)]
 34pub struct ZedDotDevSettings {
 35    pub available_models: Vec<AvailableModel>,
 36}
 37
 38#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 39#[serde(rename_all = "lowercase")]
 40pub enum AvailableProvider {
 41    Anthropic,
 42    OpenAi,
 43    Google,
 44}
 45
 46#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 47pub struct AvailableModel {
 48    provider: AvailableProvider,
 49    name: String,
 50    max_tokens: usize,
 51    tool_override: Option<String>,
 52}
 53
 54pub struct CloudLanguageModelProvider {
 55    client: Arc<Client>,
 56    llm_api_token: LlmApiToken,
 57    state: gpui::Model<State>,
 58    _maintain_client_status: Task<()>,
 59}
 60
 61pub struct State {
 62    client: Arc<Client>,
 63    user_store: Model<UserStore>,
 64    status: client::Status,
 65    _subscription: Subscription,
 66}
 67
 68impl State {
 69    fn is_signed_out(&self) -> bool {
 70        self.status.is_signed_out()
 71    }
 72
 73    fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 74        let client = self.client.clone();
 75        cx.spawn(move |this, mut cx| async move {
 76            client.authenticate_and_connect(true, &cx).await?;
 77            this.update(&mut cx, |_, cx| cx.notify())
 78        })
 79    }
 80}
 81
 82impl CloudLanguageModelProvider {
 83    pub fn new(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) -> Self {
 84        let mut status_rx = client.status();
 85        let status = *status_rx.borrow();
 86
 87        let state = cx.new_model(|cx| State {
 88            client: client.clone(),
 89            user_store,
 90            status,
 91            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 92                cx.notify();
 93            }),
 94        });
 95
 96        let state_ref = state.downgrade();
 97        let maintain_client_status = cx.spawn(|mut cx| async move {
 98            while let Some(status) = status_rx.next().await {
 99                if let Some(this) = state_ref.upgrade() {
100                    _ = this.update(&mut cx, |this, cx| {
101                        if this.status != status {
102                            this.status = status;
103                            cx.notify();
104                        }
105                    });
106                } else {
107                    break;
108                }
109            }
110        });
111
112        Self {
113            client,
114            state,
115            llm_api_token: LlmApiToken::default(),
116            _maintain_client_status: maintain_client_status,
117        }
118    }
119}
120
121impl LanguageModelProviderState for CloudLanguageModelProvider {
122    type ObservableEntity = State;
123
124    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
125        Some(self.state.clone())
126    }
127}
128
129impl LanguageModelProvider for CloudLanguageModelProvider {
130    fn id(&self) -> LanguageModelProviderId {
131        LanguageModelProviderId(PROVIDER_ID.into())
132    }
133
134    fn name(&self) -> LanguageModelProviderName {
135        LanguageModelProviderName(PROVIDER_NAME.into())
136    }
137
138    fn icon(&self) -> IconName {
139        IconName::AiZed
140    }
141
142    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
143        let mut models = BTreeMap::default();
144
145        let is_user = !cx.has_flag::<LanguageModels>();
146        if is_user {
147            models.insert(
148                anthropic::Model::Claude3_5Sonnet.id().to_string(),
149                CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
150            );
151        } else {
152            for model in anthropic::Model::iter() {
153                if !matches!(model, anthropic::Model::Custom { .. }) {
154                    models.insert(model.id().to_string(), CloudModel::Anthropic(model));
155                }
156            }
157            for model in open_ai::Model::iter() {
158                if !matches!(model, open_ai::Model::Custom { .. }) {
159                    models.insert(model.id().to_string(), CloudModel::OpenAi(model));
160                }
161            }
162            for model in google_ai::Model::iter() {
163                if !matches!(model, google_ai::Model::Custom { .. }) {
164                    models.insert(model.id().to_string(), CloudModel::Google(model));
165                }
166            }
167            for model in ZedModel::iter() {
168                models.insert(model.id().to_string(), CloudModel::Zed(model));
169            }
170
171            // Override with available models from settings
172            for model in &AllLanguageModelSettings::get_global(cx)
173                .zed_dot_dev
174                .available_models
175            {
176                let model = match model.provider {
177                    AvailableProvider::Anthropic => {
178                        CloudModel::Anthropic(anthropic::Model::Custom {
179                            name: model.name.clone(),
180                            max_tokens: model.max_tokens,
181                            tool_override: model.tool_override.clone(),
182                        })
183                    }
184                    AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
185                        name: model.name.clone(),
186                        max_tokens: model.max_tokens,
187                    }),
188                    AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
189                        name: model.name.clone(),
190                        max_tokens: model.max_tokens,
191                    }),
192                };
193                models.insert(model.id().to_string(), model.clone());
194            }
195        }
196
197        models
198            .into_values()
199            .map(|model| {
200                Arc::new(CloudLanguageModel {
201                    id: LanguageModelId::from(model.id().to_string()),
202                    model,
203                    llm_api_token: self.llm_api_token.clone(),
204                    client: self.client.clone(),
205                    request_limiter: RateLimiter::new(4),
206                }) as Arc<dyn LanguageModel>
207            })
208            .collect()
209    }
210
211    fn is_authenticated(&self, cx: &AppContext) -> bool {
212        !self.state.read(cx).is_signed_out()
213    }
214
215    fn authenticate(&self, _cx: &mut AppContext) -> Task<Result<()>> {
216        Task::ready(Ok(()))
217    }
218
219    fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
220        cx.new_view(|_cx| ConfigurationView {
221            state: self.state.clone(),
222        })
223        .into()
224    }
225
226    fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
227        Task::ready(Ok(()))
228    }
229}
230
231pub struct CloudLanguageModel {
232    id: LanguageModelId,
233    model: CloudModel,
234    llm_api_token: LlmApiToken,
235    client: Arc<Client>,
236    request_limiter: RateLimiter,
237}
238
239#[derive(Clone, Default)]
240struct LlmApiToken(Arc<RwLock<Option<String>>>);
241
242impl CloudLanguageModel {
243    async fn perform_llm_completion(
244        client: Arc<Client>,
245        llm_api_token: LlmApiToken,
246        body: PerformCompletionParams,
247    ) -> Result<Response<AsyncBody>> {
248        let http_client = &client.http_client();
249
250        let mut token = llm_api_token.acquire(&client).await?;
251        let mut did_retry = false;
252
253        let response = loop {
254            let request = http_client::Request::builder()
255                .method(Method::POST)
256                .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
257                .header("Content-Type", "application/json")
258                .header("Authorization", format!("Bearer {token}"))
259                .body(serde_json::to_string(&body)?.into())?;
260            let response = http_client.send(request).await?;
261            if response.status().is_success() {
262                break response;
263            } else if !did_retry
264                && response
265                    .headers()
266                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
267                    .is_some()
268            {
269                did_retry = true;
270                token = llm_api_token.refresh(&client).await?;
271            } else {
272                break Err(anyhow!(
273                    "cloud language model completion failed with status {}",
274                    response.status()
275                ))?;
276            }
277        };
278
279        Ok(response)
280    }
281}
282
283impl LanguageModel for CloudLanguageModel {
284    fn id(&self) -> LanguageModelId {
285        self.id.clone()
286    }
287
288    fn name(&self) -> LanguageModelName {
289        LanguageModelName::from(self.model.display_name().to_string())
290    }
291
292    fn provider_id(&self) -> LanguageModelProviderId {
293        LanguageModelProviderId(PROVIDER_ID.into())
294    }
295
296    fn provider_name(&self) -> LanguageModelProviderName {
297        LanguageModelProviderName(PROVIDER_NAME.into())
298    }
299
300    fn telemetry_id(&self) -> String {
301        format!("zed.dev/{}", self.model.id())
302    }
303
304    fn availability(&self) -> LanguageModelAvailability {
305        self.model.availability()
306    }
307
308    fn max_token_count(&self) -> usize {
309        self.model.max_token_count()
310    }
311
312    fn count_tokens(
313        &self,
314        request: LanguageModelRequest,
315        cx: &AppContext,
316    ) -> BoxFuture<'static, Result<usize>> {
317        match self.model.clone() {
318            CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
319            CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
320            CloudModel::Google(model) => {
321                let client = self.client.clone();
322                let request = request.into_google(model.id().into());
323                let request = google_ai::CountTokensRequest {
324                    contents: request.contents,
325                };
326                async move {
327                    let request = serde_json::to_string(&request)?;
328                    let response = client
329                        .request(proto::CountLanguageModelTokens {
330                            provider: proto::LanguageModelProvider::Google as i32,
331                            request,
332                        })
333                        .await?;
334                    Ok(response.token_count as usize)
335                }
336                .boxed()
337            }
338            CloudModel::Zed(_) => {
339                count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
340            }
341        }
342    }
343
344    fn stream_completion(
345        &self,
346        request: LanguageModelRequest,
347        _cx: &AsyncAppContext,
348    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
349        match &self.model {
350            CloudModel::Anthropic(model) => {
351                let request = request.into_anthropic(model.id().into());
352                let client = self.client.clone();
353                let llm_api_token = self.llm_api_token.clone();
354                let future = self.request_limiter.stream(async move {
355                    let response = Self::perform_llm_completion(
356                        client.clone(),
357                        llm_api_token,
358                        PerformCompletionParams {
359                            provider: client::LanguageModelProvider::Anthropic,
360                            model: request.model.clone(),
361                            provider_request: RawValue::from_string(serde_json::to_string(
362                                &request,
363                            )?)?,
364                        },
365                    )
366                    .await?;
367                    let body = BufReader::new(response.into_body());
368                    let stream = futures::stream::try_unfold(body, move |mut body| async move {
369                        let mut buffer = String::new();
370                        match body.read_line(&mut buffer).await {
371                            Ok(0) => Ok(None),
372                            Ok(_) => {
373                                let event: anthropic::Event = serde_json::from_str(&buffer)?;
374                                Ok(Some((event, body)))
375                            }
376                            Err(e) => Err(e.into()),
377                        }
378                    });
379
380                    Ok(anthropic::extract_text_from_events(stream))
381                });
382                async move { Ok(future.await?.boxed()) }.boxed()
383            }
384            CloudModel::OpenAi(model) => {
385                let client = self.client.clone();
386                let request = request.into_open_ai(model.id().into());
387                let llm_api_token = self.llm_api_token.clone();
388                let future = self.request_limiter.stream(async move {
389                    let response = Self::perform_llm_completion(
390                        client.clone(),
391                        llm_api_token,
392                        PerformCompletionParams {
393                            provider: client::LanguageModelProvider::OpenAi,
394                            model: request.model.clone(),
395                            provider_request: RawValue::from_string(serde_json::to_string(
396                                &request,
397                            )?)?,
398                        },
399                    )
400                    .await?;
401                    let body = BufReader::new(response.into_body());
402                    let stream = futures::stream::try_unfold(body, move |mut body| async move {
403                        let mut buffer = String::new();
404                        match body.read_line(&mut buffer).await {
405                            Ok(0) => Ok(None),
406                            Ok(_) => {
407                                let event: open_ai::ResponseStreamEvent =
408                                    serde_json::from_str(&buffer)?;
409                                Ok(Some((event, body)))
410                            }
411                            Err(e) => Err(e.into()),
412                        }
413                    });
414
415                    Ok(open_ai::extract_text_from_events(stream))
416                });
417                async move { Ok(future.await?.boxed()) }.boxed()
418            }
419            CloudModel::Google(model) => {
420                let client = self.client.clone();
421                let request = request.into_google(model.id().into());
422                let llm_api_token = self.llm_api_token.clone();
423                let future = self.request_limiter.stream(async move {
424                    let response = Self::perform_llm_completion(
425                        client.clone(),
426                        llm_api_token,
427                        PerformCompletionParams {
428                            provider: client::LanguageModelProvider::Google,
429                            model: request.model.clone(),
430                            provider_request: RawValue::from_string(serde_json::to_string(
431                                &request,
432                            )?)?,
433                        },
434                    )
435                    .await?;
436                    let body = BufReader::new(response.into_body());
437                    let stream = futures::stream::try_unfold(body, move |mut body| async move {
438                        let mut buffer = String::new();
439                        match body.read_line(&mut buffer).await {
440                            Ok(0) => Ok(None),
441                            Ok(_) => {
442                                let event: google_ai::GenerateContentResponse =
443                                    serde_json::from_str(&buffer)?;
444                                Ok(Some((event, body)))
445                            }
446                            Err(e) => Err(e.into()),
447                        }
448                    });
449
450                    Ok(google_ai::extract_text_from_events(stream))
451                });
452                async move { Ok(future.await?.boxed()) }.boxed()
453            }
454            CloudModel::Zed(model) => {
455                let client = self.client.clone();
456                let mut request = request.into_open_ai(model.id().into());
457                request.max_tokens = Some(4000);
458                let llm_api_token = self.llm_api_token.clone();
459                let future = self.request_limiter.stream(async move {
460                    let response = Self::perform_llm_completion(
461                        client.clone(),
462                        llm_api_token,
463                        PerformCompletionParams {
464                            provider: client::LanguageModelProvider::Zed,
465                            model: request.model.clone(),
466                            provider_request: RawValue::from_string(serde_json::to_string(
467                                &request,
468                            )?)?,
469                        },
470                    )
471                    .await?;
472                    let body = BufReader::new(response.into_body());
473                    let stream = futures::stream::try_unfold(body, move |mut body| async move {
474                        let mut buffer = String::new();
475                        match body.read_line(&mut buffer).await {
476                            Ok(0) => Ok(None),
477                            Ok(_) => {
478                                let event: open_ai::ResponseStreamEvent =
479                                    serde_json::from_str(&buffer)?;
480                                Ok(Some((event, body)))
481                            }
482                            Err(e) => Err(e.into()),
483                        }
484                    });
485
486                    Ok(open_ai::extract_text_from_events(stream))
487                });
488                async move { Ok(future.await?.boxed()) }.boxed()
489            }
490        }
491    }
492
493    fn use_any_tool(
494        &self,
495        request: LanguageModelRequest,
496        tool_name: String,
497        tool_description: String,
498        input_schema: serde_json::Value,
499        _cx: &AsyncAppContext,
500    ) -> BoxFuture<'static, Result<serde_json::Value>> {
501        match &self.model {
502            CloudModel::Anthropic(model) => {
503                let client = self.client.clone();
504                let mut request = request.into_anthropic(model.tool_model_id().into());
505                request.tool_choice = Some(anthropic::ToolChoice::Tool {
506                    name: tool_name.clone(),
507                });
508                request.tools = vec![anthropic::Tool {
509                    name: tool_name.clone(),
510                    description: tool_description,
511                    input_schema,
512                }];
513
514                let llm_api_token = self.llm_api_token.clone();
515                self.request_limiter
516                    .run(async move {
517                        let response = Self::perform_llm_completion(
518                            client.clone(),
519                            llm_api_token,
520                            PerformCompletionParams {
521                                provider: client::LanguageModelProvider::Anthropic,
522                                model: request.model.clone(),
523                                provider_request: RawValue::from_string(serde_json::to_string(
524                                    &request,
525                                )?)?,
526                            },
527                        )
528                        .await?;
529
530                        let mut tool_use_index = None;
531                        let mut tool_input = String::new();
532                        let mut body = BufReader::new(response.into_body());
533                        let mut line = String::new();
534                        while body.read_line(&mut line).await? > 0 {
535                            let event: anthropic::Event = serde_json::from_str(&line)?;
536                            line.clear();
537
538                            match event {
539                                anthropic::Event::ContentBlockStart {
540                                    content_block,
541                                    index,
542                                } => {
543                                    if let anthropic::Content::ToolUse { name, .. } = content_block
544                                    {
545                                        if name == tool_name {
546                                            tool_use_index = Some(index);
547                                        }
548                                    }
549                                }
550                                anthropic::Event::ContentBlockDelta { index, delta } => match delta
551                                {
552                                    anthropic::ContentDelta::TextDelta { .. } => {}
553                                    anthropic::ContentDelta::InputJsonDelta { partial_json } => {
554                                        if Some(index) == tool_use_index {
555                                            tool_input.push_str(&partial_json);
556                                        }
557                                    }
558                                },
559                                anthropic::Event::ContentBlockStop { index } => {
560                                    if Some(index) == tool_use_index {
561                                        return Ok(serde_json::from_str(&tool_input)?);
562                                    }
563                                }
564                                _ => {}
565                            }
566                        }
567
568                        if tool_use_index.is_some() {
569                            Err(anyhow!("tool content incomplete"))
570                        } else {
571                            Err(anyhow!("tool not used"))
572                        }
573                    })
574                    .boxed()
575            }
576            CloudModel::OpenAi(model) => {
577                let mut request = request.into_open_ai(model.id().into());
578                let client = self.client.clone();
579                let mut function = open_ai::FunctionDefinition {
580                    name: tool_name.clone(),
581                    description: None,
582                    parameters: None,
583                };
584                let func = open_ai::ToolDefinition::Function {
585                    function: function.clone(),
586                };
587                request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
588                // Fill in description and params separately, as they're not needed for tool_choice field.
589                function.description = Some(tool_description);
590                function.parameters = Some(input_schema);
591                request.tools = vec![open_ai::ToolDefinition::Function { function }];
592
593                let llm_api_token = self.llm_api_token.clone();
594                self.request_limiter
595                    .run(async move {
596                        let response = Self::perform_llm_completion(
597                            client.clone(),
598                            llm_api_token,
599                            PerformCompletionParams {
600                                provider: client::LanguageModelProvider::OpenAi,
601                                model: request.model.clone(),
602                                provider_request: RawValue::from_string(serde_json::to_string(
603                                    &request,
604                                )?)?,
605                            },
606                        )
607                        .await?;
608
609                        let mut body = BufReader::new(response.into_body());
610                        let mut line = String::new();
611                        let mut load_state = None;
612
613                        while body.read_line(&mut line).await? > 0 {
614                            let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
615                            line.clear();
616
617                            for choice in part.choices {
618                                let Some(tool_calls) = choice.delta.tool_calls else {
619                                    continue;
620                                };
621
622                                for call in tool_calls {
623                                    if let Some(func) = call.function {
624                                        if func.name.as_deref() == Some(tool_name.as_str()) {
625                                            load_state = Some((String::default(), call.index));
626                                        }
627                                        if let Some((arguments, (output, index))) =
628                                            func.arguments.zip(load_state.as_mut())
629                                        {
630                                            if call.index == *index {
631                                                output.push_str(&arguments);
632                                            }
633                                        }
634                                    }
635                                }
636                            }
637                        }
638
639                        if let Some((arguments, _)) = load_state {
640                            return Ok(serde_json::from_str(&arguments)?);
641                        } else {
642                            bail!("tool not used");
643                        }
644                    })
645                    .boxed()
646            }
647            CloudModel::Google(_) => {
648                future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
649            }
650            CloudModel::Zed(model) => {
651                // All Zed models are OpenAI-based at the time of writing.
652                let mut request = request.into_open_ai(model.id().into());
653                let client = self.client.clone();
654                let mut function = open_ai::FunctionDefinition {
655                    name: tool_name.clone(),
656                    description: None,
657                    parameters: None,
658                };
659                let func = open_ai::ToolDefinition::Function {
660                    function: function.clone(),
661                };
662                request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
663                // Fill in description and params separately, as they're not needed for tool_choice field.
664                function.description = Some(tool_description);
665                function.parameters = Some(input_schema);
666                request.tools = vec![open_ai::ToolDefinition::Function { function }];
667
668                let llm_api_token = self.llm_api_token.clone();
669                self.request_limiter
670                    .run(async move {
671                        let response = Self::perform_llm_completion(
672                            client.clone(),
673                            llm_api_token,
674                            PerformCompletionParams {
675                                provider: client::LanguageModelProvider::Zed,
676                                model: request.model.clone(),
677                                provider_request: RawValue::from_string(serde_json::to_string(
678                                    &request,
679                                )?)?,
680                            },
681                        )
682                        .await?;
683
684                        let mut body = BufReader::new(response.into_body());
685                        let mut line = String::new();
686                        let mut load_state = None;
687
688                        while body.read_line(&mut line).await? > 0 {
689                            let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
690                            line.clear();
691
692                            for choice in part.choices {
693                                let Some(tool_calls) = choice.delta.tool_calls else {
694                                    continue;
695                                };
696
697                                for call in tool_calls {
698                                    if let Some(func) = call.function {
699                                        if func.name.as_deref() == Some(tool_name.as_str()) {
700                                            load_state = Some((String::default(), call.index));
701                                        }
702                                        if let Some((arguments, (output, index))) =
703                                            func.arguments.zip(load_state.as_mut())
704                                        {
705                                            if call.index == *index {
706                                                output.push_str(&arguments);
707                                            }
708                                        }
709                                    }
710                                }
711                            }
712                        }
713                        if let Some((arguments, _)) = load_state {
714                            return Ok(serde_json::from_str(&arguments)?);
715                        } else {
716                            bail!("tool not used");
717                        }
718                    })
719                    .boxed()
720            }
721        }
722    }
723}
724
725impl LlmApiToken {
726    async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
727        let lock = self.0.upgradable_read().await;
728        if let Some(token) = lock.as_ref() {
729            Ok(token.to_string())
730        } else {
731            Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, &client).await
732        }
733    }
734
735    async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
736        Self::fetch(self.0.write().await, &client).await
737    }
738
739    async fn fetch<'a>(
740        mut lock: RwLockWriteGuard<'a, Option<String>>,
741        client: &Arc<Client>,
742    ) -> Result<String> {
743        let response = client.request(proto::GetLlmToken {}).await?;
744        *lock = Some(response.token.clone());
745        Ok(response.token.clone())
746    }
747}
748
749struct ConfigurationView {
750    state: gpui::Model<State>,
751}
752
753impl ConfigurationView {
754    fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
755        self.state.update(cx, |state, cx| {
756            state.authenticate(cx).detach_and_log_err(cx);
757        });
758        cx.notify();
759    }
760}
761
762impl Render for ConfigurationView {
763    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
764        const ZED_AI_URL: &str = "https://zed.dev/ai";
765        const ACCOUNT_SETTINGS_URL: &str = "https://zed.dev/account";
766
767        let is_connected = !self.state.read(cx).is_signed_out();
768        let plan = self.state.read(cx).user_store.read(cx).current_plan();
769
770        let is_pro = plan == Some(proto::Plan::ZedPro);
771
772        if is_connected {
773            v_flex()
774                .gap_3()
775                .max_w_4_5()
776                .child(Label::new(
777                    if is_pro {
778                        "You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."
779                    } else {
780                        "You have basic access to models from Anthropic, OpenAI, Google and more through the Zed AI Free plan."
781                    }))
782                .child(
783                    if is_pro {
784                        h_flex().child(
785                        Button::new("manage_settings", "Manage Subscription")
786                            .style(ButtonStyle::Filled)
787                            .on_click(cx.listener(|_, _, cx| {
788                                cx.open_url(ACCOUNT_SETTINGS_URL)
789                            })))
790                    } else {
791                        h_flex()
792                            .gap_2()
793                            .child(
794                        Button::new("learn_more", "Learn more")
795                            .style(ButtonStyle::Subtle)
796                            .on_click(cx.listener(|_, _, cx| {
797                                cx.open_url(ZED_AI_URL)
798                            })))
799                            .child(
800                        Button::new("upgrade", "Upgrade")
801                            .style(ButtonStyle::Subtle)
802                            .color(Color::Accent)
803                            .on_click(cx.listener(|_, _, cx| {
804                                cx.open_url(ACCOUNT_SETTINGS_URL)
805                            })))
806                    },
807                )
808        } else {
809            v_flex()
810                .gap_6()
811                .child(Label::new("Use the zed.dev to access language models."))
812                .child(
813                    v_flex()
814                        .gap_2()
815                        .child(
816                            Button::new("sign_in", "Sign in")
817                                .icon_color(Color::Muted)
818                                .icon(IconName::Github)
819                                .icon_position(IconPosition::Start)
820                                .style(ButtonStyle::Filled)
821                                .full_width()
822                                .on_click(cx.listener(move |this, _, cx| this.authenticate(cx))),
823                        )
824                        .child(
825                            div().flex().w_full().items_center().child(
826                                Label::new("Sign in to enable collaboration.")
827                                    .color(Color::Muted)
828                                    .size(LabelSize::Small),
829                            ),
830                        ),
831                )
832        }
833    }
834}