google_ai.rs

  1use std::collections::HashMap;
  2use std::sync::atomic::{AtomicU64, Ordering};
  3use std::sync::Mutex;
  4
  5use serde::{Deserialize, Serialize};
  6use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
  7use zed_extension_api::{self as zed, *};
  8
  9static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
 10
 11struct GoogleAiProvider {
 12    streams: Mutex<HashMap<String, StreamState>>,
 13    next_stream_id: Mutex<u64>,
 14}
 15
 16struct StreamState {
 17    response_stream: Option<HttpResponseStream>,
 18    buffer: String,
 19    started: bool,
 20    stop_reason: Option<LlmStopReason>,
 21    wants_tool_use: bool,
 22}
 23
 24struct ModelDefinition {
 25    real_id: &'static str,
 26    display_name: &'static str,
 27    max_tokens: u64,
 28    max_output_tokens: Option<u64>,
 29    supports_images: bool,
 30    supports_thinking: bool,
 31    is_default: bool,
 32    is_default_fast: bool,
 33}
 34
 35const MODELS: &[ModelDefinition] = &[
 36    ModelDefinition {
 37        real_id: "gemini-2.5-flash-lite",
 38        display_name: "Gemini 2.5 Flash-Lite",
 39        max_tokens: 1_048_576,
 40        max_output_tokens: Some(65_536),
 41        supports_images: true,
 42        supports_thinking: true,
 43        is_default: false,
 44        is_default_fast: true,
 45    },
 46    ModelDefinition {
 47        real_id: "gemini-2.5-flash",
 48        display_name: "Gemini 2.5 Flash",
 49        max_tokens: 1_048_576,
 50        max_output_tokens: Some(65_536),
 51        supports_images: true,
 52        supports_thinking: true,
 53        is_default: true,
 54        is_default_fast: false,
 55    },
 56    ModelDefinition {
 57        real_id: "gemini-2.5-pro",
 58        display_name: "Gemini 2.5 Pro",
 59        max_tokens: 1_048_576,
 60        max_output_tokens: Some(65_536),
 61        supports_images: true,
 62        supports_thinking: true,
 63        is_default: false,
 64        is_default_fast: false,
 65    },
 66    ModelDefinition {
 67        real_id: "gemini-3-pro-preview",
 68        display_name: "Gemini 3 Pro",
 69        max_tokens: 1_048_576,
 70        max_output_tokens: Some(65_536),
 71        supports_images: true,
 72        supports_thinking: true,
 73        is_default: false,
 74        is_default_fast: false,
 75    },
 76];
 77
 78fn get_real_model_id(display_name: &str) -> Option<&'static str> {
 79    MODELS
 80        .iter()
 81        .find(|m| m.display_name == display_name)
 82        .map(|m| m.real_id)
 83}
 84
 85fn get_model_supports_thinking(display_name: &str) -> bool {
 86    MODELS
 87        .iter()
 88        .find(|m| m.display_name == display_name)
 89        .map(|m| m.supports_thinking)
 90        .unwrap_or(false)
 91}
 92
 93/// Adapts a JSON schema to be compatible with Google's API subset.
 94/// Google only supports a specific subset of JSON Schema fields.
 95/// See: https://ai.google.dev/api/caching#Schema
 96fn adapt_schema_for_google(json: &mut serde_json::Value) {
 97    adapt_schema_for_google_impl(json, true);
 98}
 99
