google_ai.rs

   1use std::collections::HashMap;
   2use std::sync::atomic::{AtomicU64, Ordering};
   3
   4static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
   5
   6use serde::{Deserialize, Deserializer, Serialize, Serializer};
   7use zed_extension_api::{
   8    self as zed, http_client::HttpMethod, http_client::HttpRequest,
   9    llm_get_env_var, llm_get_provider_settings,
  10    LlmCacheConfiguration, LlmCompletionEvent, LlmCompletionRequest, LlmCustomModelConfig,
  11    LlmMessageContent, LlmMessageRole, LlmModelCapabilities, LlmModelInfo, LlmProviderInfo,
  12    LlmStopReason, LlmThinkingContent, LlmTokenUsage, LlmToolInputFormat, LlmToolUse,
  13};
  14
  15pub const DEFAULT_API_URL: &str = "https://generativelanguage.googleapis.com";
  16
  17fn get_api_url() -> String {
  18    llm_get_provider_settings(PROVIDER_ID)
  19        .and_then(|s| s.api_url)
  20        .unwrap_or_else(|| DEFAULT_API_URL.to_string())
  21}
  22
  23fn get_custom_models() -> Vec<LlmCustomModelConfig> {
  24    llm_get_provider_settings(PROVIDER_ID)
  25        .map(|s| s.available_models)
  26        .unwrap_or_default()
  27}
  28
  29fn stream_generate_content(
  30    model_id: &str,
  31    request: &LlmCompletionRequest,
  32    streams: &mut HashMap<String, StreamState>,
  33    next_stream_id: &mut u64,
  34) -> Result<String, String> {
  35    let api_key = get_api_key().ok_or_else(|| "API key not configured".to_string())?;
  36
  37    let generate_content_request = build_generate_content_request(model_id, request)?;
  38    validate_generate_content_request(&generate_content_request)?;
  39
  40    let api_url = get_api_url();
  41    let uri = format!(
  42        "{}/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
  43        api_url, model_id, api_key
  44    );
  45
  46    let body = serde_json::to_vec(&generate_content_request)
  47        .map_err(|e| format!("Failed to serialize request: {}", e))?;
  48
  49    let http_request = HttpRequest::builder()
  50        .method(HttpMethod::Post)
  51        .url(&uri)
  52        .header("Content-Type", "application/json")
  53        .body(body)
  54        .build()?;
  55
  56    let response_stream = http_request.fetch_stream()?;
  57
  58    let stream_id = format!("stream-{}", *next_stream_id);
  59    *next_stream_id += 1;
  60
  61    streams.insert(
  62        stream_id.clone(),
  63        StreamState {
  64            response_stream,
  65            buffer: String::new(),
  66            usage: None,
  67            pending_events: Vec::new(),
  68            wants_to_use_tool: false,
  69        },
  70    );
  71
  72    Ok(stream_id)
  73}
  74
  75fn count_tokens(model_id: &str, request: &LlmCompletionRequest) -> Result<u64, String> {
  76    let api_key = get_api_key().ok_or_else(|| "API key not configured".to_string())?;
  77
  78    let generate_content_request = build_generate_content_request(model_id, request)?;
  79    validate_generate_content_request(&generate_content_request)?;
  80    let count_request = CountTokensRequest {
  81        generate_content_request,
  82    };
  83
  84    let api_url = get_api_url();
  85    let uri = format!(
  86        "{}/v1beta/models/{}:countTokens?key={}",
  87        api_url, model_id, api_key
  88    );
  89
  90    let body = serde_json::to_vec(&count_request)
  91        .map_err(|e| format!("Failed to serialize request: {}", e))?;
  92
  93    let http_request = HttpRequest::builder()
  94        .method(HttpMethod::Post)
  95        .url(&uri)
  96        .header("Content-Type", "application/json")
  97        .body(body)
  98        .build()?;
  99
 100    let response = http_request.fetch()?;
 101    let response_body: CountTokensResponse = serde_json::from_slice(&response.body)
 102        .map_err(|e| format!("Failed to parse response: {}", e))?;
 103
 104    Ok(response_body.total_tokens)
 105}
 106
 107fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<(), String> {
 108    if request.model.is_empty() {
 109        return Err("Model must be specified".to_string());
 110    }
 111
 112    if request.contents.is_empty() {
 113        return Err("Request must contain at least one content item".to_string());
 114    }
 115
 116    if let Some(user_content) = request
 117        .contents
 118        .iter()
 119        .find(|content| content.role == Role::User)
 120    {
 121        if user_content.parts.is_empty() {
 122            return Err("User content must contain at least one part".to_string());
 123        }
 124    }
 125
 126    Ok(())
 127}
 128
 129// Extension implementation
 130
 131const PROVIDER_ID: &str = "google";
 132const PROVIDER_NAME: &str = "Google AI";
 133
 134struct GoogleAiExtension {
 135    streams: HashMap<String, StreamState>,
 136    next_stream_id: u64,
 137}
 138
 139struct StreamState {
 140    response_stream: zed::http_client::HttpResponseStream,
 141    buffer: String,
 142    usage: Option<UsageMetadata>,
 143    pending_events: Vec<LlmCompletionEvent>,
 144    wants_to_use_tool: bool,
 145}
 146
 147impl zed::Extension for GoogleAiExtension {
 148    fn new() -> Self {
 149        Self {
 150            streams: HashMap::new(),
 151            next_stream_id: 0,
 152        }
 153    }
 154
 155    fn llm_providers(&self) -> Vec<LlmProviderInfo> {
 156        vec![LlmProviderInfo {
 157            id: PROVIDER_ID.to_string(),
 158            name: PROVIDER_NAME.to_string(),
 159            icon: Some("icons/google-ai.svg".to_string()),
 160        }]
 161    }
 162
 163    fn llm_provider_models(&self, provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
 164        if provider_id != PROVIDER_ID {
 165            return Err(format!("Unknown provider: {}", provider_id));
 166        }
 167        Ok(get_models())
 168    }
 169
 170    fn llm_provider_settings_markdown(&self, provider_id: &str) -> Option<String> {
 171        if provider_id != PROVIDER_ID {
 172            return None;
 173        }
 174
 175        Some(
 176            r#"## Google AI Setup
 177
 178To use Google AI models in Zed, you need a Gemini API key.
 179
 1801. Go to [Google AI Studio](https://aistudio.google.com/apikey)
 1812. Create or select a project
 1823. Generate an API key
 1834. Set the `GEMINI_API_KEY` or `GOOGLE_AI_API_KEY` environment variable
 184
 185You can set this in your shell profile or use a `.envrc` file with [direnv](https://direnv.net/).
 186"#
 187            .to_string(),
 188        )
 189    }
 190
 191    fn llm_provider_is_authenticated(&self, provider_id: &str) -> bool {
 192        if provider_id != PROVIDER_ID {
 193            return false;
 194        }
 195        get_api_key().is_some()
 196    }
 197
 198    fn llm_provider_reset_credentials(&mut self, provider_id: &str) -> Result<(), String> {
 199        if provider_id != PROVIDER_ID {
 200            return Err(format!("Unknown provider: {}", provider_id));
 201        }
 202        Ok(())
 203    }
 204
 205    fn llm_count_tokens(
 206        &self,
 207        provider_id: &str,
 208        model_id: &str,
 209        request: &LlmCompletionRequest,
 210    ) -> Result<u64, String> {
 211        if provider_id != PROVIDER_ID {
 212            return Err(format!("Unknown provider: {}", provider_id));
 213        }
 214        count_tokens(model_id, request)
 215    }
 216
 217    fn llm_stream_completion_start(
 218        &mut self,
 219        provider_id: &str,
 220        model_id: &str,
 221        request: &LlmCompletionRequest,
 222    ) -> Result<String, String> {
 223        if provider_id != PROVIDER_ID {
 224            return Err(format!("Unknown provider: {}", provider_id));
 225        }
 226        stream_generate_content(model_id, request, &mut self.streams, &mut self.next_stream_id)
 227    }
 228
 229    fn llm_stream_completion_next(
 230        &mut self,
 231        stream_id: &str,
 232    ) -> Result<Option<LlmCompletionEvent>, String> {
 233        stream_generate_content_next(stream_id, &mut self.streams)
 234    }
 235
 236    fn llm_stream_completion_close(&mut self, stream_id: &str) {
 237        self.streams.remove(stream_id);
 238    }
 239
 240    fn llm_cache_configuration(
 241        &self,
 242        provider_id: &str,
 243        _model_id: &str,
 244    ) -> Option<LlmCacheConfiguration> {
 245        if provider_id != PROVIDER_ID {
 246            return None;
 247        }
 248
 249        Some(LlmCacheConfiguration {
 250            max_cache_anchors: 1,
 251            should_cache_tool_definitions: false,
 252            min_total_token_count: 32768,
 253        })
 254    }
 255}
 256
 257zed::register_extension!(GoogleAiExtension);
 258
 259// Helper functions
 260
 261fn get_api_key() -> Option<String> {
 262    llm_get_env_var("GEMINI_API_KEY").or_else(|| llm_get_env_var("GOOGLE_AI_API_KEY"))
 263}
 264
 265fn get_default_models() -> Vec<LlmModelInfo> {
 266    vec![
 267        LlmModelInfo {
 268            id: "gemini-2.5-flash-lite".to_string(),
 269            name: "Gemini 2.5 Flash-Lite".to_string(),
 270            max_token_count: 1_048_576,
 271            max_output_tokens: Some(65_536),
 272            capabilities: LlmModelCapabilities {
 273                supports_images: true,
 274                supports_tools: true,
 275                supports_tool_choice_auto: true,
 276                supports_tool_choice_any: true,
 277                supports_tool_choice_none: true,
 278                supports_thinking: true,
 279                tool_input_format: LlmToolInputFormat::JsonSchemaSubset,
 280            },
 281            is_default: false,
 282            is_default_fast: true,
 283        },
 284        LlmModelInfo {
 285            id: "gemini-2.5-flash".to_string(),
 286            name: "Gemini 2.5 Flash".to_string(),
 287            max_token_count: 1_048_576,
 288            max_output_tokens: Some(65_536),
 289            capabilities: LlmModelCapabilities {
 290                supports_images: true,
 291                supports_tools: true,
 292                supports_tool_choice_auto: true,
 293                supports_tool_choice_any: true,
 294                supports_tool_choice_none: true,
 295                supports_thinking: true,
 296                tool_input_format: LlmToolInputFormat::JsonSchemaSubset,
 297            },
 298            is_default: true,
 299            is_default_fast: false,
 300        },
 301        LlmModelInfo {
 302            id: "gemini-2.5-pro".to_string(),
 303            name: "Gemini 2.5 Pro".to_string(),
 304            max_token_count: 1_048_576,
 305            max_output_tokens: Some(65_536),
 306            capabilities: LlmModelCapabilities {
 307                supports_images: true,
 308                supports_tools: true,
 309                supports_tool_choice_auto: true,
 310                supports_tool_choice_any: true,
 311                supports_tool_choice_none: true,
 312                supports_thinking: true,
 313                tool_input_format: LlmToolInputFormat::JsonSchemaSubset,
 314            },
 315            is_default: false,
 316            is_default_fast: false,
 317        },
 318        LlmModelInfo {
 319            id: "gemini-3-pro-preview".to_string(),
 320            name: "Gemini 3 Pro".to_string(),
 321            max_token_count: 1_048_576,
 322            max_output_tokens: Some(65_536),
 323            capabilities: LlmModelCapabilities {
 324                supports_images: true,
 325                supports_tools: true,
 326                supports_tool_choice_auto: true,
 327                supports_tool_choice_any: true,
 328                supports_tool_choice_none: true,
 329                supports_thinking: true,
 330                tool_input_format: LlmToolInputFormat::JsonSchemaSubset,
 331            },
 332            is_default: false,
 333            is_default_fast: false,
 334        },
 335        LlmModelInfo {
 336            id: "gemini-3-flash-preview".to_string(),
 337            name: "Gemini 3 Flash".to_string(),
 338            max_token_count: 1_048_576,
 339            max_output_tokens: Some(65_536),
 340            capabilities: LlmModelCapabilities {
 341                supports_images: true,
 342                supports_tools: true,
 343                supports_tool_choice_auto: true,
 344                supports_tool_choice_any: true,
 345                supports_tool_choice_none: true,
 346                supports_thinking: false,
 347                tool_input_format: LlmToolInputFormat::JsonSchemaSubset,
 348            },
 349            is_default: false,
 350            is_default_fast: false,
 351        },
 352    ]
 353}
 354
 355/// Model aliases for backward compatibility with old model names.
 356/// Maps old names to canonical model IDs.
 357fn get_model_aliases() -> Vec<(&'static str, &'static str)> {
 358    vec![
 359        // Gemini 2.5 Flash-Lite aliases
 360        ("gemini-2.5-flash-lite-preview-06-17", "gemini-2.5-flash-lite"),
 361        ("gemini-2.0-flash-lite-preview", "gemini-2.5-flash-lite"),
 362        // Gemini 2.5 Flash aliases
 363        ("gemini-2.0-flash-thinking-exp", "gemini-2.5-flash"),
 364        ("gemini-2.5-flash-preview-04-17", "gemini-2.5-flash"),
 365        ("gemini-2.5-flash-preview-05-20", "gemini-2.5-flash"),
 366        ("gemini-2.5-flash-preview-latest", "gemini-2.5-flash"),
 367        ("gemini-2.0-flash", "gemini-2.5-flash"),
 368        // Gemini 2.5 Pro aliases
 369        ("gemini-2.0-pro-exp", "gemini-2.5-pro"),
 370        ("gemini-2.5-pro-preview-latest", "gemini-2.5-pro"),
 371        ("gemini-2.5-pro-exp-03-25", "gemini-2.5-pro"),
 372        ("gemini-2.5-pro-preview-03-25", "gemini-2.5-pro"),
 373        ("gemini-2.5-pro-preview-05-06", "gemini-2.5-pro"),
 374        ("gemini-2.5-pro-preview-06-05", "gemini-2.5-pro"),
 375    ]
 376}
 377
 378fn get_models() -> Vec<LlmModelInfo> {
 379    let mut models: HashMap<String, LlmModelInfo> = HashMap::new();
 380
 381    // Add default models
 382    for model in get_default_models() {
 383        models.insert(model.id.clone(), model);
 384    }
 385
 386    // Add aliases as separate model entries (pointing to the same underlying model)
 387    for (alias, canonical_id) in get_model_aliases() {
 388        if let Some(canonical_model) = models.get(canonical_id) {
 389            let mut alias_model = canonical_model.clone();
 390            alias_model.id = alias.to_string();
 391            alias_model.is_default = false;
 392            alias_model.is_default_fast = false;
 393            models.insert(alias.to_string(), alias_model);
 394        }
 395    }
 396
 397    // Add/override with custom models from settings
 398    for custom_model in get_custom_models() {
 399        let model = LlmModelInfo {
 400            id: custom_model.name.clone(),
 401            name: custom_model.display_name.unwrap_or(custom_model.name.clone()),
 402            max_token_count: custom_model.max_tokens,
 403            max_output_tokens: custom_model.max_output_tokens,
 404            capabilities: LlmModelCapabilities {
 405                supports_images: true,
 406                supports_tools: true,
 407                supports_tool_choice_auto: true,
 408                supports_tool_choice_any: true,
 409                supports_tool_choice_none: true,
 410                supports_thinking: custom_model.thinking_budget.is_some(),
 411                tool_input_format: LlmToolInputFormat::JsonSchemaSubset,
 412            },
 413            is_default: false,
 414            is_default_fast: false,
 415        };
 416        models.insert(custom_model.name, model);
 417    }
 418
 419    models.into_values().collect()
 420}
 421
 422/// Get the thinking budget for a specific model from custom settings.
 423fn get_model_thinking_budget(model_id: &str) -> Option<u32> {
 424    get_custom_models()
 425        .into_iter()
 426        .find(|m| m.name == model_id)
 427        .and_then(|m| m.thinking_budget)
 428}
 429
 430fn stream_generate_content_next(
 431    stream_id: &str,
 432    streams: &mut HashMap<String, StreamState>,
 433) -> Result<Option<LlmCompletionEvent>, String> {
 434    let state = streams
 435        .get_mut(stream_id)
 436        .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
 437
 438    loop {
 439        // Return any pending events first
 440        if let Some(event) = state.pending_events.pop() {
 441            return Ok(Some(event));
 442        }
 443
 444        if let Some(newline_pos) = state.buffer.find('\n') {
 445            let line = state.buffer[..newline_pos].to_string();
 446            state.buffer = state.buffer[newline_pos + 1..].to_string();
 447
 448            if let Some(data) = line.strip_prefix("data: ") {
 449                if data.trim().is_empty() {
 450                    continue;
 451                }
 452
 453                let response: GenerateContentResponse = match serde_json::from_str(data) {
 454                    Ok(response) => response,
 455                    Err(parse_error) => {
 456                        // Try to parse as an API error response
 457                        if let Ok(api_error) = serde_json::from_str::<ApiErrorResponse>(data) {
 458                            let error_msg = api_error
 459                                .error
 460                                .message
 461                                .unwrap_or_else(|| "Unknown API error".to_string());
 462                            let status = api_error.error.status.unwrap_or_default();
 463                            let code = api_error.error.code.unwrap_or(0);
 464                            return Err(format!(
 465                                "Google AI API error ({}): {} [status: {}]",
 466                                code, error_msg, status
 467                            ));
 468                        }
 469                        // If it's not an error response, return the parse error
 470                        return Err(format!(
 471                            "Failed to parse SSE data: {} - {}",
 472                            parse_error, data
 473                        ));
 474                    }
 475                };
 476
 477                // Handle prompt feedback (blocked prompts)
 478                if let Some(ref prompt_feedback) = response.prompt_feedback {
 479                    if let Some(ref block_reason) = prompt_feedback.block_reason {
 480                        let _stop_reason = match block_reason.as_str() {
 481                            "SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT"
 482                            | "IMAGE_SAFETY" => LlmStopReason::Refusal,
 483                            _ => LlmStopReason::Refusal,
 484                        };
 485                        return Ok(Some(LlmCompletionEvent::Stop(LlmStopReason::Refusal)));
 486                    }
 487                }
 488
 489                // Send usage updates immediately when received
 490                if let Some(ref usage) = response.usage_metadata {
 491                    let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
 492                    let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
 493                    let input_tokens = prompt_tokens.saturating_sub(cached_tokens);
 494                    state.pending_events.push(LlmCompletionEvent::Usage(LlmTokenUsage {
 495                        input_tokens,
 496                        output_tokens: usage.candidates_token_count.unwrap_or(0),
 497                        cache_creation_input_tokens: None,
 498                        cache_read_input_tokens: Some(cached_tokens).filter(|&c| c > 0),
 499                    }));
 500                    state.usage = Some(usage.clone());
 501                }
 502
 503                if let Some(candidates) = response.candidates {
 504                    for candidate in candidates {
 505                        for part in candidate.content.parts {
 506                            match part {
 507                                Part::TextPart(text_part) => {
 508                                    return Ok(Some(LlmCompletionEvent::Text(text_part.text)));
 509                                }
 510                                Part::ThoughtPart(thought_part) => {
 511                                    return Ok(Some(LlmCompletionEvent::Thinking(
 512                                        LlmThinkingContent {
 513                                            text: "(Encrypted thought)".to_string(),
 514                                            signature: Some(thought_part.thought_signature),
 515                                        },
 516                                    )));
 517                                }
 518                                Part::FunctionCallPart(fc_part) => {
 519                                    state.wants_to_use_tool = true;
 520                                    // Normalize empty string signatures to None
 521                                    let thought_signature =
 522                                        fc_part.thought_signature.filter(|s| !s.is_empty());
 523                                    // Generate unique tool use ID like hardcoded implementation
 524                                    let next_tool_id = TOOL_CALL_COUNTER.fetch_add(1, Ordering::SeqCst);
 525                                    let tool_use_id = format!("{}-{}", fc_part.function_call.name, next_tool_id);
 526                                    return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
 527                                        id: tool_use_id,
 528                                        name: fc_part.function_call.name,
 529                                        input: serde_json::to_string(&fc_part.function_call.args)
 530                                            .unwrap_or_default(),
 531                                        is_input_complete: true,
 532                                        thought_signature,
 533                                    })));
 534                                }
 535                                _ => {}
 536                            }
 537                        }
 538
 539                        if let Some(finish_reason) = candidate.finish_reason {
 540                            // Even when Gemini wants to use a Tool, the API
 541                            // responds with `finish_reason: STOP`, so we check
 542                            // wants_to_use_tool to override
 543                            let stop_reason = if state.wants_to_use_tool {
 544                                LlmStopReason::ToolUse
 545                            } else {
 546                                match finish_reason.as_str() {
 547                                    "STOP" => LlmStopReason::EndTurn,
 548                                    "MAX_TOKENS" => LlmStopReason::MaxTokens,
 549                                    "TOOL_USE" | "FUNCTION_CALL" => LlmStopReason::ToolUse,
 550                                    "SAFETY" | "RECITATION" | "OTHER" => LlmStopReason::Refusal,
 551                                    _ => LlmStopReason::EndTurn,
 552                                }
 553                            };
 554
 555                            return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
 556                        }
 557                    }
 558                }
 559            }
 560
 561            continue;
 562        }
 563
 564        // Check if the buffer contains a non-SSE error response (no "data: " prefix)
 565        // This can happen when Google returns an immediate error without streaming
 566        if !state.buffer.is_empty()
 567            && !state.buffer.contains("data: ")
 568            && state.buffer.contains("\"error\"")
 569        {
 570            // Try to parse the entire buffer as an error response
 571            if let Ok(api_error) = serde_json::from_str::<ApiErrorResponse>(&state.buffer) {
 572                let error_msg = api_error
 573                    .error
 574                    .message
 575                    .unwrap_or_else(|| "Unknown API error".to_string());
 576                let status = api_error.error.status.unwrap_or_default();
 577                let code = api_error.error.code.unwrap_or(0);
 578                streams.remove(stream_id);
 579                return Err(format!(
 580                    "Google AI API error ({}): {} [status: {}]",
 581                    code, error_msg, status
 582                ));
 583            }
 584        }
 585
 586        match state.response_stream.next_chunk() {
 587            Ok(Some(chunk)) => {
 588                let chunk_str = String::from_utf8_lossy(&chunk);
 589                state.buffer.push_str(&chunk_str);
 590            }
 591            Ok(None) => {
 592                streams.remove(stream_id);
 593                return Ok(None);
 594            }
 595            Err(e) => {
 596                streams.remove(stream_id);
 597                return Err(e);
 598            }
 599        }
 600    }
 601}
 602
 603fn build_generate_content_request(
 604    model_id: &str,
 605    request: &LlmCompletionRequest,
 606) -> Result<GenerateContentRequest, String> {
 607    let mut contents: Vec<Content> = Vec::new();
 608    let mut system_instruction: Option<SystemInstruction> = None;
 609
 610    for message in &request.messages {
 611        match message.role {
 612            LlmMessageRole::System => {
 613                let parts = convert_content_to_parts(&message.content)?;
 614                system_instruction = Some(SystemInstruction { parts });
 615            }
 616            LlmMessageRole::User | LlmMessageRole::Assistant => {
 617                let role = match message.role {
 618                    LlmMessageRole::User => Role::User,
 619                    LlmMessageRole::Assistant => Role::Model,
 620                    _ => continue,
 621                };
 622                let parts = convert_content_to_parts(&message.content)?;
 623                contents.push(Content { parts, role });
 624            }
 625        }
 626    }
 627
 628    let tools = if !request.tools.is_empty() {
 629        Some(vec![Tool {
 630            function_declarations: request
 631                .tools
 632                .iter()
 633                .map(|t| FunctionDeclaration {
 634                    name: t.name.clone(),
 635                    description: t.description.clone(),
 636                    parameters: serde_json::from_str(&t.input_schema).unwrap_or_default(),
 637                })
 638                .collect(),
 639        }])
 640    } else {
 641        None
 642    };
 643
 644    let tool_config = request.tool_choice.as_ref().map(|choice| {
 645        let mode = match choice {
 646            zed::LlmToolChoice::Auto => FunctionCallingMode::Auto,
 647            zed::LlmToolChoice::Any => FunctionCallingMode::Any,
 648            zed::LlmToolChoice::None => FunctionCallingMode::None,
 649        };
 650        ToolConfig {
 651            function_calling_config: FunctionCallingConfig {
 652                mode,
 653                allowed_function_names: None,
 654            },
 655        }
 656    });
 657
 658    let generation_config = Some(GenerationConfig {
 659        candidate_count: Some(1),
 660        stop_sequences: if request.stop_sequences.is_empty() {
 661            None
 662        } else {
 663            Some(request.stop_sequences.clone())
 664        },
 665        max_output_tokens: request.max_tokens.map(|t| t as usize),
 666        temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
 667        top_p: None,
 668        top_k: None,
 669        thinking_config: if request.thinking_allowed {
 670            // Check if this model has a custom thinking budget configured
 671            get_model_thinking_budget(model_id).map(|thinking_budget| ThinkingConfig {
 672                thinking_budget,
 673            })
 674        } else {
 675            None
 676        },
 677    });
 678
 679    Ok(GenerateContentRequest {
 680        model: ModelName {
 681            model_id: model_id.to_string(),
 682        },
 683        contents,
 684        system_instruction,
 685        generation_config,
 686        safety_settings: None,
 687        tools,
 688        tool_config,
 689    })
 690}
 691
 692fn convert_content_to_parts(content: &[LlmMessageContent]) -> Result<Vec<Part>, String> {
 693    let mut parts = Vec::new();
 694
 695    for item in content {
 696        match item {
 697            LlmMessageContent::Text(text) => {
 698                parts.push(Part::TextPart(TextPart { text: text.clone() }));
 699            }
 700            LlmMessageContent::Image(image) => {
 701                parts.push(Part::InlineDataPart(InlineDataPart {
 702                    inline_data: GenerativeContentBlob {
 703                        mime_type: "image/png".to_string(),
 704                        data: image.source.clone(),
 705                    },
 706                }));
 707            }
 708            LlmMessageContent::ToolUse(tool_use) => {
 709                // Normalize empty string signatures to None
 710                let thought_signature = tool_use
 711                    .thought_signature
 712                    .clone()
 713                    .filter(|s| !s.is_empty());
 714                parts.push(Part::FunctionCallPart(FunctionCallPart {
 715                    function_call: FunctionCall {
 716                        name: tool_use.name.clone(),
 717                        args: serde_json::from_str(&tool_use.input).unwrap_or_default(),
 718                    },
 719                    thought_signature,
 720                }));
 721            }
 722            LlmMessageContent::ToolResult(tool_result) => {
 723                match &tool_result.content {
 724                    zed::LlmToolResultContent::Text(text) => {
 725                        parts.push(Part::FunctionResponsePart(FunctionResponsePart {
 726                            function_response: FunctionResponse {
 727                                name: tool_result.tool_name.clone(),
 728                                response: serde_json::json!({ "output": text }),
 729                            },
 730                        }));
 731                    }
 732                    zed::LlmToolResultContent::Image(image) => {
 733                        // Send both the function response and the image inline
 734                        parts.push(Part::FunctionResponsePart(FunctionResponsePart {
 735                            function_response: FunctionResponse {
 736                                name: tool_result.tool_name.clone(),
 737                                response: serde_json::json!({ "output": "Tool responded with an image" }),
 738                            },
 739                        }));
 740                        parts.push(Part::InlineDataPart(InlineDataPart {
 741                            inline_data: GenerativeContentBlob {
 742                                mime_type: "image/png".to_string(),
 743                                data: image.source.clone(),
 744                            },
 745                        }));
 746                    }
 747                }
 748            }
 749            LlmMessageContent::Thinking(thinking) => {
 750                if let Some(signature) = &thinking.signature {
 751                    parts.push(Part::ThoughtPart(ThoughtPart {
 752                        thought: true,
 753                        thought_signature: signature.clone(),
 754                    }));
 755                }
 756            }
 757            LlmMessageContent::RedactedThinking(_) => {}
 758        }
 759    }
 760
 761    Ok(parts)
 762}
 763
 764// Data structures for Google AI API
 765
 766#[derive(Debug, Serialize, Deserialize)]
 767#[serde(rename_all = "camelCase")]
 768pub struct GenerateContentRequest {
 769    #[serde(default, skip_serializing_if = "ModelName::is_empty")]
 770    pub model: ModelName,
 771    pub contents: Vec<Content>,
 772    #[serde(skip_serializing_if = "Option::is_none")]
 773    pub system_instruction: Option<SystemInstruction>,
 774    #[serde(skip_serializing_if = "Option::is_none")]
 775    pub generation_config: Option<GenerationConfig>,
 776    #[serde(skip_serializing_if = "Option::is_none")]
 777    pub safety_settings: Option<Vec<SafetySetting>>,
 778    #[serde(skip_serializing_if = "Option::is_none")]
 779    pub tools: Option<Vec<Tool>>,
 780    #[serde(skip_serializing_if = "Option::is_none")]
 781    pub tool_config: Option<ToolConfig>,
 782}
 783
 784#[derive(Debug, Serialize, Deserialize)]
 785#[serde(rename_all = "camelCase")]
 786pub struct GenerateContentResponse {
 787    #[serde(skip_serializing_if = "Option::is_none")]
 788    pub candidates: Option<Vec<GenerateContentCandidate>>,
 789    #[serde(skip_serializing_if = "Option::is_none")]
 790    pub prompt_feedback: Option<PromptFeedback>,
 791    #[serde(skip_serializing_if = "Option::is_none")]
 792    pub usage_metadata: Option<UsageMetadata>,
 793}
 794
 795#[derive(Debug, Serialize, Deserialize)]
 796#[serde(rename_all = "camelCase")]
 797pub struct GenerateContentCandidate {
 798    #[serde(skip_serializing_if = "Option::is_none")]
 799    pub index: Option<usize>,
 800    pub content: Content,
 801    #[serde(skip_serializing_if = "Option::is_none")]
 802    pub finish_reason: Option<String>,
 803    #[serde(skip_serializing_if = "Option::is_none")]
 804    pub finish_message: Option<String>,
 805    #[serde(skip_serializing_if = "Option::is_none")]
 806    pub safety_ratings: Option<Vec<SafetyRating>>,
 807    #[serde(skip_serializing_if = "Option::is_none")]
 808    pub citation_metadata: Option<CitationMetadata>,
 809}
 810
 811#[derive(Debug, Serialize, Deserialize)]
 812#[serde(rename_all = "camelCase")]
 813pub struct Content {
 814    #[serde(default)]
 815    pub parts: Vec<Part>,
 816    pub role: Role,
 817}
 818
 819#[derive(Debug, Serialize, Deserialize)]
 820#[serde(rename_all = "camelCase")]
 821pub struct SystemInstruction {
 822    pub parts: Vec<Part>,
 823}
 824
 825#[derive(Debug, PartialEq, Deserialize, Serialize)]
 826#[serde(rename_all = "camelCase")]
 827pub enum Role {
 828    User,
 829    Model,
 830}
 831
 832#[derive(Debug, Serialize, Deserialize)]
 833#[serde(untagged)]
 834pub enum Part {
 835    TextPart(TextPart),
 836    InlineDataPart(InlineDataPart),
 837    FunctionCallPart(FunctionCallPart),
 838    FunctionResponsePart(FunctionResponsePart),
 839    ThoughtPart(ThoughtPart),
 840}
 841
 842#[derive(Debug, Serialize, Deserialize)]
 843#[serde(rename_all = "camelCase")]
 844pub struct TextPart {
 845    pub text: String,
 846}
 847
 848#[derive(Debug, Serialize, Deserialize)]
 849#[serde(rename_all = "camelCase")]
 850pub struct InlineDataPart {
 851    pub inline_data: GenerativeContentBlob,
 852}
 853
 854#[derive(Debug, Serialize, Deserialize)]
 855#[serde(rename_all = "camelCase")]
 856pub struct GenerativeContentBlob {
 857    pub mime_type: String,
 858    pub data: String,
 859}
 860
 861#[derive(Debug, Serialize, Deserialize)]
 862#[serde(rename_all = "camelCase")]
 863pub struct FunctionCallPart {
 864    pub function_call: FunctionCall,
 865    /// Thought signature returned by the model for function calls.
 866    /// Only present on the first function call in parallel call scenarios.
 867    #[serde(skip_serializing_if = "Option::is_none")]
 868    pub thought_signature: Option<String>,
 869}
 870
 871#[derive(Debug, Serialize, Deserialize)]
 872#[serde(rename_all = "camelCase")]
 873pub struct FunctionResponsePart {
 874    pub function_response: FunctionResponse,
 875}
 876
 877#[derive(Debug, Serialize, Deserialize)]
 878#[serde(rename_all = "camelCase")]
 879pub struct ThoughtPart {
 880    pub thought: bool,
 881    pub thought_signature: String,
 882}
 883
 884#[derive(Debug, Serialize, Deserialize)]
 885#[serde(rename_all = "camelCase")]
 886pub struct CitationSource {
 887    #[serde(skip_serializing_if = "Option::is_none")]
 888    pub start_index: Option<usize>,
 889    #[serde(skip_serializing_if = "Option::is_none")]
 890    pub end_index: Option<usize>,
 891    #[serde(skip_serializing_if = "Option::is_none")]
 892    pub uri: Option<String>,
 893    #[serde(skip_serializing_if = "Option::is_none")]
 894    pub license: Option<String>,
 895}
 896
 897#[derive(Debug, Serialize, Deserialize)]
 898#[serde(rename_all = "camelCase")]
 899pub struct CitationMetadata {
 900    pub citation_sources: Vec<CitationSource>,
 901}
 902
 903#[derive(Debug, Serialize, Deserialize)]
 904#[serde(rename_all = "camelCase")]
 905pub struct PromptFeedback {
 906    #[serde(skip_serializing_if = "Option::is_none")]
 907    pub block_reason: Option<String>,
 908    pub safety_ratings: Option<Vec<SafetyRating>>,
 909    #[serde(skip_serializing_if = "Option::is_none")]
 910    pub block_reason_message: Option<String>,
 911}
 912
 913#[derive(Debug, Clone, Serialize, Deserialize, Default)]
 914#[serde(rename_all = "camelCase")]
 915pub struct UsageMetadata {
 916    #[serde(skip_serializing_if = "Option::is_none")]
 917    pub prompt_token_count: Option<u64>,
 918    #[serde(skip_serializing_if = "Option::is_none")]
 919    pub cached_content_token_count: Option<u64>,
 920    #[serde(skip_serializing_if = "Option::is_none")]
 921    pub candidates_token_count: Option<u64>,
 922    #[serde(skip_serializing_if = "Option::is_none")]
 923    pub tool_use_prompt_token_count: Option<u64>,
 924    #[serde(skip_serializing_if = "Option::is_none")]
 925    pub thoughts_token_count: Option<u64>,
 926    #[serde(skip_serializing_if = "Option::is_none")]
 927    pub total_token_count: Option<u64>,
 928}
 929
 930#[derive(Debug, Serialize, Deserialize)]
 931#[serde(rename_all = "camelCase")]
 932pub struct ThinkingConfig {
 933    pub thinking_budget: u32,
 934}
 935
 936#[derive(Debug, Deserialize, Serialize)]
 937#[serde(rename_all = "camelCase")]
 938pub struct GenerationConfig {
 939    #[serde(skip_serializing_if = "Option::is_none")]
 940    pub candidate_count: Option<usize>,
 941    #[serde(skip_serializing_if = "Option::is_none")]
 942    pub stop_sequences: Option<Vec<String>>,
 943    #[serde(skip_serializing_if = "Option::is_none")]
 944    pub max_output_tokens: Option<usize>,
 945    #[serde(skip_serializing_if = "Option::is_none")]
 946    pub temperature: Option<f64>,
 947    #[serde(skip_serializing_if = "Option::is_none")]
 948    pub top_p: Option<f64>,
 949    #[serde(skip_serializing_if = "Option::is_none")]
 950    pub top_k: Option<usize>,
 951    #[serde(skip_serializing_if = "Option::is_none")]
 952    pub thinking_config: Option<ThinkingConfig>,
 953}
 954
 955#[derive(Debug, Serialize, Deserialize)]
 956#[serde(rename_all = "camelCase")]
 957pub struct SafetySetting {
 958    pub category: HarmCategory,
 959    pub threshold: HarmBlockThreshold,
 960}
 961
 962#[derive(Debug, Serialize, Deserialize)]
 963pub enum HarmCategory {
 964    #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
 965    Unspecified,
 966    #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
 967    Derogatory,
 968    #[serde(rename = "HARM_CATEGORY_TOXICITY")]
 969    Toxicity,
 970    #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
 971    Violence,
 972    #[serde(rename = "HARM_CATEGORY_SEXUAL")]
 973    Sexual,
 974    #[serde(rename = "HARM_CATEGORY_MEDICAL")]
 975    Medical,
 976    #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
 977    Dangerous,
 978    #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
 979    Harassment,
 980    #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
 981    HateSpeech,
 982    #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
 983    SexuallyExplicit,
 984    #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
 985    DangerousContent,
 986}
 987
 988#[derive(Debug, Serialize, Deserialize)]
 989#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
 990pub enum HarmBlockThreshold {
 991    #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
 992    Unspecified,
 993    BlockLowAndAbove,
 994    BlockMediumAndAbove,
 995    BlockOnlyHigh,
 996    BlockNone,
 997}
 998
 999#[derive(Debug, Serialize, Deserialize)]
