1use anyhow::{Context as _, Result};
2use collections::BTreeMap;
3use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
4use google_ai::{
5 FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
6 ThinkingConfig, UsageMetadata,
7};
8use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
9use http_client::HttpClient;
10use language_model::{
11 AuthenticateError, ConfigurationViewTargetAgent, EnvVar, LanguageModelCompletionError,
12 LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat,
13 LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
14};
15use language_model::{
16 IconOrSvg, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
17 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
18 LanguageModelRequest, RateLimiter, Role,
19};
20use schemars::JsonSchema;
21use serde::{Deserialize, Serialize};
22pub use settings::GoogleAvailableModel as AvailableModel;
23use settings::{Settings, SettingsStore};
24use std::pin::Pin;
25use std::sync::{
26 Arc, LazyLock,
27 atomic::{self, AtomicU64},
28};
29use strum::IntoEnumIterator;
30use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
31use ui_input::InputField;
32use util::ResultExt;
33
34use language_model::ApiKeyState;
35
36const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID;
37const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME;
38
39#[derive(Default, Clone, Debug, PartialEq)]
40pub struct GoogleSettings {
41 pub api_url: String,
42 pub available_models: Vec<AvailableModel>,
43}
44
45#[derive(Clone, Copy, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
46#[serde(tag = "type", rename_all = "lowercase")]
47pub enum ModelMode {
48 #[default]
49 Default,
50 Thinking {
51 /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
52 budget_tokens: Option<u32>,
53 },
54}
55
56pub struct GoogleLanguageModelProvider {
57 http_client: Arc<dyn HttpClient>,
58 state: Entity<State>,
59}
60
61pub struct State {
62 api_key_state: ApiKeyState,
63}
64
65const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY";
66const GOOGLE_AI_API_KEY_VAR_NAME: &str = "GOOGLE_AI_API_KEY";
67
68static API_KEY_ENV_VAR: LazyLock<EnvVar> = LazyLock::new(|| {
69 // Try GEMINI_API_KEY first as primary, fallback to GOOGLE_AI_API_KEY
70 EnvVar::new(GEMINI_API_KEY_VAR_NAME.into()).or(EnvVar::new(GOOGLE_AI_API_KEY_VAR_NAME.into()))
71});
72
73impl State {
74 fn is_authenticated(&self) -> bool {
75 self.api_key_state.has_key()
76 }
77
78 fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
79 let api_url = GoogleLanguageModelProvider::api_url(cx);
80 self.api_key_state
81 .store(api_url, api_key, |this| &mut this.api_key_state, cx)
82 }
83
84 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
85 let api_url = GoogleLanguageModelProvider::api_url(cx);
86 self.api_key_state
87 .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
88 }
89}
90
91impl GoogleLanguageModelProvider {
92 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
93 let state = cx.new(|cx| {
94 cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
95 let api_url = Self::api_url(cx);
96 this.api_key_state
97 .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
98 cx.notify();
99 })
100 .detach();
101 State {
102 api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
103 }
104 });
105
106 Self { http_client, state }
107 }
108
109 fn create_language_model(&self, model: google_ai::Model) -> Arc<dyn LanguageModel> {
110 Arc::new(GoogleLanguageModel {
111 id: LanguageModelId::from(model.id().to_string()),
112 model,
113 state: self.state.clone(),
114 http_client: self.http_client.clone(),
115 request_limiter: RateLimiter::new(4),
116 })
117 }
118
119 fn settings(cx: &App) -> &GoogleSettings {
120 &crate::AllLanguageModelSettings::get_global(cx).google
121 }
122
123 fn api_url(cx: &App) -> SharedString {
124 let api_url = &Self::settings(cx).api_url;
125 if api_url.is_empty() {
126 google_ai::API_URL.into()
127 } else {
128 SharedString::new(api_url.as_str())
129 }
130 }
131}
132
133impl LanguageModelProviderState for GoogleLanguageModelProvider {
134 type ObservableEntity = State;
135
136 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
137 Some(self.state.clone())
138 }
139}
140
141impl LanguageModelProvider for GoogleLanguageModelProvider {
142 fn id(&self) -> LanguageModelProviderId {
143 PROVIDER_ID
144 }
145
146 fn name(&self) -> LanguageModelProviderName {
147 PROVIDER_NAME
148 }
149
150 fn icon(&self) -> IconOrSvg {
151 IconOrSvg::Icon(IconName::AiGoogle)
152 }
153
154 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
155 Some(self.create_language_model(google_ai::Model::default()))
156 }
157
158 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
159 Some(self.create_language_model(google_ai::Model::default_fast()))
160 }
161
162 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
163 let mut models = BTreeMap::default();
164
165 // Add base models from google_ai::Model::iter()
166 for model in google_ai::Model::iter() {
167 if !matches!(model, google_ai::Model::Custom { .. }) {
168 models.insert(model.id().to_string(), model);
169 }
170 }
171
172 // Override with available models from settings
173 for model in &GoogleLanguageModelProvider::settings(cx).available_models {
174 models.insert(
175 model.name.clone(),
176 google_ai::Model::Custom {
177 name: model.name.clone(),
178 display_name: model.display_name.clone(),
179 max_tokens: model.max_tokens,
180 mode: model.mode.unwrap_or_default(),
181 },
182 );
183 }
184
185 models
186 .into_values()
187 .map(|model| {
188 Arc::new(GoogleLanguageModel {
189 id: LanguageModelId::from(model.id().to_string()),
190 model,
191 state: self.state.clone(),
192 http_client: self.http_client.clone(),
193 request_limiter: RateLimiter::new(4),
194 }) as Arc<dyn LanguageModel>
195 })
196 .collect()
197 }
198
199 fn is_authenticated(&self, cx: &App) -> bool {
200 self.state.read(cx).is_authenticated()
201 }
202
203 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
204 self.state.update(cx, |state, cx| state.authenticate(cx))
205 }
206
207 fn configuration_view(
208 &self,
209 target_agent: language_model::ConfigurationViewTargetAgent,
210 window: &mut Window,
211 cx: &mut App,
212 ) -> AnyView {
213 cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx))
214 .into()
215 }
216
217 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
218 self.state
219 .update(cx, |state, cx| state.set_api_key(None, cx))
220 }
221}
222
223pub struct GoogleLanguageModel {
224 id: LanguageModelId,
225 model: google_ai::Model,
226 state: Entity<State>,
227 http_client: Arc<dyn HttpClient>,
228 request_limiter: RateLimiter,
229}
230
231impl GoogleLanguageModel {
232 fn stream_completion(
233 &self,
234 request: google_ai::GenerateContentRequest,
235 cx: &AsyncApp,
236 ) -> BoxFuture<
237 'static,
238 Result<futures::stream::BoxStream<'static, Result<GenerateContentResponse>>>,
239 > {
240 let http_client = self.http_client.clone();
241
242 let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
243 let api_url = GoogleLanguageModelProvider::api_url(cx);
244 (state.api_key_state.key(&api_url), api_url)
245 });
246
247 async move {
248 let api_key = api_key.context("Missing Google API key")?;
249 let request = google_ai::stream_generate_content(
250 http_client.as_ref(),
251 &api_url,
252 &api_key,
253 request,
254 );
255 request.await.context("failed to stream completion")
256 }
257 .boxed()
258 }
259}
260
261impl LanguageModel for GoogleLanguageModel {
262 fn id(&self) -> LanguageModelId {
263 self.id.clone()
264 }
265
266 fn name(&self) -> LanguageModelName {
267 LanguageModelName::from(self.model.display_name().to_string())
268 }
269
270 fn provider_id(&self) -> LanguageModelProviderId {
271 PROVIDER_ID
272 }
273
274 fn provider_name(&self) -> LanguageModelProviderName {
275 PROVIDER_NAME
276 }
277
278 fn supports_tools(&self) -> bool {
279 self.model.supports_tools()
280 }
281
282 fn supports_images(&self) -> bool {
283 self.model.supports_images()
284 }
285
286 fn supports_thinking(&self) -> bool {
287 matches!(self.model.mode(), GoogleModelMode::Thinking { .. })
288 }
289
290 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
291 match choice {
292 LanguageModelToolChoice::Auto
293 | LanguageModelToolChoice::Any
294 | LanguageModelToolChoice::None => true,
295 }
296 }
297
298 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
299 LanguageModelToolSchemaFormat::JsonSchemaSubset
300 }
301
302 fn telemetry_id(&self) -> String {
303 format!("google/{}", self.model.request_id())
304 }
305
306 fn max_token_count(&self) -> u64 {
307 self.model.max_token_count()
308 }
309
310 fn max_output_tokens(&self) -> Option<u64> {
311 self.model.max_output_tokens()
312 }
313
314 fn count_tokens(
315 &self,
316 request: LanguageModelRequest,
317 cx: &App,
318 ) -> BoxFuture<'static, Result<u64>> {
319 let model_id = self.model.request_id().to_string();
320 let request = into_google(request, model_id, self.model.mode());
321 let http_client = self.http_client.clone();
322 let api_url = GoogleLanguageModelProvider::api_url(cx);
323 let api_key = self.state.read(cx).api_key_state.key(&api_url);
324
325 async move {
326 let Some(api_key) = api_key else {
327 return Err(LanguageModelCompletionError::NoApiKey {
328 provider: PROVIDER_NAME,
329 }
330 .into());
331 };
332 let response = google_ai::count_tokens(
333 http_client.as_ref(),
334 &api_url,
335 &api_key,
336 google_ai::CountTokensRequest {
337 generate_content_request: request,
338 },
339 )
340 .await?;
341 Ok(response.total_tokens)
342 }
343 .boxed()
344 }
345
346 fn stream_completion(
347 &self,
348 request: LanguageModelRequest,
349 cx: &AsyncApp,
350 ) -> BoxFuture<
351 'static,
352 Result<
353 futures::stream::BoxStream<
354 'static,
355 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
356 >,
357 LanguageModelCompletionError,
358 >,
359 > {
360 let request = into_google(
361 request,
362 self.model.request_id().to_string(),
363 self.model.mode(),
364 );
365 let request = self.stream_completion(request, cx);
366 let future = self.request_limiter.stream(async move {
367 let response = request.await.map_err(LanguageModelCompletionError::from)?;
368 Ok(GoogleEventMapper::new().map_stream(response))
369 });
370 async move { Ok(future.await?.boxed()) }.boxed()
371 }
372}
373
374pub fn into_google(
375 mut request: LanguageModelRequest,
376 model_id: String,
377 mode: GoogleModelMode,
378) -> google_ai::GenerateContentRequest {
379 fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
380 content
381 .into_iter()
382 .flat_map(|content| match content {
383 language_model::MessageContent::Text(text) => {
384 if !text.is_empty() {
385 vec![Part::TextPart(google_ai::TextPart { text })]
386 } else {
387 vec![]
388 }
389 }
390 language_model::MessageContent::Thinking {
391 text: _,
392 signature: Some(signature),
393 } => {
394 if !signature.is_empty() {
395 vec![Part::ThoughtPart(google_ai::ThoughtPart {
396 thought: true,
397 thought_signature: signature,
398 })]
399 } else {
400 vec![]
401 }
402 }
403 language_model::MessageContent::Thinking { .. } => {
404 vec![]
405 }
406 language_model::MessageContent::RedactedThinking(_) => vec![],
407 language_model::MessageContent::Image(image) => {
408 vec![Part::InlineDataPart(google_ai::InlineDataPart {
409 inline_data: google_ai::GenerativeContentBlob {
410 mime_type: "image/png".to_string(),
411 data: image.source.to_string(),
412 },
413 })]
414 }
415 language_model::MessageContent::ToolUse(tool_use) => {
416 // Normalize empty string signatures to None
417 let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty());
418
419 vec![Part::FunctionCallPart(google_ai::FunctionCallPart {
420 function_call: google_ai::FunctionCall {
421 name: tool_use.name.to_string(),
422 args: tool_use.input,
423 },
424 thought_signature,
425 })]
426 }
427 language_model::MessageContent::ToolResult(tool_result) => {
428 match tool_result.content {
429 language_model::LanguageModelToolResultContent::Text(text) => {
430 vec![Part::FunctionResponsePart(
431 google_ai::FunctionResponsePart {
432 function_response: google_ai::FunctionResponse {
433 name: tool_result.tool_name.to_string(),
434 // The API expects a valid JSON object
435 response: serde_json::json!({
436 "output": text
437 }),
438 },
439 },
440 )]
441 }
442 language_model::LanguageModelToolResultContent::Image(image) => {
443 vec![
444 Part::FunctionResponsePart(google_ai::FunctionResponsePart {
445 function_response: google_ai::FunctionResponse {
446 name: tool_result.tool_name.to_string(),
447 // The API expects a valid JSON object
448 response: serde_json::json!({
449 "output": "Tool responded with an image"
450 }),
451 },
452 }),
453 Part::InlineDataPart(google_ai::InlineDataPart {
454 inline_data: google_ai::GenerativeContentBlob {
455 mime_type: "image/png".to_string(),
456 data: image.source.to_string(),
457 },
458 }),
459 ]
460 }
461 }
462 }
463 })
464 .collect()
465 }
466
467 let system_instructions = if request
468 .messages
469 .first()
470 .is_some_and(|msg| matches!(msg.role, Role::System))
471 {
472 let message = request.messages.remove(0);
473 Some(SystemInstruction {
474 parts: map_content(message.content),
475 })
476 } else {
477 None
478 };
479
480 google_ai::GenerateContentRequest {
481 model: google_ai::ModelName { model_id },
482 system_instruction: system_instructions,
483 contents: request
484 .messages
485 .into_iter()
486 .filter_map(|message| {
487 let parts = map_content(message.content);
488 if parts.is_empty() {
489 None
490 } else {
491 Some(google_ai::Content {
492 parts,
493 role: match message.role {
494 Role::User => google_ai::Role::User,
495 Role::Assistant => google_ai::Role::Model,
496 Role::System => google_ai::Role::User, // Google AI doesn't have a system role
497 },
498 })
499 }
500 })
501 .collect(),
502 generation_config: Some(google_ai::GenerationConfig {
503 candidate_count: Some(1),
504 stop_sequences: Some(request.stop),
505 max_output_tokens: None,
506 temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
507 thinking_config: match (request.thinking_allowed, mode) {
508 (true, GoogleModelMode::Thinking { budget_tokens }) => {
509 budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
510 }
511 _ => None,
512 },
513 top_p: None,
514 top_k: None,
515 }),
516 safety_settings: None,
517 tools: (!request.tools.is_empty()).then(|| {
518 vec![google_ai::Tool {
519 function_declarations: request
520 .tools
521 .into_iter()
522 .map(|tool| FunctionDeclaration {
523 name: tool.name,
524 description: tool.description,
525 parameters: tool.input_schema,
526 })
527 .collect(),
528 }]
529 }),
530 tool_config: request.tool_choice.map(|choice| google_ai::ToolConfig {
531 function_calling_config: google_ai::FunctionCallingConfig {
532 mode: match choice {
533 LanguageModelToolChoice::Auto => google_ai::FunctionCallingMode::Auto,
534 LanguageModelToolChoice::Any => google_ai::FunctionCallingMode::Any,
535 LanguageModelToolChoice::None => google_ai::FunctionCallingMode::None,
536 },
537 allowed_function_names: None,
538 },
539 }),
540 }
541}
542
543pub struct GoogleEventMapper {
544 usage: UsageMetadata,
545 stop_reason: StopReason,
546}
547
548impl GoogleEventMapper {
549 pub fn new() -> Self {
550 Self {
551 usage: UsageMetadata::default(),
552 stop_reason: StopReason::EndTurn,
553 }
554 }
555
556 pub fn map_stream(
557 mut self,
558 events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
559 ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
560 {
561 events
562 .map(Some)
563 .chain(futures::stream::once(async { None }))
564 .flat_map(move |event| {
565 futures::stream::iter(match event {
566 Some(Ok(event)) => self.map_event(event),
567 Some(Err(error)) => {
568 vec![Err(LanguageModelCompletionError::from(error))]
569 }
570 None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
571 })
572 })
573 }
574
575 pub fn map_event(
576 &mut self,
577 event: GenerateContentResponse,
578 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
579 static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
580
581 let mut events: Vec<_> = Vec::new();
582 let mut wants_to_use_tool = false;
583 if let Some(usage_metadata) = event.usage_metadata {
584 update_usage(&mut self.usage, &usage_metadata);
585 events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
586 convert_usage(&self.usage),
587 )))
588 }
589
590 if let Some(prompt_feedback) = event.prompt_feedback
591 && let Some(block_reason) = prompt_feedback.block_reason.as_deref()
592 {
593 self.stop_reason = match block_reason {
594 "SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "IMAGE_SAFETY" => {
595 StopReason::Refusal
596 }
597 _ => {
598 log::error!("Unexpected Google block_reason: {block_reason}");
599 StopReason::Refusal
600 }
601 };
602 events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason)));
603
604 return events;
605 }
606
607 if let Some(candidates) = event.candidates {
608 for candidate in candidates {
609 if let Some(finish_reason) = candidate.finish_reason.as_deref() {
610 self.stop_reason = match finish_reason {
611 "STOP" => StopReason::EndTurn,
612 "MAX_TOKENS" => StopReason::MaxTokens,
613 _ => {
614 log::error!("Unexpected google finish_reason: {finish_reason}");
615 StopReason::EndTurn
616 }
617 };
618 }
619 candidate
620 .content
621 .parts
622 .into_iter()
623 .for_each(|part| match part {
624 Part::TextPart(text_part) => {
625 events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text)))
626 }
627 Part::InlineDataPart(_) => {}
628 Part::FunctionCallPart(function_call_part) => {
629 wants_to_use_tool = true;
630 let name: Arc<str> = function_call_part.function_call.name.into();
631 let next_tool_id =
632 TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
633 let id: LanguageModelToolUseId =
634 format!("{}-{}", name, next_tool_id).into();
635
636 // Normalize empty string signatures to None
637 let thought_signature = function_call_part
638 .thought_signature
639 .filter(|s| !s.is_empty());
640
641 events.push(Ok(LanguageModelCompletionEvent::ToolUse(
642 LanguageModelToolUse {
643 id,
644 name,
645 is_input_complete: true,
646 raw_input: function_call_part.function_call.args.to_string(),
647 input: function_call_part.function_call.args,
648 thought_signature,
649 },
650 )));
651 }
652 Part::FunctionResponsePart(_) => {}
653 Part::ThoughtPart(part) => {
654 events.push(Ok(LanguageModelCompletionEvent::Thinking {
655 text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
656 signature: Some(part.thought_signature),
657 }));
658 }
659 });
660 }
661 }
662
663 // Even when Gemini wants to use a Tool, the API
664 // responds with `finish_reason: STOP`
665 if wants_to_use_tool {
666 self.stop_reason = StopReason::ToolUse;
667 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
668 }
669 events
670 }
671}
672
673pub fn count_google_tokens(
674 request: LanguageModelRequest,
675 cx: &App,
676) -> BoxFuture<'static, Result<u64>> {
677 // We couldn't use the GoogleLanguageModelProvider to count tokens because the github copilot doesn't have the access to google_ai directly.
678 // So we have to use tokenizer from tiktoken_rs to count tokens.
679 cx.background_spawn(async move {
680 let messages = request
681 .messages
682 .into_iter()
683 .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
684 role: match message.role {
685 Role::User => "user".into(),
686 Role::Assistant => "assistant".into(),
687 Role::System => "system".into(),
688 },
689 content: Some(message.string_contents()),
690 name: None,
691 function_call: None,
692 })
693 .collect::<Vec<_>>();
694
695 // Tiktoken doesn't yet support these models, so we manually use the
696 // same tokenizer as GPT-4.
697 tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
698 })
699 .boxed()
700}
701
702fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
703 if let Some(prompt_token_count) = new.prompt_token_count {
704 usage.prompt_token_count = Some(prompt_token_count);
705 }
706 if let Some(cached_content_token_count) = new.cached_content_token_count {
707 usage.cached_content_token_count = Some(cached_content_token_count);
708 }
709 if let Some(candidates_token_count) = new.candidates_token_count {
710 usage.candidates_token_count = Some(candidates_token_count);
711 }
712 if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count {
713 usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count);
714 }
715 if let Some(thoughts_token_count) = new.thoughts_token_count {
716 usage.thoughts_token_count = Some(thoughts_token_count);
717 }
718 if let Some(total_token_count) = new.total_token_count {
719 usage.total_token_count = Some(total_token_count);
720 }
721}
722
723fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage {
724 let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
725 let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
726 let input_tokens = prompt_tokens - cached_tokens;
727 let output_tokens = usage.candidates_token_count.unwrap_or(0);
728
729 language_model::TokenUsage {
730 input_tokens,
731 output_tokens,
732 cache_read_input_tokens: cached_tokens,
733 cache_creation_input_tokens: 0,
734 }
735}
736
737struct ConfigurationView {
738 api_key_editor: Entity<InputField>,
739 state: Entity<State>,
740 target_agent: language_model::ConfigurationViewTargetAgent,
741 load_credentials_task: Option<Task<()>>,
742}
743
744impl ConfigurationView {
745 fn new(
746 state: Entity<State>,
747 target_agent: language_model::ConfigurationViewTargetAgent,
748 window: &mut Window,
749 cx: &mut Context<Self>,
750 ) -> Self {
751 cx.observe(&state, |_, _, cx| {
752 cx.notify();
753 })
754 .detach();
755
756 let load_credentials_task = Some(cx.spawn_in(window, {
757 let state = state.clone();
758 async move |this, cx| {
759 if let Some(task) = Some(state.update(cx, |state, cx| state.authenticate(cx))) {
760 // We don't log an error, because "not signed in" is also an error.
761 let _ = task.await;
762 }
763 this.update(cx, |this, cx| {
764 this.load_credentials_task = None;
765 cx.notify();
766 })
767 .log_err();
768 }
769 }));
770
771 Self {
772 api_key_editor: cx.new(|cx| InputField::new(window, cx, "AIzaSy...")),
773 target_agent,
774 state,
775 load_credentials_task,
776 }
777 }
778
779 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
780 let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
781 if api_key.is_empty() {
782 return;
783 }
784
785 // url changes can cause the editor to be displayed again
786 self.api_key_editor
787 .update(cx, |editor, cx| editor.set_text("", window, cx));
788
789 let state = self.state.clone();
790 cx.spawn_in(window, async move |_, cx| {
791 state
792 .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
793 .await
794 })
795 .detach_and_log_err(cx);
796 }
797
798 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
799 self.api_key_editor
800 .update(cx, |editor, cx| editor.set_text("", window, cx));
801
802 let state = self.state.clone();
803 cx.spawn_in(window, async move |_, cx| {
804 state
805 .update(cx, |state, cx| state.set_api_key(None, cx))
806 .await
807 })
808 .detach_and_log_err(cx);
809 }
810
811 fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
812 !self.state.read(cx).is_authenticated()
813 }
814}
815
816impl Render for ConfigurationView {
817 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
818 let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
819 let configured_card_label = if env_var_set {
820 format!(
821 "API key set in {} environment variable",
822 API_KEY_ENV_VAR.name
823 )
824 } else {
825 let api_url = GoogleLanguageModelProvider::api_url(cx);
826 if api_url == google_ai::API_URL {
827 "API key configured".to_string()
828 } else {
829 format!("API key configured for {}", api_url)
830 }
831 };
832
833 if self.load_credentials_task.is_some() {
834 div()
835 .child(Label::new("Loading credentials..."))
836 .into_any_element()
837 } else if self.should_render_editor(cx) {
838 v_flex()
839 .size_full()
840 .on_action(cx.listener(Self::save_api_key))
841 .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent {
842 ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Google AI".into(),
843 ConfigurationViewTargetAgent::Other(agent) => agent.clone(),
844 })))
845 .child(
846 List::new()
847 .child(
848 ListBulletItem::new("")
849 .child(Label::new("Create one by visiting"))
850 .child(ButtonLink::new("Google AI's console", "https://aistudio.google.com/app/apikey"))
851 )
852 .child(
853 ListBulletItem::new("Paste your API key below and hit enter to start using the agent")
854 )
855 )
856 .child(self.api_key_editor.clone())
857 .child(
858 Label::new(
859 format!("You can also set the {GEMINI_API_KEY_VAR_NAME} environment variable and restart Zed."),
860 )
861 .size(LabelSize::Small).color(Color::Muted),
862 )
863 .into_any_element()
864 } else {
865 ConfiguredApiCard::new(configured_card_label)
866 .disabled(env_var_set)
867 .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
868 .when(env_var_set, |this| {
869 this.tooltip_label(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR_NAME} and {GOOGLE_AI_API_KEY_VAR_NAME} environment variables are unset."))
870 })
871 .into_any_element()
872 }
873 }
874}
875
876#[cfg(test)]
877mod tests {
878 use super::*;
879 use google_ai::{
880 Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse,
881 Part, Role as GoogleRole, TextPart,
882 };
883 use language_model::{LanguageModelToolUseId, MessageContent, Role};
884 use serde_json::json;
885
886 #[test]
887 fn test_function_call_with_signature_creates_tool_use_with_signature() {
888 let mut mapper = GoogleEventMapper::new();
889
890 let response = GenerateContentResponse {
891 candidates: Some(vec![GenerateContentCandidate {
892 index: Some(0),
893 content: Content {
894 parts: vec![Part::FunctionCallPart(FunctionCallPart {
895 function_call: FunctionCall {
896 name: "test_function".to_string(),
897 args: json!({"arg": "value"}),
898 },
899 thought_signature: Some("test_signature_123".to_string()),
900 })],
901 role: GoogleRole::Model,
902 },
903 finish_reason: None,
904 finish_message: None,
905 safety_ratings: None,
906 citation_metadata: None,
907 }]),
908 prompt_feedback: None,
909 usage_metadata: None,
910 };
911
912 let events = mapper.map_event(response);
913
914 assert_eq!(events.len(), 2); // ToolUse event + Stop event
915
916 if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
917 assert_eq!(tool_use.name.as_ref(), "test_function");
918 assert_eq!(
919 tool_use.thought_signature.as_deref(),
920 Some("test_signature_123")
921 );
922 } else {
923 panic!("Expected ToolUse event");
924 }
925 }
926
927 #[test]
928 fn test_function_call_without_signature_has_none() {
929 let mut mapper = GoogleEventMapper::new();
930
931 let response = GenerateContentResponse {
932 candidates: Some(vec![GenerateContentCandidate {
933 index: Some(0),
934 content: Content {
935 parts: vec![Part::FunctionCallPart(FunctionCallPart {
936 function_call: FunctionCall {
937 name: "test_function".to_string(),
938 args: json!({"arg": "value"}),
939 },
940 thought_signature: None,
941 })],
942 role: GoogleRole::Model,
943 },
944 finish_reason: None,
945 finish_message: None,
946 safety_ratings: None,
947 citation_metadata: None,
948 }]),
949 prompt_feedback: None,
950 usage_metadata: None,
951 };
952
953 let events = mapper.map_event(response);
954
955 if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
956 assert_eq!(tool_use.thought_signature, None);
957 } else {
958 panic!("Expected ToolUse event");
959 }
960 }
961
962 #[test]
963 fn test_empty_string_signature_normalized_to_none() {
964 let mut mapper = GoogleEventMapper::new();
965
966 let response = GenerateContentResponse {
967 candidates: Some(vec![GenerateContentCandidate {
968 index: Some(0),
969 content: Content {
970 parts: vec![Part::FunctionCallPart(FunctionCallPart {
971 function_call: FunctionCall {
972 name: "test_function".to_string(),
973 args: json!({"arg": "value"}),
974 },
975 thought_signature: Some("".to_string()),
976 })],
977 role: GoogleRole::Model,
978 },
979 finish_reason: None,
980 finish_message: None,
981 safety_ratings: None,
982 citation_metadata: None,
983 }]),
984 prompt_feedback: None,
985 usage_metadata: None,
986 };
987
988 let events = mapper.map_event(response);
989
990 if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
991 assert_eq!(tool_use.thought_signature, None);
992 } else {
993 panic!("Expected ToolUse event");
994 }
995 }
996
997 #[test]
998 fn test_parallel_function_calls_preserve_signatures() {
999 let mut mapper = GoogleEventMapper::new();
1000
1001 let response = GenerateContentResponse {
1002 candidates: Some(vec![GenerateContentCandidate {
1003 index: Some(0),
1004 content: Content {
1005 parts: vec![
1006 Part::FunctionCallPart(FunctionCallPart {
1007 function_call: FunctionCall {
1008 name: "function_1".to_string(),
1009 args: json!({"arg": "value1"}),
1010 },
1011 thought_signature: Some("signature_1".to_string()),
1012 }),
1013 Part::FunctionCallPart(FunctionCallPart {
1014 function_call: FunctionCall {
1015 name: "function_2".to_string(),
1016 args: json!({"arg": "value2"}),
1017 },
1018 thought_signature: None,
1019 }),
1020 ],
1021 role: GoogleRole::Model,
1022 },
1023 finish_reason: None,
1024 finish_message: None,
1025 safety_ratings: None,
1026 citation_metadata: None,
1027 }]),
1028 prompt_feedback: None,
1029 usage_metadata: None,
1030 };
1031
1032 let events = mapper.map_event(response);
1033
1034 assert_eq!(events.len(), 3); // 2 ToolUse events + Stop event
1035
1036 if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
1037 assert_eq!(tool_use.name.as_ref(), "function_1");
1038 assert_eq!(tool_use.thought_signature.as_deref(), Some("signature_1"));
1039 } else {
1040 panic!("Expected ToolUse event for function_1");
1041 }
1042
1043 if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
1044 assert_eq!(tool_use.name.as_ref(), "function_2");
1045 assert_eq!(tool_use.thought_signature, None);
1046 } else {
1047 panic!("Expected ToolUse event for function_2");
1048 }
1049 }
1050
1051 #[test]
1052 fn test_tool_use_with_signature_converts_to_function_call_part() {
1053 let tool_use = language_model::LanguageModelToolUse {
1054 id: LanguageModelToolUseId::from("test_id"),
1055 name: "test_function".into(),
1056 raw_input: json!({"arg": "value"}).to_string(),
1057 input: json!({"arg": "value"}),
1058 is_input_complete: true,
1059 thought_signature: Some("test_signature_456".to_string()),
1060 };
1061
1062 let request = super::into_google(
1063 LanguageModelRequest {
1064 messages: vec![language_model::LanguageModelRequestMessage {
1065 role: Role::Assistant,
1066 content: vec![MessageContent::ToolUse(tool_use)],
1067 cache: false,
1068 reasoning_details: None,
1069 }],
1070 ..Default::default()
1071 },
1072 "gemini-2.5-flash".to_string(),
1073 GoogleModelMode::Default,
1074 );
1075
1076 assert_eq!(request.contents[0].parts.len(), 1);
1077 if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1078 assert_eq!(fc_part.function_call.name, "test_function");
1079 assert_eq!(
1080 fc_part.thought_signature.as_deref(),
1081 Some("test_signature_456")
1082 );
1083 } else {
1084 panic!("Expected FunctionCallPart");
1085 }
1086 }
1087
1088 #[test]
1089 fn test_tool_use_without_signature_omits_field() {
1090 let tool_use = language_model::LanguageModelToolUse {
1091 id: LanguageModelToolUseId::from("test_id"),
1092 name: "test_function".into(),
1093 raw_input: json!({"arg": "value"}).to_string(),
1094 input: json!({"arg": "value"}),
1095 is_input_complete: true,
1096 thought_signature: None,
1097 };
1098
1099 let request = super::into_google(
1100 LanguageModelRequest {
1101 messages: vec![language_model::LanguageModelRequestMessage {
1102 role: Role::Assistant,
1103 content: vec![MessageContent::ToolUse(tool_use)],
1104 cache: false,
1105 reasoning_details: None,
1106 }],
1107 ..Default::default()
1108 },
1109 "gemini-2.5-flash".to_string(),
1110 GoogleModelMode::Default,
1111 );
1112
1113 assert_eq!(request.contents[0].parts.len(), 1);
1114 if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1115 assert_eq!(fc_part.thought_signature, None);
1116 } else {
1117 panic!("Expected FunctionCallPart");
1118 }
1119 }
1120
1121 #[test]
1122 fn test_empty_signature_in_tool_use_normalized_to_none() {
1123 let tool_use = language_model::LanguageModelToolUse {
1124 id: LanguageModelToolUseId::from("test_id"),
1125 name: "test_function".into(),
1126 raw_input: json!({"arg": "value"}).to_string(),
1127 input: json!({"arg": "value"}),
1128 is_input_complete: true,
1129 thought_signature: Some("".to_string()),
1130 };
1131
1132 let request = super::into_google(
1133 LanguageModelRequest {
1134 messages: vec![language_model::LanguageModelRequestMessage {
1135 role: Role::Assistant,
1136 content: vec![MessageContent::ToolUse(tool_use)],
1137 cache: false,
1138 reasoning_details: None,
1139 }],
1140 ..Default::default()
1141 },
1142 "gemini-2.5-flash".to_string(),
1143 GoogleModelMode::Default,
1144 );
1145
1146 if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1147 assert_eq!(fc_part.thought_signature, None);
1148 } else {
1149 panic!("Expected FunctionCallPart");
1150 }
1151 }
1152
1153 #[test]
1154 fn test_round_trip_preserves_signature() {
1155 let mut mapper = GoogleEventMapper::new();
1156
1157 // Simulate receiving a response from Google with a signature
1158 let response = GenerateContentResponse {
1159 candidates: Some(vec![GenerateContentCandidate {
1160 index: Some(0),
1161 content: Content {
1162 parts: vec![Part::FunctionCallPart(FunctionCallPart {
1163 function_call: FunctionCall {
1164 name: "test_function".to_string(),
1165 args: json!({"arg": "value"}),
1166 },
1167 thought_signature: Some("round_trip_sig".to_string()),
1168 })],
1169 role: GoogleRole::Model,
1170 },
1171 finish_reason: None,
1172 finish_message: None,
1173 safety_ratings: None,
1174 citation_metadata: None,
1175 }]),
1176 prompt_feedback: None,
1177 usage_metadata: None,
1178 };
1179
1180 let events = mapper.map_event(response);
1181
1182 let tool_use = if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
1183 tool_use.clone()
1184 } else {
1185 panic!("Expected ToolUse event");
1186 };
1187
1188 // Convert back to Google format
1189 let request = super::into_google(
1190 LanguageModelRequest {
1191 messages: vec![language_model::LanguageModelRequestMessage {
1192 role: Role::Assistant,
1193 content: vec![MessageContent::ToolUse(tool_use)],
1194 cache: false,
1195 reasoning_details: None,
1196 }],
1197 ..Default::default()
1198 },
1199 "gemini-2.5-flash".to_string(),
1200 GoogleModelMode::Default,
1201 );
1202
1203 // Verify signature is preserved
1204 if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1205 assert_eq!(fc_part.thought_signature.as_deref(), Some("round_trip_sig"));
1206 } else {
1207 panic!("Expected FunctionCallPart");
1208 }
1209 }
1210
1211 #[test]
1212 fn test_mixed_text_and_function_call_with_signature() {
1213 let mut mapper = GoogleEventMapper::new();
1214
1215 let response = GenerateContentResponse {
1216 candidates: Some(vec![GenerateContentCandidate {
1217 index: Some(0),
1218 content: Content {
1219 parts: vec![
1220 Part::TextPart(TextPart {
1221 text: "I'll help with that.".to_string(),
1222 }),
1223 Part::FunctionCallPart(FunctionCallPart {
1224 function_call: FunctionCall {
1225 name: "helper_function".to_string(),
1226 args: json!({"query": "help"}),
1227 },
1228 thought_signature: Some("mixed_sig".to_string()),
1229 }),
1230 ],
1231 role: GoogleRole::Model,
1232 },
1233 finish_reason: None,
1234 finish_message: None,
1235 safety_ratings: None,
1236 citation_metadata: None,
1237 }]),
1238 prompt_feedback: None,
1239 usage_metadata: None,
1240 };
1241
1242 let events = mapper.map_event(response);
1243
1244 assert_eq!(events.len(), 3); // Text event + ToolUse event + Stop event
1245
1246 if let Ok(LanguageModelCompletionEvent::Text(text)) = &events[0] {
1247 assert_eq!(text, "I'll help with that.");
1248 } else {
1249 panic!("Expected Text event");
1250 }
1251
1252 if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
1253 assert_eq!(tool_use.name.as_ref(), "helper_function");
1254 assert_eq!(tool_use.thought_signature.as_deref(), Some("mixed_sig"));
1255 } else {
1256 panic!("Expected ToolUse event");
1257 }
1258 }
1259
1260 #[test]
1261 fn test_special_characters_in_signature_preserved() {
1262 let mut mapper = GoogleEventMapper::new();
1263
1264 let signature_with_special_chars = "sig<>\"'&%$#@!{}[]".to_string();
1265
1266 let response = GenerateContentResponse {
1267 candidates: Some(vec![GenerateContentCandidate {
1268 index: Some(0),
1269 content: Content {
1270 parts: vec![Part::FunctionCallPart(FunctionCallPart {
1271 function_call: FunctionCall {
1272 name: "test_function".to_string(),
1273 args: json!({"arg": "value"}),
1274 },
1275 thought_signature: Some(signature_with_special_chars.clone()),
1276 })],
1277 role: GoogleRole::Model,
1278 },
1279 finish_reason: None,
1280 finish_message: None,
1281 safety_ratings: None,
1282 citation_metadata: None,
1283 }]),
1284 prompt_feedback: None,
1285 usage_metadata: None,
1286 };
1287
1288 let events = mapper.map_event(response);
1289
1290 if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
1291 assert_eq!(
1292 tool_use.thought_signature.as_deref(),
1293 Some(signature_with_special_chars.as_str())
1294 );
1295 } else {
1296 panic!("Expected ToolUse event");
1297 }
1298 }
1299}