1use anyhow::{anyhow, Result};
2use collections::BTreeMap;
3use editor::{Editor, EditorElement, EditorStyle};
4use futures::{future::BoxFuture, FutureExt, StreamExt};
5use google_ai::stream_generate_content;
6use gpui::{
7 AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle,
8 View, WhiteSpace,
9};
10use http_client::HttpClient;
11use language_model::LanguageModelCompletionEvent;
12use language_model::{
13 LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
14 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
15 LanguageModelRequest, RateLimiter,
16};
17use schemars::JsonSchema;
18use serde::{Deserialize, Serialize};
19use settings::{Settings, SettingsStore};
20use std::{future, sync::Arc};
21use strum::IntoEnumIterator;
22use theme::ThemeSettings;
23use ui::{prelude::*, Icon, IconName, Tooltip};
24use util::ResultExt;
25
26use crate::AllLanguageModelSettings;
27
28const PROVIDER_ID: &str = "google";
29const PROVIDER_NAME: &str = "Google AI";
30
31#[derive(Default, Clone, Debug, PartialEq)]
32pub struct GoogleSettings {
33 pub api_url: String,
34 pub available_models: Vec<AvailableModel>,
35}
36
37#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
38pub struct AvailableModel {
39 name: String,
40 display_name: Option<String>,
41 max_tokens: usize,
42}
43
44pub struct GoogleLanguageModelProvider {
45 http_client: Arc<dyn HttpClient>,
46 state: gpui::Model<State>,
47}
48
49pub struct State {
50 api_key: Option<String>,
51 api_key_from_env: bool,
52 _subscription: Subscription,
53}
54
55const GOOGLE_AI_API_KEY_VAR: &str = "GOOGLE_AI_API_KEY";
56
57impl State {
58 fn is_authenticated(&self) -> bool {
59 self.api_key.is_some()
60 }
61
62 fn reset_api_key(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
63 let delete_credentials =
64 cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url);
65 cx.spawn(|this, mut cx| async move {
66 delete_credentials.await.ok();
67 this.update(&mut cx, |this, cx| {
68 this.api_key = None;
69 this.api_key_from_env = false;
70 cx.notify();
71 })
72 })
73 }
74
75 fn set_api_key(&mut self, api_key: String, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
76 let settings = &AllLanguageModelSettings::get_global(cx).google;
77 let write_credentials =
78 cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes());
79
80 cx.spawn(|this, mut cx| async move {
81 write_credentials.await?;
82 this.update(&mut cx, |this, cx| {
83 this.api_key = Some(api_key);
84 cx.notify();
85 })
86 })
87 }
88
89 fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
90 if self.is_authenticated() {
91 Task::ready(Ok(()))
92 } else {
93 let api_url = AllLanguageModelSettings::get_global(cx)
94 .google
95 .api_url
96 .clone();
97
98 cx.spawn(|this, mut cx| async move {
99 let (api_key, from_env) = if let Ok(api_key) = std::env::var(GOOGLE_AI_API_KEY_VAR)
100 {
101 (api_key, true)
102 } else {
103 let (_, api_key) = cx
104 .update(|cx| cx.read_credentials(&api_url))?
105 .await?
106 .ok_or_else(|| anyhow!("credentials not found"))?;
107 (String::from_utf8(api_key)?, false)
108 };
109
110 this.update(&mut cx, |this, cx| {
111 this.api_key = Some(api_key);
112 this.api_key_from_env = from_env;
113 cx.notify();
114 })
115 })
116 }
117 }
118}
119
120impl GoogleLanguageModelProvider {
121 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
122 let state = cx.new_model(|cx| State {
123 api_key: None,
124 api_key_from_env: false,
125 _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
126 cx.notify();
127 }),
128 });
129
130 Self { http_client, state }
131 }
132}
133
134impl LanguageModelProviderState for GoogleLanguageModelProvider {
135 type ObservableEntity = State;
136
137 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
138 Some(self.state.clone())
139 }
140}
141
142impl LanguageModelProvider for GoogleLanguageModelProvider {
143 fn id(&self) -> LanguageModelProviderId {
144 LanguageModelProviderId(PROVIDER_ID.into())
145 }
146
147 fn name(&self) -> LanguageModelProviderName {
148 LanguageModelProviderName(PROVIDER_NAME.into())
149 }
150
151 fn icon(&self) -> IconName {
152 IconName::AiGoogle
153 }
154
155 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
156 let mut models = BTreeMap::default();
157
158 // Add base models from google_ai::Model::iter()
159 for model in google_ai::Model::iter() {
160 if !matches!(model, google_ai::Model::Custom { .. }) {
161 models.insert(model.id().to_string(), model);
162 }
163 }
164
165 // Override with available models from settings
166 for model in &AllLanguageModelSettings::get_global(cx)
167 .google
168 .available_models
169 {
170 models.insert(
171 model.name.clone(),
172 google_ai::Model::Custom {
173 name: model.name.clone(),
174 display_name: model.display_name.clone(),
175 max_tokens: model.max_tokens,
176 },
177 );
178 }
179
180 models
181 .into_values()
182 .map(|model| {
183 Arc::new(GoogleLanguageModel {
184 id: LanguageModelId::from(model.id().to_string()),
185 model,
186 state: self.state.clone(),
187 http_client: self.http_client.clone(),
188 rate_limiter: RateLimiter::new(4),
189 }) as Arc<dyn LanguageModel>
190 })
191 .collect()
192 }
193
194 fn is_authenticated(&self, cx: &AppContext) -> bool {
195 self.state.read(cx).is_authenticated()
196 }
197
198 fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
199 self.state.update(cx, |state, cx| state.authenticate(cx))
200 }
201
202 fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
203 cx.new_view(|cx| ConfigurationView::new(self.state.clone(), cx))
204 .into()
205 }
206
207 fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
208 let state = self.state.clone();
209 let delete_credentials =
210 cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url);
211 cx.spawn(|mut cx| async move {
212 delete_credentials.await.log_err();
213 state.update(&mut cx, |this, cx| {
214 this.api_key = None;
215 cx.notify();
216 })
217 })
218 }
219}
220
221pub struct GoogleLanguageModel {
222 id: LanguageModelId,
223 model: google_ai::Model,
224 state: gpui::Model<State>,
225 http_client: Arc<dyn HttpClient>,
226 rate_limiter: RateLimiter,
227}
228
229impl LanguageModel for GoogleLanguageModel {
230 fn id(&self) -> LanguageModelId {
231 self.id.clone()
232 }
233
234 fn name(&self) -> LanguageModelName {
235 LanguageModelName::from(self.model.display_name().to_string())
236 }
237
238 fn provider_id(&self) -> LanguageModelProviderId {
239 LanguageModelProviderId(PROVIDER_ID.into())
240 }
241
242 fn provider_name(&self) -> LanguageModelProviderName {
243 LanguageModelProviderName(PROVIDER_NAME.into())
244 }
245
246 fn telemetry_id(&self) -> String {
247 format!("google/{}", self.model.id())
248 }
249
250 fn max_token_count(&self) -> usize {
251 self.model.max_token_count()
252 }
253
254 fn count_tokens(
255 &self,
256 request: LanguageModelRequest,
257 cx: &AppContext,
258 ) -> BoxFuture<'static, Result<usize>> {
259 let request = request.into_google(self.model.id().to_string());
260 let http_client = self.http_client.clone();
261 let api_key = self.state.read(cx).api_key.clone();
262
263 let settings = &AllLanguageModelSettings::get_global(cx).google;
264 let api_url = settings.api_url.clone();
265
266 async move {
267 let api_key = api_key.ok_or_else(|| anyhow!("Missing Google API key"))?;
268 let response = google_ai::count_tokens(
269 http_client.as_ref(),
270 &api_url,
271 &api_key,
272 google_ai::CountTokensRequest {
273 contents: request.contents,
274 },
275 )
276 .await?;
277 Ok(response.total_tokens)
278 }
279 .boxed()
280 }
281
282 fn stream_completion(
283 &self,
284 request: LanguageModelRequest,
285 cx: &AsyncAppContext,
286 ) -> BoxFuture<
287 'static,
288 Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
289 > {
290 let request = request.into_google(self.model.id().to_string());
291
292 let http_client = self.http_client.clone();
293 let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
294 let settings = &AllLanguageModelSettings::get_global(cx).google;
295 (state.api_key.clone(), settings.api_url.clone())
296 }) else {
297 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
298 };
299
300 let future = self.rate_limiter.stream(async move {
301 let api_key = api_key.ok_or_else(|| anyhow!("Missing Google API Key"))?;
302 let response =
303 stream_generate_content(http_client.as_ref(), &api_url, &api_key, request);
304 let events = response.await?;
305 Ok(google_ai::extract_text_from_events(events).boxed())
306 });
307 async move {
308 Ok(future
309 .await?
310 .map(|result| result.map(LanguageModelCompletionEvent::Text))
311 .boxed())
312 }
313 .boxed()
314 }
315
316 fn use_any_tool(
317 &self,
318 _request: LanguageModelRequest,
319 _name: String,
320 _description: String,
321 _schema: serde_json::Value,
322 _cx: &AsyncAppContext,
323 ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
324 future::ready(Err(anyhow!("not implemented"))).boxed()
325 }
326}
327
328struct ConfigurationView {
329 api_key_editor: View<Editor>,
330 state: gpui::Model<State>,
331 load_credentials_task: Option<Task<()>>,
332}
333
334impl ConfigurationView {
335 fn new(state: gpui::Model<State>, cx: &mut ViewContext<Self>) -> Self {
336 cx.observe(&state, |_, _, cx| {
337 cx.notify();
338 })
339 .detach();
340
341 let load_credentials_task = Some(cx.spawn({
342 let state = state.clone();
343 |this, mut cx| async move {
344 if let Some(task) = state
345 .update(&mut cx, |state, cx| state.authenticate(cx))
346 .log_err()
347 {
348 // We don't log an error, because "not signed in" is also an error.
349 let _ = task.await;
350 }
351 this.update(&mut cx, |this, cx| {
352 this.load_credentials_task = None;
353 cx.notify();
354 })
355 .log_err();
356 }
357 }));
358
359 Self {
360 api_key_editor: cx.new_view(|cx| {
361 let mut editor = Editor::single_line(cx);
362 editor.set_placeholder_text("AIzaSy...", cx);
363 editor
364 }),
365 state,
366 load_credentials_task,
367 }
368 }
369
370 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
371 let api_key = self.api_key_editor.read(cx).text(cx);
372 if api_key.is_empty() {
373 return;
374 }
375
376 let state = self.state.clone();
377 cx.spawn(|_, mut cx| async move {
378 state
379 .update(&mut cx, |state, cx| state.set_api_key(api_key, cx))?
380 .await
381 })
382 .detach_and_log_err(cx);
383
384 cx.notify();
385 }
386
387 fn reset_api_key(&mut self, cx: &mut ViewContext<Self>) {
388 self.api_key_editor
389 .update(cx, |editor, cx| editor.set_text("", cx));
390
391 let state = self.state.clone();
392 cx.spawn(|_, mut cx| async move {
393 state
394 .update(&mut cx, |state, cx| state.reset_api_key(cx))?
395 .await
396 })
397 .detach_and_log_err(cx);
398
399 cx.notify();
400 }
401
402 fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
403 let settings = ThemeSettings::get_global(cx);
404 let text_style = TextStyle {
405 color: cx.theme().colors().text,
406 font_family: settings.ui_font.family.clone(),
407 font_features: settings.ui_font.features.clone(),
408 font_fallbacks: settings.ui_font.fallbacks.clone(),
409 font_size: rems(0.875).into(),
410 font_weight: settings.ui_font.weight,
411 font_style: FontStyle::Normal,
412 line_height: relative(1.3),
413 background_color: None,
414 underline: None,
415 strikethrough: None,
416 white_space: WhiteSpace::Normal,
417 truncate: None,
418 };
419 EditorElement::new(
420 &self.api_key_editor,
421 EditorStyle {
422 background: cx.theme().colors().editor_background,
423 local_player: cx.theme().players().local(),
424 text: text_style,
425 ..Default::default()
426 },
427 )
428 }
429
430 fn should_render_editor(&self, cx: &mut ViewContext<Self>) -> bool {
431 !self.state.read(cx).is_authenticated()
432 }
433}
434
435impl Render for ConfigurationView {
436 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
437 const GOOGLE_CONSOLE_URL: &str = "https://aistudio.google.com/app/apikey";
438 const INSTRUCTIONS: [&str; 3] = [
439 "To use Zed's assistant with Google AI, you need to add an API key. Follow these steps:",
440 "- Create one by visiting:",
441 "- Paste your API key below and hit enter to use the assistant",
442 ];
443
444 let env_var_set = self.state.read(cx).api_key_from_env;
445
446 if self.load_credentials_task.is_some() {
447 div().child(Label::new("Loading credentials...")).into_any()
448 } else if self.should_render_editor(cx) {
449 v_flex()
450 .size_full()
451 .on_action(cx.listener(Self::save_api_key))
452 .child(Label::new(INSTRUCTIONS[0]))
453 .child(h_flex().child(Label::new(INSTRUCTIONS[1])).child(
454 Button::new("google_console", GOOGLE_CONSOLE_URL)
455 .style(ButtonStyle::Subtle)
456 .icon(IconName::ExternalLink)
457 .icon_size(IconSize::XSmall)
458 .icon_color(Color::Muted)
459 .on_click(move |_, cx| cx.open_url(GOOGLE_CONSOLE_URL))
460 )
461 )
462 .child(Label::new(INSTRUCTIONS[2]))
463 .child(
464 h_flex()
465 .w_full()
466 .my_2()
467 .px_2()
468 .py_1()
469 .bg(cx.theme().colors().editor_background)
470 .rounded_md()
471 .child(self.render_api_key_editor(cx)),
472 )
473 .child(
474 Label::new(
475 format!("You can also assign the {GOOGLE_AI_API_KEY_VAR} environment variable and restart Zed."),
476 )
477 .size(LabelSize::Small),
478 )
479 .into_any()
480 } else {
481 h_flex()
482 .size_full()
483 .justify_between()
484 .child(
485 h_flex()
486 .gap_1()
487 .child(Icon::new(IconName::Check).color(Color::Success))
488 .child(Label::new(if env_var_set {
489 format!("API key set in {GOOGLE_AI_API_KEY_VAR} environment variable.")
490 } else {
491 "API key configured.".to_string()
492 })),
493 )
494 .child(
495 Button::new("reset-key", "Reset key")
496 .icon(Some(IconName::Trash))
497 .icon_size(IconSize::Small)
498 .icon_position(IconPosition::Start)
499 .disabled(env_var_set)
500 .when(env_var_set, |this| {
501 this.tooltip(|cx| Tooltip::text(format!("To reset your API key, unset the {GOOGLE_AI_API_KEY_VAR} environment variable."), cx))
502 })
503 .on_click(cx.listener(|this, _, cx| this.reset_api_key(cx))),
504 )
505 .into_any()
506 }
507 }
508}