cloud.rs

  1use crate::{
  2    count_open_ai_tokens, CompletionProvider, LanguageModel, LanguageModelCompletionProvider,
  3    LanguageModelRequest,
  4};
  5use anyhow::{anyhow, Result};
  6use client::{proto, Client};
  7use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
  8use gpui::{AnyView, AppContext, Task};
  9use language_model::CloudModel;
 10use std::{future, sync::Arc};
 11use strum::IntoEnumIterator;
 12use ui::prelude::*;
 13
 14pub struct CloudCompletionProvider {
 15    client: Arc<Client>,
 16    model: CloudModel,
 17    settings_version: usize,
 18    status: client::Status,
 19    _maintain_client_status: Task<()>,
 20}
 21
 22impl CloudCompletionProvider {
 23    pub fn new(
 24        model: CloudModel,
 25        client: Arc<Client>,
 26        settings_version: usize,
 27        cx: &mut AppContext,
 28    ) -> Self {
 29        let mut status_rx = client.status();
 30        let status = *status_rx.borrow();
 31        let maintain_client_status = cx.spawn(|mut cx| async move {
 32            while let Some(status) = status_rx.next().await {
 33                let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
 34                    provider.update_current_as::<_, Self>(|provider| {
 35                        provider.status = status;
 36                    });
 37                });
 38            }
 39        });
 40        Self {
 41            client,
 42            model,
 43            settings_version,
 44            status,
 45            _maintain_client_status: maintain_client_status,
 46        }
 47    }
 48
 49    pub fn update(&mut self, model: CloudModel, settings_version: usize) {
 50        self.model = model;
 51        self.settings_version = settings_version;
 52    }
 53}
 54
 55impl LanguageModelCompletionProvider for CloudCompletionProvider {
 56    fn available_models(&self) -> Vec<LanguageModel> {
 57        let mut custom_model = if matches!(self.model, CloudModel::Custom { .. }) {
 58            Some(self.model.clone())
 59        } else {
 60            None
 61        };
 62        CloudModel::iter()
 63            .filter_map(move |model| {
 64                if let CloudModel::Custom { .. } = model {
 65                    custom_model.take()
 66                } else {
 67                    Some(model)
 68                }
 69            })
 70            .map(LanguageModel::Cloud)
 71            .collect()
 72    }
 73
 74    fn settings_version(&self) -> usize {
 75        self.settings_version
 76    }
 77
 78    fn is_authenticated(&self) -> bool {
 79        self.status.is_connected()
 80    }
 81
 82    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
 83        let client = self.client.clone();
 84        cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
 85    }
 86
 87    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
 88        cx.new_view(|_cx| AuthenticationPrompt).into()
 89    }
 90
 91    fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
 92        Task::ready(Ok(()))
 93    }
 94
 95    fn model(&self) -> LanguageModel {
 96        LanguageModel::Cloud(self.model.clone())
 97    }
 98
 99    fn count_tokens(
100        &self,
101        request: LanguageModelRequest,
102        cx: &AppContext,
103    ) -> BoxFuture<'static, Result<usize>> {
104        match &request.model {
105            LanguageModel::Cloud(CloudModel::Gpt4)
106            | LanguageModel::Cloud(CloudModel::Gpt4Turbo)
107            | LanguageModel::Cloud(CloudModel::Gpt4Omni)
108            | LanguageModel::Cloud(CloudModel::Gpt3Point5Turbo) => {
109                count_open_ai_tokens(request, cx.background_executor())
110            }
111            LanguageModel::Cloud(
112                CloudModel::Claude3_5Sonnet
113                | CloudModel::Claude3Opus
114                | CloudModel::Claude3Sonnet
115                | CloudModel::Claude3Haiku,
116            ) => {
117                // Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
118                count_open_ai_tokens(request, cx.background_executor())
119            }
120            LanguageModel::Cloud(CloudModel::Custom { name, .. }) => {
121                if name.starts_with("anthropic/") {
122                    // Can't find a tokenizer for Anthropic models, so for now just use the same as OpenAI's as an approximation.
123                    count_open_ai_tokens(request, cx.background_executor())
124                } else {
125                    let request = self.client.request(proto::CountTokensWithLanguageModel {
126                        model: name.clone(),
127                        messages: request
128                            .messages
129                            .iter()
130                            .map(|message| message.to_proto())
131                            .collect(),
132                    });
133                    async move {
134                        let response = request.await?;
135                        Ok(response.token_count as usize)
136                    }
137                    .boxed()
138                }
139            }
140            _ => future::ready(Err(anyhow!("invalid model"))).boxed(),
141        }
142    }
143
144    fn stream_completion(
145        &self,
146        mut request: LanguageModelRequest,
147    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
148        request.preprocess();
149
150        let request = proto::CompleteWithLanguageModel {
151            model: request.model.id().to_string(),
152            messages: request
153                .messages
154                .iter()
155                .map(|message| message.to_proto())
156                .collect(),
157            stop: request.stop,
158            temperature: request.temperature,
159            tools: Vec::new(),
160            tool_choice: None,
161        };
162
163        self.client
164            .request_stream(request)
165            .map_ok(|stream| {
166                stream
167                    .filter_map(|response| async move {
168                        match response {
169                            Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
170                            Err(error) => Some(Err(error)),
171                        }
172                    })
173                    .boxed()
174            })
175            .boxed()
176    }
177
178    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
179        self
180    }
181}
182
183struct AuthenticationPrompt;
184
185impl Render for AuthenticationPrompt {
186    fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
187        const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
188
189        v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
190            v_flex()
191                .gap_2()
192                .child(
193                    Button::new("sign_in", "Sign in")
194                        .icon_color(Color::Muted)
195                        .icon(IconName::Github)
196                        .icon_position(IconPosition::Start)
197                        .style(ButtonStyle::Filled)
198                        .full_width()
199                        .on_click(|_, cx| {
200                            CompletionProvider::global(cx)
201                                .authenticate(cx)
202                                .detach_and_log_err(cx);
203                        }),
204                )
205                .child(
206                    div().flex().w_full().items_center().child(
207                        Label::new("Sign in to enable collaboration.")
208                            .color(Color::Muted)
209                            .size(LabelSize::Small),
210                    ),
211                ),
212        )
213    }
214}