cloud.rs

  1use super::open_ai::count_open_ai_tokens;
  2use crate::{
  3    settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
  4    LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
  5    LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
  6};
  7use anyhow::{anyhow, Context as _, Result};
  8use client::{Client, UserStore};
  9use collections::BTreeMap;
 10use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
 11use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
 12use schemars::JsonSchema;
 13use serde::{Deserialize, Serialize};
 14use settings::{Settings, SettingsStore};
 15use std::{future, sync::Arc};
 16use strum::IntoEnumIterator;
 17use ui::prelude::*;
 18
 19use crate::{LanguageModelAvailability, LanguageModelProvider};
 20
 21use super::anthropic::count_anthropic_tokens;
 22
 23pub const PROVIDER_ID: &str = "zed.dev";
 24pub const PROVIDER_NAME: &str = "Zed";
 25
 26#[derive(Default, Clone, Debug, PartialEq)]
 27pub struct ZedDotDevSettings {
 28    pub available_models: Vec<AvailableModel>,
 29}
 30
 31#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 32#[serde(rename_all = "lowercase")]
 33pub enum AvailableProvider {
 34    Anthropic,
 35    OpenAi,
 36    Google,
 37}
 38
 39#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 40pub struct AvailableModel {
 41    provider: AvailableProvider,
 42    name: String,
 43    max_tokens: usize,
 44    tool_override: Option<String>,
 45}
 46
 47pub struct CloudLanguageModelProvider {
 48    client: Arc<Client>,
 49    state: gpui::Model<State>,
 50    _maintain_client_status: Task<()>,
 51}
 52
 53pub struct State {
 54    client: Arc<Client>,
 55    user_store: Model<UserStore>,
 56    status: client::Status,
 57    _subscription: Subscription,
 58}
 59
 60impl State {
 61    fn is_signed_out(&self) -> bool {
 62        self.status.is_signed_out()
 63    }
 64
 65    fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 66        let client = self.client.clone();
 67        cx.spawn(move |this, mut cx| async move {
 68            client.authenticate_and_connect(true, &cx).await?;
 69            this.update(&mut cx, |_, cx| cx.notify())
 70        })
 71    }
 72}
 73
 74impl CloudLanguageModelProvider {
 75    pub fn new(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) -> Self {
 76        let mut status_rx = client.status();
 77        let status = *status_rx.borrow();
 78
 79        let state = cx.new_model(|cx| State {
 80            client: client.clone(),
 81            user_store,
 82            status,
 83            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 84                cx.notify();
 85            }),
 86        });
 87
 88        let state_ref = state.downgrade();
 89        let maintain_client_status = cx.spawn(|mut cx| async move {
 90            while let Some(status) = status_rx.next().await {
 91                if let Some(this) = state_ref.upgrade() {
 92                    _ = this.update(&mut cx, |this, cx| {
 93                        if this.status != status {
 94                            this.status = status;
 95                            cx.notify();
 96                        }
 97                    });
 98                } else {
 99                    break;
100                }
101            }
102        });
103
104        Self {
105            client,
106            state,
107            _maintain_client_status: maintain_client_status,
108        }
109    }
110}
111
112impl LanguageModelProviderState for CloudLanguageModelProvider {
113    type ObservableEntity = State;
114
115    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
116        Some(self.state.clone())
117    }
118}
119
120impl LanguageModelProvider for CloudLanguageModelProvider {
121    fn id(&self) -> LanguageModelProviderId {
122        LanguageModelProviderId(PROVIDER_ID.into())
123    }
124
125    fn name(&self) -> LanguageModelProviderName {
126        LanguageModelProviderName(PROVIDER_NAME.into())
127    }
128
129    fn icon(&self) -> IconName {
130        IconName::AiZed
131    }
132
133    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
134        let mut models = BTreeMap::default();
135
136        for model in anthropic::Model::iter() {
137            if !matches!(model, anthropic::Model::Custom { .. }) {
138                models.insert(model.id().to_string(), CloudModel::Anthropic(model));
139            }
140        }
141        for model in open_ai::Model::iter() {
142            if !matches!(model, open_ai::Model::Custom { .. }) {
143                models.insert(model.id().to_string(), CloudModel::OpenAi(model));
144            }
145        }
146        for model in google_ai::Model::iter() {
147            if !matches!(model, google_ai::Model::Custom { .. }) {
148                models.insert(model.id().to_string(), CloudModel::Google(model));
149            }
150        }
151        for model in ZedModel::iter() {
152            models.insert(model.id().to_string(), CloudModel::Zed(model));
153        }
154
155        // Override with available models from settings
156        for model in &AllLanguageModelSettings::get_global(cx)
157            .zed_dot_dev
158            .available_models
159        {
160            let model = match model.provider {
161                AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
162                    name: model.name.clone(),
163                    max_tokens: model.max_tokens,
164                    tool_override: model.tool_override.clone(),
165                }),
166                AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
167                    name: model.name.clone(),
168                    max_tokens: model.max_tokens,
169                }),
170                AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
171                    name: model.name.clone(),
172                    max_tokens: model.max_tokens,
173                }),
174            };
175            models.insert(model.id().to_string(), model.clone());
176        }
177
178        models
179            .into_values()
180            .map(|model| {
181                Arc::new(CloudLanguageModel {
182                    id: LanguageModelId::from(model.id().to_string()),
183                    model,
184                    client: self.client.clone(),
185                    request_limiter: RateLimiter::new(4),
186                }) as Arc<dyn LanguageModel>
187            })
188            .collect()
189    }
190
191    fn is_authenticated(&self, cx: &AppContext) -> bool {
192        !self.state.read(cx).is_signed_out()
193    }
194
195    fn authenticate(&self, _cx: &mut AppContext) -> Task<Result<()>> {
196        Task::ready(Ok(()))
197    }
198
199    fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
200        cx.new_view(|_cx| ConfigurationView {
201            state: self.state.clone(),
202        })
203        .into()
204    }
205
206    fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
207        Task::ready(Ok(()))
208    }
209}
210
211pub struct CloudLanguageModel {
212    id: LanguageModelId,
213    model: CloudModel,
214    client: Arc<Client>,
215    request_limiter: RateLimiter,
216}
217
218impl LanguageModel for CloudLanguageModel {
219    fn id(&self) -> LanguageModelId {
220        self.id.clone()
221    }
222
223    fn name(&self) -> LanguageModelName {
224        LanguageModelName::from(self.model.display_name().to_string())
225    }
226
227    fn provider_id(&self) -> LanguageModelProviderId {
228        LanguageModelProviderId(PROVIDER_ID.into())
229    }
230
231    fn provider_name(&self) -> LanguageModelProviderName {
232        LanguageModelProviderName(PROVIDER_NAME.into())
233    }
234
235    fn telemetry_id(&self) -> String {
236        format!("zed.dev/{}", self.model.id())
237    }
238
239    fn availability(&self) -> LanguageModelAvailability {
240        self.model.availability()
241    }
242
243    fn max_token_count(&self) -> usize {
244        self.model.max_token_count()
245    }
246
247    fn count_tokens(
248        &self,
249        request: LanguageModelRequest,
250        cx: &AppContext,
251    ) -> BoxFuture<'static, Result<usize>> {
252        match self.model.clone() {
253            CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
254            CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
255            CloudModel::Google(model) => {
256                let client = self.client.clone();
257                let request = request.into_google(model.id().into());
258                let request = google_ai::CountTokensRequest {
259                    contents: request.contents,
260                };
261                async move {
262                    let request = serde_json::to_string(&request)?;
263                    let response = client
264                        .request(proto::CountLanguageModelTokens {
265                            provider: proto::LanguageModelProvider::Google as i32,
266                            request,
267                        })
268                        .await?;
269                    Ok(response.token_count as usize)
270                }
271                .boxed()
272            }
273            CloudModel::Zed(_) => {
274                count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
275            }
276        }
277    }
278
279    fn stream_completion(
280        &self,
281        request: LanguageModelRequest,
282        _: &AsyncAppContext,
283    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
284        match &self.model {
285            CloudModel::Anthropic(model) => {
286                let client = self.client.clone();
287                let request = request.into_anthropic(model.id().into());
288                let future = self.request_limiter.stream(async move {
289                    let request = serde_json::to_string(&request)?;
290                    let stream = client
291                        .request_stream(proto::StreamCompleteWithLanguageModel {
292                            provider: proto::LanguageModelProvider::Anthropic as i32,
293                            request,
294                        })
295                        .await?;
296                    Ok(anthropic::extract_text_from_events(
297                        stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
298                    ))
299                });
300                async move { Ok(future.await?.boxed()) }.boxed()
301            }
302            CloudModel::OpenAi(model) => {
303                let client = self.client.clone();
304                let request = request.into_open_ai(model.id().into());
305                let future = self.request_limiter.stream(async move {
306                    let request = serde_json::to_string(&request)?;
307                    let stream = client
308                        .request_stream(proto::StreamCompleteWithLanguageModel {
309                            provider: proto::LanguageModelProvider::OpenAi as i32,
310                            request,
311                        })
312                        .await?;
313                    Ok(open_ai::extract_text_from_events(
314                        stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
315                    ))
316                });
317                async move { Ok(future.await?.boxed()) }.boxed()
318            }
319            CloudModel::Google(model) => {
320                let client = self.client.clone();
321                let request = request.into_google(model.id().into());
322                let future = self.request_limiter.stream(async move {
323                    let request = serde_json::to_string(&request)?;
324                    let stream = client
325                        .request_stream(proto::StreamCompleteWithLanguageModel {
326                            provider: proto::LanguageModelProvider::Google as i32,
327                            request,
328                        })
329                        .await?;
330                    Ok(google_ai::extract_text_from_events(
331                        stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
332                    ))
333                });
334                async move { Ok(future.await?.boxed()) }.boxed()
335            }
336            CloudModel::Zed(model) => {
337                let client = self.client.clone();
338                let mut request = request.into_open_ai(model.id().into());
339                request.max_tokens = Some(4000);
340                let future = self.request_limiter.stream(async move {
341                    let request = serde_json::to_string(&request)?;
342                    let stream = client
343                        .request_stream(proto::StreamCompleteWithLanguageModel {
344                            provider: proto::LanguageModelProvider::Zed as i32,
345                            request,
346                        })
347                        .await?;
348                    Ok(open_ai::extract_text_from_events(
349                        stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
350                    ))
351                });
352                async move { Ok(future.await?.boxed()) }.boxed()
353            }
354        }
355    }
356
357    fn use_any_tool(
358        &self,
359        request: LanguageModelRequest,
360        tool_name: String,
361        tool_description: String,
362        input_schema: serde_json::Value,
363        _cx: &AsyncAppContext,
364    ) -> BoxFuture<'static, Result<serde_json::Value>> {
365        match &self.model {
366            CloudModel::Anthropic(model) => {
367                let client = self.client.clone();
368                let mut request = request.into_anthropic(model.tool_model_id().into());
369                request.tool_choice = Some(anthropic::ToolChoice::Tool {
370                    name: tool_name.clone(),
371                });
372                request.tools = vec![anthropic::Tool {
373                    name: tool_name.clone(),
374                    description: tool_description,
375                    input_schema,
376                }];
377
378                self.request_limiter
379                    .run(async move {
380                        let request = serde_json::to_string(&request)?;
381                        let response = client
382                            .request(proto::CompleteWithLanguageModel {
383                                provider: proto::LanguageModelProvider::Anthropic as i32,
384                                request,
385                            })
386                            .await?;
387                        let response: anthropic::Response =
388                            serde_json::from_str(&response.completion)?;
389                        response
390                            .content
391                            .into_iter()
392                            .find_map(|content| {
393                                if let anthropic::Content::ToolUse { name, input, .. } = content {
394                                    if name == tool_name {
395                                        Some(input)
396                                    } else {
397                                        None
398                                    }
399                                } else {
400                                    None
401                                }
402                            })
403                            .context("tool not used")
404                    })
405                    .boxed()
406            }
407            CloudModel::OpenAi(_) => {
408                future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed()
409            }
410            CloudModel::Google(_) => {
411                future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
412            }
413            CloudModel::Zed(_) => {
414                future::ready(Err(anyhow!("tool use not implemented for Zed models"))).boxed()
415            }
416        }
417    }
418}
419
420struct ConfigurationView {
421    state: gpui::Model<State>,
422}
423
424impl ConfigurationView {
425    fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
426        self.state.update(cx, |state, cx| {
427            state.authenticate(cx).detach_and_log_err(cx);
428        });
429        cx.notify();
430    }
431}
432
433impl Render for ConfigurationView {
434    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
435        const ZED_AI_URL: &str = "https://zed.dev/ai";
436        const ACCOUNT_SETTINGS_URL: &str = "https://zed.dev/account";
437
438        let is_connected = !self.state.read(cx).is_signed_out();
439        let plan = self.state.read(cx).user_store.read(cx).current_plan();
440
441        let is_pro = plan == Some(proto::Plan::ZedPro);
442
443        if is_connected {
444            v_flex()
445                .gap_3()
446                .max_w_4_5()
447                .child(Label::new(
448                    if is_pro {
449                        "You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."
450                    } else {
451                        "You have basic access to models from Anthropic, OpenAI, Google and more through the Zed AI Free plan."
452                    }))
453                .child(
454                    if is_pro {
455                        h_flex().child(
456                        Button::new("manage_settings", "Manage Subscription")
457                            .style(ButtonStyle::Filled)
458                            .on_click(cx.listener(|_, _, cx| {
459                                cx.open_url(ACCOUNT_SETTINGS_URL)
460                            })))
461                    } else {
462                        h_flex()
463                            .gap_2()
464                            .child(
465                        Button::new("learn_more", "Learn more")
466                            .style(ButtonStyle::Subtle)
467                            .on_click(cx.listener(|_, _, cx| {
468                                cx.open_url(ZED_AI_URL)
469                            })))
470                            .child(
471                        Button::new("upgrade", "Upgrade")
472                            .style(ButtonStyle::Subtle)
473                            .color(Color::Accent)
474                            .on_click(cx.listener(|_, _, cx| {
475                                cx.open_url(ACCOUNT_SETTINGS_URL)
476                            })))
477                    },
478                )
479        } else {
480            v_flex()
481                .gap_6()
482                .child(Label::new("Use the zed.dev to access language models."))
483                .child(
484                    v_flex()
485                        .gap_2()
486                        .child(
487                            Button::new("sign_in", "Sign in")
488                                .icon_color(Color::Muted)
489                                .icon(IconName::Github)
490                                .icon_position(IconPosition::Start)
491                                .style(ButtonStyle::Filled)
492                                .full_width()
493                                .on_click(cx.listener(move |this, _, cx| this.authenticate(cx))),
494                        )
495                        .child(
496                            div().flex().w_full().items_center().child(
497                                Label::new("Sign in to enable collaboration.")
498                                    .color(Color::Muted)
499                                    .size(LabelSize::Small),
500                            ),
501                        ),
502                )
503        }
504    }
505}