1use crate::{
2 assistant_settings::CloudModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
3 LanguageModelCompletionProvider, LanguageModelRequest,
4};
5use anyhow::{anyhow, Result};
6use client::{proto, Client};
7use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
8use gpui::{AnyView, AppContext, Task};
9use std::{future, sync::Arc};
10use strum::IntoEnumIterator;
11use ui::prelude::*;
12
13pub struct CloudCompletionProvider {
14 client: Arc<Client>,
15 model: CloudModel,
16 settings_version: usize,
17 status: client::Status,
18 _maintain_client_status: Task<()>,
19}
20
21impl CloudCompletionProvider {
22 pub fn new(
23 model: CloudModel,
24 client: Arc<Client>,
25 settings_version: usize,
26 cx: &mut AppContext,
27 ) -> Self {
28 let mut status_rx = client.status();
29 let status = *status_rx.borrow();
30 let maintain_client_status = cx.spawn(|mut cx| async move {
31 while let Some(status) = status_rx.next().await {
32 let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
33 provider.update_current_as::<_, Self>(|provider| {
34 provider.status = status;
35 });
36 });
37 }
38 });
39 Self {
40 client,
41 model,
42 settings_version,
43 status,
44 _maintain_client_status: maintain_client_status,
45 }
46 }
47
48 pub fn update(&mut self, model: CloudModel, settings_version: usize) {
49 self.model = model;
50 self.settings_version = settings_version;
51 }
52}
53
54impl LanguageModelCompletionProvider for CloudCompletionProvider {
55 fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
56 let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() {
57 Some(custom_model)
58 } else {
59 None
60 };
61 CloudModel::iter()
62 .filter_map(move |model| {
63 if let CloudModel::Custom(_) = model {
64 Some(CloudModel::Custom(custom_model.take()?))
65 } else {
66 Some(model)
67 }
68 })
69 .map(LanguageModel::Cloud)
70 .collect()
71 }
72
73 fn settings_version(&self) -> usize {
74 self.settings_version
75 }
76
77 fn is_authenticated(&self) -> bool {
78 self.status.is_connected()
79 }
80
81 fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
82 let client = self.client.clone();
83 cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
84 }
85
86 fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
87 cx.new_view(|_cx| AuthenticationPrompt).into()
88 }
89
90 fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
91 Task::ready(Ok(()))
92 }
93
94 fn model(&self) -> LanguageModel {
95 LanguageModel::Cloud(self.model.clone())
96 }
97
98 fn count_tokens(
99 &self,
100 request: LanguageModelRequest,
101 cx: &AppContext,
102 ) -> BoxFuture<'static, Result<usize>> {
103 match request.model {
104 LanguageModel::Cloud(CloudModel::Gpt4)
105 | LanguageModel::Cloud(CloudModel::Gpt4Turbo)
106 | LanguageModel::Cloud(CloudModel::Gpt4Omni)
107 | LanguageModel::Cloud(CloudModel::Gpt3Point5Turbo) => {
108 count_open_ai_tokens(request, cx.background_executor())
109 }
110 LanguageModel::Cloud(
111 CloudModel::Claude3_5Sonnet
112 | CloudModel::Claude3Opus
113 | CloudModel::Claude3Sonnet
114 | CloudModel::Claude3Haiku,
115 ) => {
116 // Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
117 count_open_ai_tokens(request, cx.background_executor())
118 }
119 LanguageModel::Cloud(CloudModel::Custom(model)) => {
120 let request = self.client.request(proto::CountTokensWithLanguageModel {
121 model,
122 messages: request
123 .messages
124 .iter()
125 .map(|message| message.to_proto())
126 .collect(),
127 });
128 async move {
129 let response = request.await?;
130 Ok(response.token_count as usize)
131 }
132 .boxed()
133 }
134 _ => future::ready(Err(anyhow!("invalid model"))).boxed(),
135 }
136 }
137
138 fn complete(
139 &self,
140 mut request: LanguageModelRequest,
141 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
142 request.preprocess();
143
144 let request = proto::CompleteWithLanguageModel {
145 model: request.model.id().to_string(),
146 messages: request
147 .messages
148 .iter()
149 .map(|message| message.to_proto())
150 .collect(),
151 stop: request.stop,
152 temperature: request.temperature,
153 tools: Vec::new(),
154 tool_choice: None,
155 };
156
157 self.client
158 .request_stream(request)
159 .map_ok(|stream| {
160 stream
161 .filter_map(|response| async move {
162 match response {
163 Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
164 Err(error) => Some(Err(error)),
165 }
166 })
167 .boxed()
168 })
169 .boxed()
170 }
171
172 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
173 self
174 }
175}
176
177struct AuthenticationPrompt;
178
179impl Render for AuthenticationPrompt {
180 fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
181 const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
182
183 v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
184 v_flex()
185 .gap_2()
186 .child(
187 Button::new("sign_in", "Sign in")
188 .icon_color(Color::Muted)
189 .icon(IconName::Github)
190 .icon_position(IconPosition::Start)
191 .style(ButtonStyle::Filled)
192 .full_width()
193 .on_click(|_, cx| {
194 CompletionProvider::global(cx)
195 .authenticate(cx)
196 .detach_and_log_err(cx);
197 }),
198 )
199 .child(
200 div().flex().w_full().items_center().child(
201 Label::new("Sign in to enable collaboration.")
202 .color(Color::Muted)
203 .size(LabelSize::Small),
204 ),
205 ),
206 )
207 }
208}