1use crate::{
2 count_open_ai_tokens, CompletionProvider, LanguageModel, LanguageModelCompletionProvider,
3 LanguageModelRequest,
4};
5use anyhow::{anyhow, Result};
6use client::{proto, Client};
7use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
8use gpui::{AnyView, AppContext, Task};
9use language_model::CloudModel;
10use std::{future, sync::Arc};
11use strum::IntoEnumIterator;
12use ui::prelude::*;
13
14pub struct CloudCompletionProvider {
15 client: Arc<Client>,
16 model: CloudModel,
17 settings_version: usize,
18 status: client::Status,
19 _maintain_client_status: Task<()>,
20}
21
22impl CloudCompletionProvider {
23 pub fn new(
24 model: CloudModel,
25 client: Arc<Client>,
26 settings_version: usize,
27 cx: &mut AppContext,
28 ) -> Self {
29 let mut status_rx = client.status();
30 let status = *status_rx.borrow();
31 let maintain_client_status = cx.spawn(|mut cx| async move {
32 while let Some(status) = status_rx.next().await {
33 let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
34 provider.update_current_as::<_, Self>(|provider| {
35 provider.status = status;
36 });
37 });
38 }
39 });
40 Self {
41 client,
42 model,
43 settings_version,
44 status,
45 _maintain_client_status: maintain_client_status,
46 }
47 }
48
49 pub fn update(&mut self, model: CloudModel, settings_version: usize) {
50 self.model = model;
51 self.settings_version = settings_version;
52 }
53}
54
55impl LanguageModelCompletionProvider for CloudCompletionProvider {
56 fn available_models(&self) -> Vec<LanguageModel> {
57 let mut custom_model = if matches!(self.model, CloudModel::Custom { .. }) {
58 Some(self.model.clone())
59 } else {
60 None
61 };
62 CloudModel::iter()
63 .filter_map(move |model| {
64 if let CloudModel::Custom { .. } = model {
65 custom_model.take()
66 } else {
67 Some(model)
68 }
69 })
70 .map(LanguageModel::Cloud)
71 .collect()
72 }
73
74 fn settings_version(&self) -> usize {
75 self.settings_version
76 }
77
78 fn is_authenticated(&self) -> bool {
79 self.status.is_connected()
80 }
81
82 fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
83 let client = self.client.clone();
84 cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
85 }
86
87 fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
88 cx.new_view(|_cx| AuthenticationPrompt).into()
89 }
90
91 fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
92 Task::ready(Ok(()))
93 }
94
95 fn model(&self) -> LanguageModel {
96 LanguageModel::Cloud(self.model.clone())
97 }
98
99 fn count_tokens(
100 &self,
101 request: LanguageModelRequest,
102 cx: &AppContext,
103 ) -> BoxFuture<'static, Result<usize>> {
104 match &request.model {
105 LanguageModel::Cloud(CloudModel::Gpt4)
106 | LanguageModel::Cloud(CloudModel::Gpt4Turbo)
107 | LanguageModel::Cloud(CloudModel::Gpt4Omni)
108 | LanguageModel::Cloud(CloudModel::Gpt3Point5Turbo) => {
109 count_open_ai_tokens(request, cx.background_executor())
110 }
111 LanguageModel::Cloud(
112 CloudModel::Claude3_5Sonnet
113 | CloudModel::Claude3Opus
114 | CloudModel::Claude3Sonnet
115 | CloudModel::Claude3Haiku,
116 ) => {
117 // Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
118 count_open_ai_tokens(request, cx.background_executor())
119 }
120 LanguageModel::Cloud(CloudModel::Custom { name, .. }) => {
121 if name.starts_with("anthropic/") {
122 // Can't find a tokenizer for Anthropic models, so for now just use the same as OpenAI's as an approximation.
123 count_open_ai_tokens(request, cx.background_executor())
124 } else {
125 let request = self.client.request(proto::CountTokensWithLanguageModel {
126 model: name.clone(),
127 messages: request
128 .messages
129 .iter()
130 .map(|message| message.to_proto())
131 .collect(),
132 });
133 async move {
134 let response = request.await?;
135 Ok(response.token_count as usize)
136 }
137 .boxed()
138 }
139 }
140 _ => future::ready(Err(anyhow!("invalid model"))).boxed(),
141 }
142 }
143
144 fn stream_completion(
145 &self,
146 mut request: LanguageModelRequest,
147 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
148 request.preprocess();
149
150 let request = proto::CompleteWithLanguageModel {
151 model: request.model.id().to_string(),
152 messages: request
153 .messages
154 .iter()
155 .map(|message| message.to_proto())
156 .collect(),
157 stop: request.stop,
158 temperature: request.temperature,
159 tools: Vec::new(),
160 tool_choice: None,
161 };
162
163 self.client
164 .request_stream(request)
165 .map_ok(|stream| {
166 stream
167 .filter_map(|response| async move {
168 match response {
169 Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
170 Err(error) => Some(Err(error)),
171 }
172 })
173 .boxed()
174 })
175 .boxed()
176 }
177
178 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
179 self
180 }
181}
182
183struct AuthenticationPrompt;
184
185impl Render for AuthenticationPrompt {
186 fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
187 const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
188
189 v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
190 v_flex()
191 .gap_2()
192 .child(
193 Button::new("sign_in", "Sign in")
194 .icon_color(Color::Muted)
195 .icon(IconName::Github)
196 .icon_position(IconPosition::Start)
197 .style(ButtonStyle::Filled)
198 .full_width()
199 .on_click(|_, cx| {
200 CompletionProvider::global(cx)
201 .authenticate(cx)
202 .detach_and_log_err(cx);
203 }),
204 )
205 .child(
206 div().flex().w_full().items_center().child(
207 Label::new("Sign in to enable collaboration.")
208 .color(Color::Muted)
209 .size(LabelSize::Small),
210 ),
211 ),
212 )
213 }
214}