cloud.rs

  1use super::open_ai::count_open_ai_tokens;
  2use crate::{
  3    settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
  4    LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
  5    LanguageModelProviderState, LanguageModelRequest,
  6};
  7use anyhow::Result;
  8use client::Client;
  9use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
 10use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
 11use settings::{Settings, SettingsStore};
 12use std::{collections::BTreeMap, sync::Arc};
 13use strum::IntoEnumIterator;
 14use ui::prelude::*;
 15
 16use crate::LanguageModelProvider;
 17
 18use super::anthropic::{count_anthropic_tokens, preprocess_anthropic_request};
 19
 20pub const PROVIDER_ID: &str = "zed.dev";
 21pub const PROVIDER_NAME: &str = "zed.dev";
 22
 23#[derive(Default, Clone, Debug, PartialEq)]
 24pub struct ZedDotDevSettings {
 25    pub available_models: Vec<CloudModel>,
 26}
 27
 28pub struct CloudLanguageModelProvider {
 29    client: Arc<Client>,
 30    state: gpui::Model<State>,
 31    _maintain_client_status: Task<()>,
 32}
 33
 34struct State {
 35    client: Arc<Client>,
 36    status: client::Status,
 37    _subscription: Subscription,
 38}
 39
 40impl State {
 41    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
 42        let client = self.client.clone();
 43        cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
 44    }
 45}
 46
 47impl CloudLanguageModelProvider {
 48    pub fn new(client: Arc<Client>, cx: &mut AppContext) -> Self {
 49        let mut status_rx = client.status();
 50        let status = *status_rx.borrow();
 51
 52        let state = cx.new_model(|cx| State {
 53            client: client.clone(),
 54            status,
 55            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 56                cx.notify();
 57            }),
 58        });
 59
 60        let state_ref = state.downgrade();
 61        let maintain_client_status = cx.spawn(|mut cx| async move {
 62            while let Some(status) = status_rx.next().await {
 63                if let Some(this) = state_ref.upgrade() {
 64                    _ = this.update(&mut cx, |this, cx| {
 65                        this.status = status;
 66                        cx.notify();
 67                    });
 68                } else {
 69                    break;
 70                }
 71            }
 72        });
 73
 74        Self {
 75            client,
 76            state,
 77            _maintain_client_status: maintain_client_status,
 78        }
 79    }
 80}
 81
 82impl LanguageModelProviderState for CloudLanguageModelProvider {
 83    fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
 84        Some(cx.observe(&self.state, |_, _, cx| {
 85            cx.notify();
 86        }))
 87    }
 88}
 89
 90impl LanguageModelProvider for CloudLanguageModelProvider {
 91    fn id(&self) -> LanguageModelProviderId {
 92        LanguageModelProviderId(PROVIDER_ID.into())
 93    }
 94
 95    fn name(&self) -> LanguageModelProviderName {
 96        LanguageModelProviderName(PROVIDER_NAME.into())
 97    }
 98
 99    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
100        let mut models = BTreeMap::default();
101
102        // Add base models from CloudModel::iter()
103        for model in CloudModel::iter() {
104            if !matches!(model, CloudModel::Custom { .. }) {
105                models.insert(model.id().to_string(), model);
106            }
107        }
108
109        // Override with available models from settings
110        for model in &AllLanguageModelSettings::get_global(cx)
111            .zed_dot_dev
112            .available_models
113        {
114            models.insert(model.id().to_string(), model.clone());
115        }
116
117        models
118            .into_values()
119            .map(|model| {
120                Arc::new(CloudLanguageModel {
121                    id: LanguageModelId::from(model.id().to_string()),
122                    model,
123                    client: self.client.clone(),
124                }) as Arc<dyn LanguageModel>
125            })
126            .collect()
127    }
128
129    fn is_authenticated(&self, cx: &AppContext) -> bool {
130        self.state.read(cx).status.is_connected()
131    }
132
133    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
134        self.state.read(cx).authenticate(cx)
135    }
136
137    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
138        cx.new_view(|_cx| AuthenticationPrompt {
139            state: self.state.clone(),
140        })
141        .into()
142    }
143
144    fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
145        Task::ready(Ok(()))
146    }
147}
148
149pub struct CloudLanguageModel {
150    id: LanguageModelId,
151    model: CloudModel,
152    client: Arc<Client>,
153}
154
155impl LanguageModel for CloudLanguageModel {
156    fn id(&self) -> LanguageModelId {
157        self.id.clone()
158    }
159
160    fn name(&self) -> LanguageModelName {
161        LanguageModelName::from(self.model.display_name().to_string())
162    }
163
164    fn provider_id(&self) -> LanguageModelProviderId {
165        LanguageModelProviderId(PROVIDER_ID.into())
166    }
167
168    fn provider_name(&self) -> LanguageModelProviderName {
169        LanguageModelProviderName(PROVIDER_NAME.into())
170    }
171
172    fn telemetry_id(&self) -> String {
173        format!("zed.dev/{}", self.model.id())
174    }
175
176    fn max_token_count(&self) -> usize {
177        self.model.max_token_count()
178    }
179
180    fn count_tokens(
181        &self,
182        request: LanguageModelRequest,
183        cx: &AppContext,
184    ) -> BoxFuture<'static, Result<usize>> {
185        match &self.model {
186            CloudModel::Gpt3Point5Turbo => {
187                count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
188            }
189            CloudModel::Gpt4 => count_open_ai_tokens(request, open_ai::Model::Four, cx),
190            CloudModel::Gpt4Turbo => count_open_ai_tokens(request, open_ai::Model::FourTurbo, cx),
191            CloudModel::Gpt4Omni => count_open_ai_tokens(request, open_ai::Model::FourOmni, cx),
192            CloudModel::Gpt4OmniMini => {
193                count_open_ai_tokens(request, open_ai::Model::FourOmniMini, cx)
194            }
195            CloudModel::Claude3_5Sonnet
196            | CloudModel::Claude3Opus
197            | CloudModel::Claude3Sonnet
198            | CloudModel::Claude3Haiku => count_anthropic_tokens(request, cx),
199            CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
200                count_anthropic_tokens(request, cx)
201            }
202            _ => {
203                let request = self.client.request(proto::CountTokensWithLanguageModel {
204                    model: self.model.id().to_string(),
205                    messages: request
206                        .messages
207                        .iter()
208                        .map(|message| message.to_proto())
209                        .collect(),
210                });
211                async move {
212                    let response = request.await?;
213                    Ok(response.token_count as usize)
214                }
215                .boxed()
216            }
217        }
218    }
219
220    fn stream_completion(
221        &self,
222        mut request: LanguageModelRequest,
223        _: &AsyncAppContext,
224    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
225        match &self.model {
226            CloudModel::Claude3Opus
227            | CloudModel::Claude3Sonnet
228            | CloudModel::Claude3Haiku
229            | CloudModel::Claude3_5Sonnet => preprocess_anthropic_request(&mut request),
230            CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
231                preprocess_anthropic_request(&mut request)
232            }
233            _ => {}
234        }
235
236        let request = proto::CompleteWithLanguageModel {
237            model: self.id.0.to_string(),
238            messages: request
239                .messages
240                .iter()
241                .map(|message| message.to_proto())
242                .collect(),
243            stop: request.stop,
244            temperature: request.temperature,
245            tools: Vec::new(),
246            tool_choice: None,
247        };
248
249        self.client
250            .request_stream(request)
251            .map_ok(|stream| {
252                stream
253                    .filter_map(|response| async move {
254                        match response {
255                            Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
256                            Err(error) => Some(Err(error)),
257                        }
258                    })
259                    .boxed()
260            })
261            .boxed()
262    }
263}
264
265struct AuthenticationPrompt {
266    state: gpui::Model<State>,
267}
268
269impl Render for AuthenticationPrompt {
270    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
271        const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
272
273        v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
274            v_flex()
275                .gap_2()
276                .child(
277                    Button::new("sign_in", "Sign in")
278                        .icon_color(Color::Muted)
279                        .icon(IconName::Github)
280                        .icon_position(IconPosition::Start)
281                        .style(ButtonStyle::Filled)
282                        .full_width()
283                        .on_click(cx.listener(move |this, _, cx| {
284                            this.state.update(cx, |provider, cx| {
285                                provider.authenticate(cx).detach_and_log_err(cx);
286                                cx.notify();
287                            });
288                        })),
289                )
290                .child(
291                    div().flex().w_full().items_center().child(
292                        Label::new("Sign in to enable collaboration.")
293                            .color(Color::Muted)
294                            .size(LabelSize::Small),
295                    ),
296                ),
297        )
298    }
299}