google_ai.rs

   1use std::collections::HashMap;
   2
   3use serde::{Deserialize, Deserializer, Serialize, Serializer};
   4use zed_extension_api::{
   5    self as zed, http_client::HttpMethod, http_client::HttpRequest,
   6    llm_get_env_var, llm_get_provider_settings,
   7    LlmCacheConfiguration, LlmCompletionEvent, LlmCompletionRequest, LlmCustomModelConfig,
   8    LlmMessageContent, LlmMessageRole, LlmModelCapabilities, LlmModelInfo, LlmProviderInfo,
   9    LlmStopReason, LlmThinkingContent, LlmTokenUsage, LlmToolInputFormat, LlmToolUse,
  10};
  11
  12pub const DEFAULT_API_URL: &str = "https://generativelanguage.googleapis.com";
  13
  14fn get_api_url() -> String {
  15    llm_get_provider_settings(PROVIDER_ID)
  16        .and_then(|s| s.api_url)
  17        .unwrap_or_else(|| DEFAULT_API_URL.to_string())
  18}
  19
  20fn get_custom_models() -> Vec<LlmCustomModelConfig> {
  21    llm_get_provider_settings(PROVIDER_ID)
  22        .map(|s| s.available_models)
  23        .unwrap_or_default()
  24}
  25
  26fn stream_generate_content(
  27    model_id: &str,
  28    request: &LlmCompletionRequest,
  29    streams: &mut HashMap<String, StreamState>,
  30    next_stream_id: &mut u64,
  31) -> Result<String, String> {
  32    let api_key = get_api_key().ok_or_else(|| "API key not configured".to_string())?;
  33
  34    let generate_content_request = build_generate_content_request(model_id, request)?;
  35    validate_generate_content_request(&generate_content_request)?;
  36
  37    let api_url = get_api_url();
  38    let uri = format!(
  39        "{}/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
  40        api_url, model_id, api_key
  41    );
  42
  43    let body = serde_json::to_vec(&generate_content_request)
  44        .map_err(|e| format!("Failed to serialize request: {}", e))?;
  45
  46    let http_request = HttpRequest::builder()
  47        .method(HttpMethod::Post)
  48        .url(&uri)
  49        .header("Content-Type", "application/json")
  50        .body(body)
  51        .build()?;
  52
  53    let response_stream = http_request.fetch_stream()?;
  54
  55    let stream_id = format!("stream-{}", *next_stream_id);
  56    *next_stream_id += 1;
  57
  58    streams.insert(
  59        stream_id.clone(),
  60        StreamState {
  61            response_stream,
  62            buffer: String::new(),
  63            usage: None,
  64            pending_events: Vec::new(),
  65            wants_to_use_tool: false,
  66        },
  67    );
  68
  69    Ok(stream_id)
  70}
  71
  72fn count_tokens(model_id: &str, request: &LlmCompletionRequest) -> Result<u64, String> {
  73    let api_key = get_api_key().ok_or_else(|| "API key not configured".to_string())?;
  74
  75    let generate_content_request = build_generate_content_request(model_id, request)?;
  76    validate_generate_content_request(&generate_content_request)?;
  77    let count_request = CountTokensRequest {
  78        generate_content_request,
  79    };
  80
  81    let api_url = get_api_url();
  82    let uri = format!(
  83        "{}/v1beta/models/{}:countTokens?key={}",
  84        api_url, model_id, api_key
  85    );
  86
  87    let body = serde_json::to_vec(&count_request)
  88        .map_err(|e| format!("Failed to serialize request: {}", e))?;
  89
  90    let http_request = HttpRequest::builder()
  91        .method(HttpMethod::Post)
  92        .url(&uri)
  93        .header("Content-Type", "application/json")
  94        .body(body)
  95        .build()?;
  96
  97    let response = http_request.fetch()?;
  98    let response_body: CountTokensResponse = serde_json::from_slice(&response.body)
  99        .map_err(|e| format!("Failed to parse response: {}", e))?;
 100
 101    Ok(response_body.total_tokens)
 102}
 103
 104fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<(), String> {
 105    if request.model.is_empty() {
 106        return Err("Model must be specified".to_string());
 107    }
 108
 109    if request.contents.is_empty() {
 110        return Err("Request must contain at least one content item".to_string());
 111    }
 112
 113    if let Some(user_content) = request
 114        .contents
 115        .iter()
 116        .find(|content| content.role == Role::User)
 117    {
 118        if user_content.parts.is_empty() {
 119            return Err("User content must contain at least one part".to_string());
 120        }
 121    }
 122
 123    Ok(())
 124}
 125
 126// Extension implementation
 127
 128const PROVIDER_ID: &str = "google-ai";
 129const PROVIDER_NAME: &str = "Google AI";
 130
 131struct GoogleAiExtension {
 132    streams: HashMap<String, StreamState>,
 133    next_stream_id: u64,
 134}
 135
 136struct StreamState {
 137    response_stream: zed::http_client::HttpResponseStream,
 138    buffer: String,
 139    usage: Option<UsageMetadata>,
 140    pending_events: Vec<LlmCompletionEvent>,
 141    wants_to_use_tool: bool,
 142}
 143
 144impl zed::Extension for GoogleAiExtension {
 145    fn new() -> Self {
 146        Self {
 147            streams: HashMap::new(),
 148            next_stream_id: 0,
 149        }
 150    }
 151
 152    fn llm_providers(&self) -> Vec<LlmProviderInfo> {
 153        vec![LlmProviderInfo {
 154            id: PROVIDER_ID.to_string(),
 155            name: PROVIDER_NAME.to_string(),
 156            icon: Some("icons/google-ai.svg".to_string()),
 157        }]
 158    }
 159
 160    fn llm_provider_models(&self, provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
 161        if provider_id != PROVIDER_ID {
 162            return Err(format!("Unknown provider: {}", provider_id));
 163        }
 164        Ok(get_models())
 165    }
 166
 167    fn llm_provider_settings_markdown(&self, provider_id: &str) -> Option<String> {
 168        if provider_id != PROVIDER_ID {
 169            return None;
 170        }
 171
 172        Some(
 173            r#"## Google AI Setup
 174
 175To use Google AI models in Zed, you need a Gemini API key.
 176
 1771. Go to [Google AI Studio](https://aistudio.google.com/apikey)
 1782. Create or select a project
 1793. Generate an API key
 1804. Set the `GEMINI_API_KEY` or `GOOGLE_AI_API_KEY` environment variable
 181
 182You can set this in your shell profile or use a `.envrc` file with [direnv](https://direnv.net/).
 183"#
 184            .to_string(),
 185        )
 186    }
 187
 188    fn llm_provider_is_authenticated(&self, provider_id: &str) -> bool {
 189        if provider_id != PROVIDER_ID {
 190            return false;
 191        }
 192        get_api_key().is_some()
 193    }
 194
 195    fn llm_provider_reset_credentials(&mut self, provider_id: &str) -> Result<(), String> {
 196        if provider_id != PROVIDER_ID {
 197            return Err(format!("Unknown provider: {}", provider_id));
 198        }
 199        Ok(())
 200    }
 201
 202    fn llm_count_tokens(
 203        &self,
 204        provider_id: &str,
 205        model_id: &str,
 206        request: &LlmCompletionRequest,
 207    ) -> Result<u64, String> {
 208        if provider_id != PROVIDER_ID {
 209            return Err(format!("Unknown provider: {}", provider_id));
 210        }
 211        count_tokens(model_id, request)
 212    }
 213
 214    fn llm_stream_completion_start(
 215        &mut self,
 216        provider_id: &str,
 217        model_id: &str,
 218        request: &LlmCompletionRequest,
 219    ) -> Result<String, String> {
 220        if provider_id != PROVIDER_ID {
 221            return Err(format!("Unknown provider: {}", provider_id));
 222        }
 223        stream_generate_content(model_id, request, &mut self.streams, &mut self.next_stream_id)
 224    }
 225
 226    fn llm_stream_completion_next(
 227        &mut self,
 228        stream_id: &str,
 229    ) -> Result<Option<LlmCompletionEvent>, String> {
 230        stream_generate_content_next(stream_id, &mut self.streams)
 231    }
 232
 233    fn llm_stream_completion_close(&mut self, stream_id: &str) {
 234        self.streams.remove(stream_id);
 235    }
 236
 237    fn llm_cache_configuration(
 238        &self,
 239        provider_id: &str,
 240        _model_id: &str,
 241    ) -> Option<LlmCacheConfiguration> {
 242        if provider_id != PROVIDER_ID {
 243            return None;
 244        }
 245
 246        Some(LlmCacheConfiguration {
 247            max_cache_anchors: 1,
 248            should_cache_tool_definitions: false,
 249            min_total_token_count: 32768,
 250        })
 251    }
 252}
 253
 254zed::register_extension!(GoogleAiExtension);
 255
 256// Helper functions
 257
 258fn get_api_key() -> Option<String> {
 259    llm_get_env_var("GEMINI_API_KEY").or_else(|| llm_get_env_var("GOOGLE_AI_API_KEY"))
 260}
 261
 262fn get_default_models() -> Vec<LlmModelInfo> {
 263    vec![
 264        LlmModelInfo {
 265            id: "gemini-2.5-flash-lite".to_string(),
 266            name: "Gemini 2.5 Flash-Lite".to_string(),
 267            max_token_count: 1_048_576,
 268            max_output_tokens: Some(65_536),
 269            capabilities: LlmModelCapabilities {
 270                supports_images: true,
 271                supports_tools: true,
 272                supports_tool_choice_auto: true,
 273                supports_tool_choice_any: true,
 274                supports_tool_choice_none: true,
 275                supports_thinking: true,
 276                tool_input_format: LlmToolInputFormat::JsonSchema,
 277            },
 278            is_default: false,
 279            is_default_fast: true,
 280        },
 281        LlmModelInfo {
 282            id: "gemini-2.5-flash".to_string(),
 283            name: "Gemini 2.5 Flash".to_string(),
 284            max_token_count: 1_048_576,
 285            max_output_tokens: Some(65_536),
 286            capabilities: LlmModelCapabilities {
 287                supports_images: true,
 288                supports_tools: true,
 289                supports_tool_choice_auto: true,
 290                supports_tool_choice_any: true,
 291                supports_tool_choice_none: true,
 292                supports_thinking: true,
 293                tool_input_format: LlmToolInputFormat::JsonSchema,
 294            },
 295            is_default: true,
 296            is_default_fast: false,
 297        },
 298        LlmModelInfo {
 299            id: "gemini-2.5-pro".to_string(),
 300            name: "Gemini 2.5 Pro".to_string(),
 301            max_token_count: 1_048_576,
 302            max_output_tokens: Some(65_536),
 303            capabilities: LlmModelCapabilities {
 304                supports_images: true,
 305                supports_tools: true,
 306                supports_tool_choice_auto: true,
 307                supports_tool_choice_any: true,
 308                supports_tool_choice_none: true,
 309                supports_thinking: true,
 310                tool_input_format: LlmToolInputFormat::JsonSchema,
 311            },
 312            is_default: false,
 313            is_default_fast: false,
 314        },
 315        LlmModelInfo {
 316            id: "gemini-3-pro-preview".to_string(),
 317            name: "Gemini 3 Pro".to_string(),
 318            max_token_count: 1_048_576,
 319            max_output_tokens: Some(65_536),
 320            capabilities: LlmModelCapabilities {
 321                supports_images: true,
 322                supports_tools: true,
 323                supports_tool_choice_auto: true,
 324                supports_tool_choice_any: true,
 325                supports_tool_choice_none: true,
 326                supports_thinking: true,
 327                tool_input_format: LlmToolInputFormat::JsonSchema,
 328            },
 329            is_default: false,
 330            is_default_fast: false,
 331        },
 332        LlmModelInfo {
 333            id: "gemini-3-flash-preview".to_string(),
 334            name: "Gemini 3 Flash".to_string(),
 335            max_token_count: 1_048_576,
 336            max_output_tokens: Some(65_536),
 337            capabilities: LlmModelCapabilities {
 338                supports_images: true,
 339                supports_tools: true,
 340                supports_tool_choice_auto: true,
 341                supports_tool_choice_any: true,
 342                supports_tool_choice_none: true,
 343                supports_thinking: true,
 344                tool_input_format: LlmToolInputFormat::JsonSchema,
 345            },
 346            is_default: false,
 347            is_default_fast: false,
 348        },
 349    ]
 350}
 351
 352/// Model aliases for backward compatibility with old model names.
 353/// Maps old names to canonical model IDs.
 354fn get_model_aliases() -> Vec<(&'static str, &'static str)> {
 355    vec![
 356        // Gemini 2.5 Flash-Lite aliases
 357        ("gemini-2.5-flash-lite-preview-06-17", "gemini-2.5-flash-lite"),
 358        ("gemini-2.0-flash-lite-preview", "gemini-2.5-flash-lite"),
 359        // Gemini 2.5 Flash aliases
 360        ("gemini-2.0-flash-thinking-exp", "gemini-2.5-flash"),
 361        ("gemini-2.5-flash-preview-04-17", "gemini-2.5-flash"),
 362        ("gemini-2.5-flash-preview-05-20", "gemini-2.5-flash"),
 363        ("gemini-2.5-flash-preview-latest", "gemini-2.5-flash"),
 364        ("gemini-2.0-flash", "gemini-2.5-flash"),
 365        // Gemini 2.5 Pro aliases
 366        ("gemini-2.0-pro-exp", "gemini-2.5-pro"),
 367        ("gemini-2.5-pro-preview-latest", "gemini-2.5-pro"),
 368        ("gemini-2.5-pro-exp-03-25", "gemini-2.5-pro"),
 369        ("gemini-2.5-pro-preview-03-25", "gemini-2.5-pro"),
 370        ("gemini-2.5-pro-preview-05-06", "gemini-2.5-pro"),
 371        ("gemini-2.5-pro-preview-06-05", "gemini-2.5-pro"),
 372    ]
 373}
 374
 375fn get_models() -> Vec<LlmModelInfo> {
 376    let mut models: HashMap<String, LlmModelInfo> = HashMap::new();
 377
 378    // Add default models
 379    for model in get_default_models() {
 380        models.insert(model.id.clone(), model);
 381    }
 382
 383    // Add aliases as separate model entries (pointing to the same underlying model)
 384    for (alias, canonical_id) in get_model_aliases() {
 385        if let Some(canonical_model) = models.get(canonical_id) {
 386            let mut alias_model = canonical_model.clone();
 387            alias_model.id = alias.to_string();
 388            alias_model.is_default = false;
 389            alias_model.is_default_fast = false;
 390            models.insert(alias.to_string(), alias_model);
 391        }
 392    }
 393
 394    // Add/override with custom models from settings
 395    for custom_model in get_custom_models() {
 396        let model = LlmModelInfo {
 397            id: custom_model.name.clone(),
 398            name: custom_model.display_name.unwrap_or(custom_model.name.clone()),
 399            max_token_count: custom_model.max_tokens,
 400            max_output_tokens: custom_model.max_output_tokens,
 401            capabilities: LlmModelCapabilities {
 402                supports_images: true,
 403                supports_tools: true,
 404                supports_tool_choice_auto: true,
 405                supports_tool_choice_any: true,
 406                supports_tool_choice_none: true,
 407                supports_thinking: custom_model.thinking_budget.is_some(),
 408                tool_input_format: LlmToolInputFormat::JsonSchema,
 409            },
 410            is_default: false,
 411            is_default_fast: false,
 412        };
 413        models.insert(custom_model.name, model);
 414    }
 415
 416    models.into_values().collect()
 417}
 418
 419/// Get the thinking budget for a specific model from custom settings.
 420fn get_model_thinking_budget(model_id: &str) -> Option<u32> {
 421    get_custom_models()
 422        .into_iter()
 423        .find(|m| m.name == model_id)
 424        .and_then(|m| m.thinking_budget)
 425}
 426
 427fn stream_generate_content_next(
 428    stream_id: &str,
 429    streams: &mut HashMap<String, StreamState>,
 430) -> Result<Option<LlmCompletionEvent>, String> {
 431    let state = streams
 432        .get_mut(stream_id)
 433        .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
 434
 435    loop {
 436        // Return any pending events first
 437        if let Some(event) = state.pending_events.pop() {
 438            return Ok(Some(event));
 439        }
 440
 441        if let Some(newline_pos) = state.buffer.find('\n') {
 442            let line = state.buffer[..newline_pos].to_string();
 443            state.buffer = state.buffer[newline_pos + 1..].to_string();
 444
 445            if let Some(data) = line.strip_prefix("data: ") {
 446                if data.trim().is_empty() {
 447                    continue;
 448                }
 449
 450                let response: GenerateContentResponse = match serde_json::from_str(data) {
 451                    Ok(response) => response,
 452                    Err(parse_error) => {
 453                        // Try to parse as an API error response
 454                        if let Ok(api_error) = serde_json::from_str::<ApiErrorResponse>(data) {
 455                            let error_msg = api_error
 456                                .error
 457                                .message
 458                                .unwrap_or_else(|| "Unknown API error".to_string());
 459                            let status = api_error.error.status.unwrap_or_default();
 460                            let code = api_error.error.code.unwrap_or(0);
 461                            return Err(format!(
 462                                "Google AI API error ({}): {} [status: {}]",
 463                                code, error_msg, status
 464                            ));
 465                        }
 466                        // If it's not an error response, return the parse error
 467                        return Err(format!(
 468                            "Failed to parse SSE data: {} - {}",
 469                            parse_error, data
 470                        ));
 471                    }
 472                };
 473
 474                // Handle prompt feedback (blocked prompts)
 475                if let Some(ref prompt_feedback) = response.prompt_feedback {
 476                    if let Some(ref block_reason) = prompt_feedback.block_reason {
 477                        let _stop_reason = match block_reason.as_str() {
 478                            "SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT"
 479                            | "IMAGE_SAFETY" => LlmStopReason::Refusal,
 480                            _ => LlmStopReason::Refusal,
 481                        };
 482                        return Ok(Some(LlmCompletionEvent::Stop(LlmStopReason::Refusal)));
 483                    }
 484                }
 485
 486                // Send usage updates immediately when received
 487                if let Some(ref usage) = response.usage_metadata {
 488                    let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
 489                    let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
 490                    let input_tokens = prompt_tokens.saturating_sub(cached_tokens);
 491                    state.pending_events.push(LlmCompletionEvent::Usage(LlmTokenUsage {
 492                        input_tokens,
 493                        output_tokens: usage.candidates_token_count.unwrap_or(0),
 494                        cache_creation_input_tokens: None,
 495                        cache_read_input_tokens: Some(cached_tokens).filter(|&c| c > 0),
 496                    }));
 497                    state.usage = Some(usage.clone());
 498                }
 499
 500                if let Some(candidates) = response.candidates {
 501                    for candidate in candidates {
 502                        for part in candidate.content.parts {
 503                            match part {
 504                                Part::TextPart(text_part) => {
 505                                    return Ok(Some(LlmCompletionEvent::Text(text_part.text)));
 506                                }
 507                                Part::ThoughtPart(thought_part) => {
 508                                    return Ok(Some(LlmCompletionEvent::Thinking(
 509                                        LlmThinkingContent {
 510                                            text: "(Encrypted thought)".to_string(),
 511                                            signature: Some(thought_part.thought_signature),
 512                                        },
 513                                    )));
 514                                }
 515                                Part::FunctionCallPart(fc_part) => {
 516                                    state.wants_to_use_tool = true;
 517                                    // Normalize empty string signatures to None
 518                                    let thought_signature =
 519                                        fc_part.thought_signature.filter(|s| !s.is_empty());
 520                                    return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
 521                                        id: fc_part.function_call.name.clone(),
 522                                        name: fc_part.function_call.name,
 523                                        input: serde_json::to_string(&fc_part.function_call.args)
 524                                            .unwrap_or_default(),
 525                                        is_input_complete: true,
 526                                        thought_signature,
 527                                    })));
 528                                }
 529                                _ => {}
 530                            }
 531                        }
 532
 533                        if let Some(finish_reason) = candidate.finish_reason {
 534                            // Even when Gemini wants to use a Tool, the API
 535                            // responds with `finish_reason: STOP`, so we check
 536                            // wants_to_use_tool to override
 537                            let stop_reason = if state.wants_to_use_tool {
 538                                LlmStopReason::ToolUse
 539                            } else {
 540                                match finish_reason.as_str() {
 541                                    "STOP" => LlmStopReason::EndTurn,
 542                                    "MAX_TOKENS" => LlmStopReason::MaxTokens,
 543                                    "TOOL_USE" | "FUNCTION_CALL" => LlmStopReason::ToolUse,
 544                                    "SAFETY" | "RECITATION" | "OTHER" => LlmStopReason::Refusal,
 545                                    _ => LlmStopReason::EndTurn,
 546                                }
 547                            };
 548
 549                            return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
 550                        }
 551                    }
 552                }
 553            }
 554
 555            continue;
 556        }
 557
 558        // Check if the buffer contains a non-SSE error response (no "data: " prefix)
 559        // This can happen when Google returns an immediate error without streaming
 560        if !state.buffer.is_empty()
 561            && !state.buffer.contains("data: ")
 562            && state.buffer.contains("\"error\"")
 563        {
 564            // Try to parse the entire buffer as an error response
 565            if let Ok(api_error) = serde_json::from_str::<ApiErrorResponse>(&state.buffer) {
 566                let error_msg = api_error
 567                    .error
 568                    .message
 569                    .unwrap_or_else(|| "Unknown API error".to_string());
 570                let status = api_error.error.status.unwrap_or_default();
 571                let code = api_error.error.code.unwrap_or(0);
 572                streams.remove(stream_id);
 573                return Err(format!(
 574                    "Google AI API error ({}): {} [status: {}]",
 575                    code, error_msg, status
 576                ));
 577            }
 578        }
 579
 580        match state.response_stream.next_chunk() {
 581            Ok(Some(chunk)) => {
 582                let chunk_str = String::from_utf8_lossy(&chunk);
 583                state.buffer.push_str(&chunk_str);
 584            }
 585            Ok(None) => {
 586                streams.remove(stream_id);
 587                return Ok(None);
 588            }
 589            Err(e) => {
 590                streams.remove(stream_id);
 591                return Err(e);
 592            }
 593        }
 594    }
 595}
 596
 597fn build_generate_content_request(
 598    model_id: &str,
 599    request: &LlmCompletionRequest,
 600) -> Result<GenerateContentRequest, String> {
 601    let mut contents: Vec<Content> = Vec::new();
 602    let mut system_instruction: Option<SystemInstruction> = None;
 603
 604    for message in &request.messages {
 605        match message.role {
 606            LlmMessageRole::System => {
 607                let parts = convert_content_to_parts(&message.content)?;
 608                system_instruction = Some(SystemInstruction { parts });
 609            }
 610            LlmMessageRole::User | LlmMessageRole::Assistant => {
 611                let role = match message.role {
 612                    LlmMessageRole::User => Role::User,
 613                    LlmMessageRole::Assistant => Role::Model,
 614                    _ => continue,
 615                };
 616                let parts = convert_content_to_parts(&message.content)?;
 617                contents.push(Content { parts, role });
 618            }
 619        }
 620    }
 621
 622    let tools = if !request.tools.is_empty() {
 623        Some(vec![Tool {
 624            function_declarations: request
 625                .tools
 626                .iter()
 627                .map(|t| FunctionDeclaration {
 628                    name: t.name.clone(),
 629                    description: t.description.clone(),
 630                    parameters: serde_json::from_str(&t.input_schema).unwrap_or_default(),
 631                })
 632                .collect(),
 633        }])
 634    } else {
 635        None
 636    };
 637
 638    let tool_config = request.tool_choice.as_ref().map(|choice| {
 639        let mode = match choice {
 640            zed::LlmToolChoice::Auto => FunctionCallingMode::Auto,
 641            zed::LlmToolChoice::Any => FunctionCallingMode::Any,
 642            zed::LlmToolChoice::None => FunctionCallingMode::None,
 643        };
 644        ToolConfig {
 645            function_calling_config: FunctionCallingConfig {
 646                mode,
 647                allowed_function_names: None,
 648            },
 649        }
 650    });
 651
 652    let generation_config = Some(GenerationConfig {
 653        candidate_count: Some(1),
 654        stop_sequences: if request.stop_sequences.is_empty() {
 655            None
 656        } else {
 657            Some(request.stop_sequences.clone())
 658        },
 659        max_output_tokens: request.max_tokens.map(|t| t as usize),
 660        temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
 661        top_p: None,
 662        top_k: None,
 663        thinking_config: if request.thinking_allowed {
 664            // Check if this model has a custom thinking budget configured
 665            get_model_thinking_budget(model_id).map(|thinking_budget| ThinkingConfig {
 666                thinking_budget,
 667            })
 668        } else {
 669            None
 670        },
 671    });
 672
 673    Ok(GenerateContentRequest {
 674        model: ModelName {
 675            model_id: model_id.to_string(),
 676        },
 677        contents,
 678        system_instruction,
 679        generation_config,
 680        safety_settings: None,
 681        tools,
 682        tool_config,
 683    })
 684}
 685
 686fn convert_content_to_parts(content: &[LlmMessageContent]) -> Result<Vec<Part>, String> {
 687    let mut parts = Vec::new();
 688
 689    for item in content {
 690        match item {
 691            LlmMessageContent::Text(text) => {
 692                parts.push(Part::TextPart(TextPart { text: text.clone() }));
 693            }
 694            LlmMessageContent::Image(image) => {
 695                parts.push(Part::InlineDataPart(InlineDataPart {
 696                    inline_data: GenerativeContentBlob {
 697                        mime_type: "image/png".to_string(),
 698                        data: image.source.clone(),
 699                    },
 700                }));
 701            }
 702            LlmMessageContent::ToolUse(tool_use) => {
 703                // Normalize empty string signatures to None
 704                let thought_signature = tool_use
 705                    .thought_signature
 706                    .clone()
 707                    .filter(|s| !s.is_empty());
 708                parts.push(Part::FunctionCallPart(FunctionCallPart {
 709                    function_call: FunctionCall {
 710                        name: tool_use.name.clone(),
 711                        args: serde_json::from_str(&tool_use.input).unwrap_or_default(),
 712                    },
 713                    thought_signature,
 714                }));
 715            }
 716            LlmMessageContent::ToolResult(tool_result) => {
 717                match &tool_result.content {
 718                    zed::LlmToolResultContent::Text(text) => {
 719                        parts.push(Part::FunctionResponsePart(FunctionResponsePart {
 720                            function_response: FunctionResponse {
 721                                name: tool_result.tool_name.clone(),
 722                                response: serde_json::json!({ "output": text }),
 723                            },
 724                        }));
 725                    }
 726                    zed::LlmToolResultContent::Image(image) => {
 727                        // Send both the function response and the image inline
 728                        parts.push(Part::FunctionResponsePart(FunctionResponsePart {
 729                            function_response: FunctionResponse {
 730                                name: tool_result.tool_name.clone(),
 731                                response: serde_json::json!({ "output": "Tool responded with an image" }),
 732                            },
 733                        }));
 734                        parts.push(Part::InlineDataPart(InlineDataPart {
 735                            inline_data: GenerativeContentBlob {
 736                                mime_type: "image/png".to_string(),
 737                                data: image.source.clone(),
 738                            },
 739                        }));
 740                    }
 741                }
 742            }
 743            LlmMessageContent::Thinking(thinking) => {
 744                if let Some(signature) = &thinking.signature {
 745                    parts.push(Part::ThoughtPart(ThoughtPart {
 746                        thought: true,
 747                        thought_signature: signature.clone(),
 748                    }));
 749                }
 750            }
 751            LlmMessageContent::RedactedThinking(_) => {}
 752        }
 753    }
 754
 755    Ok(parts)
 756}
 757
 758// Data structures for Google AI API
 759
 760#[derive(Debug, Serialize, Deserialize)]
 761#[serde(rename_all = "camelCase")]
 762pub struct GenerateContentRequest {
 763    #[serde(default, skip_serializing_if = "ModelName::is_empty")]
 764    pub model: ModelName,
 765    pub contents: Vec<Content>,
 766    #[serde(skip_serializing_if = "Option::is_none")]
 767    pub system_instruction: Option<SystemInstruction>,
 768    #[serde(skip_serializing_if = "Option::is_none")]
 769    pub generation_config: Option<GenerationConfig>,
 770    #[serde(skip_serializing_if = "Option::is_none")]
 771    pub safety_settings: Option<Vec<SafetySetting>>,
 772    #[serde(skip_serializing_if = "Option::is_none")]
 773    pub tools: Option<Vec<Tool>>,
 774    #[serde(skip_serializing_if = "Option::is_none")]
 775    pub tool_config: Option<ToolConfig>,
 776}
 777
 778#[derive(Debug, Serialize, Deserialize)]
 779#[serde(rename_all = "camelCase")]
 780pub struct GenerateContentResponse {
 781    #[serde(skip_serializing_if = "Option::is_none")]
 782    pub candidates: Option<Vec<GenerateContentCandidate>>,
 783    #[serde(skip_serializing_if = "Option::is_none")]
 784    pub prompt_feedback: Option<PromptFeedback>,
 785    #[serde(skip_serializing_if = "Option::is_none")]
 786    pub usage_metadata: Option<UsageMetadata>,
 787}
 788
 789#[derive(Debug, Serialize, Deserialize)]
 790#[serde(rename_all = "camelCase")]
 791pub struct GenerateContentCandidate {
 792    #[serde(skip_serializing_if = "Option::is_none")]
 793    pub index: Option<usize>,
 794    pub content: Content,
 795    #[serde(skip_serializing_if = "Option::is_none")]
 796    pub finish_reason: Option<String>,
 797    #[serde(skip_serializing_if = "Option::is_none")]
 798    pub finish_message: Option<String>,
 799    #[serde(skip_serializing_if = "Option::is_none")]
 800    pub safety_ratings: Option<Vec<SafetyRating>>,
 801    #[serde(skip_serializing_if = "Option::is_none")]
 802    pub citation_metadata: Option<CitationMetadata>,
 803}
 804
 805#[derive(Debug, Serialize, Deserialize)]
 806#[serde(rename_all = "camelCase")]
 807pub struct Content {
 808    #[serde(default)]
 809    pub parts: Vec<Part>,
 810    pub role: Role,
 811}
 812
 813#[derive(Debug, Serialize, Deserialize)]
 814#[serde(rename_all = "camelCase")]
 815pub struct SystemInstruction {
 816    pub parts: Vec<Part>,
 817}
 818
 819#[derive(Debug, PartialEq, Deserialize, Serialize)]
 820#[serde(rename_all = "camelCase")]
 821pub enum Role {
 822    User,
 823    Model,
 824}
 825
 826#[derive(Debug, Serialize, Deserialize)]
 827#[serde(untagged)]
 828pub enum Part {
 829    TextPart(TextPart),
 830    InlineDataPart(InlineDataPart),
 831    FunctionCallPart(FunctionCallPart),
 832    FunctionResponsePart(FunctionResponsePart),
 833    ThoughtPart(ThoughtPart),
 834}
 835
 836#[derive(Debug, Serialize, Deserialize)]
 837#[serde(rename_all = "camelCase")]
 838pub struct TextPart {
 839    pub text: String,
 840}
 841
 842#[derive(Debug, Serialize, Deserialize)]
 843#[serde(rename_all = "camelCase")]
 844pub struct InlineDataPart {
 845    pub inline_data: GenerativeContentBlob,
 846}
 847
 848#[derive(Debug, Serialize, Deserialize)]
 849#[serde(rename_all = "camelCase")]
 850pub struct GenerativeContentBlob {
 851    pub mime_type: String,
 852    pub data: String,
 853}
 854
 855#[derive(Debug, Serialize, Deserialize)]
 856#[serde(rename_all = "camelCase")]
 857pub struct FunctionCallPart {
 858    pub function_call: FunctionCall,
 859    /// Thought signature returned by the model for function calls.
 860    /// Only present on the first function call in parallel call scenarios.
 861    #[serde(skip_serializing_if = "Option::is_none")]
 862    pub thought_signature: Option<String>,
 863}
 864
 865#[derive(Debug, Serialize, Deserialize)]
 866#[serde(rename_all = "camelCase")]
 867pub struct FunctionResponsePart {
 868    pub function_response: FunctionResponse,
 869}
 870
 871#[derive(Debug, Serialize, Deserialize)]
 872#[serde(rename_all = "camelCase")]
 873pub struct ThoughtPart {
 874    pub thought: bool,
 875    pub thought_signature: String,
 876}
 877
 878#[derive(Debug, Serialize, Deserialize)]
 879#[serde(rename_all = "camelCase")]
 880pub struct CitationSource {
 881    #[serde(skip_serializing_if = "Option::is_none")]
 882    pub start_index: Option<usize>,
 883    #[serde(skip_serializing_if = "Option::is_none")]
 884    pub end_index: Option<usize>,
 885    #[serde(skip_serializing_if = "Option::is_none")]
 886    pub uri: Option<String>,
 887    #[serde(skip_serializing_if = "Option::is_none")]
 888    pub license: Option<String>,
 889}
 890
 891#[derive(Debug, Serialize, Deserialize)]
 892#[serde(rename_all = "camelCase")]
 893pub struct CitationMetadata {
 894    pub citation_sources: Vec<CitationSource>,
 895}
 896
 897#[derive(Debug, Serialize, Deserialize)]
 898#[serde(rename_all = "camelCase")]
 899pub struct PromptFeedback {
 900    #[serde(skip_serializing_if = "Option::is_none")]
 901    pub block_reason: Option<String>,
 902    pub safety_ratings: Option<Vec<SafetyRating>>,
 903    #[serde(skip_serializing_if = "Option::is_none")]
 904    pub block_reason_message: Option<String>,
 905}
 906
 907#[derive(Debug, Clone, Serialize, Deserialize, Default)]
 908#[serde(rename_all = "camelCase")]
 909pub struct UsageMetadata {
 910    #[serde(skip_serializing_if = "Option::is_none")]
 911    pub prompt_token_count: Option<u64>,
 912    #[serde(skip_serializing_if = "Option::is_none")]
 913    pub cached_content_token_count: Option<u64>,
 914    #[serde(skip_serializing_if = "Option::is_none")]
 915    pub candidates_token_count: Option<u64>,
 916    #[serde(skip_serializing_if = "Option::is_none")]
 917    pub tool_use_prompt_token_count: Option<u64>,
 918    #[serde(skip_serializing_if = "Option::is_none")]
 919    pub thoughts_token_count: Option<u64>,
 920    #[serde(skip_serializing_if = "Option::is_none")]
 921    pub total_token_count: Option<u64>,
 922}
 923
 924#[derive(Debug, Serialize, Deserialize)]
 925#[serde(rename_all = "camelCase")]
 926pub struct ThinkingConfig {
 927    pub thinking_budget: u32,
 928}
 929
 930#[derive(Debug, Deserialize, Serialize)]
 931#[serde(rename_all = "camelCase")]
 932pub struct GenerationConfig {
 933    #[serde(skip_serializing_if = "Option::is_none")]
 934    pub candidate_count: Option<usize>,
 935    #[serde(skip_serializing_if = "Option::is_none")]
 936    pub stop_sequences: Option<Vec<String>>,
 937    #[serde(skip_serializing_if = "Option::is_none")]
 938    pub max_output_tokens: Option<usize>,
 939    #[serde(skip_serializing_if = "Option::is_none")]
 940    pub temperature: Option<f64>,
 941    #[serde(skip_serializing_if = "Option::is_none")]
 942    pub top_p: Option<f64>,
 943    #[serde(skip_serializing_if = "Option::is_none")]
 944    pub top_k: Option<usize>,
 945    #[serde(skip_serializing_if = "Option::is_none")]
 946    pub thinking_config: Option<ThinkingConfig>,
 947}
 948
 949#[derive(Debug, Serialize, Deserialize)]
 950#[serde(rename_all = "camelCase")]
 951pub struct SafetySetting {
 952    pub category: HarmCategory,
 953    pub threshold: HarmBlockThreshold,
 954}
 955
 956#[derive(Debug, Serialize, Deserialize)]
 957pub enum HarmCategory {
 958    #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
 959    Unspecified,
 960    #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
 961    Derogatory,
 962    #[serde(rename = "HARM_CATEGORY_TOXICITY")]
 963    Toxicity,
 964    #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
 965    Violence,
 966    #[serde(rename = "HARM_CATEGORY_SEXUAL")]
 967    Sexual,
 968    #[serde(rename = "HARM_CATEGORY_MEDICAL")]
 969    Medical,
 970    #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
 971    Dangerous,
 972    #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
 973    Harassment,
 974    #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
 975    HateSpeech,
 976    #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
 977    SexuallyExplicit,
 978    #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
 979    DangerousContent,
 980}
 981
 982#[derive(Debug, Serialize, Deserialize)]
 983#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
 984pub enum HarmBlockThreshold {
 985    #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
 986    Unspecified,
 987    BlockLowAndAbove,
 988    BlockMediumAndAbove,
 989    BlockOnlyHigh,
 990    BlockNone,
 991}
 992
 993#[derive(Debug, Serialize, Deserialize)]
 994#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
 995pub enum HarmProbability {
 996    #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
 997    Unspecified,
 998    Negligible,
 999    Low,
1000    Medium,
1001    High,
1002}
1003
1004#[derive(Debug, Serialize, Deserialize)]
1005#[serde(rename_all = "camelCase")]
1006pub struct SafetyRating {
1007    pub category: HarmCategory,
1008    pub probability: HarmProbability,
1009}
1010
1011#[derive(Debug, Serialize, Deserialize)]
1012#[serde(rename_all = "camelCase")]
1013pub struct CountTokensRequest {
1014    pub generate_content_request: GenerateContentRequest,
1015}
1016
1017#[derive(Debug, Serialize, Deserialize)]
1018#[serde(rename_all = "camelCase")]
1019pub struct CountTokensResponse {
1020    pub total_tokens: u64,
1021}
1022
1023#[derive(Debug, Serialize, Deserialize)]
1024pub struct FunctionCall {
1025    pub name: String,
1026    pub args: serde_json::Value,
1027}
1028
1029#[derive(Debug, Serialize, Deserialize)]
1030pub struct FunctionResponse {
1031    pub name: String,
1032    pub response: serde_json::Value,
1033}
1034
1035#[derive(Debug, Serialize, Deserialize)]
1036#[serde(rename_all = "camelCase")]
1037pub struct Tool {
1038    pub function_declarations: Vec<FunctionDeclaration>,
1039}
1040
1041#[derive(Debug, Serialize, Deserialize)]
1042#[serde(rename_all = "camelCase")]
1043pub struct ToolConfig {
1044    pub function_calling_config: FunctionCallingConfig,
1045}
1046
1047#[derive(Debug, Serialize, Deserialize)]
1048#[serde(rename_all = "camelCase")]
1049pub struct FunctionCallingConfig {
1050    pub mode: FunctionCallingMode,
1051    #[serde(skip_serializing_if = "Option::is_none")]
1052    pub allowed_function_names: Option<Vec<String>>,
1053}
1054
1055#[derive(Debug, Serialize, Deserialize)]
1056#[serde(rename_all = "lowercase")]
1057pub enum FunctionCallingMode {
1058    Auto,
1059    Any,
1060    None,
1061}
1062
1063#[derive(Debug, Serialize, Deserialize)]
1064pub struct FunctionDeclaration {
1065    pub name: String,
1066    pub description: String,
1067    pub parameters: serde_json::Value,
1068}
1069
1070#[derive(Debug, Default)]
1071pub struct ModelName {
1072    pub model_id: String,
1073}
1074
1075impl ModelName {
1076    pub fn is_empty(&self) -> bool {
1077        self.model_id.is_empty()
1078    }
1079}
1080
1081const MODEL_NAME_PREFIX: &str = "models/";
1082
1083/// Google API error response structure
1084#[derive(Debug, Deserialize)]
1085pub struct ApiErrorResponse {
1086    pub error: ApiError,
1087}
1088
1089#[derive(Debug, Deserialize)]
1090pub struct ApiError {
1091    pub code: Option<u16>,
1092    pub message: Option<String>,
1093    pub status: Option<String>,
1094}
1095
1096impl Serialize for ModelName {
1097    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1098    where
1099        S: Serializer,
1100    {
1101        serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id))
1102    }
1103}
1104
1105impl<'de> Deserialize<'de> for ModelName {
1106    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1107    where
1108        D: Deserializer<'de>,
1109    {
1110        let string = String::deserialize(deserializer)?;
1111        if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) {
1112            Ok(Self {
1113                model_id: id.to_string(),
1114            })
1115        } else {
1116            Err(serde::de::Error::custom(format!(
1117                "Expected model name to begin with {}, got: {}",
1118                MODEL_NAME_PREFIX, string
1119            )))
1120        }
1121    }
1122}