google_ai.rs

  1use std::collections::HashMap;
  2use std::sync::atomic::{AtomicU64, Ordering};
  3use std::sync::Mutex;
  4
  5use serde::{Deserialize, Serialize};
  6use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
  7use zed_extension_api::{self as zed, *};
  8
  9static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
 10
 11struct GoogleAiProvider {
 12    streams: Mutex<HashMap<String, StreamState>>,
 13    next_stream_id: Mutex<u64>,
 14}
 15
 16struct StreamState {
 17    response_stream: Option<HttpResponseStream>,
 18    buffer: String,
 19    started: bool,
 20    stop_reason: Option<LlmStopReason>,
 21    wants_tool_use: bool,
 22}
 23
 24struct ModelDefinition {
 25    real_id: &'static str,
 26    display_name: &'static str,
 27    max_tokens: u64,
 28    max_output_tokens: Option<u64>,
 29    supports_images: bool,
 30    supports_thinking: bool,
 31    is_default: bool,
 32    is_default_fast: bool,
 33}
 34
 35const MODELS: &[ModelDefinition] = &[
 36    ModelDefinition {
 37        real_id: "gemini-2.5-flash-lite",
 38        display_name: "Gemini 2.5 Flash-Lite",
 39        max_tokens: 1_048_576,
 40        max_output_tokens: Some(65_536),
 41        supports_images: true,
 42        supports_thinking: true,
 43        is_default: false,
 44        is_default_fast: true,
 45    },
 46    ModelDefinition {
 47        real_id: "gemini-2.5-flash",
 48        display_name: "Gemini 2.5 Flash",
 49        max_tokens: 1_048_576,
 50        max_output_tokens: Some(65_536),
 51        supports_images: true,
 52        supports_thinking: true,
 53        is_default: true,
 54        is_default_fast: false,
 55    },
 56    ModelDefinition {
 57        real_id: "gemini-2.5-pro",
 58        display_name: "Gemini 2.5 Pro",
 59        max_tokens: 1_048_576,
 60        max_output_tokens: Some(65_536),
 61        supports_images: true,
 62        supports_thinking: true,
 63        is_default: false,
 64        is_default_fast: false,
 65    },
 66    ModelDefinition {
 67        real_id: "gemini-3-pro-preview",
 68        display_name: "Gemini 3 Pro",
 69        max_tokens: 1_048_576,
 70        max_output_tokens: Some(65_536),
 71        supports_images: true,
 72        supports_thinking: true,
 73        is_default: false,
 74        is_default_fast: false,
 75    },
 76];
 77
 78fn get_real_model_id(display_name: &str) -> Option<&'static str> {
 79    MODELS
 80        .iter()
 81        .find(|m| m.display_name == display_name)
 82        .map(|m| m.real_id)
 83}
 84
 85fn get_model_supports_thinking(display_name: &str) -> bool {
 86    MODELS
 87        .iter()
 88        .find(|m| m.display_name == display_name)
 89        .map(|m| m.supports_thinking)
 90        .unwrap_or(false)
 91}
 92
 93/// Adapts a JSON schema to be compatible with Google's API subset.
 94/// Google only supports a specific subset of JSON Schema fields.
 95/// See: https://ai.google.dev/api/caching#Schema
 96fn adapt_schema_for_google(json: &mut serde_json::Value) {
 97    adapt_schema_for_google_impl(json, true);
 98}
 99
