google_ai.rs

  1use std::collections::HashMap;
  2use std::sync::atomic::{AtomicU64, Ordering};
  3use std::sync::Mutex;
  4
  5use serde::{Deserialize, Serialize};
  6use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
  7use zed_extension_api::{self as zed, *};
  8
  9static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
 10
 11struct GoogleAiProvider {
 12    streams: Mutex<HashMap<String, StreamState>>,
 13    next_stream_id: Mutex<u64>,
 14}
 15
 16struct StreamState {
 17    response_stream: Option<HttpResponseStream>,
 18    buffer: String,
 19    started: bool,
 20    stop_reason: Option<LlmStopReason>,
 21    wants_tool_use: bool,
 22}
 23
 24struct ModelDefinition {
 25    real_id: &'static str,
 26    display_name: &'static str,
 27    max_tokens: u64,
 28    max_output_tokens: Option<u64>,
 29    supports_images: bool,
 30    supports_thinking: bool,
 31    is_default: bool,
 32    is_default_fast: bool,
 33}
 34
 35const MODELS: &[ModelDefinition] = &[
 36    ModelDefinition {
 37        real_id: "gemini-2.5-flash-lite",
 38        display_name: "Gemini 2.5 Flash-Lite",
 39        max_tokens: 1_048_576,
 40        max_output_tokens: Some(65_536),
 41        supports_images: true,
 42        supports_thinking: true,
 43        is_default: false,
 44        is_default_fast: true,
 45    },
 46    ModelDefinition {
 47        real_id: "gemini-2.5-flash",
 48        display_name: "Gemini 2.5 Flash",
 49        max_tokens: 1_048_576,
 50        max_output_tokens: Some(65_536),
 51        supports_images: true,
 52        supports_thinking: true,
 53        is_default: true,
 54        is_default_fast: false,
 55    },
 56    ModelDefinition {
 57        real_id: "gemini-2.5-pro",
 58        display_name: "Gemini 2.5 Pro",
 59        max_tokens: 1_048_576,
 60        max_output_tokens: Some(65_536),
 61        supports_images: true,
 62        supports_thinking: true,
 63        is_default: false,
 64        is_default_fast: false,
 65    },
 66    ModelDefinition {
 67        real_id: "gemini-3-pro-preview",
 68        display_name: "Gemini 3 Pro",
 69        max_tokens: 1_048_576,
 70        max_output_tokens: Some(65_536),
 71        supports_images: true,
 72        supports_thinking: true,
 73        is_default: false,
 74        is_default_fast: false,
 75    },
 76];
 77
 78fn get_real_model_id(display_name: &str) -> Option<&'static str> {
 79    MODELS
 80        .iter()
 81        .find(|m| m.display_name == display_name)
 82        .map(|m| m.real_id)
 83}
 84
 85fn get_model_supports_thinking(display_name: &str) -> bool {
 86    MODELS
 87        .iter()
 88        .find(|m| m.display_name == display_name)
 89        .map(|m| m.supports_thinking)
 90        .unwrap_or(false)
 91}
 92
 93/// Adapts a JSON schema to be compatible with Google's API subset.
 94/// Google only supports a specific subset of JSON Schema fields.
 95/// See: https://ai.google.dev/api/caching#Schema
 96fn adapt_schema_for_google(json: &mut serde_json::Value) {
 97    adapt_schema_for_google_impl(json, true);
 98}
 99