100fn adapt_schema_for_google_impl(json: &mut serde_json::Value, is_schema: bool) {
101    if let serde_json::Value::Object(obj) = json {
102        // Google's Schema only supports these fields:
103        // type, format, title, description, nullable, enum, maxItems, minItems,
104        // properties, required, minProperties, maxProperties, minLength, maxLength,
105        // pattern, example, anyOf, propertyOrdering, default, items, minimum, maximum
106        const ALLOWED_KEYS: &[&str] = &[
107            "type",
108            "format",
109            "title",
110            "description",
111            "nullable",
112            "enum",
113            "maxItems",
114            "minItems",
115            "properties",
116            "required",
117            "minProperties",
118            "maxProperties",
119            "minLength",
120            "maxLength",
121            "pattern",
122            "example",
123            "anyOf",
124            "propertyOrdering",
125            "default",
126            "items",
127            "minimum",
128            "maximum",
129        ];
130
131        // Convert oneOf to anyOf before filtering keys
132        if let Some(one_of) = obj.remove("oneOf") {
133            obj.insert("anyOf".to_string(), one_of);
134        }
135
136        // If type is an array (e.g., ["string", "null"]), take just the first type
137        if let Some(type_field) = obj.get_mut("type") {
138            if let serde_json::Value::Array(types) = type_field {
139                if let Some(first_type) = types.first().cloned() {
140                    *type_field = first_type;
141                }
142            }
143        }
144
145        // Only filter keys if this is a schema object, not a properties map
146        if is_schema {
147            obj.retain(|key, _| ALLOWED_KEYS.contains(&key.as_str()));
148        }
149
150        // Recursively process nested values
151        // "properties" contains a map of property names -> schemas
152        // "items" and "anyOf" contain schemas directly
153        for (key, value) in obj.iter_mut() {
154            if key == "properties" {
155                // properties is a map of property_name -> schema
156                if let serde_json::Value::Object(props) = value {
157                    for (_, prop_schema) in props.iter_mut() {
158                        adapt_schema_for_google_impl(prop_schema, true);
159                    }
160                }
161            } else if key == "items" {
162                // items is a schema
163                adapt_schema_for_google_impl(value, true);
164            } else if key == "anyOf" {
165                // anyOf is an array of schemas
166                if let serde_json::Value::Array(arr) = value {
167                    for item in arr.iter_mut() {
168                        adapt_schema_for_google_impl(item, true);
169                    }
170                }
171            }
172        }
173    } else if let serde_json::Value::Array(arr) = json {
174        for item in arr.iter_mut() {
175            adapt_schema_for_google_impl(item, true);
176        }
177    }
178}
179
180#[derive(Serialize)]
181#[serde(rename_all = "camelCase")]
182struct GoogleRequest {
183    contents: Vec<GoogleContent>,
184    #[serde(skip_serializing_if = "Option::is_none")]
185    system_instruction: Option<GoogleSystemInstruction>,
186    #[serde(skip_serializing_if = "Option::is_none")]
187    generation_config: Option<GoogleGenerationConfig>,
188    #[serde(skip_serializing_if = "Option::is_none")]
189    tools: Option<Vec<GoogleTool>>,
190    #[serde(skip_serializing_if = "Option::is_none")]
191    tool_config: Option<GoogleToolConfig>,
192}
193
194#[derive(Serialize)]
195#[serde(rename_all = "camelCase")]
196struct GoogleSystemInstruction {
197    parts: Vec<GooglePart>,
198}
199
200#[derive(Serialize, Deserialize, Debug, Clone)]
201#[serde(rename_all = "camelCase")]
202struct GoogleContent {
203    parts: Vec<GooglePart>,
204    #[serde(skip_serializing_if = "Option::is_none")]
205    role: Option<String>,
206}
207
208#[derive(Serialize, Deserialize, Debug, Clone)]
209#[serde(untagged)]
210enum GooglePart {
211    Text(GoogleTextPart),
212    InlineData(GoogleInlineDataPart),
213    FunctionCall(GoogleFunctionCallPart),
214    FunctionResponse(GoogleFunctionResponsePart),
215    Thought(GoogleThoughtPart),
216}
217
218#[derive(Serialize, Deserialize, Debug, Clone)]
219#[serde(rename_all = "camelCase")]
220struct GoogleTextPart {
221    text: String,
222}
223
224#[derive(Serialize, Deserialize, Debug, Clone)]
225#[serde(rename_all = "camelCase")]
226struct GoogleInlineDataPart {
227    inline_data: GoogleBlob,
228}
229
230#[derive(Serialize, Deserialize, Debug, Clone)]
231#[serde(rename_all = "camelCase")]
232struct GoogleBlob {
233    mime_type: String,
234    data: String,
235}
236
237#[derive(Serialize, Deserialize, Debug, Clone)]
238#[serde(rename_all = "camelCase")]
239struct GoogleFunctionCallPart {
240    function_call: GoogleFunctionCall,
241    #[serde(skip_serializing_if = "Option::is_none")]
242    thought_signature: Option<String>,
243}
244
245#[derive(Serialize, Deserialize, Debug, Clone)]
246#[serde(rename_all = "camelCase")]
247struct GoogleFunctionCall {
248    name: String,
249    args: serde_json::Value,
250}
251
252#[derive(Serialize, Deserialize, Debug, Clone)]
253#[serde(rename_all = "camelCase")]
254struct GoogleFunctionResponsePart {
255    function_response: GoogleFunctionResponse,
256}
257
258#[derive(Serialize, Deserialize, Debug, Clone)]
259#[serde(rename_all = "camelCase")]
260struct GoogleFunctionResponse {
261    name: String,
262    response: serde_json::Value,
263}
264
265#[derive(Serialize, Deserialize, Debug, Clone)]
266#[serde(rename_all = "camelCase")]
267struct GoogleThoughtPart {
268    thought: bool,
269    thought_signature: String,
270}
271
272#[derive(Serialize)]
273#[serde(rename_all = "camelCase")]
274struct GoogleGenerationConfig {
275    #[serde(skip_serializing_if = "Option::is_none")]
276    candidate_count: Option<usize>,
277    #[serde(skip_serializing_if = "Option::is_none")]
278    stop_sequences: Option<Vec<String>>,
279    #[serde(skip_serializing_if = "Option::is_none")]
280    max_output_tokens: Option<usize>,
281    #[serde(skip_serializing_if = "Option::is_none")]
282    temperature: Option<f64>,
283    #[serde(skip_serializing_if = "Option::is_none")]
284    thinking_config: Option<GoogleThinkingConfig>,
285}
286
287#[derive(Serialize)]
288#[serde(rename_all = "camelCase")]
289struct GoogleThinkingConfig {
290    thinking_budget: u32,
291}
292
293#[derive(Serialize)]
294#[serde(rename_all = "camelCase")]
295struct GoogleTool {
296    function_declarations: Vec<GoogleFunctionDeclaration>,
297}
298
299#[derive(Serialize)]
300#[serde(rename_all = "camelCase")]
301struct GoogleFunctionDeclaration {
302    name: String,
303    description: String,
304    parameters: serde_json::Value,
305}
306
307#[derive(Serialize)]
308#[serde(rename_all = "camelCase")]
309struct GoogleToolConfig {
310    function_calling_config: GoogleFunctionCallingConfig,
311}
312
313#[derive(Serialize)]
314#[serde(rename_all = "camelCase")]
315struct GoogleFunctionCallingConfig {
316    mode: String,
317    #[serde(skip_serializing_if = "Option::is_none")]
318    allowed_function_names: Option<Vec<String>>,
319}
320
321#[derive(Deserialize, Debug)]
322#[serde(rename_all = "camelCase")]
323struct GoogleStreamResponse {
324    #[serde(default)]
325    candidates: Vec<GoogleCandidate>,
326    #[serde(default)]
327    usage_metadata: Option<GoogleUsageMetadata>,
328}
329
330#[derive(Deserialize, Debug)]
331#[serde(rename_all = "camelCase")]
332struct GoogleCandidate {
333    #[serde(default)]
334    content: Option<GoogleContent>,
335    #[serde(default)]
336    finish_reason: Option<String>,
337}
338
339#[derive(Deserialize, Debug)]
340#[serde(rename_all = "camelCase")]
341struct GoogleUsageMetadata {
342    #[serde(default)]
343    prompt_token_count: u64,
344    #[serde(default)]
345    candidates_token_count: u64,
346}
347
348fn convert_request(
349    model_id: &str,
350    request: &LlmCompletionRequest,
351) -> Result<(GoogleRequest, String), String> {
352    let real_model_id =
353        get_real_model_id(model_id).ok_or_else(|| format!("Unknown model: {}", model_id))?;
354
355    let supports_thinking = get_model_supports_thinking(model_id);
356
357    let mut contents: Vec<GoogleContent> = Vec::new();
358    let mut system_parts: Vec<GooglePart> = Vec::new();
359
360    for msg in &request.messages {
361        match msg.role {
362            LlmMessageRole::System => {
363                for content in &msg.content {
364                    if let LlmMessageContent::Text(text) = content {
365                        if !text.is_empty() {
366                            system_parts
367                                .push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
368                        }
369                    }
370                }
371            }
372            LlmMessageRole::User => {
373                let mut parts: Vec<GooglePart> = Vec::new();
374
375                for content in &msg.content {
376                    match content {
377                        LlmMessageContent::Text(text) => {
378                            if !text.is_empty() {
379                                parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
380                            }
381                        }
382                        LlmMessageContent::Image(img) => {
383                            parts.push(GooglePart::InlineData(GoogleInlineDataPart {
384                                inline_data: GoogleBlob {
385                                    mime_type: "image/png".to_string(),
386                                    data: img.source.clone(),
387                                },
388                            }));
389                        }
390                        LlmMessageContent::ToolResult(result) => {
391                            let response_value = match &result.content {
392                                LlmToolResultContent::Text(t) => {
393                                    serde_json::json!({ "output": t })
394                                }
395                                LlmToolResultContent::Image(_) => {
396                                    serde_json::json!({ "output": "Tool responded with an image" })
397                                }
398                            };
399                            parts.push(GooglePart::FunctionResponse(GoogleFunctionResponsePart {
400                                function_response: GoogleFunctionResponse {
401                                    name: result.tool_name.clone(),
402                                    response: response_value,
403                                },
404                            }));
405                        }
406                        _ => {}
407                    }
408                }
409
410                if !parts.is_empty() {
411                    contents.push(GoogleContent {
412                        parts,
413                        role: Some("user".to_string()),
414                    });
415                }
416            }
417            LlmMessageRole::Assistant => {
418                let mut parts: Vec<GooglePart> = Vec::new();
419
420                for content in &msg.content {
421                    match content {
422                        LlmMessageContent::Text(text) => {
423                            if !text.is_empty() {
424                                parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
425                            }
426                        }
427                        LlmMessageContent::ToolUse(tool_use) => {
428                            let thought_signature =
429                                tool_use.thought_signature.clone().filter(|s| !s.is_empty());
430
431                            let args: serde_json::Value =
432                                serde_json::from_str(&tool_use.input).unwrap_or_default();
433
434                            parts.push(GooglePart::FunctionCall(GoogleFunctionCallPart {
435                                function_call: GoogleFunctionCall {
436                                    name: tool_use.name.clone(),
437                                    args,
438                                },
439                                thought_signature,
440                            }));
441                        }
442                        LlmMessageContent::Thinking(thinking) => {
443                            if let Some(ref signature) = thinking.signature {
444                                if !signature.is_empty() {
445                                    parts.push(GooglePart::Thought(GoogleThoughtPart {
446                                        thought: true,
447                                        thought_signature: signature.clone(),
448                                    }));
449                                }
450                            }
451                        }
452                        _ => {}
453                    }
454                }
455
456                if !parts.is_empty() {
457                    contents.push(GoogleContent {
458                        parts,
459                        role: Some("model".to_string()),
460                    });
461                }
462            }
463        }
464    }
465
466    let system_instruction = if system_parts.is_empty() {
467        None
468    } else {
469        Some(GoogleSystemInstruction {
470            parts: system_parts,
471        })
472    };
473
474    let tools: Option<Vec<GoogleTool>> = if request.tools.is_empty() {
475        None
476    } else {
477        let declarations: Vec<GoogleFunctionDeclaration> = request
478            .tools
479            .iter()
480            .map(|t| {
481                let mut parameters: serde_json::Value = serde_json::from_str(&t.input_schema)
482                    .unwrap_or(serde_json::Value::Object(Default::default()));
483                adapt_schema_for_google(&mut parameters);
484                GoogleFunctionDeclaration {
485                    name: t.name.clone(),
486                    description: t.description.clone(),
487                    parameters,
488                }
489            })
490            .collect();
491        Some(vec![GoogleTool {
492            function_declarations: declarations,
493        }])
494    };
495
496    let tool_config = request.tool_choice.as_ref().map(|tc| {
497        let mode = match tc {
498            LlmToolChoice::Auto => "AUTO",
499            LlmToolChoice::Any => "ANY",
500            LlmToolChoice::None => "NONE",
501        };
502        GoogleToolConfig {
503            function_calling_config: GoogleFunctionCallingConfig {
504                mode: mode.to_string(),
505                allowed_function_names: None,
506            },
507        }
508    });
509
510    let thinking_config = if supports_thinking && request.thinking_allowed {
511        Some(GoogleThinkingConfig {
512            thinking_budget: 8192,
513        })
514    } else {
515        None
516    };
517
518    let generation_config = Some(GoogleGenerationConfig {
519        candidate_count: Some(1),
520        stop_sequences: if request.stop_sequences.is_empty() {
521            None
522        } else {
523            Some(request.stop_sequences.clone())
524        },
525        max_output_tokens: None,
526        temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
527        thinking_config,
528    });
529
530    Ok((
531        GoogleRequest {
532            contents,
533            system_instruction,
534            generation_config,
535            tools,
536            tool_config,
537        },
538        real_model_id.to_string(),
539    ))
540}
541
542fn parse_stream_line(line: &str) -> Option<GoogleStreamResponse> {
543    let trimmed = line.trim();
544    if trimmed.is_empty() || trimmed == "[" || trimmed == "]" || trimmed == "," {
545        return None;
546    }
547
548    let json_str = trimmed.strip_prefix("data: ").unwrap_or(trimmed);
549    let json_str = json_str.trim_start_matches(',').trim();
550
551    if json_str.is_empty() {
552        return None;
553    }
554
555    serde_json::from_str(json_str).ok()
556}
557
558impl zed::Extension for GoogleAiProvider {
559    fn new() -> Self {
560        Self {
561            streams: Mutex::new(HashMap::new()),
562            next_stream_id: Mutex::new(0),
563        }
564    }
565
566    fn llm_providers(&self) -> Vec<LlmProviderInfo> {
567        vec![LlmProviderInfo {
568            id: "google-ai".into(),
569            name: "Google AI".into(),
570            icon: Some("icons/google-ai.svg".into()),
571        }]
572    }
573
574    fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
575        Ok(MODELS
576            .iter()
577            .map(|m| LlmModelInfo {
578                id: m.display_name.to_string(),
579                name: m.display_name.to_string(),
580                max_token_count: m.max_tokens,
581                max_output_tokens: m.max_output_tokens,
582                capabilities: LlmModelCapabilities {
583                    supports_images: m.supports_images,
584                    supports_tools: true,
585                    supports_tool_choice_auto: true,
586                    supports_tool_choice_any: true,
587                    supports_tool_choice_none: true,
588                    supports_thinking: m.supports_thinking,
589                    tool_input_format: LlmToolInputFormat::JsonSchema,
590                },
591                is_default: m.is_default,
592                is_default_fast: m.is_default_fast,
593            })
594            .collect())
595    }
596
597    fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
598        llm_get_credential("google-ai").is_some()
599    }
600
601    fn llm_provider_authenticate(&mut self, _provider_id: &str) -> Result<(), String> {
602        if llm_get_credential("google-ai").is_some() {
603            Ok(())
604        } else {
605            Err("No API key configured".to_string())
606        }
607    }
608
609    fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
610        Some(
611            "To use Google AI, you need an API key. You can create one [here](https://aistudio.google.com/apikey).".to_string(),
612        )
613    }
614
615    fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
616        llm_delete_credential("google-ai")
617    }
618
619    fn llm_stream_completion_start(
620        &mut self,
621        _provider_id: &str,
622        model_id: &str,
623        request: &LlmCompletionRequest,
624    ) -> Result<String, String> {
625        let api_key = llm_get_credential("google-ai").ok_or_else(|| {
626            "No API key configured. Please add your Google AI API key in settings.".to_string()
627        })?;
628
629        let (google_request, real_model_id) = convert_request(model_id, request)?;
630
631        let body = serde_json::to_vec(&google_request)
632            .map_err(|e| format!("Failed to serialize request: {}", e))?;
633
634        let url = format!(
635            "https://generativelanguage.googleapis.com/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
636            real_model_id, api_key
637        );
638
639        let http_request = HttpRequest {
640            method: HttpMethod::Post,
641            url,
642            headers: vec![("Content-Type".to_string(), "application/json".to_string())],
643            body: Some(body),
644            redirect_policy: RedirectPolicy::FollowAll,
645        };
646
647        let response_stream = http_request
648            .fetch_stream()
649            .map_err(|e| format!("HTTP request failed: {}", e))?;
650
651        let stream_id = {
652            let mut id_counter = self.next_stream_id.lock().unwrap();
653            let id = format!("google-ai-stream-{}", *id_counter);
654            *id_counter += 1;
655            id
656        };
657
658        self.streams.lock().unwrap().insert(
659            stream_id.clone(),
660            StreamState {
661                response_stream: Some(response_stream),
662                buffer: String::new(),
663                started: false,
664                stop_reason: None,
665                wants_tool_use: false,
666            },
667        );
668
669        Ok(stream_id)
670    }
671
672    fn llm_stream_completion_next(
673        &mut self,
674        stream_id: &str,
675    ) -> Result<Option<LlmCompletionEvent>, String> {
676        let mut streams = self.streams.lock().unwrap();
677        let state = streams
678            .get_mut(stream_id)
679            .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
680
681        if !state.started {
682            state.started = true;
683            return Ok(Some(LlmCompletionEvent::Started));
684        }
685
686        let response_stream = state
687            .response_stream
688            .as_mut()
689            .ok_or_else(|| "Stream already closed".to_string())?;
690
691        loop {
692            if let Some(newline_pos) = state.buffer.find('\n') {
693                let line = state.buffer[..newline_pos].to_string();
694                state.buffer = state.buffer[newline_pos + 1..].to_string();
695
696                if let Some(response) = parse_stream_line(&line) {
697                    for candidate in response.candidates {
698                        if let Some(finish_reason) = &candidate.finish_reason {
699                            state.stop_reason = Some(match finish_reason.as_str() {
700                                "STOP" => {
701                                    if state.wants_tool_use {
702                                        LlmStopReason::ToolUse
703                                    } else {
704                                        LlmStopReason::EndTurn
705                                    }
706                                }
707                                "MAX_TOKENS" => LlmStopReason::MaxTokens,
708                                "SAFETY" => LlmStopReason::Refusal,
709                                _ => LlmStopReason::EndTurn,
710                            });
711                        }
712
713                        if let Some(content) = candidate.content {
714                            for part in content.parts {
715                                match part {
716                                    GooglePart::Text(text_part) => {
717                                        if !text_part.text.is_empty() {
718                                            return Ok(Some(LlmCompletionEvent::Text(
719                                                text_part.text,
720                                            )));
721                                        }
722                                    }
723                                    GooglePart::FunctionCall(fc_part) => {
724                                        state.wants_tool_use = true;
725                                        let next_tool_id =
726                                            TOOL_CALL_COUNTER.fetch_add(1, Ordering::SeqCst);
727                                        let id = format!(
728                                            "{}-{}",
729                                            fc_part.function_call.name, next_tool_id
730                                        );
731
732                                        let thought_signature =
733                                            fc_part.thought_signature.filter(|s| !s.is_empty());
734
735                                        return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
736                                            id,
737                                            name: fc_part.function_call.name,
738                                            input: fc_part.function_call.args.to_string(),
739                                            thought_signature,
740                                        })));
741                                    }
742                                    GooglePart::Thought(thought_part) => {
743                                        return Ok(Some(LlmCompletionEvent::Thinking(
744                                            LlmThinkingContent {
745                                                text: "(Encrypted thought)".to_string(),
746                                                signature: Some(thought_part.thought_signature),
747                                            },
748                                        )));
749                                    }
750                                    _ => {}
751                                }
752                            }
753                        }
754                    }
755
756                    if let Some(usage) = response.usage_metadata {
757                        return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
758                            input_tokens: usage.prompt_token_count,
759                            output_tokens: usage.candidates_token_count,
760                            cache_creation_input_tokens: None,
761                            cache_read_input_tokens: None,
762                        })));
763                    }
764                }
765
766                continue;
767            }
768
769            match response_stream.next_chunk() {
770                Ok(Some(chunk)) => {
771                    let text = String::from_utf8_lossy(&chunk);
772                    state.buffer.push_str(&text);
773                }
774                Ok(None) => {
775                    // Stream ended - check if we have a stop reason
776                    if let Some(stop_reason) = state.stop_reason.take() {
777                        return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
778                    }
779
780                    // No stop reason - this is unexpected. Check if buffer contains error info
781                    let mut error_msg = String::from("Stream ended unexpectedly.");
782
783                    // Try to parse remaining buffer as potential error response
784                    if !state.buffer.is_empty() {
785                        error_msg.push_str(&format!(
786                            "\nRemaining buffer: {}",
787                            &state.buffer[..state.buffer.len().min(1000)]
788                        ));
789                    }
790
791                    return Err(error_msg);
792                }
793                Err(e) => {
794                    return Err(format!("Stream error: {}", e));
795                }
796            }
797        }
798    }
799
800    fn llm_stream_completion_close(&mut self, stream_id: &str) {
801        self.streams.lock().unwrap().remove(stream_id);
802    }
803}
804
805zed::register_extension!(GoogleAiProvider);