1use anyhow::{anyhow, Result};
2use collections::BTreeMap;
3use editor::{Editor, EditorElement, EditorStyle};
4use futures::{future::BoxFuture, FutureExt, StreamExt};
5use google_ai::stream_generate_content;
6use gpui::{
7 AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle,
8 View, WhiteSpace,
9};
10use http_client::HttpClient;
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use settings::{Settings, SettingsStore};
14use std::{future, sync::Arc, time::Duration};
15use strum::IntoEnumIterator;
16use theme::ThemeSettings;
17use ui::{prelude::*, Icon, IconName, Tooltip};
18use util::ResultExt;
19
20use crate::LanguageModelCompletionEvent;
21use crate::{
22 settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
23 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
24 LanguageModelProviderState, LanguageModelRequest, RateLimiter,
25};
26
27const PROVIDER_ID: &str = "google";
28const PROVIDER_NAME: &str = "Google AI";
29
30#[derive(Default, Clone, Debug, PartialEq)]
31pub struct GoogleSettings {
32 pub api_url: String,
33 pub low_speed_timeout: Option<Duration>,
34 pub available_models: Vec<AvailableModel>,
35}
36
37#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
38pub struct AvailableModel {
39 name: String,
40 display_name: Option<String>,
41 max_tokens: usize,
42}
43
44pub struct GoogleLanguageModelProvider {
45 http_client: Arc<dyn HttpClient>,
46 state: gpui::Model<State>,
47}
48
49pub struct State {
50 api_key: Option<String>,
51 api_key_from_env: bool,
52 _subscription: Subscription,
53}
54
55const GOOGLE_AI_API_KEY_VAR: &str = "GOOGLE_AI_API_KEY";
56
57impl State {
58 fn is_authenticated(&self) -> bool {
59 self.api_key.is_some()
60 }
61
62 fn reset_api_key(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
63 let delete_credentials =
64 cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url);
65 cx.spawn(|this, mut cx| async move {
66 delete_credentials.await.ok();
67 this.update(&mut cx, |this, cx| {
68 this.api_key = None;
69 this.api_key_from_env = false;
70 cx.notify();
71 })
72 })
73 }
74
75 fn set_api_key(&mut self, api_key: String, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
76 let settings = &AllLanguageModelSettings::get_global(cx).google;
77 let write_credentials =
78 cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes());
79
80 cx.spawn(|this, mut cx| async move {
81 write_credentials.await?;
82 this.update(&mut cx, |this, cx| {
83 this.api_key = Some(api_key);
84 cx.notify();
85 })
86 })
87 }
88
89 fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
90 if self.is_authenticated() {
91 Task::ready(Ok(()))
92 } else {
93 let api_url = AllLanguageModelSettings::get_global(cx)
94 .google
95 .api_url
96 .clone();
97
98 cx.spawn(|this, mut cx| async move {
99 let (api_key, from_env) = if let Ok(api_key) = std::env::var(GOOGLE_AI_API_KEY_VAR)
100 {
101 (api_key, true)
102 } else {
103 let (_, api_key) = cx
104 .update(|cx| cx.read_credentials(&api_url))?
105 .await?
106 .ok_or_else(|| anyhow!("credentials not found"))?;
107 (String::from_utf8(api_key)?, false)
108 };
109
110 this.update(&mut cx, |this, cx| {
111 this.api_key = Some(api_key);
112 this.api_key_from_env = from_env;
113 cx.notify();
114 })
115 })
116 }
117 }
118}
119
120impl GoogleLanguageModelProvider {
121 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
122 let state = cx.new_model(|cx| State {
123 api_key: None,
124 api_key_from_env: false,
125 _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
126 cx.notify();
127 }),
128 });
129
130 Self { http_client, state }
131 }
132}
133
134impl LanguageModelProviderState for GoogleLanguageModelProvider {
135 type ObservableEntity = State;
136
137 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
138 Some(self.state.clone())
139 }
140}
141
142impl LanguageModelProvider for GoogleLanguageModelProvider {
143 fn id(&self) -> LanguageModelProviderId {
144 LanguageModelProviderId(PROVIDER_ID.into())
145 }
146
147 fn name(&self) -> LanguageModelProviderName {
148 LanguageModelProviderName(PROVIDER_NAME.into())
149 }
150
151 fn icon(&self) -> IconName {
152 IconName::AiGoogle
153 }
154
155 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
156 let mut models = BTreeMap::default();
157
158 // Add base models from google_ai::Model::iter()
159 for model in google_ai::Model::iter() {
160 if !matches!(model, google_ai::Model::Custom { .. }) {
161 models.insert(model.id().to_string(), model);
162 }
163 }
164
165 // Override with available models from settings
166 for model in &AllLanguageModelSettings::get_global(cx)
167 .google
168 .available_models
169 {
170 models.insert(
171 model.name.clone(),
172 google_ai::Model::Custom {
173 name: model.name.clone(),
174 display_name: model.display_name.clone(),
175 max_tokens: model.max_tokens,
176 },
177 );
178 }
179
180 models
181 .into_values()
182 .map(|model| {
183 Arc::new(GoogleLanguageModel {
184 id: LanguageModelId::from(model.id().to_string()),
185 model,
186 state: self.state.clone(),
187 http_client: self.http_client.clone(),
188 rate_limiter: RateLimiter::new(4),
189 }) as Arc<dyn LanguageModel>
190 })
191 .collect()
192 }
193
194 fn is_authenticated(&self, cx: &AppContext) -> bool {
195 self.state.read(cx).is_authenticated()
196 }
197
198 fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
199 self.state.update(cx, |state, cx| state.authenticate(cx))
200 }
201
202 fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
203 cx.new_view(|cx| ConfigurationView::new(self.state.clone(), cx))
204 .into()
205 }
206
207 fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
208 let state = self.state.clone();
209 let delete_credentials =
210 cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url);
211 cx.spawn(|mut cx| async move {
212 delete_credentials.await.log_err();
213 state.update(&mut cx, |this, cx| {
214 this.api_key = None;
215 cx.notify();
216 })
217 })
218 }
219}
220
221pub struct GoogleLanguageModel {
222 id: LanguageModelId,
223 model: google_ai::Model,
224 state: gpui::Model<State>,
225 http_client: Arc<dyn HttpClient>,
226 rate_limiter: RateLimiter,
227}
228
229impl LanguageModel for GoogleLanguageModel {
230 fn id(&self) -> LanguageModelId {
231 self.id.clone()
232 }
233
234 fn name(&self) -> LanguageModelName {
235 LanguageModelName::from(self.model.display_name().to_string())
236 }
237
238 fn provider_id(&self) -> LanguageModelProviderId {
239 LanguageModelProviderId(PROVIDER_ID.into())
240 }
241
242 fn provider_name(&self) -> LanguageModelProviderName {
243 LanguageModelProviderName(PROVIDER_NAME.into())
244 }
245
246 fn telemetry_id(&self) -> String {
247 format!("google/{}", self.model.id())
248 }
249
250 fn max_token_count(&self) -> usize {
251 self.model.max_token_count()
252 }
253
254 fn count_tokens(
255 &self,
256 request: LanguageModelRequest,
257 cx: &AppContext,
258 ) -> BoxFuture<'static, Result<usize>> {
259 let request = request.into_google(self.model.id().to_string());
260 let http_client = self.http_client.clone();
261 let api_key = self.state.read(cx).api_key.clone();
262
263 let settings = &AllLanguageModelSettings::get_global(cx).google;
264 let api_url = settings.api_url.clone();
265 let low_speed_timeout = settings.low_speed_timeout;
266
267 async move {
268 let api_key = api_key.ok_or_else(|| anyhow!("Missing Google API key"))?;
269 let response = google_ai::count_tokens(
270 http_client.as_ref(),
271 &api_url,
272 &api_key,
273 google_ai::CountTokensRequest {
274 contents: request.contents,
275 },
276 low_speed_timeout,
277 )
278 .await?;
279 Ok(response.total_tokens)
280 }
281 .boxed()
282 }
283
284 fn stream_completion(
285 &self,
286 request: LanguageModelRequest,
287 cx: &AsyncAppContext,
288 ) -> BoxFuture<
289 'static,
290 Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
291 > {
292 let request = request.into_google(self.model.id().to_string());
293
294 let http_client = self.http_client.clone();
295 let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
296 let settings = &AllLanguageModelSettings::get_global(cx).google;
297 (
298 state.api_key.clone(),
299 settings.api_url.clone(),
300 settings.low_speed_timeout,
301 )
302 }) else {
303 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
304 };
305
306 let future = self.rate_limiter.stream(async move {
307 let api_key = api_key.ok_or_else(|| anyhow!("Missing Google API Key"))?;
308 let response = stream_generate_content(
309 http_client.as_ref(),
310 &api_url,
311 &api_key,
312 request,
313 low_speed_timeout,
314 );
315 let events = response.await?;
316 Ok(google_ai::extract_text_from_events(events).boxed())
317 });
318 async move {
319 Ok(future
320 .await?
321 .map(|result| result.map(LanguageModelCompletionEvent::Text))
322 .boxed())
323 }
324 .boxed()
325 }
326
327 fn use_any_tool(
328 &self,
329 _request: LanguageModelRequest,
330 _name: String,
331 _description: String,
332 _schema: serde_json::Value,
333 _cx: &AsyncAppContext,
334 ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
335 future::ready(Err(anyhow!("not implemented"))).boxed()
336 }
337}
338
339struct ConfigurationView {
340 api_key_editor: View<Editor>,
341 state: gpui::Model<State>,
342 load_credentials_task: Option<Task<()>>,
343}
344
345impl ConfigurationView {
346 fn new(state: gpui::Model<State>, cx: &mut ViewContext<Self>) -> Self {
347 cx.observe(&state, |_, _, cx| {
348 cx.notify();
349 })
350 .detach();
351
352 let load_credentials_task = Some(cx.spawn({
353 let state = state.clone();
354 |this, mut cx| async move {
355 if let Some(task) = state
356 .update(&mut cx, |state, cx| state.authenticate(cx))
357 .log_err()
358 {
359 // We don't log an error, because "not signed in" is also an error.
360 let _ = task.await;
361 }
362 this.update(&mut cx, |this, cx| {
363 this.load_credentials_task = None;
364 cx.notify();
365 })
366 .log_err();
367 }
368 }));
369
370 Self {
371 api_key_editor: cx.new_view(|cx| {
372 let mut editor = Editor::single_line(cx);
373 editor.set_placeholder_text("AIzaSy...", cx);
374 editor
375 }),
376 state,
377 load_credentials_task,
378 }
379 }
380
381 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
382 let api_key = self.api_key_editor.read(cx).text(cx);
383 if api_key.is_empty() {
384 return;
385 }
386
387 let state = self.state.clone();
388 cx.spawn(|_, mut cx| async move {
389 state
390 .update(&mut cx, |state, cx| state.set_api_key(api_key, cx))?
391 .await
392 })
393 .detach_and_log_err(cx);
394
395 cx.notify();
396 }
397
398 fn reset_api_key(&mut self, cx: &mut ViewContext<Self>) {
399 self.api_key_editor
400 .update(cx, |editor, cx| editor.set_text("", cx));
401
402 let state = self.state.clone();
403 cx.spawn(|_, mut cx| async move {
404 state
405 .update(&mut cx, |state, cx| state.reset_api_key(cx))?
406 .await
407 })
408 .detach_and_log_err(cx);
409
410 cx.notify();
411 }
412
413 fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
414 let settings = ThemeSettings::get_global(cx);
415 let text_style = TextStyle {
416 color: cx.theme().colors().text,
417 font_family: settings.ui_font.family.clone(),
418 font_features: settings.ui_font.features.clone(),
419 font_fallbacks: settings.ui_font.fallbacks.clone(),
420 font_size: rems(0.875).into(),
421 font_weight: settings.ui_font.weight,
422 font_style: FontStyle::Normal,
423 line_height: relative(1.3),
424 background_color: None,
425 underline: None,
426 strikethrough: None,
427 white_space: WhiteSpace::Normal,
428 truncate: None,
429 };
430 EditorElement::new(
431 &self.api_key_editor,
432 EditorStyle {
433 background: cx.theme().colors().editor_background,
434 local_player: cx.theme().players().local(),
435 text: text_style,
436 ..Default::default()
437 },
438 )
439 }
440
441 fn should_render_editor(&self, cx: &mut ViewContext<Self>) -> bool {
442 !self.state.read(cx).is_authenticated()
443 }
444}
445
446impl Render for ConfigurationView {
447 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
448 const GOOGLE_CONSOLE_URL: &str = "https://aistudio.google.com/app/apikey";
449 const INSTRUCTIONS: [&str; 3] = [
450 "To use Zed's assistant with Google AI, you need to add an API key. Follow these steps:",
451 "- Create one by visiting:",
452 "- Paste your API key below and hit enter to use the assistant",
453 ];
454
455 let env_var_set = self.state.read(cx).api_key_from_env;
456
457 if self.load_credentials_task.is_some() {
458 div().child(Label::new("Loading credentials...")).into_any()
459 } else if self.should_render_editor(cx) {
460 v_flex()
461 .size_full()
462 .on_action(cx.listener(Self::save_api_key))
463 .child(Label::new(INSTRUCTIONS[0]))
464 .child(h_flex().child(Label::new(INSTRUCTIONS[1])).child(
465 Button::new("google_console", GOOGLE_CONSOLE_URL)
466 .style(ButtonStyle::Subtle)
467 .icon(IconName::ExternalLink)
468 .icon_size(IconSize::XSmall)
469 .icon_color(Color::Muted)
470 .on_click(move |_, cx| cx.open_url(GOOGLE_CONSOLE_URL))
471 )
472 )
473 .child(Label::new(INSTRUCTIONS[2]))
474 .child(
475 h_flex()
476 .w_full()
477 .my_2()
478 .px_2()
479 .py_1()
480 .bg(cx.theme().colors().editor_background)
481 .rounded_md()
482 .child(self.render_api_key_editor(cx)),
483 )
484 .child(
485 Label::new(
486 format!("You can also assign the {GOOGLE_AI_API_KEY_VAR} environment variable and restart Zed."),
487 )
488 .size(LabelSize::Small),
489 )
490 .into_any()
491 } else {
492 h_flex()
493 .size_full()
494 .justify_between()
495 .child(
496 h_flex()
497 .gap_1()
498 .child(Icon::new(IconName::Check).color(Color::Success))
499 .child(Label::new(if env_var_set {
500 format!("API key set in {GOOGLE_AI_API_KEY_VAR} environment variable.")
501 } else {
502 "API key configured.".to_string()
503 })),
504 )
505 .child(
506 Button::new("reset-key", "Reset key")
507 .icon(Some(IconName::Trash))
508 .icon_size(IconSize::Small)
509 .icon_position(IconPosition::Start)
510 .disabled(env_var_set)
511 .when(env_var_set, |this| {
512 this.tooltip(|cx| Tooltip::text(format!("To reset your API key, unset the {GOOGLE_AI_API_KEY_VAR} environment variable."), cx))
513 })
514 .on_click(cx.listener(|this, _, cx| this.reset_api_key(cx))),
515 )
516 .into_any()
517 }
518 }
519}