1use anyhow::{anyhow, Result};
2use collections::BTreeMap;
3use editor::{Editor, EditorElement, EditorStyle};
4use futures::{future::BoxFuture, FutureExt, StreamExt};
5use google_ai::stream_generate_content;
6use gpui::{
7 AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle,
8 View, WhiteSpace,
9};
10use http_client::HttpClient;
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use settings::{Settings, SettingsStore};
14use std::{future, sync::Arc, time::Duration};
15use strum::IntoEnumIterator;
16use theme::ThemeSettings;
17use ui::{prelude::*, Icon, IconName, Tooltip};
18use util::ResultExt;
19
20use crate::{
21 settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
22 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
23 LanguageModelProviderState, LanguageModelRequest, RateLimiter,
24};
25
26const PROVIDER_ID: &str = "google";
27const PROVIDER_NAME: &str = "Google AI";
28
29#[derive(Default, Clone, Debug, PartialEq)]
30pub struct GoogleSettings {
31 pub api_url: String,
32 pub low_speed_timeout: Option<Duration>,
33 pub available_models: Vec<AvailableModel>,
34}
35
36#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
37pub struct AvailableModel {
38 name: String,
39 max_tokens: usize,
40}
41
42pub struct GoogleLanguageModelProvider {
43 http_client: Arc<dyn HttpClient>,
44 state: gpui::Model<State>,
45}
46
47pub struct State {
48 api_key: Option<String>,
49 api_key_from_env: bool,
50 _subscription: Subscription,
51}
52
53const GOOGLE_AI_API_KEY_VAR: &'static str = "GOOGLE_AI_API_KEY";
54
55impl State {
56 fn is_authenticated(&self) -> bool {
57 self.api_key.is_some()
58 }
59
60 fn reset_api_key(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
61 let delete_credentials =
62 cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url);
63 cx.spawn(|this, mut cx| async move {
64 delete_credentials.await.ok();
65 this.update(&mut cx, |this, cx| {
66 this.api_key = None;
67 this.api_key_from_env = false;
68 cx.notify();
69 })
70 })
71 }
72
73 fn set_api_key(&mut self, api_key: String, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
74 let settings = &AllLanguageModelSettings::get_global(cx).google;
75 let write_credentials =
76 cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes());
77
78 cx.spawn(|this, mut cx| async move {
79 write_credentials.await?;
80 this.update(&mut cx, |this, cx| {
81 this.api_key = Some(api_key);
82 cx.notify();
83 })
84 })
85 }
86
87 fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
88 if self.is_authenticated() {
89 Task::ready(Ok(()))
90 } else {
91 let api_url = AllLanguageModelSettings::get_global(cx)
92 .google
93 .api_url
94 .clone();
95
96 cx.spawn(|this, mut cx| async move {
97 let (api_key, from_env) = if let Ok(api_key) = std::env::var(GOOGLE_AI_API_KEY_VAR)
98 {
99 (api_key, true)
100 } else {
101 let (_, api_key) = cx
102 .update(|cx| cx.read_credentials(&api_url))?
103 .await?
104 .ok_or_else(|| anyhow!("credentials not found"))?;
105 (String::from_utf8(api_key)?, false)
106 };
107
108 this.update(&mut cx, |this, cx| {
109 this.api_key = Some(api_key);
110 this.api_key_from_env = from_env;
111 cx.notify();
112 })
113 })
114 }
115 }
116}
117
118impl GoogleLanguageModelProvider {
119 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
120 let state = cx.new_model(|cx| State {
121 api_key: None,
122 api_key_from_env: false,
123 _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
124 cx.notify();
125 }),
126 });
127
128 Self { http_client, state }
129 }
130}
131
132impl LanguageModelProviderState for GoogleLanguageModelProvider {
133 type ObservableEntity = State;
134
135 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
136 Some(self.state.clone())
137 }
138}
139
140impl LanguageModelProvider for GoogleLanguageModelProvider {
141 fn id(&self) -> LanguageModelProviderId {
142 LanguageModelProviderId(PROVIDER_ID.into())
143 }
144
145 fn name(&self) -> LanguageModelProviderName {
146 LanguageModelProviderName(PROVIDER_NAME.into())
147 }
148
149 fn icon(&self) -> IconName {
150 IconName::AiGoogle
151 }
152
153 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
154 let mut models = BTreeMap::default();
155
156 // Add base models from google_ai::Model::iter()
157 for model in google_ai::Model::iter() {
158 if !matches!(model, google_ai::Model::Custom { .. }) {
159 models.insert(model.id().to_string(), model);
160 }
161 }
162
163 // Override with available models from settings
164 for model in &AllLanguageModelSettings::get_global(cx)
165 .google
166 .available_models
167 {
168 models.insert(
169 model.name.clone(),
170 google_ai::Model::Custom {
171 name: model.name.clone(),
172 max_tokens: model.max_tokens,
173 },
174 );
175 }
176
177 models
178 .into_values()
179 .map(|model| {
180 Arc::new(GoogleLanguageModel {
181 id: LanguageModelId::from(model.id().to_string()),
182 model,
183 state: self.state.clone(),
184 http_client: self.http_client.clone(),
185 rate_limiter: RateLimiter::new(4),
186 }) as Arc<dyn LanguageModel>
187 })
188 .collect()
189 }
190
191 fn is_authenticated(&self, cx: &AppContext) -> bool {
192 self.state.read(cx).is_authenticated()
193 }
194
195 fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
196 self.state.update(cx, |state, cx| state.authenticate(cx))
197 }
198
199 fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
200 cx.new_view(|cx| ConfigurationView::new(self.state.clone(), cx))
201 .into()
202 }
203
204 fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
205 let state = self.state.clone();
206 let delete_credentials =
207 cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url);
208 cx.spawn(|mut cx| async move {
209 delete_credentials.await.log_err();
210 state.update(&mut cx, |this, cx| {
211 this.api_key = None;
212 cx.notify();
213 })
214 })
215 }
216}
217
218pub struct GoogleLanguageModel {
219 id: LanguageModelId,
220 model: google_ai::Model,
221 state: gpui::Model<State>,
222 http_client: Arc<dyn HttpClient>,
223 rate_limiter: RateLimiter,
224}
225
226impl LanguageModel for GoogleLanguageModel {
227 fn id(&self) -> LanguageModelId {
228 self.id.clone()
229 }
230
231 fn name(&self) -> LanguageModelName {
232 LanguageModelName::from(self.model.display_name().to_string())
233 }
234
235 fn provider_id(&self) -> LanguageModelProviderId {
236 LanguageModelProviderId(PROVIDER_ID.into())
237 }
238
239 fn provider_name(&self) -> LanguageModelProviderName {
240 LanguageModelProviderName(PROVIDER_NAME.into())
241 }
242
243 fn telemetry_id(&self) -> String {
244 format!("google/{}", self.model.id())
245 }
246
247 fn max_token_count(&self) -> usize {
248 self.model.max_token_count()
249 }
250
251 fn count_tokens(
252 &self,
253 request: LanguageModelRequest,
254 cx: &AppContext,
255 ) -> BoxFuture<'static, Result<usize>> {
256 let request = request.into_google(self.model.id().to_string());
257 let http_client = self.http_client.clone();
258 let api_key = self.state.read(cx).api_key.clone();
259 let api_url = AllLanguageModelSettings::get_global(cx)
260 .google
261 .api_url
262 .clone();
263
264 async move {
265 let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
266 let response = google_ai::count_tokens(
267 http_client.as_ref(),
268 &api_url,
269 &api_key,
270 google_ai::CountTokensRequest {
271 contents: request.contents,
272 },
273 )
274 .await?;
275 Ok(response.total_tokens)
276 }
277 .boxed()
278 }
279
280 fn stream_completion(
281 &self,
282 request: LanguageModelRequest,
283 cx: &AsyncAppContext,
284 ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
285 let request = request.into_google(self.model.id().to_string());
286
287 let http_client = self.http_client.clone();
288 let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
289 let settings = &AllLanguageModelSettings::get_global(cx).google;
290 (state.api_key.clone(), settings.api_url.clone())
291 }) else {
292 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
293 };
294
295 let future = self.rate_limiter.stream(async move {
296 let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
297 let response =
298 stream_generate_content(http_client.as_ref(), &api_url, &api_key, request);
299 let events = response.await?;
300 Ok(google_ai::extract_text_from_events(events).boxed())
301 });
302 async move { Ok(future.await?.boxed()) }.boxed()
303 }
304
305 fn use_any_tool(
306 &self,
307 _request: LanguageModelRequest,
308 _name: String,
309 _description: String,
310 _schema: serde_json::Value,
311 _cx: &AsyncAppContext,
312 ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
313 future::ready(Err(anyhow!("not implemented"))).boxed()
314 }
315}
316
317struct ConfigurationView {
318 api_key_editor: View<Editor>,
319 state: gpui::Model<State>,
320 load_credentials_task: Option<Task<()>>,
321}
322
323impl ConfigurationView {
324 fn new(state: gpui::Model<State>, cx: &mut ViewContext<Self>) -> Self {
325 cx.observe(&state, |_, _, cx| {
326 cx.notify();
327 })
328 .detach();
329
330 let load_credentials_task = Some(cx.spawn({
331 let state = state.clone();
332 |this, mut cx| async move {
333 if let Some(task) = state
334 .update(&mut cx, |state, cx| state.authenticate(cx))
335 .log_err()
336 {
337 // We don't log an error, because "not signed in" is also an error.
338 let _ = task.await;
339 }
340 this.update(&mut cx, |this, cx| {
341 this.load_credentials_task = None;
342 cx.notify();
343 })
344 .log_err();
345 }
346 }));
347
348 Self {
349 api_key_editor: cx.new_view(|cx| {
350 let mut editor = Editor::single_line(cx);
351 editor.set_placeholder_text("AIzaSy...", cx);
352 editor
353 }),
354 state,
355 load_credentials_task,
356 }
357 }
358
359 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
360 let api_key = self.api_key_editor.read(cx).text(cx);
361 if api_key.is_empty() {
362 return;
363 }
364
365 let state = self.state.clone();
366 cx.spawn(|_, mut cx| async move {
367 state
368 .update(&mut cx, |state, cx| state.set_api_key(api_key, cx))?
369 .await
370 })
371 .detach_and_log_err(cx);
372
373 cx.notify();
374 }
375
376 fn reset_api_key(&mut self, cx: &mut ViewContext<Self>) {
377 self.api_key_editor
378 .update(cx, |editor, cx| editor.set_text("", cx));
379
380 let state = self.state.clone();
381 cx.spawn(|_, mut cx| async move {
382 state
383 .update(&mut cx, |state, cx| state.reset_api_key(cx))?
384 .await
385 })
386 .detach_and_log_err(cx);
387
388 cx.notify();
389 }
390
391 fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
392 let settings = ThemeSettings::get_global(cx);
393 let text_style = TextStyle {
394 color: cx.theme().colors().text,
395 font_family: settings.ui_font.family.clone(),
396 font_features: settings.ui_font.features.clone(),
397 font_fallbacks: settings.ui_font.fallbacks.clone(),
398 font_size: rems(0.875).into(),
399 font_weight: settings.ui_font.weight,
400 font_style: FontStyle::Normal,
401 line_height: relative(1.3),
402 background_color: None,
403 underline: None,
404 strikethrough: None,
405 white_space: WhiteSpace::Normal,
406 };
407 EditorElement::new(
408 &self.api_key_editor,
409 EditorStyle {
410 background: cx.theme().colors().editor_background,
411 local_player: cx.theme().players().local(),
412 text: text_style,
413 ..Default::default()
414 },
415 )
416 }
417
418 fn should_render_editor(&self, cx: &mut ViewContext<Self>) -> bool {
419 !self.state.read(cx).is_authenticated()
420 }
421}
422
423impl Render for ConfigurationView {
424 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
425 const INSTRUCTIONS: [&str; 4] = [
426 "To use the Google AI assistant, you need to add your Google AI API key.",
427 "You can create an API key at: https://makersuite.google.com/app/apikey",
428 "",
429 "Paste your Google AI API key below and hit enter to use the assistant:",
430 ];
431
432 let env_var_set = self.state.read(cx).api_key_from_env;
433
434 if self.load_credentials_task.is_some() {
435 div().child(Label::new("Loading credentials...")).into_any()
436 } else if self.should_render_editor(cx) {
437 v_flex()
438 .size_full()
439 .on_action(cx.listener(Self::save_api_key))
440 .children(
441 INSTRUCTIONS.map(|instruction| Label::new(instruction)),
442 )
443 .child(
444 h_flex()
445 .w_full()
446 .my_2()
447 .px_2()
448 .py_1()
449 .bg(cx.theme().colors().editor_background)
450 .rounded_md()
451 .child(self.render_api_key_editor(cx)),
452 )
453 .child(
454 Label::new(
455 format!("You can also assign the {GOOGLE_AI_API_KEY_VAR} environment variable and restart Zed."),
456 )
457 .size(LabelSize::Small),
458 )
459 .into_any()
460 } else {
461 h_flex()
462 .size_full()
463 .justify_between()
464 .child(
465 h_flex()
466 .gap_1()
467 .child(Icon::new(IconName::Check).color(Color::Success))
468 .child(Label::new(if env_var_set {
469 format!("API key set in {GOOGLE_AI_API_KEY_VAR} environment variable.")
470 } else {
471 "API key configured.".to_string()
472 })),
473 )
474 .child(
475 Button::new("reset-key", "Reset key")
476 .icon(Some(IconName::Trash))
477 .icon_size(IconSize::Small)
478 .icon_position(IconPosition::Start)
479 .disabled(env_var_set)
480 .when(env_var_set, |this| {
481 this.tooltip(|cx| Tooltip::text(format!("To reset your API key, unset the {GOOGLE_AI_API_KEY_VAR} environment variable."), cx))
482 })
483 .on_click(cx.listener(|this, _, cx| this.reset_api_key(cx))),
484 )
485 .into_any()
486 }
487 }
488}