1000#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1001pub enum HarmProbability {
1002    #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
1003    Unspecified,
1004    Negligible,
1005    Low,
1006    Medium,
1007    High,
1008}
1009
1010#[derive(Debug, Serialize, Deserialize)]
1011#[serde(rename_all = "camelCase")]
1012pub struct SafetyRating {
1013    pub category: HarmCategory,
1014    pub probability: HarmProbability,
1015}
1016
1017#[derive(Debug, Serialize, Deserialize)]
1018#[serde(rename_all = "camelCase")]
1019pub struct CountTokensRequest {
1020    pub generate_content_request: GenerateContentRequest,
1021}
1022
1023#[derive(Debug, Serialize, Deserialize)]
1024#[serde(rename_all = "camelCase")]
1025pub struct CountTokensResponse {
1026    pub total_tokens: u64,
1027}
1028
1029#[derive(Debug, Serialize, Deserialize)]
1030pub struct FunctionCall {
1031    pub name: String,
1032    pub args: serde_json::Value,
1033}
1034
1035#[derive(Debug, Serialize, Deserialize)]
1036pub struct FunctionResponse {
1037    pub name: String,
1038    pub response: serde_json::Value,
1039}
1040
1041#[derive(Debug, Serialize, Deserialize)]
1042#[serde(rename_all = "camelCase")]
1043pub struct Tool {
1044    pub function_declarations: Vec<FunctionDeclaration>,
1045}
1046
1047#[derive(Debug, Serialize, Deserialize)]
1048#[serde(rename_all = "camelCase")]
1049pub struct ToolConfig {
1050    pub function_calling_config: FunctionCallingConfig,
1051}
1052
1053#[derive(Debug, Serialize, Deserialize)]
1054#[serde(rename_all = "camelCase")]
1055pub struct FunctionCallingConfig {
1056    pub mode: FunctionCallingMode,
1057    #[serde(skip_serializing_if = "Option::is_none")]
1058    pub allowed_function_names: Option<Vec<String>>,
1059}
1060
1061#[derive(Debug, Serialize, Deserialize)]
1062#[serde(rename_all = "lowercase")]
1063pub enum FunctionCallingMode {
1064    Auto,
1065    Any,
1066    None,
1067}
1068
1069#[derive(Debug, Serialize, Deserialize)]
1070pub struct FunctionDeclaration {
1071    pub name: String,
1072    pub description: String,
1073    pub parameters: serde_json::Value,
1074}
1075
1076#[derive(Debug, Default)]
1077pub struct ModelName {
1078    pub model_id: String,
1079}
1080
1081impl ModelName {
1082    pub fn is_empty(&self) -> bool {
1083        self.model_id.is_empty()
1084    }
1085}
1086
1087const MODEL_NAME_PREFIX: &str = "models/";
1088
1089/// Google API error response structure
1090#[derive(Debug, Deserialize)]
1091pub struct ApiErrorResponse {
1092    pub error: ApiError,
1093}
1094
1095#[derive(Debug, Deserialize)]
1096pub struct ApiError {
1097    pub code: Option<u16>,
1098    pub message: Option<String>,
1099    pub status: Option<String>,
1100}
1101
1102impl Serialize for ModelName {
1103    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1104    where
1105        S: Serializer,
1106    {
1107        serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id))
1108    }
1109}
1110
1111impl<'de> Deserialize<'de> for ModelName {
1112    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1113    where
1114        D: Deserializer<'de>,
1115    {
1116        let string = String::deserialize(deserializer)?;
1117        if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) {
1118            Ok(Self {
1119                model_id: id.to_string(),
1120            })
1121        } else {
1122            Err(serde::de::Error::custom(format!(
1123                "Expected model name to begin with {}, got: {}",
1124                MODEL_NAME_PREFIX, string
1125            )))
1126        }
1127    }
1128}