1use anyhow::{anyhow, Result};
2use collections::BTreeMap;
3use editor::{Editor, EditorElement, EditorStyle};
4use futures::{future::BoxFuture, FutureExt, StreamExt};
5use google_ai::stream_generate_content;
6use gpui::{
7 AnyView, AppContext, AsyncAppContext, FocusHandle, FocusableView, FontStyle, ModelContext,
8 Subscription, Task, TextStyle, View, WhiteSpace,
9};
10use http_client::HttpClient;
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use settings::{Settings, SettingsStore};
14use std::{future, sync::Arc, time::Duration};
15use strum::IntoEnumIterator;
16use theme::ThemeSettings;
17use ui::{prelude::*, Indicator};
18use util::ResultExt;
19
20use crate::{
21 settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
22 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
23 LanguageModelProviderState, LanguageModelRequest, RateLimiter,
24};
25
26const PROVIDER_ID: &str = "google";
27const PROVIDER_NAME: &str = "Google AI";
28
29#[derive(Default, Clone, Debug, PartialEq)]
30pub struct GoogleSettings {
31 pub api_url: String,
32 pub low_speed_timeout: Option<Duration>,
33 pub available_models: Vec<AvailableModel>,
34}
35
36#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
37pub struct AvailableModel {
38 name: String,
39 max_tokens: usize,
40}
41
42pub struct GoogleLanguageModelProvider {
43 http_client: Arc<dyn HttpClient>,
44 state: gpui::Model<State>,
45}
46
47pub struct State {
48 api_key: Option<String>,
49 _subscription: Subscription,
50}
51
52impl State {
53 fn is_authenticated(&self) -> bool {
54 self.api_key.is_some()
55 }
56
57 fn reset_api_key(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
58 let delete_credentials =
59 cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url);
60 cx.spawn(|this, mut cx| async move {
61 delete_credentials.await.ok();
62 this.update(&mut cx, |this, cx| {
63 this.api_key = None;
64 cx.notify();
65 })
66 })
67 }
68}
69
70impl GoogleLanguageModelProvider {
71 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
72 let state = cx.new_model(|cx| State {
73 api_key: None,
74 _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
75 cx.notify();
76 }),
77 });
78
79 Self { http_client, state }
80 }
81}
82
83impl LanguageModelProviderState for GoogleLanguageModelProvider {
84 type ObservableEntity = State;
85
86 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
87 Some(self.state.clone())
88 }
89}
90
91impl LanguageModelProvider for GoogleLanguageModelProvider {
92 fn id(&self) -> LanguageModelProviderId {
93 LanguageModelProviderId(PROVIDER_ID.into())
94 }
95
96 fn name(&self) -> LanguageModelProviderName {
97 LanguageModelProviderName(PROVIDER_NAME.into())
98 }
99
100 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
101 let mut models = BTreeMap::default();
102
103 // Add base models from google_ai::Model::iter()
104 for model in google_ai::Model::iter() {
105 if !matches!(model, google_ai::Model::Custom { .. }) {
106 models.insert(model.id().to_string(), model);
107 }
108 }
109
110 // Override with available models from settings
111 for model in &AllLanguageModelSettings::get_global(cx)
112 .google
113 .available_models
114 {
115 models.insert(
116 model.name.clone(),
117 google_ai::Model::Custom {
118 name: model.name.clone(),
119 max_tokens: model.max_tokens,
120 },
121 );
122 }
123
124 models
125 .into_values()
126 .map(|model| {
127 Arc::new(GoogleLanguageModel {
128 id: LanguageModelId::from(model.id().to_string()),
129 model,
130 state: self.state.clone(),
131 http_client: self.http_client.clone(),
132 rate_limiter: RateLimiter::new(4),
133 }) as Arc<dyn LanguageModel>
134 })
135 .collect()
136 }
137
138 fn is_authenticated(&self, cx: &AppContext) -> bool {
139 self.state.read(cx).is_authenticated()
140 }
141
142 fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
143 if self.is_authenticated(cx) {
144 Task::ready(Ok(()))
145 } else {
146 let api_url = AllLanguageModelSettings::get_global(cx)
147 .google
148 .api_url
149 .clone();
150 let state = self.state.clone();
151 cx.spawn(|mut cx| async move {
152 let api_key = if let Ok(api_key) = std::env::var("GOOGLE_AI_API_KEY") {
153 api_key
154 } else {
155 let (_, api_key) = cx
156 .update(|cx| cx.read_credentials(&api_url))?
157 .await?
158 .ok_or_else(|| anyhow!("credentials not found"))?;
159 String::from_utf8(api_key)?
160 };
161
162 state.update(&mut cx, |this, cx| {
163 this.api_key = Some(api_key);
164 cx.notify();
165 })
166 })
167 }
168 }
169
170 fn configuration_view(&self, cx: &mut WindowContext) -> (AnyView, Option<FocusHandle>) {
171 let view = cx.new_view(|cx| ConfigurationView::new(self.state.clone(), cx));
172
173 let focus_handle = view.focus_handle(cx);
174 (view.into(), Some(focus_handle))
175 }
176
177 fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
178 let state = self.state.clone();
179 let delete_credentials =
180 cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url);
181 cx.spawn(|mut cx| async move {
182 delete_credentials.await.log_err();
183 state.update(&mut cx, |this, cx| {
184 this.api_key = None;
185 cx.notify();
186 })
187 })
188 }
189}
190
191pub struct GoogleLanguageModel {
192 id: LanguageModelId,
193 model: google_ai::Model,
194 state: gpui::Model<State>,
195 http_client: Arc<dyn HttpClient>,
196 rate_limiter: RateLimiter,
197}
198
199impl LanguageModel for GoogleLanguageModel {
200 fn id(&self) -> LanguageModelId {
201 self.id.clone()
202 }
203
204 fn name(&self) -> LanguageModelName {
205 LanguageModelName::from(self.model.display_name().to_string())
206 }
207
208 fn provider_id(&self) -> LanguageModelProviderId {
209 LanguageModelProviderId(PROVIDER_ID.into())
210 }
211
212 fn provider_name(&self) -> LanguageModelProviderName {
213 LanguageModelProviderName(PROVIDER_NAME.into())
214 }
215
216 fn telemetry_id(&self) -> String {
217 format!("google/{}", self.model.id())
218 }
219
220 fn max_token_count(&self) -> usize {
221 self.model.max_token_count()
222 }
223
224 fn count_tokens(
225 &self,
226 request: LanguageModelRequest,
227 cx: &AppContext,
228 ) -> BoxFuture<'static, Result<usize>> {
229 let request = request.into_google(self.model.id().to_string());
230 let http_client = self.http_client.clone();
231 let api_key = self.state.read(cx).api_key.clone();
232 let api_url = AllLanguageModelSettings::get_global(cx)
233 .google
234 .api_url
235 .clone();
236
237 async move {
238 let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
239 let response = google_ai::count_tokens(
240 http_client.as_ref(),
241 &api_url,
242 &api_key,
243 google_ai::CountTokensRequest {
244 contents: request.contents,
245 },
246 )
247 .await?;
248 Ok(response.total_tokens)
249 }
250 .boxed()
251 }
252
253 fn stream_completion(
254 &self,
255 request: LanguageModelRequest,
256 cx: &AsyncAppContext,
257 ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
258 let request = request.into_google(self.model.id().to_string());
259
260 let http_client = self.http_client.clone();
261 let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
262 let settings = &AllLanguageModelSettings::get_global(cx).google;
263 (state.api_key.clone(), settings.api_url.clone())
264 }) else {
265 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
266 };
267
268 let future = self.rate_limiter.stream(async move {
269 let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
270 let response =
271 stream_generate_content(http_client.as_ref(), &api_url, &api_key, request);
272 let events = response.await?;
273 Ok(google_ai::extract_text_from_events(events).boxed())
274 });
275 async move { Ok(future.await?.boxed()) }.boxed()
276 }
277
278 fn use_any_tool(
279 &self,
280 _request: LanguageModelRequest,
281 _name: String,
282 _description: String,
283 _schema: serde_json::Value,
284 _cx: &AsyncAppContext,
285 ) -> BoxFuture<'static, Result<serde_json::Value>> {
286 future::ready(Err(anyhow!("not implemented"))).boxed()
287 }
288}
289
290struct ConfigurationView {
291 api_key_editor: View<Editor>,
292 state: gpui::Model<State>,
293}
294
295impl ConfigurationView {
296 fn new(state: gpui::Model<State>, cx: &mut WindowContext) -> Self {
297 Self {
298 api_key_editor: cx.new_view(|cx| {
299 let mut editor = Editor::single_line(cx);
300 editor.set_placeholder_text("AIzaSy...", cx);
301 editor
302 }),
303 state,
304 }
305 }
306
307 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
308 let api_key = self.api_key_editor.read(cx).text(cx);
309 if api_key.is_empty() {
310 return;
311 }
312
313 let settings = &AllLanguageModelSettings::get_global(cx).google;
314 let write_credentials =
315 cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes());
316 let state = self.state.clone();
317 cx.spawn(|_, mut cx| async move {
318 write_credentials.await?;
319 state.update(&mut cx, |this, cx| {
320 this.api_key = Some(api_key);
321 cx.notify();
322 })
323 })
324 .detach_and_log_err(cx);
325 }
326
327 fn reset_api_key(&mut self, cx: &mut ViewContext<Self>) {
328 self.api_key_editor
329 .update(cx, |editor, cx| editor.set_text("", cx));
330 self.state
331 .update(cx, |state, cx| state.reset_api_key(cx))
332 .detach_and_log_err(cx);
333 }
334
335 fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
336 let settings = ThemeSettings::get_global(cx);
337 let text_style = TextStyle {
338 color: cx.theme().colors().text,
339 font_family: settings.ui_font.family.clone(),
340 font_features: settings.ui_font.features.clone(),
341 font_fallbacks: settings.ui_font.fallbacks.clone(),
342 font_size: rems(0.875).into(),
343 font_weight: settings.ui_font.weight,
344 font_style: FontStyle::Normal,
345 line_height: relative(1.3),
346 background_color: None,
347 underline: None,
348 strikethrough: None,
349 white_space: WhiteSpace::Normal,
350 };
351 EditorElement::new(
352 &self.api_key_editor,
353 EditorStyle {
354 background: cx.theme().colors().editor_background,
355 local_player: cx.theme().players().local(),
356 text: text_style,
357 ..Default::default()
358 },
359 )
360 }
361}
362
363impl FocusableView for ConfigurationView {
364 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
365 self.api_key_editor.read(cx).focus_handle(cx)
366 }
367}
368
369impl Render for ConfigurationView {
370 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
371 const INSTRUCTIONS: [&str; 4] = [
372 "To use the Google AI assistant, you need to add your Google AI API key.",
373 "You can create an API key at: https://makersuite.google.com/app/apikey",
374 "",
375 "Paste your Google AI API key below and hit enter to use the assistant:",
376 ];
377
378 if self.state.read(cx).is_authenticated() {
379 h_flex()
380 .size_full()
381 .justify_between()
382 .child(
383 h_flex()
384 .gap_2()
385 .child(Indicator::dot().color(Color::Success))
386 .child(Label::new("API Key configured").size(LabelSize::Small)),
387 )
388 .child(
389 Button::new("reset-key", "Reset key")
390 .icon(Some(IconName::Trash))
391 .icon_size(IconSize::Small)
392 .icon_position(IconPosition::Start)
393 .on_click(cx.listener(|this, _, cx| this.reset_api_key(cx))),
394 )
395 .into_any()
396 } else {
397 v_flex()
398 .size_full()
399 .on_action(cx.listener(Self::save_api_key))
400 .children(
401 INSTRUCTIONS.map(|instruction| Label::new(instruction)),
402 )
403 .child(
404 h_flex()
405 .w_full()
406 .my_2()
407 .px_2()
408 .py_1()
409 .bg(cx.theme().colors().editor_background)
410 .rounded_md()
411 .child(self.render_api_key_editor(cx)),
412 )
413 .child(
414 Label::new(
415 "You can also assign the GOOGLE_AI_API_KEY environment variable and restart Zed.",
416 )
417 .size(LabelSize::Small),
418 )
419 .into_any()
420 }
421 }
422}