cloud.rs

  1use crate::{
  2    assistant_settings::CloudModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
  3    LanguageModelCompletionProvider, LanguageModelRequest,
  4};
  5use anyhow::{anyhow, Result};
  6use client::{proto, Client};
  7use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
  8use gpui::{AnyView, AppContext, Task};
  9use std::{future, sync::Arc};
 10use strum::IntoEnumIterator;
 11use ui::prelude::*;
 12
 13pub struct CloudCompletionProvider {
 14    client: Arc<Client>,
 15    model: CloudModel,
 16    settings_version: usize,
 17    status: client::Status,
 18    _maintain_client_status: Task<()>,
 19}
 20
 21impl CloudCompletionProvider {
 22    pub fn new(
 23        model: CloudModel,
 24        client: Arc<Client>,
 25        settings_version: usize,
 26        cx: &mut AppContext,
 27    ) -> Self {
 28        let mut status_rx = client.status();
 29        let status = *status_rx.borrow();
 30        let maintain_client_status = cx.spawn(|mut cx| async move {
 31            while let Some(status) = status_rx.next().await {
 32                let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
 33                    provider.update_current_as::<_, Self>(|provider| {
 34                        provider.status = status;
 35                    });
 36                });
 37            }
 38        });
 39        Self {
 40            client,
 41            model,
 42            settings_version,
 43            status,
 44            _maintain_client_status: maintain_client_status,
 45        }
 46    }
 47
 48    pub fn update(&mut self, model: CloudModel, settings_version: usize) {
 49        self.model = model;
 50        self.settings_version = settings_version;
 51    }
 52}
 53
 54impl LanguageModelCompletionProvider for CloudCompletionProvider {
 55    fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
 56        let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() {
 57            Some(custom_model)
 58        } else {
 59            None
 60        };
 61        CloudModel::iter()
 62            .filter_map(move |model| {
 63                if let CloudModel::Custom(_) = model {
 64                    Some(CloudModel::Custom(custom_model.take()?))
 65                } else {
 66                    Some(model)
 67                }
 68            })
 69            .map(LanguageModel::Cloud)
 70            .collect()
 71    }
 72
 73    fn settings_version(&self) -> usize {
 74        self.settings_version
 75    }
 76
 77    fn is_authenticated(&self) -> bool {
 78        self.status.is_connected()
 79    }
 80
 81    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
 82        let client = self.client.clone();
 83        cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
 84    }
 85
 86    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
 87        cx.new_view(|_cx| AuthenticationPrompt).into()
 88    }
 89
 90    fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
 91        Task::ready(Ok(()))
 92    }
 93
 94    fn model(&self) -> LanguageModel {
 95        LanguageModel::Cloud(self.model.clone())
 96    }
 97
 98    fn count_tokens(
 99        &self,
100        request: LanguageModelRequest,
101        cx: &AppContext,
102    ) -> BoxFuture<'static, Result<usize>> {
103        match request.model {
104            LanguageModel::Cloud(CloudModel::Gpt4)
105            | LanguageModel::Cloud(CloudModel::Gpt4Turbo)
106            | LanguageModel::Cloud(CloudModel::Gpt4Omni)
107            | LanguageModel::Cloud(CloudModel::Gpt3Point5Turbo) => {
108                count_open_ai_tokens(request, cx.background_executor())
109            }
110            LanguageModel::Cloud(
111                CloudModel::Claude3_5Sonnet
112                | CloudModel::Claude3Opus
113                | CloudModel::Claude3Sonnet
114                | CloudModel::Claude3Haiku,
115            ) => {
116                // Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
117                count_open_ai_tokens(request, cx.background_executor())
118            }
119            LanguageModel::Cloud(CloudModel::Custom(model)) => {
120                let request = self.client.request(proto::CountTokensWithLanguageModel {
121                    model,
122                    messages: request
123                        .messages
124                        .iter()
125                        .map(|message| message.to_proto())
126                        .collect(),
127                });
128                async move {
129                    let response = request.await?;
130                    Ok(response.token_count as usize)
131                }
132                .boxed()
133            }
134            _ => future::ready(Err(anyhow!("invalid model"))).boxed(),
135        }
136    }
137
138    fn complete(
139        &self,
140        mut request: LanguageModelRequest,
141    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
142        request.preprocess();
143
144        let request = proto::CompleteWithLanguageModel {
145            model: request.model.id().to_string(),
146            messages: request
147                .messages
148                .iter()
149                .map(|message| message.to_proto())
150                .collect(),
151            stop: request.stop,
152            temperature: request.temperature,
153            tools: Vec::new(),
154            tool_choice: None,
155        };
156
157        self.client
158            .request_stream(request)
159            .map_ok(|stream| {
160                stream
161                    .filter_map(|response| async move {
162                        match response {
163                            Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
164                            Err(error) => Some(Err(error)),
165                        }
166                    })
167                    .boxed()
168            })
169            .boxed()
170    }
171
172    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
173        self
174    }
175}
176
177struct AuthenticationPrompt;
178
179impl Render for AuthenticationPrompt {
180    fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
181        const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
182
183        v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
184            v_flex()
185                .gap_2()
186                .child(
187                    Button::new("sign_in", "Sign in")
188                        .icon_color(Color::Muted)
189                        .icon(IconName::Github)
190                        .icon_position(IconPosition::Start)
191                        .style(ButtonStyle::Filled)
192                        .full_width()
193                        .on_click(|_, cx| {
194                            CompletionProvider::global(cx)
195                                .authenticate(cx)
196                                .detach_and_log_err(cx);
197                        }),
198                )
199                .child(
200                    div().flex().w_full().items_center().child(
201                        Label::new("Sign in to enable collaboration.")
202                            .color(Color::Muted)
203                            .size(LabelSize::Small),
204                    ),
205                ),
206        )
207    }
208}