cloud.rs

  1use super::open_ai::count_open_ai_tokens;
  2use crate::{
  3    settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
  4    LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
  5    LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
  6};
  7use anyhow::{anyhow, Context as _, Result};
  8use client::{Client, UserStore};
  9use collections::BTreeMap;
 10use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
 11use gpui::{
 12    AnyView, AppContext, AsyncAppContext, FocusHandle, Model, ModelContext, Subscription, Task,
 13};
 14use schemars::JsonSchema;
 15use serde::{Deserialize, Serialize};
 16use settings::{Settings, SettingsStore};
 17use std::{future, sync::Arc};
 18use strum::IntoEnumIterator;
 19use ui::prelude::*;
 20
 21use crate::LanguageModelProvider;
 22
 23use super::anthropic::count_anthropic_tokens;
 24
 25pub const PROVIDER_ID: &str = "zed.dev";
 26pub const PROVIDER_NAME: &str = "Zed AI";
 27
 28#[derive(Default, Clone, Debug, PartialEq)]
 29pub struct ZedDotDevSettings {
 30    pub available_models: Vec<AvailableModel>,
 31}
 32
 33#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 34#[serde(rename_all = "lowercase")]
 35pub enum AvailableProvider {
 36    Anthropic,
 37    OpenAi,
 38    Google,
 39}
 40
 41#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 42pub struct AvailableModel {
 43    provider: AvailableProvider,
 44    name: String,
 45    max_tokens: usize,
 46    tool_override: Option<String>,
 47}
 48
 49pub struct CloudLanguageModelProvider {
 50    client: Arc<Client>,
 51    state: gpui::Model<State>,
 52    _maintain_client_status: Task<()>,
 53}
 54
 55pub struct State {
 56    client: Arc<Client>,
 57    user_store: Model<UserStore>,
 58    status: client::Status,
 59    _subscription: Subscription,
 60}
 61
 62impl State {
 63    fn is_connected(&self) -> bool {
 64        self.status.is_connected()
 65    }
 66
 67    fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 68        let client = self.client.clone();
 69        cx.spawn(move |this, mut cx| async move {
 70            client.authenticate_and_connect(true, &cx).await?;
 71            this.update(&mut cx, |_, cx| cx.notify())
 72        })
 73    }
 74}
 75
 76impl CloudLanguageModelProvider {
 77    pub fn new(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) -> Self {
 78        let mut status_rx = client.status();
 79        let status = *status_rx.borrow();
 80
 81        let state = cx.new_model(|cx| State {
 82            client: client.clone(),
 83            user_store,
 84            status,
 85            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 86                cx.notify();
 87            }),
 88        });
 89
 90        let state_ref = state.downgrade();
 91        let maintain_client_status = cx.spawn(|mut cx| async move {
 92            while let Some(status) = status_rx.next().await {
 93                if let Some(this) = state_ref.upgrade() {
 94                    _ = this.update(&mut cx, |this, cx| {
 95                        if this.status != status {
 96                            this.status = status;
 97                            cx.notify();
 98                        }
 99                    });
100                } else {
101                    break;
102                }
103            }
104        });
105
106        Self {
107            client,
108            state,
109            _maintain_client_status: maintain_client_status,
110        }
111    }
112}
113
114impl LanguageModelProviderState for CloudLanguageModelProvider {
115    type ObservableEntity = State;
116
117    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
118        Some(self.state.clone())
119    }
120}
121
122impl LanguageModelProvider for CloudLanguageModelProvider {
123    fn id(&self) -> LanguageModelProviderId {
124        LanguageModelProviderId(PROVIDER_ID.into())
125    }
126
127    fn name(&self) -> LanguageModelProviderName {
128        LanguageModelProviderName(PROVIDER_NAME.into())
129    }
130
131    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
132        let mut models = BTreeMap::default();
133
134        for model in anthropic::Model::iter() {
135            if !matches!(model, anthropic::Model::Custom { .. }) {
136                models.insert(model.id().to_string(), CloudModel::Anthropic(model));
137            }
138        }
139        for model in open_ai::Model::iter() {
140            if !matches!(model, open_ai::Model::Custom { .. }) {
141                models.insert(model.id().to_string(), CloudModel::OpenAi(model));
142            }
143        }
144        for model in google_ai::Model::iter() {
145            if !matches!(model, google_ai::Model::Custom { .. }) {
146                models.insert(model.id().to_string(), CloudModel::Google(model));
147            }
148        }
149        for model in ZedModel::iter() {
150            models.insert(model.id().to_string(), CloudModel::Zed(model));
151        }
152
153        // Override with available models from settings
154        for model in &AllLanguageModelSettings::get_global(cx)
155            .zed_dot_dev
156            .available_models
157        {
158            let model = match model.provider {
159                AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
160                    name: model.name.clone(),
161                    max_tokens: model.max_tokens,
162                    tool_override: model.tool_override.clone(),
163                }),
164                AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
165                    name: model.name.clone(),
166                    max_tokens: model.max_tokens,
167                }),
168                AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
169                    name: model.name.clone(),
170                    max_tokens: model.max_tokens,
171                }),
172            };
173            models.insert(model.id().to_string(), model.clone());
174        }
175
176        models
177            .into_values()
178            .map(|model| {
179                Arc::new(CloudLanguageModel {
180                    id: LanguageModelId::from(model.id().to_string()),
181                    model,
182                    client: self.client.clone(),
183                    request_limiter: RateLimiter::new(4),
184                }) as Arc<dyn LanguageModel>
185            })
186            .collect()
187    }
188
189    fn is_authenticated(&self, cx: &AppContext) -> bool {
190        self.state.read(cx).status.is_connected()
191    }
192
193    fn authenticate(&self, _cx: &mut AppContext) -> Task<Result<()>> {
194        Task::ready(Ok(()))
195    }
196
197    fn configuration_view(&self, cx: &mut WindowContext) -> (AnyView, Option<FocusHandle>) {
198        let view = cx
199            .new_view(|_cx| ConfigurationView {
200                state: self.state.clone(),
201            })
202            .into();
203        (view, None)
204    }
205
206    fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
207        Task::ready(Ok(()))
208    }
209}
210
211pub struct CloudLanguageModel {
212    id: LanguageModelId,
213    model: CloudModel,
214    client: Arc<Client>,
215    request_limiter: RateLimiter,
216}
217
218impl LanguageModel for CloudLanguageModel {
219    fn id(&self) -> LanguageModelId {
220        self.id.clone()
221    }
222
223    fn name(&self) -> LanguageModelName {
224        LanguageModelName::from(self.model.display_name().to_string())
225    }
226
227    fn provider_id(&self) -> LanguageModelProviderId {
228        LanguageModelProviderId(PROVIDER_ID.into())
229    }
230
231    fn provider_name(&self) -> LanguageModelProviderName {
232        LanguageModelProviderName(PROVIDER_NAME.into())
233    }
234
235    fn telemetry_id(&self) -> String {
236        format!("zed.dev/{}", self.model.id())
237    }
238
239    fn max_token_count(&self) -> usize {
240        self.model.max_token_count()
241    }
242
243    fn count_tokens(
244        &self,
245        request: LanguageModelRequest,
246        cx: &AppContext,
247    ) -> BoxFuture<'static, Result<usize>> {
248        match self.model.clone() {
249            CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
250            CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
251            CloudModel::Google(model) => {
252                let client = self.client.clone();
253                let request = request.into_google(model.id().into());
254                let request = google_ai::CountTokensRequest {
255                    contents: request.contents,
256                };
257                async move {
258                    let request = serde_json::to_string(&request)?;
259                    let response = client
260                        .request(proto::CountLanguageModelTokens {
261                            provider: proto::LanguageModelProvider::Google as i32,
262                            request,
263                        })
264                        .await?;
265                    Ok(response.token_count as usize)
266                }
267                .boxed()
268            }
269            CloudModel::Zed(_) => {
270                count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
271            }
272        }
273    }
274
275    fn stream_completion(
276        &self,
277        request: LanguageModelRequest,
278        _: &AsyncAppContext,
279    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
280        match &self.model {
281            CloudModel::Anthropic(model) => {
282                let client = self.client.clone();
283                let request = request.into_anthropic(model.id().into());
284                let future = self.request_limiter.stream(async move {
285                    let request = serde_json::to_string(&request)?;
286                    let stream = client
287                        .request_stream(proto::StreamCompleteWithLanguageModel {
288                            provider: proto::LanguageModelProvider::Anthropic as i32,
289                            request,
290                        })
291                        .await?;
292                    Ok(anthropic::extract_text_from_events(
293                        stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
294                    ))
295                });
296                async move { Ok(future.await?.boxed()) }.boxed()
297            }
298            CloudModel::OpenAi(model) => {
299                let client = self.client.clone();
300                let request = request.into_open_ai(model.id().into());
301                let future = self.request_limiter.stream(async move {
302                    let request = serde_json::to_string(&request)?;
303                    let stream = client
304                        .request_stream(proto::StreamCompleteWithLanguageModel {
305                            provider: proto::LanguageModelProvider::OpenAi as i32,
306                            request,
307                        })
308                        .await?;
309                    Ok(open_ai::extract_text_from_events(
310                        stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
311                    ))
312                });
313                async move { Ok(future.await?.boxed()) }.boxed()
314            }
315            CloudModel::Google(model) => {
316                let client = self.client.clone();
317                let request = request.into_google(model.id().into());
318                let future = self.request_limiter.stream(async move {
319                    let request = serde_json::to_string(&request)?;
320                    let stream = client
321                        .request_stream(proto::StreamCompleteWithLanguageModel {
322                            provider: proto::LanguageModelProvider::Google as i32,
323                            request,
324                        })
325                        .await?;
326                    Ok(google_ai::extract_text_from_events(
327                        stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
328                    ))
329                });
330                async move { Ok(future.await?.boxed()) }.boxed()
331            }
332            CloudModel::Zed(model) => {
333                let client = self.client.clone();
334                let mut request = request.into_open_ai(model.id().into());
335                request.max_tokens = Some(4000);
336                let future = self.request_limiter.stream(async move {
337                    let request = serde_json::to_string(&request)?;
338                    let stream = client
339                        .request_stream(proto::StreamCompleteWithLanguageModel {
340                            provider: proto::LanguageModelProvider::Zed as i32,
341                            request,
342                        })
343                        .await?;
344                    Ok(open_ai::extract_text_from_events(
345                        stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
346                    ))
347                });
348                async move { Ok(future.await?.boxed()) }.boxed()
349            }
350        }
351    }
352
353    fn use_any_tool(
354        &self,
355        request: LanguageModelRequest,
356        tool_name: String,
357        tool_description: String,
358        input_schema: serde_json::Value,
359        _cx: &AsyncAppContext,
360    ) -> BoxFuture<'static, Result<serde_json::Value>> {
361        match &self.model {
362            CloudModel::Anthropic(model) => {
363                let client = self.client.clone();
364                let mut request = request.into_anthropic(model.tool_model_id().into());
365                request.tool_choice = Some(anthropic::ToolChoice::Tool {
366                    name: tool_name.clone(),
367                });
368                request.tools = vec![anthropic::Tool {
369                    name: tool_name.clone(),
370                    description: tool_description,
371                    input_schema,
372                }];
373
374                self.request_limiter
375                    .run(async move {
376                        let request = serde_json::to_string(&request)?;
377                        let response = client
378                            .request(proto::CompleteWithLanguageModel {
379                                provider: proto::LanguageModelProvider::Anthropic as i32,
380                                request,
381                            })
382                            .await?;
383                        let response: anthropic::Response =
384                            serde_json::from_str(&response.completion)?;
385                        response
386                            .content
387                            .into_iter()
388                            .find_map(|content| {
389                                if let anthropic::Content::ToolUse { name, input, .. } = content {
390                                    if name == tool_name {
391                                        Some(input)
392                                    } else {
393                                        None
394                                    }
395                                } else {
396                                    None
397                                }
398                            })
399                            .context("tool not used")
400                    })
401                    .boxed()
402            }
403            CloudModel::OpenAi(_) => {
404                future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed()
405            }
406            CloudModel::Google(_) => {
407                future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
408            }
409            CloudModel::Zed(_) => {
410                future::ready(Err(anyhow!("tool use not implemented for Zed models"))).boxed()
411            }
412        }
413    }
414}
415
416struct ConfigurationView {
417    state: gpui::Model<State>,
418}
419
420impl ConfigurationView {
421    fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
422        self.state.update(cx, |state, cx| {
423            state.authenticate(cx).detach_and_log_err(cx);
424        });
425        cx.notify();
426    }
427}
428
429impl Render for ConfigurationView {
430    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
431        const ZED_AI_URL: &str = "https://zed.dev/ai";
432        const ACCOUNT_SETTINGS_URL: &str = "https://zed.dev/account";
433
434        let is_connected = self.state.read(cx).is_connected();
435        let plan = self.state.read(cx).user_store.read(cx).current_plan();
436
437        let is_pro = plan == Some(proto::Plan::ZedPro);
438
439        if is_connected {
440            v_flex()
441                .gap_3()
442                .max_w_4_5()
443                .child(Label::new(
444                    if is_pro {
445                        "You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."
446                    } else {
447                        "You have basic access to models from Anthropic, OpenAI, Google and more through the Zed AI Free plan."
448                    }))
449                .child(
450                    if is_pro {
451                        h_flex().child(
452                        Button::new("manage_settings", "Manage Subscription")
453                            .style(ButtonStyle::Filled)
454                            .on_click(cx.listener(|_, _, cx| {
455                                cx.open_url(ACCOUNT_SETTINGS_URL)
456                            })))
457                    } else {
458                        h_flex()
459                            .gap_2()
460                            .child(
461                        Button::new("learn_more", "Learn more")
462                            .style(ButtonStyle::Subtle)
463                            .on_click(cx.listener(|_, _, cx| {
464                                cx.open_url(ZED_AI_URL)
465                            })))
466                            .child(
467                        Button::new("upgrade", "Upgrade")
468                            .style(ButtonStyle::Subtle)
469                            .color(Color::Accent)
470                            .on_click(cx.listener(|_, _, cx| {
471                                cx.open_url(ACCOUNT_SETTINGS_URL)
472                            })))
473                    },
474                )
475        } else {
476            v_flex()
477                .gap_6()
478                .child(Label::new("Use the zed.dev to access language models."))
479                .child(
480                    v_flex()
481                        .gap_2()
482                        .child(
483                            Button::new("sign_in", "Sign in")
484                                .icon_color(Color::Muted)
485                                .icon(IconName::Github)
486                                .icon_position(IconPosition::Start)
487                                .style(ButtonStyle::Filled)
488                                .full_width()
489                                .on_click(cx.listener(move |this, _, cx| this.authenticate(cx))),
490                        )
491                        .child(
492                            div().flex().w_full().items_center().child(
493                                Label::new("Sign in to enable collaboration.")
494                                    .color(Color::Muted)
495                                    .size(LabelSize::Small),
496                            ),
497                        ),
498                )
499        }
500    }
501}