google.rs

   1use anyhow::{Context as _, Result};
   2use collections::BTreeMap;
   3use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
   4use google_ai::{
   5    FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
   6    ThinkingConfig, UsageMetadata,
   7};
   8use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
   9use http_client::HttpClient;
  10use language_model::{
  11    AuthenticateError, ConfigurationViewTargetAgent, EnvVar, LanguageModelCompletionError,
  12    LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat,
  13    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
  14};
  15use language_model::{
  16    IconOrSvg, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
  17    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  18    LanguageModelRequest, RateLimiter, Role,
  19};
  20use schemars::JsonSchema;
  21use serde::{Deserialize, Serialize};
  22pub use settings::GoogleAvailableModel as AvailableModel;
  23use settings::{Settings, SettingsStore};
  24use std::pin::Pin;
  25use std::sync::{
  26    Arc, LazyLock,
  27    atomic::{self, AtomicU64},
  28};
  29use strum::IntoEnumIterator;
  30use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
  31use ui_input::InputField;
  32use util::ResultExt;
  33
  34use language_model::ApiKeyState;
  35
  36const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID;
  37const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME;
  38
  39#[derive(Default, Clone, Debug, PartialEq)]
  40pub struct GoogleSettings {
  41    pub api_url: String,
  42    pub available_models: Vec<AvailableModel>,
  43}
  44
  45#[derive(Clone, Copy, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
  46#[serde(tag = "type", rename_all = "lowercase")]
  47pub enum ModelMode {
  48    #[default]
  49    Default,
  50    Thinking {
  51        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  52        budget_tokens: Option<u32>,
  53    },
  54}
  55
  56pub struct GoogleLanguageModelProvider {
  57    http_client: Arc<dyn HttpClient>,
  58    state: Entity<State>,
  59}
  60
  61pub struct State {
  62    api_key_state: ApiKeyState,
  63}
  64
  65const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY";
  66const GOOGLE_AI_API_KEY_VAR_NAME: &str = "GOOGLE_AI_API_KEY";
  67
  68static API_KEY_ENV_VAR: LazyLock<EnvVar> = LazyLock::new(|| {
  69    // Try GEMINI_API_KEY first as primary, fallback to GOOGLE_AI_API_KEY
  70    EnvVar::new(GEMINI_API_KEY_VAR_NAME.into()).or(EnvVar::new(GOOGLE_AI_API_KEY_VAR_NAME.into()))
  71});
  72
  73impl State {
  74    fn is_authenticated(&self) -> bool {
  75        self.api_key_state.has_key()
  76    }
  77
  78    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
  79        let api_url = GoogleLanguageModelProvider::api_url(cx);
  80        self.api_key_state
  81            .store(api_url, api_key, |this| &mut this.api_key_state, cx)
  82    }
  83
  84    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
  85        let api_url = GoogleLanguageModelProvider::api_url(cx);
  86        self.api_key_state
  87            .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
  88    }
  89}
  90
  91impl GoogleLanguageModelProvider {
  92    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
  93        let state = cx.new(|cx| {
  94            cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
  95                let api_url = Self::api_url(cx);
  96                this.api_key_state
  97                    .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
  98                cx.notify();
  99            })
 100            .detach();
 101            State {
 102                api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
 103            }
 104        });
 105
 106        Self { http_client, state }
 107    }
 108
 109    fn create_language_model(&self, model: google_ai::Model) -> Arc<dyn LanguageModel> {
 110        Arc::new(GoogleLanguageModel {
 111            id: LanguageModelId::from(model.id().to_string()),
 112            model,
 113            state: self.state.clone(),
 114            http_client: self.http_client.clone(),
 115            request_limiter: RateLimiter::new(4),
 116        })
 117    }
 118
 119    fn settings(cx: &App) -> &GoogleSettings {
 120        &crate::AllLanguageModelSettings::get_global(cx).google
 121    }
 122
 123    fn api_url(cx: &App) -> SharedString {
 124        let api_url = &Self::settings(cx).api_url;
 125        if api_url.is_empty() {
 126            google_ai::API_URL.into()
 127        } else {
 128            SharedString::new(api_url.as_str())
 129        }
 130    }
 131}
 132
 133impl LanguageModelProviderState for GoogleLanguageModelProvider {
 134    type ObservableEntity = State;
 135
 136    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 137        Some(self.state.clone())
 138    }
 139}
 140
 141impl LanguageModelProvider for GoogleLanguageModelProvider {
 142    fn id(&self) -> LanguageModelProviderId {
 143        PROVIDER_ID
 144    }
 145
 146    fn name(&self) -> LanguageModelProviderName {
 147        PROVIDER_NAME
 148    }
 149
 150    fn icon(&self) -> IconOrSvg {
 151        IconOrSvg::Icon(IconName::AiGoogle)
 152    }
 153
 154    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 155        Some(self.create_language_model(google_ai::Model::default()))
 156    }
 157
 158    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 159        Some(self.create_language_model(google_ai::Model::default_fast()))
 160    }
 161
 162    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 163        let mut models = BTreeMap::default();
 164
 165        // Add base models from google_ai::Model::iter()
 166        for model in google_ai::Model::iter() {
 167            if !matches!(model, google_ai::Model::Custom { .. }) {
 168                models.insert(model.id().to_string(), model);
 169            }
 170        }
 171
 172        // Override with available models from settings
 173        for model in &GoogleLanguageModelProvider::settings(cx).available_models {
 174            models.insert(
 175                model.name.clone(),
 176                google_ai::Model::Custom {
 177                    name: model.name.clone(),
 178                    display_name: model.display_name.clone(),
 179                    max_tokens: model.max_tokens,
 180                    mode: model.mode.unwrap_or_default(),
 181                },
 182            );
 183        }
 184
 185        models
 186            .into_values()
 187            .map(|model| {
 188                Arc::new(GoogleLanguageModel {
 189                    id: LanguageModelId::from(model.id().to_string()),
 190                    model,
 191                    state: self.state.clone(),
 192                    http_client: self.http_client.clone(),
 193                    request_limiter: RateLimiter::new(4),
 194                }) as Arc<dyn LanguageModel>
 195            })
 196            .collect()
 197    }
 198
 199    fn is_authenticated(&self, cx: &App) -> bool {
 200        self.state.read(cx).is_authenticated()
 201    }
 202
 203    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 204        self.state.update(cx, |state, cx| state.authenticate(cx))
 205    }
 206
 207    fn configuration_view(
 208        &self,
 209        target_agent: language_model::ConfigurationViewTargetAgent,
 210        window: &mut Window,
 211        cx: &mut App,
 212    ) -> AnyView {
 213        cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx))
 214            .into()
 215    }
 216
 217    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 218        self.state
 219            .update(cx, |state, cx| state.set_api_key(None, cx))
 220    }
 221}
 222
 223pub struct GoogleLanguageModel {
 224    id: LanguageModelId,
 225    model: google_ai::Model,
 226    state: Entity<State>,
 227    http_client: Arc<dyn HttpClient>,
 228    request_limiter: RateLimiter,
 229}
 230
 231impl GoogleLanguageModel {
 232    fn stream_completion(
 233        &self,
 234        request: google_ai::GenerateContentRequest,
 235        cx: &AsyncApp,
 236    ) -> BoxFuture<
 237        'static,
 238        Result<futures::stream::BoxStream<'static, Result<GenerateContentResponse>>>,
 239    > {
 240        let http_client = self.http_client.clone();
 241
 242        let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
 243            let api_url = GoogleLanguageModelProvider::api_url(cx);
 244            (state.api_key_state.key(&api_url), api_url)
 245        });
 246
 247        async move {
 248            let api_key = api_key.context("Missing Google API key")?;
 249            let request = google_ai::stream_generate_content(
 250                http_client.as_ref(),
 251                &api_url,
 252                &api_key,
 253                request,
 254            );
 255            request.await.context("failed to stream completion")
 256        }
 257        .boxed()
 258    }
 259}
 260
 261impl LanguageModel for GoogleLanguageModel {
 262    fn id(&self) -> LanguageModelId {
 263        self.id.clone()
 264    }
 265
 266    fn name(&self) -> LanguageModelName {
 267        LanguageModelName::from(self.model.display_name().to_string())
 268    }
 269
 270    fn provider_id(&self) -> LanguageModelProviderId {
 271        PROVIDER_ID
 272    }
 273
 274    fn provider_name(&self) -> LanguageModelProviderName {
 275        PROVIDER_NAME
 276    }
 277
 278    fn supports_tools(&self) -> bool {
 279        self.model.supports_tools()
 280    }
 281
 282    fn supports_images(&self) -> bool {
 283        self.model.supports_images()
 284    }
 285
 286    fn supports_thinking(&self) -> bool {
 287        matches!(self.model.mode(), GoogleModelMode::Thinking { .. })
 288    }
 289
 290    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 291        match choice {
 292            LanguageModelToolChoice::Auto
 293            | LanguageModelToolChoice::Any
 294            | LanguageModelToolChoice::None => true,
 295        }
 296    }
 297
 298    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 299        LanguageModelToolSchemaFormat::JsonSchemaSubset
 300    }
 301
 302    fn telemetry_id(&self) -> String {
 303        format!("google/{}", self.model.request_id())
 304    }
 305
 306    fn max_token_count(&self) -> u64 {
 307        self.model.max_token_count()
 308    }
 309
 310    fn max_output_tokens(&self) -> Option<u64> {
 311        self.model.max_output_tokens()
 312    }
 313
 314    fn count_tokens(
 315        &self,
 316        request: LanguageModelRequest,
 317        cx: &App,
 318    ) -> BoxFuture<'static, Result<u64>> {
 319        let model_id = self.model.request_id().to_string();
 320        let request = into_google(request, model_id, self.model.mode());
 321        let http_client = self.http_client.clone();
 322        let api_url = GoogleLanguageModelProvider::api_url(cx);
 323        let api_key = self.state.read(cx).api_key_state.key(&api_url);
 324
 325        async move {
 326            let Some(api_key) = api_key else {
 327                return Err(LanguageModelCompletionError::NoApiKey {
 328                    provider: PROVIDER_NAME,
 329                }
 330                .into());
 331            };
 332            let response = google_ai::count_tokens(
 333                http_client.as_ref(),
 334                &api_url,
 335                &api_key,
 336                google_ai::CountTokensRequest {
 337                    generate_content_request: request,
 338                },
 339            )
 340            .await?;
 341            Ok(response.total_tokens)
 342        }
 343        .boxed()
 344    }
 345
 346    fn stream_completion(
 347        &self,
 348        request: LanguageModelRequest,
 349        cx: &AsyncApp,
 350    ) -> BoxFuture<
 351        'static,
 352        Result<
 353            futures::stream::BoxStream<
 354                'static,
 355                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
 356            >,
 357            LanguageModelCompletionError,
 358        >,
 359    > {
 360        let request = into_google(
 361            request,
 362            self.model.request_id().to_string(),
 363            self.model.mode(),
 364        );
 365        let request = self.stream_completion(request, cx);
 366        let future = self.request_limiter.stream(async move {
 367            let response = request.await.map_err(LanguageModelCompletionError::from)?;
 368            Ok(GoogleEventMapper::new().map_stream(response))
 369        });
 370        async move { Ok(future.await?.boxed()) }.boxed()
 371    }
 372}
 373
 374pub fn into_google(
 375    mut request: LanguageModelRequest,
 376    model_id: String,
 377    mode: GoogleModelMode,
 378) -> google_ai::GenerateContentRequest {
 379    fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
 380        content
 381            .into_iter()
 382            .flat_map(|content| match content {
 383                language_model::MessageContent::Text(text) => {
 384                    if !text.is_empty() {
 385                        vec![Part::TextPart(google_ai::TextPart { text })]
 386                    } else {
 387                        vec![]
 388                    }
 389                }
 390                language_model::MessageContent::Thinking {
 391                    text: _,
 392                    signature: Some(signature),
 393                } => {
 394                    if !signature.is_empty() {
 395                        vec![Part::ThoughtPart(google_ai::ThoughtPart {
 396                            thought: true,
 397                            thought_signature: signature,
 398                        })]
 399                    } else {
 400                        vec![]
 401                    }
 402                }
 403                language_model::MessageContent::Thinking { .. } => {
 404                    vec![]
 405                }
 406                language_model::MessageContent::RedactedThinking(_) => vec![],
 407                language_model::MessageContent::Image(image) => {
 408                    vec![Part::InlineDataPart(google_ai::InlineDataPart {
 409                        inline_data: google_ai::GenerativeContentBlob {
 410                            mime_type: "image/png".to_string(),
 411                            data: image.source.to_string(),
 412                        },
 413                    })]
 414                }
 415                language_model::MessageContent::ToolUse(tool_use) => {
 416                    // Normalize empty string signatures to None
 417                    let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty());
 418
 419                    vec![Part::FunctionCallPart(google_ai::FunctionCallPart {
 420                        function_call: google_ai::FunctionCall {
 421                            name: tool_use.name.to_string(),
 422                            args: tool_use.input,
 423                        },
 424                        thought_signature,
 425                    })]
 426                }
 427                language_model::MessageContent::ToolResult(tool_result) => {
 428                    match tool_result.content {
 429                        language_model::LanguageModelToolResultContent::Text(text) => {
 430                            vec![Part::FunctionResponsePart(
 431                                google_ai::FunctionResponsePart {
 432                                    function_response: google_ai::FunctionResponse {
 433                                        name: tool_result.tool_name.to_string(),
 434                                        // The API expects a valid JSON object
 435                                        response: serde_json::json!({
 436                                            "output": text
 437                                        }),
 438                                    },
 439                                },
 440                            )]
 441                        }
 442                        language_model::LanguageModelToolResultContent::Image(image) => {
 443                            vec![
 444                                Part::FunctionResponsePart(google_ai::FunctionResponsePart {
 445                                    function_response: google_ai::FunctionResponse {
 446                                        name: tool_result.tool_name.to_string(),
 447                                        // The API expects a valid JSON object
 448                                        response: serde_json::json!({
 449                                            "output": "Tool responded with an image"
 450                                        }),
 451                                    },
 452                                }),
 453                                Part::InlineDataPart(google_ai::InlineDataPart {
 454                                    inline_data: google_ai::GenerativeContentBlob {
 455                                        mime_type: "image/png".to_string(),
 456                                        data: image.source.to_string(),
 457                                    },
 458                                }),
 459                            ]
 460                        }
 461                    }
 462                }
 463            })
 464            .collect()
 465    }
 466
 467    let system_instructions = if request
 468        .messages
 469        .first()
 470        .is_some_and(|msg| matches!(msg.role, Role::System))
 471    {
 472        let message = request.messages.remove(0);
 473        Some(SystemInstruction {
 474            parts: map_content(message.content),
 475        })
 476    } else {
 477        None
 478    };
 479
 480    google_ai::GenerateContentRequest {
 481        model: google_ai::ModelName { model_id },
 482        system_instruction: system_instructions,
 483        contents: request
 484            .messages
 485            .into_iter()
 486            .filter_map(|message| {
 487                let parts = map_content(message.content);
 488                if parts.is_empty() {
 489                    None
 490                } else {
 491                    Some(google_ai::Content {
 492                        parts,
 493                        role: match message.role {
 494                            Role::User => google_ai::Role::User,
 495                            Role::Assistant => google_ai::Role::Model,
 496                            Role::System => google_ai::Role::User, // Google AI doesn't have a system role
 497                        },
 498                    })
 499                }
 500            })
 501            .collect(),
 502        generation_config: Some(google_ai::GenerationConfig {
 503            candidate_count: Some(1),
 504            stop_sequences: Some(request.stop),
 505            max_output_tokens: None,
 506            temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
 507            thinking_config: match (request.thinking_allowed, mode) {
 508                (true, GoogleModelMode::Thinking { budget_tokens }) => {
 509                    budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
 510                }
 511                _ => None,
 512            },
 513            top_p: None,
 514            top_k: None,
 515        }),
 516        safety_settings: None,
 517        tools: (!request.tools.is_empty()).then(|| {
 518            vec![google_ai::Tool {
 519                function_declarations: request
 520                    .tools
 521                    .into_iter()
 522                    .map(|tool| FunctionDeclaration {
 523                        name: tool.name,
 524                        description: tool.description,
 525                        parameters: tool.input_schema,
 526                    })
 527                    .collect(),
 528            }]
 529        }),
 530        tool_config: request.tool_choice.map(|choice| google_ai::ToolConfig {
 531            function_calling_config: google_ai::FunctionCallingConfig {
 532                mode: match choice {
 533                    LanguageModelToolChoice::Auto => google_ai::FunctionCallingMode::Auto,
 534                    LanguageModelToolChoice::Any => google_ai::FunctionCallingMode::Any,
 535                    LanguageModelToolChoice::None => google_ai::FunctionCallingMode::None,
 536                },
 537                allowed_function_names: None,
 538            },
 539        }),
 540    }
 541}
 542
 543pub struct GoogleEventMapper {
 544    usage: UsageMetadata,
 545    stop_reason: StopReason,
 546}
 547
 548impl GoogleEventMapper {
 549    pub fn new() -> Self {
 550        Self {
 551            usage: UsageMetadata::default(),
 552            stop_reason: StopReason::EndTurn,
 553        }
 554    }
 555
 556    pub fn map_stream(
 557        mut self,
 558        events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
 559    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 560    {
 561        events
 562            .map(Some)
 563            .chain(futures::stream::once(async { None }))
 564            .flat_map(move |event| {
 565                futures::stream::iter(match event {
 566                    Some(Ok(event)) => self.map_event(event),
 567                    Some(Err(error)) => {
 568                        vec![Err(LanguageModelCompletionError::from(error))]
 569                    }
 570                    None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
 571                })
 572            })
 573    }
 574
 575    pub fn map_event(
 576        &mut self,
 577        event: GenerateContentResponse,
 578    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 579        static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
 580
 581        let mut events: Vec<_> = Vec::new();
 582        let mut wants_to_use_tool = false;
 583        if let Some(usage_metadata) = event.usage_metadata {
 584            update_usage(&mut self.usage, &usage_metadata);
 585            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
 586                convert_usage(&self.usage),
 587            )))
 588        }
 589
 590        if let Some(prompt_feedback) = event.prompt_feedback
 591            && let Some(block_reason) = prompt_feedback.block_reason.as_deref()
 592        {
 593            self.stop_reason = match block_reason {
 594                "SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "IMAGE_SAFETY" => {
 595                    StopReason::Refusal
 596                }
 597                _ => {
 598                    log::error!("Unexpected Google block_reason: {block_reason}");
 599                    StopReason::Refusal
 600                }
 601            };
 602            events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason)));
 603
 604            return events;
 605        }
 606
 607        if let Some(candidates) = event.candidates {
 608            for candidate in candidates {
 609                if let Some(finish_reason) = candidate.finish_reason.as_deref() {
 610                    self.stop_reason = match finish_reason {
 611                        "STOP" => StopReason::EndTurn,
 612                        "MAX_TOKENS" => StopReason::MaxTokens,
 613                        _ => {
 614                            log::error!("Unexpected google finish_reason: {finish_reason}");
 615                            StopReason::EndTurn
 616                        }
 617                    };
 618                }
 619                candidate
 620                    .content
 621                    .parts
 622                    .into_iter()
 623                    .for_each(|part| match part {
 624                        Part::TextPart(text_part) => {
 625                            events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text)))
 626                        }
 627                        Part::InlineDataPart(_) => {}
 628                        Part::FunctionCallPart(function_call_part) => {
 629                            wants_to_use_tool = true;
 630                            let name: Arc<str> = function_call_part.function_call.name.into();
 631                            let next_tool_id =
 632                                TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
 633                            let id: LanguageModelToolUseId =
 634                                format!("{}-{}", name, next_tool_id).into();
 635
 636                            // Normalize empty string signatures to None
 637                            let thought_signature = function_call_part
 638                                .thought_signature
 639                                .filter(|s| !s.is_empty());
 640
 641                            events.push(Ok(LanguageModelCompletionEvent::ToolUse(
 642                                LanguageModelToolUse {
 643                                    id,
 644                                    name,
 645                                    is_input_complete: true,
 646                                    raw_input: function_call_part.function_call.args.to_string(),
 647                                    input: function_call_part.function_call.args,
 648                                    thought_signature,
 649                                },
 650                            )));
 651                        }
 652                        Part::FunctionResponsePart(_) => {}
 653                        Part::ThoughtPart(part) => {
 654                            events.push(Ok(LanguageModelCompletionEvent::Thinking {
 655                                text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
 656                                signature: Some(part.thought_signature),
 657                            }));
 658                        }
 659                    });
 660            }
 661        }
 662
 663        // Even when Gemini wants to use a Tool, the API
 664        // responds with `finish_reason: STOP`
 665        if wants_to_use_tool {
 666            self.stop_reason = StopReason::ToolUse;
 667            events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
 668        }
 669        events
 670    }
 671}
 672
 673pub fn count_google_tokens(
 674    request: LanguageModelRequest,
 675    cx: &App,
 676) -> BoxFuture<'static, Result<u64>> {
 677    // We couldn't use the GoogleLanguageModelProvider to count tokens because the github copilot doesn't have the access to google_ai directly.
 678    // So we have to use tokenizer from tiktoken_rs to count tokens.
 679    cx.background_spawn(async move {
 680        let messages = request
 681            .messages
 682            .into_iter()
 683            .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
 684                role: match message.role {
 685                    Role::User => "user".into(),
 686                    Role::Assistant => "assistant".into(),
 687                    Role::System => "system".into(),
 688                },
 689                content: Some(message.string_contents()),
 690                name: None,
 691                function_call: None,
 692            })
 693            .collect::<Vec<_>>();
 694
 695        // Tiktoken doesn't yet support these models, so we manually use the
 696        // same tokenizer as GPT-4.
 697        tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
 698    })
 699    .boxed()
 700}
 701
 702fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
 703    if let Some(prompt_token_count) = new.prompt_token_count {
 704        usage.prompt_token_count = Some(prompt_token_count);
 705    }
 706    if let Some(cached_content_token_count) = new.cached_content_token_count {
 707        usage.cached_content_token_count = Some(cached_content_token_count);
 708    }
 709    if let Some(candidates_token_count) = new.candidates_token_count {
 710        usage.candidates_token_count = Some(candidates_token_count);
 711    }
 712    if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count {
 713        usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count);
 714    }
 715    if let Some(thoughts_token_count) = new.thoughts_token_count {
 716        usage.thoughts_token_count = Some(thoughts_token_count);
 717    }
 718    if let Some(total_token_count) = new.total_token_count {
 719        usage.total_token_count = Some(total_token_count);
 720    }
 721}
 722
 723fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage {
 724    let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
 725    let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
 726    let input_tokens = prompt_tokens - cached_tokens;
 727    let output_tokens = usage.candidates_token_count.unwrap_or(0);
 728
 729    language_model::TokenUsage {
 730        input_tokens,
 731        output_tokens,
 732        cache_read_input_tokens: cached_tokens,
 733        cache_creation_input_tokens: 0,
 734    }
 735}
 736
 737struct ConfigurationView {
 738    api_key_editor: Entity<InputField>,
 739    state: Entity<State>,
 740    target_agent: language_model::ConfigurationViewTargetAgent,
 741    load_credentials_task: Option<Task<()>>,
 742}
 743
 744impl ConfigurationView {
 745    fn new(
 746        state: Entity<State>,
 747        target_agent: language_model::ConfigurationViewTargetAgent,
 748        window: &mut Window,
 749        cx: &mut Context<Self>,
 750    ) -> Self {
 751        cx.observe(&state, |_, _, cx| {
 752            cx.notify();
 753        })
 754        .detach();
 755
 756        let load_credentials_task = Some(cx.spawn_in(window, {
 757            let state = state.clone();
 758            async move |this, cx| {
 759                if let Some(task) = Some(state.update(cx, |state, cx| state.authenticate(cx))) {
 760                    // We don't log an error, because "not signed in" is also an error.
 761                    let _ = task.await;
 762                }
 763                this.update(cx, |this, cx| {
 764                    this.load_credentials_task = None;
 765                    cx.notify();
 766                })
 767                .log_err();
 768            }
 769        }));
 770
 771        Self {
 772            api_key_editor: cx.new(|cx| InputField::new(window, cx, "AIzaSy...")),
 773            target_agent,
 774            state,
 775            load_credentials_task,
 776        }
 777    }
 778
 779    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 780        let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
 781        if api_key.is_empty() {
 782            return;
 783        }
 784
 785        // url changes can cause the editor to be displayed again
 786        self.api_key_editor
 787            .update(cx, |editor, cx| editor.set_text("", window, cx));
 788
 789        let state = self.state.clone();
 790        cx.spawn_in(window, async move |_, cx| {
 791            state
 792                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
 793                .await
 794        })
 795        .detach_and_log_err(cx);
 796    }
 797
 798    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 799        self.api_key_editor
 800            .update(cx, |editor, cx| editor.set_text("", window, cx));
 801
 802        let state = self.state.clone();
 803        cx.spawn_in(window, async move |_, cx| {
 804            state
 805                .update(cx, |state, cx| state.set_api_key(None, cx))
 806                .await
 807        })
 808        .detach_and_log_err(cx);
 809    }
 810
 811    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
 812        !self.state.read(cx).is_authenticated()
 813    }
 814}
 815
 816impl Render for ConfigurationView {
 817    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 818        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
 819        let configured_card_label = if env_var_set {
 820            format!(
 821                "API key set in {} environment variable",
 822                API_KEY_ENV_VAR.name
 823            )
 824        } else {
 825            let api_url = GoogleLanguageModelProvider::api_url(cx);
 826            if api_url == google_ai::API_URL {
 827                "API key configured".to_string()
 828            } else {
 829                format!("API key configured for {}", api_url)
 830            }
 831        };
 832
 833        if self.load_credentials_task.is_some() {
 834            div()
 835                .child(Label::new("Loading credentials..."))
 836                .into_any_element()
 837        } else if self.should_render_editor(cx) {
 838            v_flex()
 839                .size_full()
 840                .on_action(cx.listener(Self::save_api_key))
 841                .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent {
 842                    ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Google AI".into(),
 843                    ConfigurationViewTargetAgent::Other(agent) => agent.clone(),
 844                })))
 845                .child(
 846                    List::new()
 847                        .child(
 848                            ListBulletItem::new("")
 849                                .child(Label::new("Create one by visiting"))
 850                                .child(ButtonLink::new("Google AI's console", "https://aistudio.google.com/app/apikey"))
 851                        )
 852                        .child(
 853                            ListBulletItem::new("Paste your API key below and hit enter to start using the agent")
 854                        )
 855                )
 856                .child(self.api_key_editor.clone())
 857                .child(
 858                    Label::new(
 859                        format!("You can also set the {GEMINI_API_KEY_VAR_NAME} environment variable and restart Zed."),
 860                    )
 861                    .size(LabelSize::Small).color(Color::Muted),
 862                )
 863                .into_any_element()
 864        } else {
 865            ConfiguredApiCard::new(configured_card_label)
 866                .disabled(env_var_set)
 867                .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
 868                .when(env_var_set, |this| {
 869                    this.tooltip_label(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR_NAME} and {GOOGLE_AI_API_KEY_VAR_NAME} environment variables are unset."))
 870                })
 871                .into_any_element()
 872        }
 873    }
 874}
 875
 876#[cfg(test)]
 877mod tests {
 878    use super::*;
 879    use google_ai::{
 880        Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse,
 881        Part, Role as GoogleRole, TextPart,
 882    };
 883    use language_model::{LanguageModelToolUseId, MessageContent, Role};
 884    use serde_json::json;
 885
 886    #[test]
 887    fn test_function_call_with_signature_creates_tool_use_with_signature() {
 888        let mut mapper = GoogleEventMapper::new();
 889
 890        let response = GenerateContentResponse {
 891            candidates: Some(vec![GenerateContentCandidate {
 892                index: Some(0),
 893                content: Content {
 894                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
 895                        function_call: FunctionCall {
 896                            name: "test_function".to_string(),
 897                            args: json!({"arg": "value"}),
 898                        },
 899                        thought_signature: Some("test_signature_123".to_string()),
 900                    })],
 901                    role: GoogleRole::Model,
 902                },
 903                finish_reason: None,
 904                finish_message: None,
 905                safety_ratings: None,
 906                citation_metadata: None,
 907            }]),
 908            prompt_feedback: None,
 909            usage_metadata: None,
 910        };
 911
 912        let events = mapper.map_event(response);
 913
 914        assert_eq!(events.len(), 2); // ToolUse event + Stop event
 915
 916        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
 917            assert_eq!(tool_use.name.as_ref(), "test_function");
 918            assert_eq!(
 919                tool_use.thought_signature.as_deref(),
 920                Some("test_signature_123")
 921            );
 922        } else {
 923            panic!("Expected ToolUse event");
 924        }
 925    }
 926
 927    #[test]
 928    fn test_function_call_without_signature_has_none() {
 929        let mut mapper = GoogleEventMapper::new();
 930
 931        let response = GenerateContentResponse {
 932            candidates: Some(vec![GenerateContentCandidate {
 933                index: Some(0),
 934                content: Content {
 935                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
 936                        function_call: FunctionCall {
 937                            name: "test_function".to_string(),
 938                            args: json!({"arg": "value"}),
 939                        },
 940                        thought_signature: None,
 941                    })],
 942                    role: GoogleRole::Model,
 943                },
 944                finish_reason: None,
 945                finish_message: None,
 946                safety_ratings: None,
 947                citation_metadata: None,
 948            }]),
 949            prompt_feedback: None,
 950            usage_metadata: None,
 951        };
 952
 953        let events = mapper.map_event(response);
 954
 955        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
 956            assert_eq!(tool_use.thought_signature, None);
 957        } else {
 958            panic!("Expected ToolUse event");
 959        }
 960    }
 961
 962    #[test]
 963    fn test_empty_string_signature_normalized_to_none() {
 964        let mut mapper = GoogleEventMapper::new();
 965
 966        let response = GenerateContentResponse {
 967            candidates: Some(vec![GenerateContentCandidate {
 968                index: Some(0),
 969                content: Content {
 970                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
 971                        function_call: FunctionCall {
 972                            name: "test_function".to_string(),
 973                            args: json!({"arg": "value"}),
 974                        },
 975                        thought_signature: Some("".to_string()),
 976                    })],
 977                    role: GoogleRole::Model,
 978                },
 979                finish_reason: None,
 980                finish_message: None,
 981                safety_ratings: None,
 982                citation_metadata: None,
 983            }]),
 984            prompt_feedback: None,
 985            usage_metadata: None,
 986        };
 987
 988        let events = mapper.map_event(response);
 989
 990        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
 991            assert_eq!(tool_use.thought_signature, None);
 992        } else {
 993            panic!("Expected ToolUse event");
 994        }
 995    }
 996
 997    #[test]
 998    fn test_parallel_function_calls_preserve_signatures() {
 999        let mut mapper = GoogleEventMapper::new();
1000
1001        let response = GenerateContentResponse {
1002            candidates: Some(vec![GenerateContentCandidate {
1003                index: Some(0),
1004                content: Content {
1005                    parts: vec![
1006                        Part::FunctionCallPart(FunctionCallPart {
1007                            function_call: FunctionCall {
1008                                name: "function_1".to_string(),
1009                                args: json!({"arg": "value1"}),
1010                            },
1011                            thought_signature: Some("signature_1".to_string()),
1012                        }),
1013                        Part::FunctionCallPart(FunctionCallPart {
1014                            function_call: FunctionCall {
1015                                name: "function_2".to_string(),
1016                                args: json!({"arg": "value2"}),
1017                            },
1018                            thought_signature: None,
1019                        }),
1020                    ],
1021                    role: GoogleRole::Model,
1022                },
1023                finish_reason: None,
1024                finish_message: None,
1025                safety_ratings: None,
1026                citation_metadata: None,
1027            }]),
1028            prompt_feedback: None,
1029            usage_metadata: None,
1030        };
1031
1032        let events = mapper.map_event(response);
1033
1034        assert_eq!(events.len(), 3); // 2 ToolUse events + Stop event
1035
1036        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
1037            assert_eq!(tool_use.name.as_ref(), "function_1");
1038            assert_eq!(tool_use.thought_signature.as_deref(), Some("signature_1"));
1039        } else {
1040            panic!("Expected ToolUse event for function_1");
1041        }
1042
1043        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
1044            assert_eq!(tool_use.name.as_ref(), "function_2");
1045            assert_eq!(tool_use.thought_signature, None);
1046        } else {
1047            panic!("Expected ToolUse event for function_2");
1048        }
1049    }
1050
1051    #[test]
1052    fn test_tool_use_with_signature_converts_to_function_call_part() {
1053        let tool_use = language_model::LanguageModelToolUse {
1054            id: LanguageModelToolUseId::from("test_id"),
1055            name: "test_function".into(),
1056            raw_input: json!({"arg": "value"}).to_string(),
1057            input: json!({"arg": "value"}),
1058            is_input_complete: true,
1059            thought_signature: Some("test_signature_456".to_string()),
1060        };
1061
1062        let request = super::into_google(
1063            LanguageModelRequest {
1064                messages: vec![language_model::LanguageModelRequestMessage {
1065                    role: Role::Assistant,
1066                    content: vec![MessageContent::ToolUse(tool_use)],
1067                    cache: false,
1068                    reasoning_details: None,
1069                }],
1070                ..Default::default()
1071            },
1072            "gemini-2.5-flash".to_string(),
1073            GoogleModelMode::Default,
1074        );
1075
1076        assert_eq!(request.contents[0].parts.len(), 1);
1077        if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1078            assert_eq!(fc_part.function_call.name, "test_function");
1079            assert_eq!(
1080                fc_part.thought_signature.as_deref(),
1081                Some("test_signature_456")
1082            );
1083        } else {
1084            panic!("Expected FunctionCallPart");
1085        }
1086    }
1087
1088    #[test]
1089    fn test_tool_use_without_signature_omits_field() {
1090        let tool_use = language_model::LanguageModelToolUse {
1091            id: LanguageModelToolUseId::from("test_id"),
1092            name: "test_function".into(),
1093            raw_input: json!({"arg": "value"}).to_string(),
1094            input: json!({"arg": "value"}),
1095            is_input_complete: true,
1096            thought_signature: None,
1097        };
1098
1099        let request = super::into_google(
1100            LanguageModelRequest {
1101                messages: vec![language_model::LanguageModelRequestMessage {
1102                    role: Role::Assistant,
1103                    content: vec![MessageContent::ToolUse(tool_use)],
1104                    cache: false,
1105                    reasoning_details: None,
1106                }],
1107                ..Default::default()
1108            },
1109            "gemini-2.5-flash".to_string(),
1110            GoogleModelMode::Default,
1111        );
1112
1113        assert_eq!(request.contents[0].parts.len(), 1);
1114        if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1115            assert_eq!(fc_part.thought_signature, None);
1116        } else {
1117            panic!("Expected FunctionCallPart");
1118        }
1119    }
1120
1121    #[test]
1122    fn test_empty_signature_in_tool_use_normalized_to_none() {
1123        let tool_use = language_model::LanguageModelToolUse {
1124            id: LanguageModelToolUseId::from("test_id"),
1125            name: "test_function".into(),
1126            raw_input: json!({"arg": "value"}).to_string(),
1127            input: json!({"arg": "value"}),
1128            is_input_complete: true,
1129            thought_signature: Some("".to_string()),
1130        };
1131
1132        let request = super::into_google(
1133            LanguageModelRequest {
1134                messages: vec![language_model::LanguageModelRequestMessage {
1135                    role: Role::Assistant,
1136                    content: vec![MessageContent::ToolUse(tool_use)],
1137                    cache: false,
1138                    reasoning_details: None,
1139                }],
1140                ..Default::default()
1141            },
1142            "gemini-2.5-flash".to_string(),
1143            GoogleModelMode::Default,
1144        );
1145
1146        if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1147            assert_eq!(fc_part.thought_signature, None);
1148        } else {
1149            panic!("Expected FunctionCallPart");
1150        }
1151    }
1152
1153    #[test]
1154    fn test_round_trip_preserves_signature() {
1155        let mut mapper = GoogleEventMapper::new();
1156
1157        // Simulate receiving a response from Google with a signature
1158        let response = GenerateContentResponse {
1159            candidates: Some(vec![GenerateContentCandidate {
1160                index: Some(0),
1161                content: Content {
1162                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
1163                        function_call: FunctionCall {
1164                            name: "test_function".to_string(),
1165                            args: json!({"arg": "value"}),
1166                        },
1167                        thought_signature: Some("round_trip_sig".to_string()),
1168                    })],
1169                    role: GoogleRole::Model,
1170                },
1171                finish_reason: None,
1172                finish_message: None,
1173                safety_ratings: None,
1174                citation_metadata: None,
1175            }]),
1176            prompt_feedback: None,
1177            usage_metadata: None,
1178        };
1179
1180        let events = mapper.map_event(response);
1181
1182        let tool_use = if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
1183            tool_use.clone()
1184        } else {
1185            panic!("Expected ToolUse event");
1186        };
1187
1188        // Convert back to Google format
1189        let request = super::into_google(
1190            LanguageModelRequest {
1191                messages: vec![language_model::LanguageModelRequestMessage {
1192                    role: Role::Assistant,
1193                    content: vec![MessageContent::ToolUse(tool_use)],
1194                    cache: false,
1195                    reasoning_details: None,
1196                }],
1197                ..Default::default()
1198            },
1199            "gemini-2.5-flash".to_string(),
1200            GoogleModelMode::Default,
1201        );
1202
1203        // Verify signature is preserved
1204        if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1205            assert_eq!(fc_part.thought_signature.as_deref(), Some("round_trip_sig"));
1206        } else {
1207            panic!("Expected FunctionCallPart");
1208        }
1209    }
1210
1211    #[test]
1212    fn test_mixed_text_and_function_call_with_signature() {
1213        let mut mapper = GoogleEventMapper::new();
1214
1215        let response = GenerateContentResponse {
1216            candidates: Some(vec![GenerateContentCandidate {
1217                index: Some(0),
1218                content: Content {
1219                    parts: vec![
1220                        Part::TextPart(TextPart {
1221                            text: "I'll help with that.".to_string(),
1222                        }),
1223                        Part::FunctionCallPart(FunctionCallPart {
1224                            function_call: FunctionCall {
1225                                name: "helper_function".to_string(),
1226                                args: json!({"query": "help"}),
1227                            },
1228                            thought_signature: Some("mixed_sig".to_string()),
1229                        }),
1230                    ],
1231                    role: GoogleRole::Model,
1232                },
1233                finish_reason: None,
1234                finish_message: None,
1235                safety_ratings: None,
1236                citation_metadata: None,
1237            }]),
1238            prompt_feedback: None,
1239            usage_metadata: None,
1240        };
1241
1242        let events = mapper.map_event(response);
1243
1244        assert_eq!(events.len(), 3); // Text event + ToolUse event + Stop event
1245
1246        if let Ok(LanguageModelCompletionEvent::Text(text)) = &events[0] {
1247            assert_eq!(text, "I'll help with that.");
1248        } else {
1249            panic!("Expected Text event");
1250        }
1251
1252        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
1253            assert_eq!(tool_use.name.as_ref(), "helper_function");
1254            assert_eq!(tool_use.thought_signature.as_deref(), Some("mixed_sig"));
1255        } else {
1256            panic!("Expected ToolUse event");
1257        }
1258    }
1259
1260    #[test]
1261    fn test_special_characters_in_signature_preserved() {
1262        let mut mapper = GoogleEventMapper::new();
1263
1264        let signature_with_special_chars = "sig<>\"'&%$#@!{}[]".to_string();
1265
1266        let response = GenerateContentResponse {
1267            candidates: Some(vec![GenerateContentCandidate {
1268                index: Some(0),
1269                content: Content {
1270                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
1271                        function_call: FunctionCall {
1272                            name: "test_function".to_string(),
1273                            args: json!({"arg": "value"}),
1274                        },
1275                        thought_signature: Some(signature_with_special_chars.clone()),
1276                    })],
1277                    role: GoogleRole::Model,
1278                },
1279                finish_reason: None,
1280                finish_message: None,
1281                safety_ratings: None,
1282                citation_metadata: None,
1283            }]),
1284            prompt_feedback: None,
1285            usage_metadata: None,
1286        };
1287
1288        let events = mapper.map_event(response);
1289
1290        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
1291            assert_eq!(
1292                tool_use.thought_signature.as_deref(),
1293                Some(signature_with_special_chars.as_str())
1294            );
1295        } else {
1296            panic!("Expected ToolUse event");
1297        }
1298    }
1299}