100fn adapt_schema_for_google_impl(json: &mut serde_json::Value, is_schema: bool) {
101    if let serde_json::Value::Object(obj) = json {
102        // Google's Schema only supports these fields:
103        // type, format, title, description, nullable, enum, maxItems, minItems,
104        // properties, required, minProperties, maxProperties, minLength, maxLength,
105        // pattern, example, anyOf, propertyOrdering, default, items, minimum, maximum
106        const ALLOWED_KEYS: &[&str] = &[
107            "type",
108            "format",
109            "title",
110            "description",
111            "nullable",
112            "enum",
113            "maxItems",
114            "minItems",
115            "properties",
116            "required",
117            "minProperties",
118            "maxProperties",
119            "minLength",
120            "maxLength",
121            "pattern",
122            "example",
123            "anyOf",
124            "propertyOrdering",
125            "default",
126            "items",
127            "minimum",
128            "maximum",
129        ];
130
131        // Convert oneOf to anyOf before filtering keys
132        if let Some(one_of) = obj.remove("oneOf") {
133            obj.insert("anyOf".to_string(), one_of);
134        }
135
136        // If type is an array (e.g., ["string", "null"]), take just the first type
137        if let Some(type_field) = obj.get_mut("type") {
138            if let serde_json::Value::Array(types) = type_field {
139                if let Some(first_type) = types.first().cloned() {
140                    *type_field = first_type;
141                }
142            }
143        }
144
145        // Only filter keys if this is a schema object, not a properties map
146        if is_schema {
147            obj.retain(|key, _| ALLOWED_KEYS.contains(&key.as_str()));
148        }
149
150        // Recursively process nested values
151        // "properties" contains a map of property names -> schemas
152        // "items" and "anyOf" contain schemas directly
153        for (key, value) in obj.iter_mut() {
154            if key == "properties" {
155                // properties is a map of property_name -> schema
156                if let serde_json::Value::Object(props) = value {
157                    for (_, prop_schema) in props.iter_mut() {
158                        adapt_schema_for_google_impl(prop_schema, true);
159                    }
160                }
161            } else if key == "items" {
162                // items is a schema
163                adapt_schema_for_google_impl(value, true);
164            } else if key == "anyOf" {
165                // anyOf is an array of schemas
166                if let serde_json::Value::Array(arr) = value {
167                    for item in arr.iter_mut() {
168                        adapt_schema_for_google_impl(item, true);
169                    }
170                }
171            }
172        }
173    } else if let serde_json::Value::Array(arr) = json {
174        for item in arr.iter_mut() {
175            adapt_schema_for_google_impl(item, true);
176        }
177    }
178}
179
180#[derive(Serialize)]
181#[serde(rename_all = "camelCase")]
182struct GoogleRequest {
183    contents: Vec<GoogleContent>,
184    #[serde(skip_serializing_if = "Option::is_none")]
185    system_instruction: Option<GoogleSystemInstruction>,
186    #[serde(skip_serializing_if = "Option::is_none")]
187    generation_config: Option<GoogleGenerationConfig>,
188    #[serde(skip_serializing_if = "Option::is_none")]
189    tools: Option<Vec<GoogleTool>>,
190    #[serde(skip_serializing_if = "Option::is_none")]
191    tool_config: Option<GoogleToolConfig>,
192}
193
194#[derive(Serialize)]
195#[serde(rename_all = "camelCase")]
196struct GoogleSystemInstruction {
197    parts: Vec<GooglePart>,
198}
199
200#[derive(Serialize, Deserialize, Debug, Clone)]
201#[serde(rename_all = "camelCase")]
202struct GoogleContent {
203    parts: Vec<GooglePart>,
204    #[serde(skip_serializing_if = "Option::is_none")]
205    role: Option<String>,
206}
207
208#[derive(Serialize, Deserialize, Debug, Clone)]
209#[serde(untagged)]
210enum GooglePart {
211    Text(GoogleTextPart),
212    InlineData(GoogleInlineDataPart),
213    FunctionCall(GoogleFunctionCallPart),
214    FunctionResponse(GoogleFunctionResponsePart),
215    Thought(GoogleThoughtPart),
216}
217
218#[derive(Serialize, Deserialize, Debug, Clone)]
219#[serde(rename_all = "camelCase")]
220struct GoogleTextPart {
221    text: String,
222}
223
224#[derive(Serialize, Deserialize, Debug, Clone)]
225#[serde(rename_all = "camelCase")]
226struct GoogleInlineDataPart {
227    inline_data: GoogleBlob,
228}
229
230#[derive(Serialize, Deserialize, Debug, Clone)]
231#[serde(rename_all = "camelCase")]
232struct GoogleBlob {
233    mime_type: String,
234    data: String,
235}
236
237#[derive(Serialize, Deserialize, Debug, Clone)]
238#[serde(rename_all = "camelCase")]
239struct GoogleFunctionCallPart {
240    function_call: GoogleFunctionCall,
241    #[serde(skip_serializing_if = "Option::is_none")]
242    thought_signature: Option<String>,
243}
244
245#[derive(Serialize, Deserialize, Debug, Clone)]
246#[serde(rename_all = "camelCase")]
247struct GoogleFunctionCall {
248    name: String,
249    args: serde_json::Value,
250}
251
252#[derive(Serialize, Deserialize, Debug, Clone)]
253#[serde(rename_all = "camelCase")]
254struct GoogleFunctionResponsePart {
255    function_response: GoogleFunctionResponse,
256}
257
258#[derive(Serialize, Deserialize, Debug, Clone)]
259#[serde(rename_all = "camelCase")]
260struct GoogleFunctionResponse {
261    name: String,
262    response: serde_json::Value,
263}
264
265#[derive(Serialize, Deserialize, Debug, Clone)]
266#[serde(rename_all = "camelCase")]
267struct GoogleThoughtPart {
268    thought: bool,
269    thought_signature: String,
270}
271
272#[derive(Serialize)]
273#[serde(rename_all = "camelCase")]
274struct GoogleGenerationConfig {
275    #[serde(skip_serializing_if = "Option::is_none")]
276    candidate_count: Option<usize>,
277    #[serde(skip_serializing_if = "Option::is_none")]
278    stop_sequences: Option<Vec<String>>,
279    #[serde(skip_serializing_if = "Option::is_none")]
280    max_output_tokens: Option<usize>,
281    #[serde(skip_serializing_if = "Option::is_none")]
282    temperature: Option<f64>,
283    #[serde(skip_serializing_if = "Option::is_none")]
284    thinking_config: Option<GoogleThinkingConfig>,
285}
286
287#[derive(Serialize)]
288#[serde(rename_all = "camelCase")]
289struct GoogleThinkingConfig {
290    thinking_budget: u32,
291}
292
293#[derive(Serialize)]
294#[serde(rename_all = "camelCase")]
295struct GoogleTool {
296    function_declarations: Vec<GoogleFunctionDeclaration>,
297}
298
299#[derive(Serialize)]
300#[serde(rename_all = "camelCase")]
301struct GoogleFunctionDeclaration {
302    name: String,
303    description: String,
304    parameters: serde_json::Value,
305}
306
307#[derive(Serialize)]
308#[serde(rename_all = "camelCase")]
309struct GoogleToolConfig {
310    function_calling_config: GoogleFunctionCallingConfig,
311}
312
313#[derive(Serialize)]
314#[serde(rename_all = "camelCase")]
315struct GoogleFunctionCallingConfig {
316    mode: String,
317    #[serde(skip_serializing_if = "Option::is_none")]
318    allowed_function_names: Option<Vec<String>>,
319}
320
321#[derive(Deserialize, Debug)]
322#[serde(rename_all = "camelCase")]
323struct GoogleStreamResponse {
324    #[serde(default)]
325    candidates: Vec<GoogleCandidate>,
326    #[serde(default)]
327    usage_metadata: Option<GoogleUsageMetadata>,
328}
329
330#[derive(Deserialize, Debug)]
331#[serde(rename_all = "camelCase")]
332struct GoogleCandidate {
333    #[serde(default)]
334    content: Option<GoogleContent>,
335    #[serde(default)]
336    finish_reason: Option<String>,
337}
338
339#[derive(Deserialize, Debug)]
340#[serde(rename_all = "camelCase")]
341struct GoogleUsageMetadata {
342    #[serde(default)]
343    prompt_token_count: u64,
344    #[serde(default)]
345    candidates_token_count: u64,
346}
347
348fn convert_request(
349    model_id: &str,
350    request: &LlmCompletionRequest,
351) -> Result<(GoogleRequest, String), String> {
352    let real_model_id =
353        get_real_model_id(model_id).ok_or_else(|| format!("Unknown model: {}", model_id))?;
354
355    let supports_thinking = get_model_supports_thinking(model_id);
356
357    let mut contents: Vec<GoogleContent> = Vec::new();
358    let mut system_parts: Vec<GooglePart> = Vec::new();
359
360    for msg in &request.messages {
361        match msg.role {
362            LlmMessageRole::System => {
363                for content in &msg.content {
364                    if let LlmMessageContent::Text(text) = content {
365                        if !text.is_empty() {
366                            system_parts
367                                .push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
368                        }
369                    }
370                }
371            }
372            LlmMessageRole::User => {
373                let mut parts: Vec<GooglePart> = Vec::new();
374
375                for content in &msg.content {
376                    match content {
377                        LlmMessageContent::Text(text) => {
378                            if !text.is_empty() {
379                                parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
380                            }
381                        }
382                        LlmMessageContent::Image(img) => {
383                            parts.push(GooglePart::InlineData(GoogleInlineDataPart {
384                                inline_data: GoogleBlob {
385                                    mime_type: "image/png".to_string(),
386                                    data: img.source.clone(),
387                                },
388                            }));
389                        }
390                        LlmMessageContent::ToolResult(result) => {
391                            let response_value = match &result.content {
392                                LlmToolResultContent::Text(t) => {
393                                    serde_json::json!({ "output": t })
394                                }
395                                LlmToolResultContent::Image(_) => {
396                                    serde_json::json!({ "output": "Tool responded with an image" })
397                                }
398                            };
399                            parts.push(GooglePart::FunctionResponse(GoogleFunctionResponsePart {
400                                function_response: GoogleFunctionResponse {
401                                    name: result.tool_name.clone(),
402                                    response: response_value,
403                                },
404                            }));
405                        }
406                        _ => {}
407                    }
408                }
409
410                if !parts.is_empty() {
411                    contents.push(GoogleContent {
412                        parts,
413                        role: Some("user".to_string()),
414                    });
415                }
416            }
417            LlmMessageRole::Assistant => {
418                let mut parts: Vec<GooglePart> = Vec::new();
419
420                for content in &msg.content {
421                    match content {
422                        LlmMessageContent::Text(text) => {
423                            if !text.is_empty() {
424                                parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
425                            }
426                        }
427                        LlmMessageContent::ToolUse(tool_use) => {
428                            let thought_signature =
429                                tool_use.thought_signature.clone().filter(|s| !s.is_empty());
430
431                            let args: serde_json::Value =
432                                serde_json::from_str(&tool_use.input).unwrap_or_default();
433
434                            parts.push(GooglePart::FunctionCall(GoogleFunctionCallPart {
435                                function_call: GoogleFunctionCall {
436                                    name: tool_use.name.clone(),
437                                    args,
438                                },
439                                thought_signature,
440                            }));
441                        }
442                        LlmMessageContent::Thinking(thinking) => {
443                            if let Some(ref signature) = thinking.signature {
444                                if !signature.is_empty() {
445                                    parts.push(GooglePart::Thought(GoogleThoughtPart {
446                                        thought: true,
447                                        thought_signature: signature.clone(),
448                                    }));
449                                }
450                            }
451                        }
452                        _ => {}
453                    }
454                }
455
456                if !parts.is_empty() {
457                    contents.push(GoogleContent {
458                        parts,
459                        role: Some("model".to_string()),
460                    });
461                }
462            }
463        }
464    }
465
466    let system_instruction = if system_parts.is_empty() {
467        None
468    } else {
469        Some(GoogleSystemInstruction {
470            parts: system_parts,
471        })
472    };
473
474    let tools: Option<Vec<GoogleTool>> = if request.tools.is_empty() {
475        None
476    } else {
477        let declarations: Vec<GoogleFunctionDeclaration> = request
478            .tools
479            .iter()
480            .map(|t| {
481                let mut parameters: serde_json::Value = serde_json::from_str(&t.input_schema)
482                    .unwrap_or(serde_json::Value::Object(Default::default()));
483                adapt_schema_for_google(&mut parameters);
484                GoogleFunctionDeclaration {
485                    name: t.name.clone(),
486                    description: t.description.clone(),
487                    parameters,
488                }
489            })
490            .collect();
491        Some(vec![GoogleTool {
492            function_declarations: declarations,
493        }])
494    };
495
496    let tool_config = request.tool_choice.as_ref().map(|tc| {
497        let mode = match tc {
498            LlmToolChoice::Auto => "AUTO",
499            LlmToolChoice::Any => "ANY",
500            LlmToolChoice::None => "NONE",
501        };
502        GoogleToolConfig {
503            function_calling_config: GoogleFunctionCallingConfig {
504                mode: mode.to_string(),
505                allowed_function_names: None,
506            },
507        }
508    });
509
510    let thinking_config = if supports_thinking && request.thinking_allowed {
511        Some(GoogleThinkingConfig {
512            thinking_budget: 8192,
513        })
514    } else {
515        None
516    };
517
518    let generation_config = Some(GoogleGenerationConfig {
519        candidate_count: Some(1),
520        stop_sequences: if request.stop_sequences.is_empty() {
521            None
522        } else {
523            Some(request.stop_sequences.clone())
524        },
525        max_output_tokens: None,
526        temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
527        thinking_config,
528    });
529
530    Ok((
531        GoogleRequest {
532            contents,
533            system_instruction,
534            generation_config,
535            tools,
536            tool_config,
537        },
538        real_model_id.to_string(),
539    ))
540}
541
542fn parse_stream_line(line: &str) -> Option<GoogleStreamResponse> {
543    let trimmed = line.trim();
544    if trimmed.is_empty() || trimmed == "[" || trimmed == "]" || trimmed == "," {
545        return None;
546    }
547
548    let json_str = trimmed.strip_prefix("data: ").unwrap_or(trimmed);
549    let json_str = json_str.trim_start_matches(',').trim();
550
551    if json_str.is_empty() {
552        return None;
553    }
554
555    serde_json::from_str(json_str).ok()
556}
557
558impl zed::Extension for GoogleAiProvider {
559    fn new() -> Self {
560        Self {
561            streams: Mutex::new(HashMap::new()),
562            next_stream_id: Mutex::new(0),
563        }
564    }
565
566    fn llm_providers(&self) -> Vec<LlmProviderInfo> {
567        vec![LlmProviderInfo {
568            id: "google-ai".into(),
569            name: "Google AI".into(),
570            icon: Some("icons/google-ai.svg".into()),
571        }]
572    }
573
574    fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
575        Ok(MODELS
576            .iter()
577            .map(|m| LlmModelInfo {
578                id: m.display_name.to_string(),
579                name: m.display_name.to_string(),
580                max_token_count: m.max_tokens,
581                max_output_tokens: m.max_output_tokens,
582                capabilities: LlmModelCapabilities {
583                    supports_images: m.supports_images,
584                    supports_tools: true,
585                    supports_tool_choice_auto: true,
586                    supports_tool_choice_any: true,
587                    supports_tool_choice_none: true,
588                    supports_thinking: m.supports_thinking,
589                    tool_input_format: LlmToolInputFormat::JsonSchema,
590                },
591                is_default: m.is_default,
592                is_default_fast: m.is_default_fast,
593            })
594            .collect())
595    }
596
597    fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
598        llm_get_credential("google-ai").is_some()
599    }
600
601    fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
602        Some(
603            r#"# Google AI Setup
604
605Welcome to **Google AI**! This extension provides access to Google Gemini models.
606
607## Configuration
608
609Enter your Google AI API key below. You can get your API key at [aistudio.google.com/apikey](https://aistudio.google.com/apikey).
610
611## Available Models
612
613| Display Name | Real Model | Context | Output |
614|--------------|------------|---------|--------|
615| Gemini 2.5 Flash-Lite | gemini-2.5-flash-lite | 1M | 65K |
616| Gemini 2.5 Flash | gemini-2.5-flash | 1M | 65K |
617| Gemini 2.5 Pro | gemini-2.5-pro | 1M | 65K |
618| Gemini 3 Pro | gemini-3-pro-preview | 1M | 65K |
619
620## Features
621
622- ✅ Full streaming support
623- ✅ Tool/function calling with thought signatures
624- ✅ Vision (image inputs)
625- ✅ Extended thinking support
626- ✅ All Gemini models
627
628## Pricing
629
630Uses your Google AI API credits. See [Google AI pricing](https://ai.google.dev/pricing) for details.
631"#
632            .to_string(),
633        )
634    }
635
636    fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
637        llm_delete_credential("google-ai")
638    }
639
640    fn llm_stream_completion_start(
641        &mut self,
642        _provider_id: &str,
643        model_id: &str,
644        request: &LlmCompletionRequest,
645    ) -> Result<String, String> {
646        let api_key = llm_get_credential("google-ai").ok_or_else(|| {
647            "No API key configured. Please add your Google AI API key in settings.".to_string()
648        })?;
649
650        let (google_request, real_model_id) = convert_request(model_id, request)?;
651
652        let body = serde_json::to_vec(&google_request)
653            .map_err(|e| format!("Failed to serialize request: {}", e))?;
654
655        let url = format!(
656            "https://generativelanguage.googleapis.com/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
657            real_model_id, api_key
658        );
659
660        let http_request = HttpRequest {
661            method: HttpMethod::Post,
662            url,
663            headers: vec![("Content-Type".to_string(), "application/json".to_string())],
664            body: Some(body),
665            redirect_policy: RedirectPolicy::FollowAll,
666        };
667
668        let response_stream = http_request
669            .fetch_stream()
670            .map_err(|e| format!("HTTP request failed: {}", e))?;
671
672        let stream_id = {
673            let mut id_counter = self.next_stream_id.lock().unwrap();
674            let id = format!("google-ai-stream-{}", *id_counter);
675            *id_counter += 1;
676            id
677        };
678
679        self.streams.lock().unwrap().insert(
680            stream_id.clone(),
681            StreamState {
682                response_stream: Some(response_stream),
683                buffer: String::new(),
684                started: false,
685                stop_reason: None,
686                wants_tool_use: false,
687            },
688        );
689
690        Ok(stream_id)
691    }
692
693    fn llm_stream_completion_next(
694        &mut self,
695        stream_id: &str,
696    ) -> Result<Option<LlmCompletionEvent>, String> {
697        let mut streams = self.streams.lock().unwrap();
698        let state = streams
699            .get_mut(stream_id)
700            .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
701
702        if !state.started {
703            state.started = true;
704            return Ok(Some(LlmCompletionEvent::Started));
705        }
706
707        let response_stream = state
708            .response_stream
709            .as_mut()
710            .ok_or_else(|| "Stream already closed".to_string())?;
711
712        loop {
713            if let Some(newline_pos) = state.buffer.find('\n') {
714                let line = state.buffer[..newline_pos].to_string();
715                state.buffer = state.buffer[newline_pos + 1..].to_string();
716
717                if let Some(response) = parse_stream_line(&line) {
718                    for candidate in response.candidates {
719                        if let Some(finish_reason) = &candidate.finish_reason {
720                            state.stop_reason = Some(match finish_reason.as_str() {
721                                "STOP" => {
722                                    if state.wants_tool_use {
723                                        LlmStopReason::ToolUse
724                                    } else {
725                                        LlmStopReason::EndTurn
726                                    }
727                                }
728                                "MAX_TOKENS" => LlmStopReason::MaxTokens,
729                                "SAFETY" => LlmStopReason::Refusal,
730                                _ => LlmStopReason::EndTurn,
731                            });
732                        }
733
734                        if let Some(content) = candidate.content {
735                            for part in content.parts {
736                                match part {
737                                    GooglePart::Text(text_part) => {
738                                        if !text_part.text.is_empty() {
739                                            return Ok(Some(LlmCompletionEvent::Text(
740                                                text_part.text,
741                                            )));
742                                        }
743                                    }
744                                    GooglePart::FunctionCall(fc_part) => {
745                                        state.wants_tool_use = true;
746                                        let next_tool_id =
747                                            TOOL_CALL_COUNTER.fetch_add(1, Ordering::SeqCst);
748                                        let id = format!(
749                                            "{}-{}",
750                                            fc_part.function_call.name, next_tool_id
751                                        );
752
753                                        let thought_signature =
754                                            fc_part.thought_signature.filter(|s| !s.is_empty());
755
756                                        return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
757                                            id,
758                                            name: fc_part.function_call.name,
759                                            input: fc_part.function_call.args.to_string(),
760                                            thought_signature,
761                                        })));
762                                    }
763                                    GooglePart::Thought(thought_part) => {
764                                        return Ok(Some(LlmCompletionEvent::Thinking(
765                                            LlmThinkingContent {
766                                                text: "(Encrypted thought)".to_string(),
767                                                signature: Some(thought_part.thought_signature),
768                                            },
769                                        )));
770                                    }
771                                    _ => {}
772                                }
773                            }
774                        }
775                    }
776
777                    if let Some(usage) = response.usage_metadata {
778                        return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
779                            input_tokens: usage.prompt_token_count,
780                            output_tokens: usage.candidates_token_count,
781                            cache_creation_input_tokens: None,
782                            cache_read_input_tokens: None,
783                        })));
784                    }
785                }
786
787                continue;
788            }
789
790            match response_stream.next_chunk() {
791                Ok(Some(chunk)) => {
792                    let text = String::from_utf8_lossy(&chunk);
793                    state.buffer.push_str(&text);
794                }
795                Ok(None) => {
796                    // Stream ended - check if we have a stop reason
797                    if let Some(stop_reason) = state.stop_reason.take() {
798                        return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
799                    }
800
801                    // No stop reason - this is unexpected. Check if buffer contains error info
802                    let mut error_msg = String::from("Stream ended unexpectedly.");
803
804                    // Try to parse remaining buffer as potential error response
805                    if !state.buffer.is_empty() {
806                        error_msg.push_str(&format!(
807                            "\nRemaining buffer: {}",
808                            &state.buffer[..state.buffer.len().min(1000)]
809                        ));
810                    }
811
812                    return Err(error_msg);
813                }
814                Err(e) => {
815                    return Err(format!("Stream error: {}", e));
816                }
817            }
818        }
819    }
820
821    fn llm_stream_completion_close(&mut self, stream_id: &str) {
822        self.streams.lock().unwrap().remove(stream_id);
823    }
824}
825
826zed::register_extension!(GoogleAiProvider);