google.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use collections::BTreeMap;
   3use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture};
   4use google_ai::{
   5    FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
   6    ThinkingConfig, UsageMetadata,
   7};
   8use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
   9use http_client::HttpClient;
  10use language_model::{
  11    AuthenticateError, ConfigurationViewTargetAgent, EnvVar, LanguageModelCompletionError,
  12    LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat,
  13    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
  14};
  15use language_model::{
  16    LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
  17    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  18    LanguageModelRequest, RateLimiter, Role,
  19};
  20use schemars::JsonSchema;
  21use serde::{Deserialize, Serialize};
  22pub use settings::GoogleAvailableModel as AvailableModel;
  23use settings::{Settings, SettingsStore};
  24use std::pin::Pin;
  25use std::sync::{
  26    Arc, LazyLock,
  27    atomic::{self, AtomicU64},
  28};
  29use strum::IntoEnumIterator;
  30use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
  31use ui_input::InputField;
  32use util::ResultExt;
  33
  34use language_model::ApiKeyState;
  35
  36const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID;
  37const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME;
  38
  39#[derive(Default, Clone, Debug, PartialEq)]
  40pub struct GoogleSettings {
  41    pub api_url: String,
  42    pub available_models: Vec<AvailableModel>,
  43}
  44
  45#[derive(Clone, Copy, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
  46#[serde(tag = "type", rename_all = "lowercase")]
  47pub enum ModelMode {
  48    #[default]
  49    Default,
  50    Thinking {
  51        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  52        budget_tokens: Option<u32>,
  53    },
  54}
  55
  56pub struct GoogleLanguageModelProvider {
  57    http_client: Arc<dyn HttpClient>,
  58    state: Entity<State>,
  59}
  60
  61pub struct State {
  62    api_key_state: ApiKeyState,
  63}
  64
  65const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY";
  66const GOOGLE_AI_API_KEY_VAR_NAME: &str = "GOOGLE_AI_API_KEY";
  67
  68static API_KEY_ENV_VAR: LazyLock<EnvVar> = LazyLock::new(|| {
  69    // Try GEMINI_API_KEY first as primary, fallback to GOOGLE_AI_API_KEY
  70    EnvVar::new(GEMINI_API_KEY_VAR_NAME.into()).or(EnvVar::new(GOOGLE_AI_API_KEY_VAR_NAME.into()))
  71});
  72
  73impl State {
  74    fn is_authenticated(&self) -> bool {
  75        self.api_key_state.has_key()
  76    }
  77
  78    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
  79        let api_url = GoogleLanguageModelProvider::api_url(cx);
  80        self.api_key_state
  81            .store(api_url, api_key, |this| &mut this.api_key_state, cx)
  82    }
  83
  84    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
  85        let api_url = GoogleLanguageModelProvider::api_url(cx);
  86        self.api_key_state
  87            .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
  88    }
  89}
  90
  91impl GoogleLanguageModelProvider {
  92    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
  93        let state = cx.new(|cx| {
  94            cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
  95                let api_url = Self::api_url(cx);
  96                this.api_key_state
  97                    .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
  98                cx.notify();
  99            })
 100            .detach();
 101            State {
 102                api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
 103            }
 104        });
 105
 106        Self { http_client, state }
 107    }
 108
 109    fn create_language_model(&self, model: google_ai::Model) -> Arc<dyn LanguageModel> {
 110        Arc::new(GoogleLanguageModel {
 111            id: LanguageModelId::from(model.id().to_string()),
 112            model,
 113            state: self.state.clone(),
 114            http_client: self.http_client.clone(),
 115            request_limiter: RateLimiter::new(4),
 116        })
 117    }
 118
 119    fn settings(cx: &App) -> &GoogleSettings {
 120        &crate::AllLanguageModelSettings::get_global(cx).google
 121    }
 122
 123    fn api_url(cx: &App) -> SharedString {
 124        let api_url = &Self::settings(cx).api_url;
 125        if api_url.is_empty() {
 126            google_ai::API_URL.into()
 127        } else {
 128            SharedString::new(api_url.as_str())
 129        }
 130    }
 131}
 132
 133impl LanguageModelProviderState for GoogleLanguageModelProvider {
 134    type ObservableEntity = State;
 135
 136    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 137        Some(self.state.clone())
 138    }
 139}
 140
 141impl LanguageModelProvider for GoogleLanguageModelProvider {
 142    fn id(&self) -> LanguageModelProviderId {
 143        PROVIDER_ID
 144    }
 145
 146    fn name(&self) -> LanguageModelProviderName {
 147        PROVIDER_NAME
 148    }
 149
 150    fn icon(&self) -> IconName {
 151        IconName::AiGoogle
 152    }
 153
 154    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 155        Some(self.create_language_model(google_ai::Model::default()))
 156    }
 157
 158    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 159        Some(self.create_language_model(google_ai::Model::default_fast()))
 160    }
 161
 162    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 163        let mut models = BTreeMap::default();
 164
 165        // Add base models from google_ai::Model::iter()
 166        for model in google_ai::Model::iter() {
 167            if !matches!(model, google_ai::Model::Custom { .. }) {
 168                models.insert(model.id().to_string(), model);
 169            }
 170        }
 171
 172        // Override with available models from settings
 173        for model in &GoogleLanguageModelProvider::settings(cx).available_models {
 174            models.insert(
 175                model.name.clone(),
 176                google_ai::Model::Custom {
 177                    name: model.name.clone(),
 178                    display_name: model.display_name.clone(),
 179                    max_tokens: model.max_tokens,
 180                    mode: model.mode.unwrap_or_default(),
 181                },
 182            );
 183        }
 184
 185        models
 186            .into_values()
 187            .map(|model| {
 188                Arc::new(GoogleLanguageModel {
 189                    id: LanguageModelId::from(model.id().to_string()),
 190                    model,
 191                    state: self.state.clone(),
 192                    http_client: self.http_client.clone(),
 193                    request_limiter: RateLimiter::new(4),
 194                }) as Arc<dyn LanguageModel>
 195            })
 196            .collect()
 197    }
 198
 199    fn is_authenticated(&self, cx: &App) -> bool {
 200        self.state.read(cx).is_authenticated()
 201    }
 202
 203    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 204        self.state.update(cx, |state, cx| state.authenticate(cx))
 205    }
 206
 207    fn configuration_view(
 208        &self,
 209        target_agent: language_model::ConfigurationViewTargetAgent,
 210        window: &mut Window,
 211        cx: &mut App,
 212    ) -> AnyView {
 213        cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx))
 214            .into()
 215    }
 216
 217    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 218        self.state
 219            .update(cx, |state, cx| state.set_api_key(None, cx))
 220    }
 221}
 222
 223pub struct GoogleLanguageModel {
 224    id: LanguageModelId,
 225    model: google_ai::Model,
 226    state: Entity<State>,
 227    http_client: Arc<dyn HttpClient>,
 228    request_limiter: RateLimiter,
 229}
 230
 231impl GoogleLanguageModel {
 232    fn stream_completion(
 233        &self,
 234        request: google_ai::GenerateContentRequest,
 235        cx: &AsyncApp,
 236    ) -> BoxFuture<
 237        'static,
 238        Result<futures::stream::BoxStream<'static, Result<GenerateContentResponse>>>,
 239    > {
 240        let http_client = self.http_client.clone();
 241
 242        let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
 243            let api_url = GoogleLanguageModelProvider::api_url(cx);
 244            (state.api_key_state.key(&api_url), api_url)
 245        }) else {
 246            return future::ready(Err(anyhow!("App state dropped"))).boxed();
 247        };
 248
 249        async move {
 250            let api_key = api_key.context("Missing Google API key")?;
 251            let request = google_ai::stream_generate_content(
 252                http_client.as_ref(),
 253                &api_url,
 254                &api_key,
 255                request,
 256            );
 257            request.await.context("failed to stream completion")
 258        }
 259        .boxed()
 260    }
 261}
 262
 263impl LanguageModel for GoogleLanguageModel {
 264    fn id(&self) -> LanguageModelId {
 265        self.id.clone()
 266    }
 267
 268    fn name(&self) -> LanguageModelName {
 269        LanguageModelName::from(self.model.display_name().to_string())
 270    }
 271
 272    fn provider_id(&self) -> LanguageModelProviderId {
 273        PROVIDER_ID
 274    }
 275
 276    fn provider_name(&self) -> LanguageModelProviderName {
 277        PROVIDER_NAME
 278    }
 279
 280    fn supports_tools(&self) -> bool {
 281        self.model.supports_tools()
 282    }
 283
 284    fn supports_images(&self) -> bool {
 285        self.model.supports_images()
 286    }
 287
 288    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 289        match choice {
 290            LanguageModelToolChoice::Auto
 291            | LanguageModelToolChoice::Any
 292            | LanguageModelToolChoice::None => true,
 293        }
 294    }
 295
 296    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 297        LanguageModelToolSchemaFormat::JsonSchemaSubset
 298    }
 299
 300    fn telemetry_id(&self) -> String {
 301        format!("google/{}", self.model.request_id())
 302    }
 303
 304    fn max_token_count(&self) -> u64 {
 305        self.model.max_token_count()
 306    }
 307
 308    fn max_output_tokens(&self) -> Option<u64> {
 309        self.model.max_output_tokens()
 310    }
 311
 312    fn count_tokens(
 313        &self,
 314        request: LanguageModelRequest,
 315        cx: &App,
 316    ) -> BoxFuture<'static, Result<u64>> {
 317        let model_id = self.model.request_id().to_string();
 318        let request = into_google(request, model_id, self.model.mode());
 319        let http_client = self.http_client.clone();
 320        let api_url = GoogleLanguageModelProvider::api_url(cx);
 321        let api_key = self.state.read(cx).api_key_state.key(&api_url);
 322
 323        async move {
 324            let Some(api_key) = api_key else {
 325                return Err(LanguageModelCompletionError::NoApiKey {
 326                    provider: PROVIDER_NAME,
 327                }
 328                .into());
 329            };
 330            let response = google_ai::count_tokens(
 331                http_client.as_ref(),
 332                &api_url,
 333                &api_key,
 334                google_ai::CountTokensRequest {
 335                    generate_content_request: request,
 336                },
 337            )
 338            .await?;
 339            Ok(response.total_tokens)
 340        }
 341        .boxed()
 342    }
 343
 344    fn stream_completion(
 345        &self,
 346        request: LanguageModelRequest,
 347        cx: &AsyncApp,
 348    ) -> BoxFuture<
 349        'static,
 350        Result<
 351            futures::stream::BoxStream<
 352                'static,
 353                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
 354            >,
 355            LanguageModelCompletionError,
 356        >,
 357    > {
 358        let request = into_google(
 359            request,
 360            self.model.request_id().to_string(),
 361            self.model.mode(),
 362        );
 363        let request = self.stream_completion(request, cx);
 364        let future = self.request_limiter.stream(async move {
 365            let response = request.await.map_err(LanguageModelCompletionError::from)?;
 366            Ok(GoogleEventMapper::new().map_stream(response))
 367        });
 368        async move { Ok(future.await?.boxed()) }.boxed()
 369    }
 370}
 371
 372pub fn into_google(
 373    mut request: LanguageModelRequest,
 374    model_id: String,
 375    mode: GoogleModelMode,
 376) -> google_ai::GenerateContentRequest {
 377    fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
 378        content
 379            .into_iter()
 380            .flat_map(|content| match content {
 381                language_model::MessageContent::Text(text) => {
 382                    if !text.is_empty() {
 383                        vec![Part::TextPart(google_ai::TextPart { text })]
 384                    } else {
 385                        vec![]
 386                    }
 387                }
 388                language_model::MessageContent::Thinking {
 389                    text: _,
 390                    signature: Some(signature),
 391                } => {
 392                    if !signature.is_empty() {
 393                        vec![Part::ThoughtPart(google_ai::ThoughtPart {
 394                            thought: true,
 395                            thought_signature: signature,
 396                        })]
 397                    } else {
 398                        vec![]
 399                    }
 400                }
 401                language_model::MessageContent::Thinking { .. } => {
 402                    vec![]
 403                }
 404                language_model::MessageContent::RedactedThinking(_) => vec![],
 405                language_model::MessageContent::Image(image) => {
 406                    vec![Part::InlineDataPart(google_ai::InlineDataPart {
 407                        inline_data: google_ai::GenerativeContentBlob {
 408                            mime_type: "image/png".to_string(),
 409                            data: image.source.to_string(),
 410                        },
 411                    })]
 412                }
 413                language_model::MessageContent::ToolUse(tool_use) => {
 414                    // Normalize empty string signatures to None
 415                    let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty());
 416
 417                    vec![Part::FunctionCallPart(google_ai::FunctionCallPart {
 418                        function_call: google_ai::FunctionCall {
 419                            name: tool_use.name.to_string(),
 420                            args: tool_use.input,
 421                        },
 422                        thought_signature,
 423                    })]
 424                }
 425                language_model::MessageContent::ToolResult(tool_result) => {
 426                    match tool_result.content {
 427                        language_model::LanguageModelToolResultContent::Text(text) => {
 428                            vec![Part::FunctionResponsePart(
 429                                google_ai::FunctionResponsePart {
 430                                    function_response: google_ai::FunctionResponse {
 431                                        name: tool_result.tool_name.to_string(),
 432                                        // The API expects a valid JSON object
 433                                        response: serde_json::json!({
 434                                            "output": text
 435                                        }),
 436                                    },
 437                                },
 438                            )]
 439                        }
 440                        language_model::LanguageModelToolResultContent::Image(image) => {
 441                            vec![
 442                                Part::FunctionResponsePart(google_ai::FunctionResponsePart {
 443                                    function_response: google_ai::FunctionResponse {
 444                                        name: tool_result.tool_name.to_string(),
 445                                        // The API expects a valid JSON object
 446                                        response: serde_json::json!({
 447                                            "output": "Tool responded with an image"
 448                                        }),
 449                                    },
 450                                }),
 451                                Part::InlineDataPart(google_ai::InlineDataPart {
 452                                    inline_data: google_ai::GenerativeContentBlob {
 453                                        mime_type: "image/png".to_string(),
 454                                        data: image.source.to_string(),
 455                                    },
 456                                }),
 457                            ]
 458                        }
 459                    }
 460                }
 461            })
 462            .collect()
 463    }
 464
 465    let system_instructions = if request
 466        .messages
 467        .first()
 468        .is_some_and(|msg| matches!(msg.role, Role::System))
 469    {
 470        let message = request.messages.remove(0);
 471        Some(SystemInstruction {
 472            parts: map_content(message.content),
 473        })
 474    } else {
 475        None
 476    };
 477
 478    google_ai::GenerateContentRequest {
 479        model: google_ai::ModelName { model_id },
 480        system_instruction: system_instructions,
 481        contents: request
 482            .messages
 483            .into_iter()
 484            .filter_map(|message| {
 485                let parts = map_content(message.content);
 486                if parts.is_empty() {
 487                    None
 488                } else {
 489                    Some(google_ai::Content {
 490                        parts,
 491                        role: match message.role {
 492                            Role::User => google_ai::Role::User,
 493                            Role::Assistant => google_ai::Role::Model,
 494                            Role::System => google_ai::Role::User, // Google AI doesn't have a system role
 495                        },
 496                    })
 497                }
 498            })
 499            .collect(),
 500        generation_config: Some(google_ai::GenerationConfig {
 501            candidate_count: Some(1),
 502            stop_sequences: Some(request.stop),
 503            max_output_tokens: None,
 504            temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
 505            thinking_config: match (request.thinking_allowed, mode) {
 506                (true, GoogleModelMode::Thinking { budget_tokens }) => {
 507                    budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
 508                }
 509                _ => None,
 510            },
 511            top_p: None,
 512            top_k: None,
 513        }),
 514        safety_settings: None,
 515        tools: (!request.tools.is_empty()).then(|| {
 516            vec![google_ai::Tool {
 517                function_declarations: request
 518                    .tools
 519                    .into_iter()
 520                    .map(|tool| FunctionDeclaration {
 521                        name: tool.name,
 522                        description: tool.description,
 523                        parameters: tool.input_schema,
 524                    })
 525                    .collect(),
 526            }]
 527        }),
 528        tool_config: request.tool_choice.map(|choice| google_ai::ToolConfig {
 529            function_calling_config: google_ai::FunctionCallingConfig {
 530                mode: match choice {
 531                    LanguageModelToolChoice::Auto => google_ai::FunctionCallingMode::Auto,
 532                    LanguageModelToolChoice::Any => google_ai::FunctionCallingMode::Any,
 533                    LanguageModelToolChoice::None => google_ai::FunctionCallingMode::None,
 534                },
 535                allowed_function_names: None,
 536            },
 537        }),
 538    }
 539}
 540
 541pub struct GoogleEventMapper {
 542    usage: UsageMetadata,
 543    stop_reason: StopReason,
 544}
 545
 546impl GoogleEventMapper {
 547    pub fn new() -> Self {
 548        Self {
 549            usage: UsageMetadata::default(),
 550            stop_reason: StopReason::EndTurn,
 551        }
 552    }
 553
 554    pub fn map_stream(
 555        mut self,
 556        events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
 557    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 558    {
 559        events
 560            .map(Some)
 561            .chain(futures::stream::once(async { None }))
 562            .flat_map(move |event| {
 563                futures::stream::iter(match event {
 564                    Some(Ok(event)) => self.map_event(event),
 565                    Some(Err(error)) => {
 566                        vec![Err(LanguageModelCompletionError::from(error))]
 567                    }
 568                    None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
 569                })
 570            })
 571    }
 572
 573    pub fn map_event(
 574        &mut self,
 575        event: GenerateContentResponse,
 576    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 577        static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
 578
 579        let mut events: Vec<_> = Vec::new();
 580        let mut wants_to_use_tool = false;
 581        if let Some(usage_metadata) = event.usage_metadata {
 582            update_usage(&mut self.usage, &usage_metadata);
 583            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
 584                convert_usage(&self.usage),
 585            )))
 586        }
 587
 588        if let Some(prompt_feedback) = event.prompt_feedback
 589            && let Some(block_reason) = prompt_feedback.block_reason.as_deref()
 590        {
 591            self.stop_reason = match block_reason {
 592                "SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "IMAGE_SAFETY" => {
 593                    StopReason::Refusal
 594                }
 595                _ => {
 596                    log::error!("Unexpected Google block_reason: {block_reason}");
 597                    StopReason::Refusal
 598                }
 599            };
 600            events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason)));
 601
 602            return events;
 603        }
 604
 605        if let Some(candidates) = event.candidates {
 606            for candidate in candidates {
 607                if let Some(finish_reason) = candidate.finish_reason.as_deref() {
 608                    self.stop_reason = match finish_reason {
 609                        "STOP" => StopReason::EndTurn,
 610                        "MAX_TOKENS" => StopReason::MaxTokens,
 611                        _ => {
 612                            log::error!("Unexpected google finish_reason: {finish_reason}");
 613                            StopReason::EndTurn
 614                        }
 615                    };
 616                }
 617                candidate
 618                    .content
 619                    .parts
 620                    .into_iter()
 621                    .for_each(|part| match part {
 622                        Part::TextPart(text_part) => {
 623                            events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text)))
 624                        }
 625                        Part::InlineDataPart(_) => {}
 626                        Part::FunctionCallPart(function_call_part) => {
 627                            wants_to_use_tool = true;
 628                            let name: Arc<str> = function_call_part.function_call.name.into();
 629                            let next_tool_id =
 630                                TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
 631                            let id: LanguageModelToolUseId =
 632                                format!("{}-{}", name, next_tool_id).into();
 633
 634                            // Normalize empty string signatures to None
 635                            let thought_signature = function_call_part
 636                                .thought_signature
 637                                .filter(|s| !s.is_empty());
 638
 639                            events.push(Ok(LanguageModelCompletionEvent::ToolUse(
 640                                LanguageModelToolUse {
 641                                    id,
 642                                    name,
 643                                    is_input_complete: true,
 644                                    raw_input: function_call_part.function_call.args.to_string(),
 645                                    input: function_call_part.function_call.args,
 646                                    thought_signature,
 647                                },
 648                            )));
 649                        }
 650                        Part::FunctionResponsePart(_) => {}
 651                        Part::ThoughtPart(part) => {
 652                            events.push(Ok(LanguageModelCompletionEvent::Thinking {
 653                                text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
 654                                signature: Some(part.thought_signature),
 655                            }));
 656                        }
 657                    });
 658            }
 659        }
 660
 661        // Even when Gemini wants to use a Tool, the API
 662        // responds with `finish_reason: STOP`
 663        if wants_to_use_tool {
 664            self.stop_reason = StopReason::ToolUse;
 665            events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
 666        }
 667        events
 668    }
 669}
 670
 671pub fn count_google_tokens(
 672    request: LanguageModelRequest,
 673    cx: &App,
 674) -> BoxFuture<'static, Result<u64>> {
 675    // We couldn't use the GoogleLanguageModelProvider to count tokens because the github copilot doesn't have the access to google_ai directly.
 676    // So we have to use tokenizer from tiktoken_rs to count tokens.
 677    cx.background_spawn(async move {
 678        let messages = request
 679            .messages
 680            .into_iter()
 681            .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
 682                role: match message.role {
 683                    Role::User => "user".into(),
 684                    Role::Assistant => "assistant".into(),
 685                    Role::System => "system".into(),
 686                },
 687                content: Some(message.string_contents()),
 688                name: None,
 689                function_call: None,
 690            })
 691            .collect::<Vec<_>>();
 692
 693        // Tiktoken doesn't yet support these models, so we manually use the
 694        // same tokenizer as GPT-4.
 695        tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
 696    })
 697    .boxed()
 698}
 699
 700fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
 701    if let Some(prompt_token_count) = new.prompt_token_count {
 702        usage.prompt_token_count = Some(prompt_token_count);
 703    }
 704    if let Some(cached_content_token_count) = new.cached_content_token_count {
 705        usage.cached_content_token_count = Some(cached_content_token_count);
 706    }
 707    if let Some(candidates_token_count) = new.candidates_token_count {
 708        usage.candidates_token_count = Some(candidates_token_count);
 709    }
 710    if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count {
 711        usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count);
 712    }
 713    if let Some(thoughts_token_count) = new.thoughts_token_count {
 714        usage.thoughts_token_count = Some(thoughts_token_count);
 715    }
 716    if let Some(total_token_count) = new.total_token_count {
 717        usage.total_token_count = Some(total_token_count);
 718    }
 719}
 720
 721fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage {
 722    let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
 723    let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
 724    let input_tokens = prompt_tokens - cached_tokens;
 725    let output_tokens = usage.candidates_token_count.unwrap_or(0);
 726
 727    language_model::TokenUsage {
 728        input_tokens,
 729        output_tokens,
 730        cache_read_input_tokens: cached_tokens,
 731        cache_creation_input_tokens: 0,
 732    }
 733}
 734
 735struct ConfigurationView {
 736    api_key_editor: Entity<InputField>,
 737    state: Entity<State>,
 738    target_agent: language_model::ConfigurationViewTargetAgent,
 739    load_credentials_task: Option<Task<()>>,
 740}
 741
 742impl ConfigurationView {
 743    fn new(
 744        state: Entity<State>,
 745        target_agent: language_model::ConfigurationViewTargetAgent,
 746        window: &mut Window,
 747        cx: &mut Context<Self>,
 748    ) -> Self {
 749        cx.observe(&state, |_, _, cx| {
 750            cx.notify();
 751        })
 752        .detach();
 753
 754        let load_credentials_task = Some(cx.spawn_in(window, {
 755            let state = state.clone();
 756            async move |this, cx| {
 757                if let Some(task) = state
 758                    .update(cx, |state, cx| state.authenticate(cx))
 759                    .log_err()
 760                {
 761                    // We don't log an error, because "not signed in" is also an error.
 762                    let _ = task.await;
 763                }
 764                this.update(cx, |this, cx| {
 765                    this.load_credentials_task = None;
 766                    cx.notify();
 767                })
 768                .log_err();
 769            }
 770        }));
 771
 772        Self {
 773            api_key_editor: cx.new(|cx| InputField::new(window, cx, "AIzaSy...")),
 774            target_agent,
 775            state,
 776            load_credentials_task,
 777        }
 778    }
 779
 780    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 781        let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
 782        if api_key.is_empty() {
 783            return;
 784        }
 785
 786        // url changes can cause the editor to be displayed again
 787        self.api_key_editor
 788            .update(cx, |editor, cx| editor.set_text("", window, cx));
 789
 790        let state = self.state.clone();
 791        cx.spawn_in(window, async move |_, cx| {
 792            state
 793                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
 794                .await
 795        })
 796        .detach_and_log_err(cx);
 797    }
 798
 799    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 800        self.api_key_editor
 801            .update(cx, |editor, cx| editor.set_text("", window, cx));
 802
 803        let state = self.state.clone();
 804        cx.spawn_in(window, async move |_, cx| {
 805            state
 806                .update(cx, |state, cx| state.set_api_key(None, cx))?
 807                .await
 808        })
 809        .detach_and_log_err(cx);
 810    }
 811
 812    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
 813        !self.state.read(cx).is_authenticated()
 814    }
 815}
 816
 817impl Render for ConfigurationView {
 818    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 819        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
 820        let configured_card_label = if env_var_set {
 821            format!(
 822                "API key set in {} environment variable",
 823                API_KEY_ENV_VAR.name
 824            )
 825        } else {
 826            let api_url = GoogleLanguageModelProvider::api_url(cx);
 827            if api_url == google_ai::API_URL {
 828                "API key configured".to_string()
 829            } else {
 830                format!("API key configured for {}", api_url)
 831            }
 832        };
 833
 834        if self.load_credentials_task.is_some() {
 835            div()
 836                .child(Label::new("Loading credentials..."))
 837                .into_any_element()
 838        } else if self.should_render_editor(cx) {
 839            v_flex()
 840                .size_full()
 841                .on_action(cx.listener(Self::save_api_key))
 842                .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent {
 843                    ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Google AI".into(),
 844                    ConfigurationViewTargetAgent::Other(agent) => agent.clone(),
 845                })))
 846                .child(
 847                    List::new()
 848                        .child(
 849                            ListBulletItem::new("")
 850                                .child(Label::new("Create one by visiting"))
 851                                .child(ButtonLink::new("Google AI's console", "https://aistudio.google.com/app/apikey"))
 852                        )
 853                        .child(
 854                            ListBulletItem::new("Paste your API key below and hit enter to start using the agent")
 855                        )
 856                )
 857                .child(self.api_key_editor.clone())
 858                .child(
 859                    Label::new(
 860                        format!("You can also assign the {GEMINI_API_KEY_VAR_NAME} environment variable and restart Zed."),
 861                    )
 862                    .size(LabelSize::Small).color(Color::Muted),
 863                )
 864                .into_any_element()
 865        } else {
 866            ConfiguredApiCard::new(configured_card_label)
 867                .disabled(env_var_set)
 868                .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
 869                .when(env_var_set, |this| {
 870                    this.tooltip_label(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR_NAME} and {GOOGLE_AI_API_KEY_VAR_NAME} environment variables are unset."))
 871                })
 872                .into_any_element()
 873        }
 874    }
 875}
 876
 877#[cfg(test)]
 878mod tests {
 879    use super::*;
 880    use google_ai::{
 881        Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse,
 882        Part, Role as GoogleRole, TextPart,
 883    };
 884    use language_model::{LanguageModelToolUseId, MessageContent, Role};
 885    use serde_json::json;
 886
 887    #[test]
 888    fn test_function_call_with_signature_creates_tool_use_with_signature() {
 889        let mut mapper = GoogleEventMapper::new();
 890
 891        let response = GenerateContentResponse {
 892            candidates: Some(vec![GenerateContentCandidate {
 893                index: Some(0),
 894                content: Content {
 895                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
 896                        function_call: FunctionCall {
 897                            name: "test_function".to_string(),
 898                            args: json!({"arg": "value"}),
 899                        },
 900                        thought_signature: Some("test_signature_123".to_string()),
 901                    })],
 902                    role: GoogleRole::Model,
 903                },
 904                finish_reason: None,
 905                finish_message: None,
 906                safety_ratings: None,
 907                citation_metadata: None,
 908            }]),
 909            prompt_feedback: None,
 910            usage_metadata: None,
 911        };
 912
 913        let events = mapper.map_event(response);
 914
 915        assert_eq!(events.len(), 2); // ToolUse event + Stop event
 916
 917        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
 918            assert_eq!(tool_use.name.as_ref(), "test_function");
 919            assert_eq!(
 920                tool_use.thought_signature.as_deref(),
 921                Some("test_signature_123")
 922            );
 923        } else {
 924            panic!("Expected ToolUse event");
 925        }
 926    }
 927
 928    #[test]
 929    fn test_function_call_without_signature_has_none() {
 930        let mut mapper = GoogleEventMapper::new();
 931
 932        let response = GenerateContentResponse {
 933            candidates: Some(vec![GenerateContentCandidate {
 934                index: Some(0),
 935                content: Content {
 936                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
 937                        function_call: FunctionCall {
 938                            name: "test_function".to_string(),
 939                            args: json!({"arg": "value"}),
 940                        },
 941                        thought_signature: None,
 942                    })],
 943                    role: GoogleRole::Model,
 944                },
 945                finish_reason: None,
 946                finish_message: None,
 947                safety_ratings: None,
 948                citation_metadata: None,
 949            }]),
 950            prompt_feedback: None,
 951            usage_metadata: None,
 952        };
 953
 954        let events = mapper.map_event(response);
 955
 956        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
 957            assert_eq!(tool_use.thought_signature, None);
 958        } else {
 959            panic!("Expected ToolUse event");
 960        }
 961    }
 962
 963    #[test]
 964    fn test_empty_string_signature_normalized_to_none() {
 965        let mut mapper = GoogleEventMapper::new();
 966
 967        let response = GenerateContentResponse {
 968            candidates: Some(vec![GenerateContentCandidate {
 969                index: Some(0),
 970                content: Content {
 971                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
 972                        function_call: FunctionCall {
 973                            name: "test_function".to_string(),
 974                            args: json!({"arg": "value"}),
 975                        },
 976                        thought_signature: Some("".to_string()),
 977                    })],
 978                    role: GoogleRole::Model,
 979                },
 980                finish_reason: None,
 981                finish_message: None,
 982                safety_ratings: None,
 983                citation_metadata: None,
 984            }]),
 985            prompt_feedback: None,
 986            usage_metadata: None,
 987        };
 988
 989        let events = mapper.map_event(response);
 990
 991        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
 992            assert_eq!(tool_use.thought_signature, None);
 993        } else {
 994            panic!("Expected ToolUse event");
 995        }
 996    }
 997
 998    #[test]
 999    fn test_parallel_function_calls_preserve_signatures() {
1000        let mut mapper = GoogleEventMapper::new();
1001
1002        let response = GenerateContentResponse {
1003            candidates: Some(vec![GenerateContentCandidate {
1004                index: Some(0),
1005                content: Content {
1006                    parts: vec![
1007                        Part::FunctionCallPart(FunctionCallPart {
1008                            function_call: FunctionCall {
1009                                name: "function_1".to_string(),
1010                                args: json!({"arg": "value1"}),
1011                            },
1012                            thought_signature: Some("signature_1".to_string()),
1013                        }),
1014                        Part::FunctionCallPart(FunctionCallPart {
1015                            function_call: FunctionCall {
1016                                name: "function_2".to_string(),
1017                                args: json!({"arg": "value2"}),
1018                            },
1019                            thought_signature: None,
1020                        }),
1021                    ],
1022                    role: GoogleRole::Model,
1023                },
1024                finish_reason: None,
1025                finish_message: None,
1026                safety_ratings: None,
1027                citation_metadata: None,
1028            }]),
1029            prompt_feedback: None,
1030            usage_metadata: None,
1031        };
1032
1033        let events = mapper.map_event(response);
1034
1035        assert_eq!(events.len(), 3); // 2 ToolUse events + Stop event
1036
1037        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
1038            assert_eq!(tool_use.name.as_ref(), "function_1");
1039            assert_eq!(tool_use.thought_signature.as_deref(), Some("signature_1"));
1040        } else {
1041            panic!("Expected ToolUse event for function_1");
1042        }
1043
1044        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
1045            assert_eq!(tool_use.name.as_ref(), "function_2");
1046            assert_eq!(tool_use.thought_signature, None);
1047        } else {
1048            panic!("Expected ToolUse event for function_2");
1049        }
1050    }
1051
1052    #[test]
1053    fn test_tool_use_with_signature_converts_to_function_call_part() {
1054        let tool_use = language_model::LanguageModelToolUse {
1055            id: LanguageModelToolUseId::from("test_id"),
1056            name: "test_function".into(),
1057            raw_input: json!({"arg": "value"}).to_string(),
1058            input: json!({"arg": "value"}),
1059            is_input_complete: true,
1060            thought_signature: Some("test_signature_456".to_string()),
1061        };
1062
1063        let request = super::into_google(
1064            LanguageModelRequest {
1065                messages: vec![language_model::LanguageModelRequestMessage {
1066                    role: Role::Assistant,
1067                    content: vec![MessageContent::ToolUse(tool_use)],
1068                    cache: false,
1069                    reasoning_details: None,
1070                }],
1071                ..Default::default()
1072            },
1073            "gemini-2.5-flash".to_string(),
1074            GoogleModelMode::Default,
1075        );
1076
1077        assert_eq!(request.contents[0].parts.len(), 1);
1078        if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1079            assert_eq!(fc_part.function_call.name, "test_function");
1080            assert_eq!(
1081                fc_part.thought_signature.as_deref(),
1082                Some("test_signature_456")
1083            );
1084        } else {
1085            panic!("Expected FunctionCallPart");
1086        }
1087    }
1088
1089    #[test]
1090    fn test_tool_use_without_signature_omits_field() {
1091        let tool_use = language_model::LanguageModelToolUse {
1092            id: LanguageModelToolUseId::from("test_id"),
1093            name: "test_function".into(),
1094            raw_input: json!({"arg": "value"}).to_string(),
1095            input: json!({"arg": "value"}),
1096            is_input_complete: true,
1097            thought_signature: None,
1098        };
1099
1100        let request = super::into_google(
1101            LanguageModelRequest {
1102                messages: vec![language_model::LanguageModelRequestMessage {
1103                    role: Role::Assistant,
1104                    content: vec![MessageContent::ToolUse(tool_use)],
1105                    cache: false,
1106                    reasoning_details: None,
1107                }],
1108                ..Default::default()
1109            },
1110            "gemini-2.5-flash".to_string(),
1111            GoogleModelMode::Default,
1112        );
1113
1114        assert_eq!(request.contents[0].parts.len(), 1);
1115        if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1116            assert_eq!(fc_part.thought_signature, None);
1117        } else {
1118            panic!("Expected FunctionCallPart");
1119        }
1120    }
1121
1122    #[test]
1123    fn test_empty_signature_in_tool_use_normalized_to_none() {
1124        let tool_use = language_model::LanguageModelToolUse {
1125            id: LanguageModelToolUseId::from("test_id"),
1126            name: "test_function".into(),
1127            raw_input: json!({"arg": "value"}).to_string(),
1128            input: json!({"arg": "value"}),
1129            is_input_complete: true,
1130            thought_signature: Some("".to_string()),
1131        };
1132
1133        let request = super::into_google(
1134            LanguageModelRequest {
1135                messages: vec![language_model::LanguageModelRequestMessage {
1136                    role: Role::Assistant,
1137                    content: vec![MessageContent::ToolUse(tool_use)],
1138                    cache: false,
1139                    reasoning_details: None,
1140                }],
1141                ..Default::default()
1142            },
1143            "gemini-2.5-flash".to_string(),
1144            GoogleModelMode::Default,
1145        );
1146
1147        if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1148            assert_eq!(fc_part.thought_signature, None);
1149        } else {
1150            panic!("Expected FunctionCallPart");
1151        }
1152    }
1153
1154    #[test]
1155    fn test_round_trip_preserves_signature() {
1156        let mut mapper = GoogleEventMapper::new();
1157
1158        // Simulate receiving a response from Google with a signature
1159        let response = GenerateContentResponse {
1160            candidates: Some(vec![GenerateContentCandidate {
1161                index: Some(0),
1162                content: Content {
1163                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
1164                        function_call: FunctionCall {
1165                            name: "test_function".to_string(),
1166                            args: json!({"arg": "value"}),
1167                        },
1168                        thought_signature: Some("round_trip_sig".to_string()),
1169                    })],
1170                    role: GoogleRole::Model,
1171                },
1172                finish_reason: None,
1173                finish_message: None,
1174                safety_ratings: None,
1175                citation_metadata: None,
1176            }]),
1177            prompt_feedback: None,
1178            usage_metadata: None,
1179        };
1180
1181        let events = mapper.map_event(response);
1182
1183        let tool_use = if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
1184            tool_use.clone()
1185        } else {
1186            panic!("Expected ToolUse event");
1187        };
1188
1189        // Convert back to Google format
1190        let request = super::into_google(
1191            LanguageModelRequest {
1192                messages: vec![language_model::LanguageModelRequestMessage {
1193                    role: Role::Assistant,
1194                    content: vec![MessageContent::ToolUse(tool_use)],
1195                    cache: false,
1196                    reasoning_details: None,
1197                }],
1198                ..Default::default()
1199            },
1200            "gemini-2.5-flash".to_string(),
1201            GoogleModelMode::Default,
1202        );
1203
1204        // Verify signature is preserved
1205        if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1206            assert_eq!(fc_part.thought_signature.as_deref(), Some("round_trip_sig"));
1207        } else {
1208            panic!("Expected FunctionCallPart");
1209        }
1210    }
1211
1212    #[test]
1213    fn test_mixed_text_and_function_call_with_signature() {
1214        let mut mapper = GoogleEventMapper::new();
1215
1216        let response = GenerateContentResponse {
1217            candidates: Some(vec![GenerateContentCandidate {
1218                index: Some(0),
1219                content: Content {
1220                    parts: vec![
1221                        Part::TextPart(TextPart {
1222                            text: "I'll help with that.".to_string(),
1223                        }),
1224                        Part::FunctionCallPart(FunctionCallPart {
1225                            function_call: FunctionCall {
1226                                name: "helper_function".to_string(),
1227                                args: json!({"query": "help"}),
1228                            },
1229                            thought_signature: Some("mixed_sig".to_string()),
1230                        }),
1231                    ],
1232                    role: GoogleRole::Model,
1233                },
1234                finish_reason: None,
1235                finish_message: None,
1236                safety_ratings: None,
1237                citation_metadata: None,
1238            }]),
1239            prompt_feedback: None,
1240            usage_metadata: None,
1241        };
1242
1243        let events = mapper.map_event(response);
1244
1245        assert_eq!(events.len(), 3); // Text event + ToolUse event + Stop event
1246
1247        if let Ok(LanguageModelCompletionEvent::Text(text)) = &events[0] {
1248            assert_eq!(text, "I'll help with that.");
1249        } else {
1250            panic!("Expected Text event");
1251        }
1252
1253        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
1254            assert_eq!(tool_use.name.as_ref(), "helper_function");
1255            assert_eq!(tool_use.thought_signature.as_deref(), Some("mixed_sig"));
1256        } else {
1257            panic!("Expected ToolUse event");
1258        }
1259    }
1260
1261    #[test]
1262    fn test_special_characters_in_signature_preserved() {
1263        let mut mapper = GoogleEventMapper::new();
1264
1265        let signature_with_special_chars = "sig<>\"'&%$#@!{}[]".to_string();
1266
1267        let response = GenerateContentResponse {
1268            candidates: Some(vec![GenerateContentCandidate {
1269                index: Some(0),
1270                content: Content {
1271                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
1272                        function_call: FunctionCall {
1273                            name: "test_function".to_string(),
1274                            args: json!({"arg": "value"}),
1275                        },
1276                        thought_signature: Some(signature_with_special_chars.clone()),
1277                    })],
1278                    role: GoogleRole::Model,
1279                },
1280                finish_reason: None,
1281                finish_message: None,
1282                safety_ratings: None,
1283                citation_metadata: None,
1284            }]),
1285            prompt_feedback: None,
1286            usage_metadata: None,
1287        };
1288
1289        let events = mapper.map_event(response);
1290
1291        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
1292            assert_eq!(
1293                tool_use.thought_signature.as_deref(),
1294                Some(signature_with_special_chars.as_str())
1295            );
1296        } else {
1297            panic!("Expected ToolUse event");
1298        }
1299    }
1300}