1use anyhow::{anyhow, Result};
2use collections::BTreeMap;
3use editor::{Editor, EditorElement, EditorStyle};
4use futures::{future::BoxFuture, FutureExt, StreamExt};
5use google_ai::stream_generate_content;
6use gpui::{
7 AnyView, AppContext, AsyncAppContext, FocusHandle, FocusableView, FontStyle, ModelContext,
8 Subscription, Task, TextStyle, View, WhiteSpace,
9};
10use http_client::HttpClient;
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use settings::{Settings, SettingsStore};
14use std::{future, sync::Arc, time::Duration};
15use strum::IntoEnumIterator;
16use theme::ThemeSettings;
17use ui::{prelude::*, Indicator};
18use util::ResultExt;
19
20use crate::{
21 settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
22 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
23 LanguageModelProviderState, LanguageModelRequest, RateLimiter,
24};
25
26const PROVIDER_ID: &str = "google";
27const PROVIDER_NAME: &str = "Google AI";
28
29#[derive(Default, Clone, Debug, PartialEq)]
30pub struct GoogleSettings {
31 pub api_url: String,
32 pub low_speed_timeout: Option<Duration>,
33 pub available_models: Vec<AvailableModel>,
34}
35
36#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
37pub struct AvailableModel {
38 name: String,
39 max_tokens: usize,
40}
41
42pub struct GoogleLanguageModelProvider {
43 http_client: Arc<dyn HttpClient>,
44 state: gpui::Model<State>,
45}
46
47pub struct State {
48 api_key: Option<String>,
49 _subscription: Subscription,
50}
51
52impl State {
53 fn is_authenticated(&self) -> bool {
54 self.api_key.is_some()
55 }
56
57 fn reset_api_key(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
58 let delete_credentials =
59 cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url);
60 cx.spawn(|this, mut cx| async move {
61 delete_credentials.await.ok();
62 this.update(&mut cx, |this, cx| {
63 this.api_key = None;
64 cx.notify();
65 })
66 })
67 }
68}
69
70impl GoogleLanguageModelProvider {
71 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
72 let state = cx.new_model(|cx| State {
73 api_key: None,
74 _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
75 cx.notify();
76 }),
77 });
78
79 Self { http_client, state }
80 }
81}
82
83impl LanguageModelProviderState for GoogleLanguageModelProvider {
84 type ObservableEntity = State;
85
86 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
87 Some(self.state.clone())
88 }
89}
90
91impl LanguageModelProvider for GoogleLanguageModelProvider {
92 fn id(&self) -> LanguageModelProviderId {
93 LanguageModelProviderId(PROVIDER_ID.into())
94 }
95
96 fn name(&self) -> LanguageModelProviderName {
97 LanguageModelProviderName(PROVIDER_NAME.into())
98 }
99
100 fn icon(&self) -> IconName {
101 IconName::AiGoogle
102 }
103
104 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
105 let mut models = BTreeMap::default();
106
107 // Add base models from google_ai::Model::iter()
108 for model in google_ai::Model::iter() {
109 if !matches!(model, google_ai::Model::Custom { .. }) {
110 models.insert(model.id().to_string(), model);
111 }
112 }
113
114 // Override with available models from settings
115 for model in &AllLanguageModelSettings::get_global(cx)
116 .google
117 .available_models
118 {
119 models.insert(
120 model.name.clone(),
121 google_ai::Model::Custom {
122 name: model.name.clone(),
123 max_tokens: model.max_tokens,
124 },
125 );
126 }
127
128 models
129 .into_values()
130 .map(|model| {
131 Arc::new(GoogleLanguageModel {
132 id: LanguageModelId::from(model.id().to_string()),
133 model,
134 state: self.state.clone(),
135 http_client: self.http_client.clone(),
136 rate_limiter: RateLimiter::new(4),
137 }) as Arc<dyn LanguageModel>
138 })
139 .collect()
140 }
141
142 fn is_authenticated(&self, cx: &AppContext) -> bool {
143 self.state.read(cx).is_authenticated()
144 }
145
146 fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
147 if self.is_authenticated(cx) {
148 Task::ready(Ok(()))
149 } else {
150 let api_url = AllLanguageModelSettings::get_global(cx)
151 .google
152 .api_url
153 .clone();
154 let state = self.state.clone();
155 cx.spawn(|mut cx| async move {
156 let api_key = if let Ok(api_key) = std::env::var("GOOGLE_AI_API_KEY") {
157 api_key
158 } else {
159 let (_, api_key) = cx
160 .update(|cx| cx.read_credentials(&api_url))?
161 .await?
162 .ok_or_else(|| anyhow!("credentials not found"))?;
163 String::from_utf8(api_key)?
164 };
165
166 state.update(&mut cx, |this, cx| {
167 this.api_key = Some(api_key);
168 cx.notify();
169 })
170 })
171 }
172 }
173
174 fn configuration_view(&self, cx: &mut WindowContext) -> (AnyView, Option<FocusHandle>) {
175 let view = cx.new_view(|cx| ConfigurationView::new(self.state.clone(), cx));
176
177 let focus_handle = view.focus_handle(cx);
178 (view.into(), Some(focus_handle))
179 }
180
181 fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
182 let state = self.state.clone();
183 let delete_credentials =
184 cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url);
185 cx.spawn(|mut cx| async move {
186 delete_credentials.await.log_err();
187 state.update(&mut cx, |this, cx| {
188 this.api_key = None;
189 cx.notify();
190 })
191 })
192 }
193}
194
195pub struct GoogleLanguageModel {
196 id: LanguageModelId,
197 model: google_ai::Model,
198 state: gpui::Model<State>,
199 http_client: Arc<dyn HttpClient>,
200 rate_limiter: RateLimiter,
201}
202
203impl LanguageModel for GoogleLanguageModel {
204 fn id(&self) -> LanguageModelId {
205 self.id.clone()
206 }
207
208 fn name(&self) -> LanguageModelName {
209 LanguageModelName::from(self.model.display_name().to_string())
210 }
211
212 fn provider_id(&self) -> LanguageModelProviderId {
213 LanguageModelProviderId(PROVIDER_ID.into())
214 }
215
216 fn provider_name(&self) -> LanguageModelProviderName {
217 LanguageModelProviderName(PROVIDER_NAME.into())
218 }
219
220 fn telemetry_id(&self) -> String {
221 format!("google/{}", self.model.id())
222 }
223
224 fn max_token_count(&self) -> usize {
225 self.model.max_token_count()
226 }
227
228 fn count_tokens(
229 &self,
230 request: LanguageModelRequest,
231 cx: &AppContext,
232 ) -> BoxFuture<'static, Result<usize>> {
233 let request = request.into_google(self.model.id().to_string());
234 let http_client = self.http_client.clone();
235 let api_key = self.state.read(cx).api_key.clone();
236 let api_url = AllLanguageModelSettings::get_global(cx)
237 .google
238 .api_url
239 .clone();
240
241 async move {
242 let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
243 let response = google_ai::count_tokens(
244 http_client.as_ref(),
245 &api_url,
246 &api_key,
247 google_ai::CountTokensRequest {
248 contents: request.contents,
249 },
250 )
251 .await?;
252 Ok(response.total_tokens)
253 }
254 .boxed()
255 }
256
257 fn stream_completion(
258 &self,
259 request: LanguageModelRequest,
260 cx: &AsyncAppContext,
261 ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
262 let request = request.into_google(self.model.id().to_string());
263
264 let http_client = self.http_client.clone();
265 let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
266 let settings = &AllLanguageModelSettings::get_global(cx).google;
267 (state.api_key.clone(), settings.api_url.clone())
268 }) else {
269 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
270 };
271
272 let future = self.rate_limiter.stream(async move {
273 let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
274 let response =
275 stream_generate_content(http_client.as_ref(), &api_url, &api_key, request);
276 let events = response.await?;
277 Ok(google_ai::extract_text_from_events(events).boxed())
278 });
279 async move { Ok(future.await?.boxed()) }.boxed()
280 }
281
282 fn use_any_tool(
283 &self,
284 _request: LanguageModelRequest,
285 _name: String,
286 _description: String,
287 _schema: serde_json::Value,
288 _cx: &AsyncAppContext,
289 ) -> BoxFuture<'static, Result<serde_json::Value>> {
290 future::ready(Err(anyhow!("not implemented"))).boxed()
291 }
292}
293
294struct ConfigurationView {
295 focus_handle: FocusHandle,
296 api_key_editor: View<Editor>,
297 state: gpui::Model<State>,
298}
299
300impl ConfigurationView {
301 fn new(state: gpui::Model<State>, cx: &mut ViewContext<Self>) -> Self {
302 let focus_handle = cx.focus_handle();
303
304 cx.on_focus(&focus_handle, |this, cx| {
305 if this.should_render_editor(cx) {
306 this.api_key_editor.read(cx).focus_handle(cx).focus(cx)
307 }
308 })
309 .detach();
310
311 Self {
312 api_key_editor: cx.new_view(|cx| {
313 let mut editor = Editor::single_line(cx);
314 editor.set_placeholder_text("AIzaSy...", cx);
315 editor
316 }),
317 state,
318 focus_handle,
319 }
320 }
321
322 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
323 let api_key = self.api_key_editor.read(cx).text(cx);
324 if api_key.is_empty() {
325 return;
326 }
327
328 let settings = &AllLanguageModelSettings::get_global(cx).google;
329 let write_credentials =
330 cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes());
331 let state = self.state.clone();
332 cx.spawn(|_, mut cx| async move {
333 write_credentials.await?;
334 state.update(&mut cx, |this, cx| {
335 this.api_key = Some(api_key);
336 cx.notify();
337 })
338 })
339 .detach_and_log_err(cx);
340 }
341
342 fn reset_api_key(&mut self, cx: &mut ViewContext<Self>) {
343 self.api_key_editor
344 .update(cx, |editor, cx| editor.set_text("", cx));
345 self.state
346 .update(cx, |state, cx| state.reset_api_key(cx))
347 .detach_and_log_err(cx);
348 }
349
350 fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
351 let settings = ThemeSettings::get_global(cx);
352 let text_style = TextStyle {
353 color: cx.theme().colors().text,
354 font_family: settings.ui_font.family.clone(),
355 font_features: settings.ui_font.features.clone(),
356 font_fallbacks: settings.ui_font.fallbacks.clone(),
357 font_size: rems(0.875).into(),
358 font_weight: settings.ui_font.weight,
359 font_style: FontStyle::Normal,
360 line_height: relative(1.3),
361 background_color: None,
362 underline: None,
363 strikethrough: None,
364 white_space: WhiteSpace::Normal,
365 };
366 EditorElement::new(
367 &self.api_key_editor,
368 EditorStyle {
369 background: cx.theme().colors().editor_background,
370 local_player: cx.theme().players().local(),
371 text: text_style,
372 ..Default::default()
373 },
374 )
375 }
376
377 fn should_render_editor(&self, cx: &mut ViewContext<Self>) -> bool {
378 !self.state.read(cx).is_authenticated()
379 }
380}
381
382impl FocusableView for ConfigurationView {
383 fn focus_handle(&self, _cx: &AppContext) -> FocusHandle {
384 self.focus_handle.clone()
385 }
386}
387
388impl Render for ConfigurationView {
389 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
390 const INSTRUCTIONS: [&str; 4] = [
391 "To use the Google AI assistant, you need to add your Google AI API key.",
392 "You can create an API key at: https://makersuite.google.com/app/apikey",
393 "",
394 "Paste your Google AI API key below and hit enter to use the assistant:",
395 ];
396
397 if self.should_render_editor(cx) {
398 v_flex()
399 .id("google-ai-configuration-view")
400 .track_focus(&self.focus_handle)
401 .size_full()
402 .on_action(cx.listener(Self::save_api_key))
403 .children(
404 INSTRUCTIONS.map(|instruction| Label::new(instruction)),
405 )
406 .child(
407 h_flex()
408 .w_full()
409 .my_2()
410 .px_2()
411 .py_1()
412 .bg(cx.theme().colors().editor_background)
413 .rounded_md()
414 .child(self.render_api_key_editor(cx)),
415 )
416 .child(
417 Label::new(
418 "You can also assign the GOOGLE_AI_API_KEY environment variable and restart Zed.",
419 )
420 .size(LabelSize::Small),
421 )
422 .into_any()
423 } else {
424 h_flex()
425 .id("google-ai-configuration-view")
426 .track_focus(&self.focus_handle)
427 .size_full()
428 .justify_between()
429 .child(
430 h_flex()
431 .gap_2()
432 .child(Indicator::dot().color(Color::Success))
433 .child(Label::new("API Key configured").size(LabelSize::Small)),
434 )
435 .child(
436 Button::new("reset-key", "Reset key")
437 .icon(Some(IconName::Trash))
438 .icon_size(IconSize::Small)
439 .icon_position(IconPosition::Start)
440 .on_click(cx.listener(|this, _, cx| this.reset_api_key(cx))),
441 )
442 .into_any()
443 }
444 }
445}