100fn adapt_schema_for_google_impl(json: &mut serde_json::Value, is_schema: bool) {
101    if let serde_json::Value::Object(obj) = json {
102        // Google's Schema only supports these fields:
103        // type, format, title, description, nullable, enum, maxItems, minItems,
104        // properties, required, minProperties, maxProperties, minLength, maxLength,
105        // pattern, example, anyOf, propertyOrdering, default, items, minimum, maximum
106        const ALLOWED_KEYS: &[&str] = &[
107            "type",
108            "format",
109            "title",
110            "description",
111            "nullable",
112            "enum",
113            "maxItems",
114            "minItems",
115            "properties",
116            "required",
117            "minProperties",
118            "maxProperties",
119            "minLength",
120            "maxLength",
121            "pattern",
122            "example",
123            "anyOf",
124            "propertyOrdering",
125            "default",
126            "items",
127            "minimum",
128            "maximum",
129        ];
130
131        // Convert oneOf to anyOf before filtering keys
132        if let Some(one_of) = obj.remove("oneOf") {
133            obj.insert("anyOf".to_string(), one_of);
134        }
135
136        // If type is an array (e.g., ["string", "null"]), take just the first type
137        if let Some(type_field) = obj.get_mut("type") {
138            if let serde_json::Value::Array(types) = type_field {
139                if let Some(first_type) = types.first().cloned() {
140                    *type_field = first_type;
141                }
142            }
143        }
144
145        // Only filter keys if this is a schema object, not a properties map
146        if is_schema {
147            obj.retain(|key, _| ALLOWED_KEYS.contains(&key.as_str()));
148        }
149
150        // Recursively process nested values
151        // "properties" contains a map of property names -> schemas
152        // "items" and "anyOf" contain schemas directly
153        for (key, value) in obj.iter_mut() {
154            if key == "properties" {
155                // properties is a map of property_name -> schema
156                if let serde_json::Value::Object(props) = value {
157                    for (_, prop_schema) in props.iter_mut() {
158                        adapt_schema_for_google_impl(prop_schema, true);
159                    }
160                }
161            } else if key == "items" {
162                // items is a schema
163                adapt_schema_for_google_impl(value, true);
164            } else if key == "anyOf" {
165                // anyOf is an array of schemas
166                if let serde_json::Value::Array(arr) = value {
167                    for item in arr.iter_mut() {
168                        adapt_schema_for_google_impl(item, true);
169                    }
170                }
171            }
172        }
173    } else if let serde_json::Value::Array(arr) = json {
174        for item in arr.iter_mut() {
175            adapt_schema_for_google_impl(item, true);
176        }
177    }
178}
179
180#[derive(Serialize)]
181#[serde(rename_all = "camelCase")]
182struct GoogleRequest {
183    contents: Vec<GoogleContent>,
184    #[serde(skip_serializing_if = "Option::is_none")]
185    system_instruction: Option<GoogleSystemInstruction>,
186    #[serde(skip_serializing_if = "Option::is_none")]
187    generation_config: Option<GoogleGenerationConfig>,
188    #[serde(skip_serializing_if = "Option::is_none")]
189    tools: Option<Vec<GoogleTool>>,
190    #[serde(skip_serializing_if = "Option::is_none")]
191    tool_config: Option<GoogleToolConfig>,
192}
193
194#[derive(Serialize)]
195#[serde(rename_all = "camelCase")]
196struct GoogleSystemInstruction {
197    parts: Vec<GooglePart>,
198}
199
200#[derive(Serialize, Deserialize, Debug, Clone)]
201#[serde(rename_all = "camelCase")]
202struct GoogleContent {
203    parts: Vec<GooglePart>,
204    #[serde(skip_serializing_if = "Option::is_none")]
205    role: Option<String>,
206}
207
208#[derive(Serialize, Deserialize, Debug, Clone)]
209#[serde(untagged)]
210enum GooglePart {
211    Text(GoogleTextPart),
212    InlineData(GoogleInlineDataPart),
213    FunctionCall(GoogleFunctionCallPart),
214    FunctionResponse(GoogleFunctionResponsePart),
215    Thought(GoogleThoughtPart),
216}
217
218#[derive(Serialize, Deserialize, Debug, Clone)]
219#[serde(rename_all = "camelCase")]
220struct GoogleTextPart {
221    text: String,
222}
223
224#[derive(Serialize, Deserialize, Debug, Clone)]
225#[serde(rename_all = "camelCase")]
226struct GoogleInlineDataPart {
227    inline_data: GoogleBlob,
228}
229
230#[derive(Serialize, Deserialize, Debug, Clone)]
231#[serde(rename_all = "camelCase")]
232struct GoogleBlob {
233    mime_type: String,
234    data: String,
235}
236
237#[derive(Serialize, Deserialize, Debug, Clone)]
238#[serde(rename_all = "camelCase")]
239struct GoogleFunctionCallPart {
240    function_call: GoogleFunctionCall,
241    #[serde(skip_serializing_if = "Option::is_none")]
242    thought_signature: Option<String>,
243}
244
245#[derive(Serialize, Deserialize, Debug, Clone)]
246#[serde(rename_all = "camelCase")]
247struct GoogleFunctionCall {
248    name: String,
249    args: serde_json::Value,
250}
251
252#[derive(Serialize, Deserialize, Debug, Clone)]
253#[serde(rename_all = "camelCase")]
254struct GoogleFunctionResponsePart {
255    function_response: GoogleFunctionResponse,
256}
257
258#[derive(Serialize, Deserialize, Debug, Clone)]
259#[serde(rename_all = "camelCase")]
260struct GoogleFunctionResponse {
261    name: String,
262    response: serde_json::Value,
263}
264
265#[derive(Serialize, Deserialize, Debug, Clone)]
266#[serde(rename_all = "camelCase")]
267struct GoogleThoughtPart {
268    thought: bool,
269    thought_signature: String,
270}
271
272#[derive(Serialize)]
273#[serde(rename_all = "camelCase")]
274struct GoogleGenerationConfig {
275    #[serde(skip_serializing_if = "Option::is_none")]
276    candidate_count: Option<usize>,
277    #[serde(skip_serializing_if = "Option::is_none")]
278    stop_sequences: Option<Vec<String>>,
279    #[serde(skip_serializing_if = "Option::is_none")]
280    max_output_tokens: Option<usize>,
281    #[serde(skip_serializing_if = "Option::is_none")]
282    temperature: Option<f64>,
283    #[serde(skip_serializing_if = "Option::is_none")]
284    thinking_config: Option<GoogleThinkingConfig>,
285}
286
287#[derive(Serialize)]
288#[serde(rename_all = "camelCase")]
289struct GoogleThinkingConfig {
290    thinking_budget: u32,
291}
292
293#[derive(Serialize)]
294#[serde(rename_all = "camelCase")]
295struct GoogleTool {
296    function_declarations: Vec<GoogleFunctionDeclaration>,
297}
298
299#[derive(Serialize)]
300#[serde(rename_all = "camelCase")]
301struct GoogleFunctionDeclaration {
302    name: String,
303    description: String,
304    parameters: serde_json::Value,
305}
306
307#[derive(Serialize)]
308#[serde(rename_all = "camelCase")]
309struct GoogleToolConfig {
310    function_calling_config: GoogleFunctionCallingConfig,
311}
312
313#[derive(Serialize)]
314#[serde(rename_all = "camelCase")]
315struct GoogleFunctionCallingConfig {
316    mode: String,
317    #[serde(skip_serializing_if = "Option::is_none")]
318    allowed_function_names: Option<Vec<String>>,
319}
320
321#[derive(Deserialize, Debug)]
322#[serde(rename_all = "camelCase")]
323struct GoogleStreamResponse {
324    #[serde(default)]
325    candidates: Vec<GoogleCandidate>,
326    #[serde(default)]
327    usage_metadata: Option<GoogleUsageMetadata>,
328}
329
330#[derive(Deserialize, Debug)]
331#[serde(rename_all = "camelCase")]
332struct GoogleCandidate {
333    #[serde(default)]
334    content: Option<GoogleContent>,
335    #[serde(default)]
336    finish_reason: Option<String>,
337}
338
339#[derive(Deserialize, Debug)]
340#[serde(rename_all = "camelCase")]
341struct GoogleUsageMetadata {
342    #[serde(default)]
343    prompt_token_count: u64,
344    #[serde(default)]
345    candidates_token_count: u64,
346}
347
348fn convert_request(
349    model_id: &str,
350    request: &LlmCompletionRequest,
351) -> Result<(GoogleRequest, String), String> {
352    let real_model_id =
353        get_real_model_id(model_id).ok_or_else(|| format!("Unknown model: {}", model_id))?;
354
355    let supports_thinking = get_model_supports_thinking(model_id);
356
357    let mut contents: Vec<GoogleContent> = Vec::new();
358    let mut system_parts: Vec<GooglePart> = Vec::new();
359
360    for msg in &request.messages {
361        match msg.role {
362            LlmMessageRole::System => {
363                for content in &msg.content {
364                    if let LlmMessageContent::Text(text) = content {
365                        if !text.is_empty() {
366                            system_parts
367                                .push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
368                        }
369                    }
370                }
371            }
372            LlmMessageRole::User => {
373                let mut parts: Vec<GooglePart> = Vec::new();
374
375                for content in &msg.content {
376                    match content {
377                        LlmMessageContent::Text(text) => {
378                            if !text.is_empty() {
379                                parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
380                            }
381                        }
382                        LlmMessageContent::Image(img) => {
383                            parts.push(GooglePart::InlineData(GoogleInlineDataPart {
384                                inline_data: GoogleBlob {
385                                    mime_type: "image/png".to_string(),
386                                    data: img.source.clone(),
387                                },
388                            }));
389                        }
390                        LlmMessageContent::ToolResult(result) => {
391                            let response_value = match &result.content {
392                                LlmToolResultContent::Text(t) => {
393                                    serde_json::json!({ "output": t })
394                                }
395                                LlmToolResultContent::Image(_) => {
396                                    serde_json::json!({ "output": "Tool responded with an image" })
397                                }
398                            };
399                            parts.push(GooglePart::FunctionResponse(GoogleFunctionResponsePart {
400                                function_response: GoogleFunctionResponse {
401                                    name: result.tool_name.clone(),
402                                    response: response_value,
403                                },
404                            }));
405                        }
406                        _ => {}
407                    }
408                }
409
410                if !parts.is_empty() {
411                    contents.push(GoogleContent {
412                        parts,
413                        role: Some("user".to_string()),
414                    });
415                }
416            }
417            LlmMessageRole::Assistant => {
418                let mut parts: Vec<GooglePart> = Vec::new();
419
420                for content in &msg.content {
421                    match content {
422                        LlmMessageContent::Text(text) => {
423                            if !text.is_empty() {
424                                parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
425                            }
426                        }
427                        LlmMessageContent::ToolUse(tool_use) => {
428                            let thought_signature =
429                                tool_use.thought_signature.clone().filter(|s| !s.is_empty());
430
431                            let args: serde_json::Value =
432                                serde_json::from_str(&tool_use.input).unwrap_or_default();
433
434                            parts.push(GooglePart::FunctionCall(GoogleFunctionCallPart {
435                                function_call: GoogleFunctionCall {
436                                    name: tool_use.name.clone(),
437                                    args,
438                                },
439                                thought_signature,
440                            }));
441                        }
442                        LlmMessageContent::Thinking(thinking) => {
443                            if let Some(ref signature) = thinking.signature {
444                                if !signature.is_empty() {
445                                    parts.push(GooglePart::Thought(GoogleThoughtPart {
446                                        thought: true,
447                                        thought_signature: signature.clone(),
448                                    }));
449                                }
450                            }
451                        }
452                        _ => {}
453                    }
454                }
455
456                if !parts.is_empty() {
457                    contents.push(GoogleContent {
458                        parts,
459                        role: Some("model".to_string()),
460                    });
461                }
462            }
463        }
464    }
465
466    let system_instruction = if system_parts.is_empty() {
467        None
468    } else {
469        Some(GoogleSystemInstruction {
470            parts: system_parts,
471        })
472    };
473
474    let tools: Option<Vec<GoogleTool>> = if request.tools.is_empty() {
475        None
476    } else {
477        let declarations: Vec<GoogleFunctionDeclaration> = request
478            .tools
479            .iter()
480            .map(|t| {
481                let mut parameters: serde_json::Value = serde_json::from_str(&t.input_schema)
482                    .unwrap_or(serde_json::Value::Object(Default::default()));
483                adapt_schema_for_google(&mut parameters);
484                GoogleFunctionDeclaration {
485                    name: t.name.clone(),
486                    description: t.description.clone(),
487                    parameters,
488                }
489            })
490            .collect();
491        Some(vec![GoogleTool {
492            function_declarations: declarations,
493        }])
494    };
495
496    let tool_config = request.tool_choice.as_ref().map(|tc| {
497        let mode = match tc {
498            LlmToolChoice::Auto => "AUTO",
499            LlmToolChoice::Any => "ANY",
500            LlmToolChoice::None => "NONE",
501        };
502        GoogleToolConfig {
503            function_calling_config: GoogleFunctionCallingConfig {
504                mode: mode.to_string(),
505                allowed_function_names: None,
506            },
507        }
508    });
509
510    let thinking_config = if supports_thinking && request.thinking_allowed {
511        Some(GoogleThinkingConfig {
512            thinking_budget: 8192,
513        })
514    } else {
515        None
516    };
517
518    let generation_config = Some(GoogleGenerationConfig {
519        candidate_count: Some(1),
520        stop_sequences: if request.stop_sequences.is_empty() {
521            None
522        } else {
523            Some(request.stop_sequences.clone())
524        },
525        max_output_tokens: None,
526        temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
527        thinking_config,
528    });
529
530    Ok((
531        GoogleRequest {
532            contents,
533            system_instruction,
534            generation_config,
535            tools,
536            tool_config,
537        },
538        real_model_id.to_string(),
539    ))
540}
541
542fn parse_stream_line(line: &str) -> Option<GoogleStreamResponse> {
543    let trimmed = line.trim();
544    if trimmed.is_empty() || trimmed == "[" || trimmed == "]" || trimmed == "," {
545        return None;
546    }
547
548    let json_str = trimmed.strip_prefix("data: ").unwrap_or(trimmed);
549    let json_str = json_str.trim_start_matches(',').trim();
550
551    if json_str.is_empty() {
552        return None;
553    }
554
555    serde_json::from_str(json_str).ok()
556}
557
558impl zed::Extension for GoogleAiProvider {
559    fn new() -> Self {
560        Self {
561            streams: Mutex::new(HashMap::new()),
562            next_stream_id: Mutex::new(0),
563        }
564    }
565
566    fn llm_providers(&self) -> Vec<LlmProviderInfo> {
567        vec![LlmProviderInfo {
568            id: "google-ai".into(),
569            name: "Google AI".into(),
570            icon: Some("icons/google-ai.svg".into()),
571        }]
572    }
573
574    fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
575        Ok(MODELS
576            .iter()
577            .map(|m| LlmModelInfo {
578                id: m.display_name.to_string(),
579                name: m.display_name.to_string(),
580                max_token_count: m.max_tokens,
581                max_output_tokens: m.max_output_tokens,
582                capabilities: LlmModelCapabilities {
583                    supports_images: m.supports_images,
584                    supports_tools: true,
585                    supports_tool_choice_auto: true,
586                    supports_tool_choice_any: true,
587                    supports_tool_choice_none: true,
588                    supports_thinking: m.supports_thinking,
589                    tool_input_format: LlmToolInputFormat::JsonSchema,
590                },
591                is_default: m.is_default,
592                is_default_fast: m.is_default_fast,
593            })
594            .collect())
595    }
596
597    fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
598        llm_get_credential("google-ai").is_some()
599    }
600
601    fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
602        Some(
603            r#"# Google AI Setup
604
605Welcome to **Google AI**! This extension provides access to Google Gemini models.
606
607## Configuration
608
609Enter your Google AI API key below. You can get your API key at [aistudio.google.com/apikey](https://aistudio.google.com/apikey).
610
611## Available Models
612
613| Display Name | Real Model | Context | Output |
614|--------------|------------|---------|--------|
615| Gemini 2.5 Flash-Lite | gemini-2.5-flash-lite | 1M | 65K |
616| Gemini 2.5 Flash | gemini-2.5-flash | 1M | 65K |
617| Gemini 2.5 Pro | gemini-2.5-pro | 1M | 65K |
618| Gemini 3 Pro | gemini-3-pro-preview | 1M | 65K |
619
620## Features
621
622- ✅ Full streaming support
623- ✅ Tool/function calling with thought signatures
624- ✅ Vision (image inputs)
625- ✅ Extended thinking support
626- ✅ All Gemini models
627
628## Pricing
629
630Uses your Google AI API credits. See [Google AI pricing](https://ai.google.dev/pricing) for details.
631"#
632            .to_string(),
633        )
634    }
635
636    fn llm_provider_authenticate(&mut self, _provider_id: &str) -> Result<(), String> {
637        let provided = llm_request_credential(
638            "google-ai",
639            LlmCredentialType::ApiKey,
640            "Google AI API Key",
641            "AIza...",
642        )?;
643        if provided {
644            Ok(())
645        } else {
646            Err("Authentication cancelled".to_string())
647        }
648    }
649
650    fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
651        llm_delete_credential("google-ai")
652    }
653
654    fn llm_stream_completion_start(
655        &mut self,
656        _provider_id: &str,
657        model_id: &str,
658        request: &LlmCompletionRequest,
659    ) -> Result<String, String> {
660        let api_key = llm_get_credential("google-ai").ok_or_else(|| {
661            "No API key configured. Please add your Google AI API key in settings.".to_string()
662        })?;
663
664        let (google_request, real_model_id) = convert_request(model_id, request)?;
665
666        let body = serde_json::to_vec(&google_request)
667            .map_err(|e| format!("Failed to serialize request: {}", e))?;
668
669        let url = format!(
670            "https://generativelanguage.googleapis.com/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
671            real_model_id, api_key
672        );
673
674        let http_request = HttpRequest {
675            method: HttpMethod::Post,
676            url,
677            headers: vec![("Content-Type".to_string(), "application/json".to_string())],
678            body: Some(body),
679            redirect_policy: RedirectPolicy::FollowAll,
680        };
681
682        let response_stream = http_request
683            .fetch_stream()
684            .map_err(|e| format!("HTTP request failed: {}", e))?;
685
686        let stream_id = {
687            let mut id_counter = self.next_stream_id.lock().unwrap();
688            let id = format!("google-ai-stream-{}", *id_counter);
689            *id_counter += 1;
690            id
691        };
692
693        self.streams.lock().unwrap().insert(
694            stream_id.clone(),
695            StreamState {
696                response_stream: Some(response_stream),
697                buffer: String::new(),
698                started: false,
699                stop_reason: None,
700                wants_tool_use: false,
701            },
702        );
703
704        Ok(stream_id)
705    }
706
707    fn llm_stream_completion_next(
708        &mut self,
709        stream_id: &str,
710    ) -> Result<Option<LlmCompletionEvent>, String> {
711        let mut streams = self.streams.lock().unwrap();
712        let state = streams
713            .get_mut(stream_id)
714            .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
715
716        if !state.started {
717            state.started = true;
718            return Ok(Some(LlmCompletionEvent::Started));
719        }
720
721        let response_stream = state
722            .response_stream
723            .as_mut()
724            .ok_or_else(|| "Stream already closed".to_string())?;
725
726        loop {
727            if let Some(newline_pos) = state.buffer.find('\n') {
728                let line = state.buffer[..newline_pos].to_string();
729                state.buffer = state.buffer[newline_pos + 1..].to_string();
730
731                if let Some(response) = parse_stream_line(&line) {
732                    for candidate in response.candidates {
733                        if let Some(finish_reason) = &candidate.finish_reason {
734                            state.stop_reason = Some(match finish_reason.as_str() {
735                                "STOP" => {
736                                    if state.wants_tool_use {
737                                        LlmStopReason::ToolUse
738                                    } else {
739                                        LlmStopReason::EndTurn
740                                    }
741                                }
742                                "MAX_TOKENS" => LlmStopReason::MaxTokens,
743                                "SAFETY" => LlmStopReason::Refusal,
744                                _ => LlmStopReason::EndTurn,
745                            });
746                        }
747
748                        if let Some(content) = candidate.content {
749                            for part in content.parts {
750                                match part {
751                                    GooglePart::Text(text_part) => {
752                                        if !text_part.text.is_empty() {
753                                            return Ok(Some(LlmCompletionEvent::Text(
754                                                text_part.text,
755                                            )));
756                                        }
757                                    }
758                                    GooglePart::FunctionCall(fc_part) => {
759                                        state.wants_tool_use = true;
760                                        let next_tool_id =
761                                            TOOL_CALL_COUNTER.fetch_add(1, Ordering::SeqCst);
762                                        let id = format!(
763                                            "{}-{}",
764                                            fc_part.function_call.name, next_tool_id
765                                        );
766
767                                        let thought_signature =
768                                            fc_part.thought_signature.filter(|s| !s.is_empty());
769
770                                        return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
771                                            id,
772                                            name: fc_part.function_call.name,
773                                            input: fc_part.function_call.args.to_string(),
774                                            thought_signature,
775                                        })));
776                                    }
777                                    GooglePart::Thought(thought_part) => {
778                                        return Ok(Some(LlmCompletionEvent::Thinking(
779                                            LlmThinkingContent {
780                                                text: "(Encrypted thought)".to_string(),
781                                                signature: Some(thought_part.thought_signature),
782                                            },
783                                        )));
784                                    }
785                                    _ => {}
786                                }
787                            }
788                        }
789                    }
790
791                    if let Some(usage) = response.usage_metadata {
792                        return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
793                            input_tokens: usage.prompt_token_count,
794                            output_tokens: usage.candidates_token_count,
795                            cache_creation_input_tokens: None,
796                            cache_read_input_tokens: None,
797                        })));
798                    }
799                }
800
801                continue;
802            }
803
804            match response_stream.next_chunk() {
805                Ok(Some(chunk)) => {
806                    let text = String::from_utf8_lossy(&chunk);
807                    state.buffer.push_str(&text);
808                }
809                Ok(None) => {
810                    // Stream ended - check if we have a stop reason
811                    if let Some(stop_reason) = state.stop_reason.take() {
812                        return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
813                    }
814
815                    // No stop reason - this is unexpected. Check if buffer contains error info
816                    let mut error_msg = String::from("Stream ended unexpectedly.");
817
818                    // Try to parse remaining buffer as potential error response
819                    if !state.buffer.is_empty() {
820                        error_msg.push_str(&format!(
821                            "\nRemaining buffer: {}",
822                            &state.buffer[..state.buffer.len().min(1000)]
823                        ));
824                    }
825
826                    return Err(error_msg);
827                }
828                Err(e) => {
829                    return Err(format!("Stream error: {}", e));
830                }
831            }
832        }
833    }
834
835    fn llm_stream_completion_close(&mut self, stream_id: &str) {
836        self.streams.lock().unwrap().remove(stream_id);
837    }
838}
839
840zed::register_extension!(GoogleAiProvider);