google_ai.rs

  1use std::collections::HashMap;
  2use std::sync::atomic::{AtomicU64, Ordering};
  3use std::sync::Mutex;
  4
  5use serde::{Deserialize, Serialize};
  6use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
  7use zed_extension_api::{self as zed, *};
  8
  9static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
 10
 11struct GoogleAiProvider {
 12    streams: Mutex<HashMap<String, StreamState>>,
 13    next_stream_id: Mutex<u64>,
 14}
 15
 16struct StreamState {
 17    response_stream: Option<HttpResponseStream>,
 18    buffer: String,
 19    started: bool,
 20    stop_reason: Option<LlmStopReason>,
 21    wants_tool_use: bool,
 22}
 23
 24struct ModelDefinition {
 25    real_id: &'static str,
 26    display_name: &'static str,
 27    max_tokens: u64,
 28    max_output_tokens: Option<u64>,
 29    supports_images: bool,
 30    supports_thinking: bool,
 31    is_default: bool,
 32    is_default_fast: bool,
 33}
 34
 35const MODELS: &[ModelDefinition] = &[
 36    ModelDefinition {
 37        real_id: "gemini-2.5-flash-lite",
 38        display_name: "Gemini 2.5 Flash-Lite",
 39        max_tokens: 1_048_576,
 40        max_output_tokens: Some(65_536),
 41        supports_images: true,
 42        supports_thinking: true,
 43        is_default: false,
 44        is_default_fast: true,
 45    },
 46    ModelDefinition {
 47        real_id: "gemini-2.5-flash",
 48        display_name: "Gemini 2.5 Flash",
 49        max_tokens: 1_048_576,
 50        max_output_tokens: Some(65_536),
 51        supports_images: true,
 52        supports_thinking: true,
 53        is_default: true,
 54        is_default_fast: false,
 55    },
 56    ModelDefinition {
 57        real_id: "gemini-2.5-pro",
 58        display_name: "Gemini 2.5 Pro",
 59        max_tokens: 1_048_576,
 60        max_output_tokens: Some(65_536),
 61        supports_images: true,
 62        supports_thinking: true,
 63        is_default: false,
 64        is_default_fast: false,
 65    },
 66    ModelDefinition {
 67        real_id: "gemini-3-pro-preview",
 68        display_name: "Gemini 3 Pro",
 69        max_tokens: 1_048_576,
 70        max_output_tokens: Some(65_536),
 71        supports_images: true,
 72        supports_thinking: true,
 73        is_default: false,
 74        is_default_fast: false,
 75    },
 76];
 77
 78fn get_real_model_id(display_name: &str) -> Option<&'static str> {
 79    MODELS
 80        .iter()
 81        .find(|m| m.display_name == display_name)
 82        .map(|m| m.real_id)
 83}
 84
 85fn get_model_supports_thinking(display_name: &str) -> bool {
 86    MODELS
 87        .iter()
 88        .find(|m| m.display_name == display_name)
 89        .map(|m| m.supports_thinking)
 90        .unwrap_or(false)
 91}
 92
 93/// Adapts a JSON schema to be compatible with Google's API subset.
 94/// Google only supports a specific subset of JSON Schema fields.
 95/// See: https://ai.google.dev/api/caching#Schema
 96fn adapt_schema_for_google(json: &mut serde_json::Value) {
 97    adapt_schema_for_google_impl(json, true);
 98}
 99
100fn adapt_schema_for_google_impl(json: &mut serde_json::Value, is_schema: bool) {
101    if let serde_json::Value::Object(obj) = json {
102        // Google's Schema only supports these fields:
103        // type, format, title, description, nullable, enum, maxItems, minItems,
104        // properties, required, minProperties, maxProperties, minLength, maxLength,
105        // pattern, example, anyOf, propertyOrdering, default, items, minimum, maximum
106        const ALLOWED_KEYS: &[&str] = &[
107            "type",
108            "format",
109            "title",
110            "description",
111            "nullable",
112            "enum",
113            "maxItems",
114            "minItems",
115            "properties",
116            "required",
117            "minProperties",
118            "maxProperties",
119            "minLength",
120            "maxLength",
121            "pattern",
122            "example",
123            "anyOf",
124            "propertyOrdering",
125            "default",
126            "items",
127            "minimum",
128            "maximum",
129        ];
130
131        // Convert oneOf to anyOf before filtering keys
132        if let Some(one_of) = obj.remove("oneOf") {
133            obj.insert("anyOf".to_string(), one_of);
134        }
135
136        // If type is an array (e.g., ["string", "null"]), take just the first type
137        if let Some(type_field) = obj.get_mut("type") {
138            if let serde_json::Value::Array(types) = type_field {
139                if let Some(first_type) = types.first().cloned() {
140                    *type_field = first_type;
141                }
142            }
143        }
144
145        // Only filter keys if this is a schema object, not a properties map
146        if is_schema {
147            obj.retain(|key, _| ALLOWED_KEYS.contains(&key.as_str()));
148        }
149
150        // Recursively process nested values
151        // "properties" contains a map of property names -> schemas
152        // "items" and "anyOf" contain schemas directly
153        for (key, value) in obj.iter_mut() {
154            if key == "properties" {
155                // properties is a map of property_name -> schema
156                if let serde_json::Value::Object(props) = value {
157                    for (_, prop_schema) in props.iter_mut() {
158                        adapt_schema_for_google_impl(prop_schema, true);
159                    }
160                }
161            } else if key == "items" {
162                // items is a schema
163                adapt_schema_for_google_impl(value, true);
164            } else if key == "anyOf" {
165                // anyOf is an array of schemas
166                if let serde_json::Value::Array(arr) = value {
167                    for item in arr.iter_mut() {
168                        adapt_schema_for_google_impl(item, true);
169                    }
170                }
171            }
172        }
173    } else if let serde_json::Value::Array(arr) = json {
174        for item in arr.iter_mut() {
175            adapt_schema_for_google_impl(item, true);
176        }
177    }
178}
179
180#[derive(Serialize)]
181#[serde(rename_all = "camelCase")]
182struct GoogleRequest {
183    contents: Vec<GoogleContent>,
184    #[serde(skip_serializing_if = "Option::is_none")]
185    system_instruction: Option<GoogleSystemInstruction>,
186    #[serde(skip_serializing_if = "Option::is_none")]
187    generation_config: Option<GoogleGenerationConfig>,
188    #[serde(skip_serializing_if = "Option::is_none")]
189    tools: Option<Vec<GoogleTool>>,
190    #[serde(skip_serializing_if = "Option::is_none")]
191    tool_config: Option<GoogleToolConfig>,
192}
193
194#[derive(Serialize)]
195#[serde(rename_all = "camelCase")]
196struct GoogleSystemInstruction {
197    parts: Vec<GooglePart>,
198}
199
200#[derive(Serialize, Deserialize, Debug, Clone)]
201#[serde(rename_all = "camelCase")]
202struct GoogleContent {
203    parts: Vec<GooglePart>,
204    #[serde(skip_serializing_if = "Option::is_none")]
205    role: Option<String>,
206}
207
208#[derive(Serialize, Deserialize, Debug, Clone)]
209#[serde(untagged)]
210enum GooglePart {
211    Text(GoogleTextPart),
212    InlineData(GoogleInlineDataPart),
213    FunctionCall(GoogleFunctionCallPart),
214    FunctionResponse(GoogleFunctionResponsePart),
215    Thought(GoogleThoughtPart),
216}
217
218#[derive(Serialize, Deserialize, Debug, Clone)]
219#[serde(rename_all = "camelCase")]
220struct GoogleTextPart {
221    text: String,
222}
223
224#[derive(Serialize, Deserialize, Debug, Clone)]
225#[serde(rename_all = "camelCase")]
226struct GoogleInlineDataPart {
227    inline_data: GoogleBlob,
228}
229
230#[derive(Serialize, Deserialize, Debug, Clone)]
231#[serde(rename_all = "camelCase")]
232struct GoogleBlob {
233    mime_type: String,
234    data: String,
235}
236
237#[derive(Serialize, Deserialize, Debug, Clone)]
238#[serde(rename_all = "camelCase")]
239struct GoogleFunctionCallPart {
240    function_call: GoogleFunctionCall,
241    #[serde(skip_serializing_if = "Option::is_none")]
242    thought_signature: Option<String>,
243}
244
245#[derive(Serialize, Deserialize, Debug, Clone)]
246#[serde(rename_all = "camelCase")]
247struct GoogleFunctionCall {
248    name: String,
249    args: serde_json::Value,
250}
251
252#[derive(Serialize, Deserialize, Debug, Clone)]
253#[serde(rename_all = "camelCase")]
254struct GoogleFunctionResponsePart {
255    function_response: GoogleFunctionResponse,
256}
257
258#[derive(Serialize, Deserialize, Debug, Clone)]
259#[serde(rename_all = "camelCase")]
260struct GoogleFunctionResponse {
261    name: String,
262    response: serde_json::Value,
263}
264
265#[derive(Serialize, Deserialize, Debug, Clone)]
266#[serde(rename_all = "camelCase")]
267struct GoogleThoughtPart {
268    thought: bool,
269    thought_signature: String,
270}
271
272#[derive(Serialize)]
273#[serde(rename_all = "camelCase")]
274struct GoogleGenerationConfig {
275    #[serde(skip_serializing_if = "Option::is_none")]
276    candidate_count: Option<usize>,
277    #[serde(skip_serializing_if = "Option::is_none")]
278    stop_sequences: Option<Vec<String>>,
279    #[serde(skip_serializing_if = "Option::is_none")]
280    max_output_tokens: Option<usize>,
281    #[serde(skip_serializing_if = "Option::is_none")]
282    temperature: Option<f64>,
283    #[serde(skip_serializing_if = "Option::is_none")]
284    thinking_config: Option<GoogleThinkingConfig>,
285}
286
287#[derive(Serialize)]
288#[serde(rename_all = "camelCase")]
289struct GoogleThinkingConfig {
290    thinking_budget: u32,
291}
292
293#[derive(Serialize)]
294#[serde(rename_all = "camelCase")]
295struct GoogleTool {
296    function_declarations: Vec<GoogleFunctionDeclaration>,
297}
298
299#[derive(Serialize)]
300#[serde(rename_all = "camelCase")]
301struct GoogleFunctionDeclaration {
302    name: String,
303    description: String,
304    parameters: serde_json::Value,
305}
306
307#[derive(Serialize)]
308#[serde(rename_all = "camelCase")]
309struct GoogleToolConfig {
310    function_calling_config: GoogleFunctionCallingConfig,
311}
312
313#[derive(Serialize)]
314#[serde(rename_all = "camelCase")]
315struct GoogleFunctionCallingConfig {
316    mode: String,
317    #[serde(skip_serializing_if = "Option::is_none")]
318    allowed_function_names: Option<Vec<String>>,
319}
320
321#[derive(Deserialize, Debug)]
322#[serde(rename_all = "camelCase")]
323struct GoogleStreamResponse {
324    #[serde(default)]
325    candidates: Vec<GoogleCandidate>,
326    #[serde(default)]
327    usage_metadata: Option<GoogleUsageMetadata>,
328}
329
330#[derive(Deserialize, Debug)]
331#[serde(rename_all = "camelCase")]
332struct GoogleCandidate {
333    #[serde(default)]
334    content: Option<GoogleContent>,
335    #[serde(default)]
336    finish_reason: Option<String>,
337}
338
339#[derive(Deserialize, Debug)]
340#[serde(rename_all = "camelCase")]
341struct GoogleUsageMetadata {
342    #[serde(default)]
343    prompt_token_count: u64,
344    #[serde(default)]
345    candidates_token_count: u64,
346}
347
348fn convert_request(
349    model_id: &str,
350    request: &LlmCompletionRequest,
351) -> Result<(GoogleRequest, String), String> {
352    let real_model_id =
353        get_real_model_id(model_id).ok_or_else(|| format!("Unknown model: {}", model_id))?;
354
355    let supports_thinking = get_model_supports_thinking(model_id);
356
357    let mut contents: Vec<GoogleContent> = Vec::new();
358    let mut system_parts: Vec<GooglePart> = Vec::new();
359
360    for msg in &request.messages {
361        match msg.role {
362            LlmMessageRole::System => {
363                for content in &msg.content {
364                    if let LlmMessageContent::Text(text) = content {
365                        if !text.is_empty() {
366                            system_parts
367                                .push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
368                        }
369                    }
370                }
371            }
372            LlmMessageRole::User => {
373                let mut parts: Vec<GooglePart> = Vec::new();
374
375                for content in &msg.content {
376                    match content {
377                        LlmMessageContent::Text(text) => {
378                            if !text.is_empty() {
379                                parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
380                            }
381                        }
382                        LlmMessageContent::Image(img) => {
383                            parts.push(GooglePart::InlineData(GoogleInlineDataPart {
384                                inline_data: GoogleBlob {
385                                    mime_type: "image/png".to_string(),
386                                    data: img.source.clone(),
387                                },
388                            }));
389                        }
390                        LlmMessageContent::ToolResult(result) => {
391                            let response_value = match &result.content {
392                                LlmToolResultContent::Text(t) => {
393                                    serde_json::json!({ "output": t })
394                                }
395                                LlmToolResultContent::Image(_) => {
396                                    serde_json::json!({ "output": "Tool responded with an image" })
397                                }
398                            };
399                            parts.push(GooglePart::FunctionResponse(GoogleFunctionResponsePart {
400                                function_response: GoogleFunctionResponse {
401                                    name: result.tool_name.clone(),
402                                    response: response_value,
403                                },
404                            }));
405                        }
406                        _ => {}
407                    }
408                }
409
410                if !parts.is_empty() {
411                    contents.push(GoogleContent {
412                        parts,
413                        role: Some("user".to_string()),
414                    });
415                }
416            }
417            LlmMessageRole::Assistant => {
418                let mut parts: Vec<GooglePart> = Vec::new();
419
420                for content in &msg.content {
421                    match content {
422                        LlmMessageContent::Text(text) => {
423                            if !text.is_empty() {
424                                parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
425                            }
426                        }
427                        LlmMessageContent::ToolUse(tool_use) => {
428                            let thought_signature =
429                                tool_use.thought_signature.clone().filter(|s| !s.is_empty());
430
431                            let args: serde_json::Value =
432                                serde_json::from_str(&tool_use.input).unwrap_or_default();
433
434                            parts.push(GooglePart::FunctionCall(GoogleFunctionCallPart {
435                                function_call: GoogleFunctionCall {
436                                    name: tool_use.name.clone(),
437                                    args,
438                                },
439                                thought_signature,
440                            }));
441                        }
442                        LlmMessageContent::Thinking(thinking) => {
443                            if let Some(ref signature) = thinking.signature {
444                                if !signature.is_empty() {
445                                    parts.push(GooglePart::Thought(GoogleThoughtPart {
446                                        thought: true,
447                                        thought_signature: signature.clone(),
448                                    }));
449                                }
450                            }
451                        }
452                        _ => {}
453                    }
454                }
455
456                if !parts.is_empty() {
457                    contents.push(GoogleContent {
458                        parts,
459                        role: Some("model".to_string()),
460                    });
461                }
462            }
463        }
464    }
465
466    let system_instruction = if system_parts.is_empty() {
467        None
468    } else {
469        Some(GoogleSystemInstruction {
470            parts: system_parts,
471        })
472    };
473
474    let tools: Option<Vec<GoogleTool>> = if request.tools.is_empty() {
475        None
476    } else {
477        let declarations: Vec<GoogleFunctionDeclaration> = request
478            .tools
479            .iter()
480            .map(|t| {
481                let mut parameters: serde_json::Value = serde_json::from_str(&t.input_schema)
482                    .unwrap_or(serde_json::Value::Object(Default::default()));
483                adapt_schema_for_google(&mut parameters);
484                GoogleFunctionDeclaration {
485                    name: t.name.clone(),
486                    description: t.description.clone(),
487                    parameters,
488                }
489            })
490            .collect();
491        Some(vec![GoogleTool {
492            function_declarations: declarations,
493        }])
494    };
495
496    let tool_config = request.tool_choice.as_ref().map(|tc| {
497        let mode = match tc {
498            LlmToolChoice::Auto => "AUTO",
499            LlmToolChoice::Any => "ANY",
500            LlmToolChoice::None => "NONE",
501        };
502        GoogleToolConfig {
503            function_calling_config: GoogleFunctionCallingConfig {
504                mode: mode.to_string(),
505                allowed_function_names: None,
506            },
507        }
508    });
509
510    let thinking_config = if supports_thinking && request.thinking_allowed {
511        Some(GoogleThinkingConfig {
512            thinking_budget: 8192,
513        })
514    } else {
515        None
516    };
517
518    let generation_config = Some(GoogleGenerationConfig {
519        candidate_count: Some(1),
520        stop_sequences: if request.stop_sequences.is_empty() {
521            None
522        } else {
523            Some(request.stop_sequences.clone())
524        },
525        max_output_tokens: None,
526        temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
527        thinking_config,
528    });
529
530    Ok((
531        GoogleRequest {
532            contents,
533            system_instruction,
534            generation_config,
535            tools,
536            tool_config,
537        },
538        real_model_id.to_string(),
539    ))
540}
541
542fn parse_stream_line(line: &str) -> Option<GoogleStreamResponse> {
543    let trimmed = line.trim();
544    if trimmed.is_empty() || trimmed == "[" || trimmed == "]" || trimmed == "," {
545        return None;
546    }
547
548    let json_str = trimmed.strip_prefix("data: ").unwrap_or(trimmed);
549    let json_str = json_str.trim_start_matches(',').trim();
550
551    if json_str.is_empty() {
552        return None;
553    }
554
555    serde_json::from_str(json_str).ok()
556}
557
558impl zed::Extension for GoogleAiProvider {
559    fn new() -> Self {
560        Self {
561            streams: Mutex::new(HashMap::new()),
562            next_stream_id: Mutex::new(0),
563        }
564    }
565
566    fn llm_providers(&self) -> Vec<LlmProviderInfo> {
567        vec![LlmProviderInfo {
568            id: "google-ai".into(),
569            name: "Google AI".into(),
570            icon: Some("icons/google-ai.svg".into()),
571        }]
572    }
573
574    fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
575        Ok(MODELS
576            .iter()
577            .map(|m| LlmModelInfo {
578                id: m.display_name.to_string(),
579                name: m.display_name.to_string(),
580                max_token_count: m.max_tokens,
581                max_output_tokens: m.max_output_tokens,
582                capabilities: LlmModelCapabilities {
583                    supports_images: m.supports_images,
584                    supports_tools: true,
585                    supports_tool_choice_auto: true,
586                    supports_tool_choice_any: true,
587                    supports_tool_choice_none: true,
588                    supports_thinking: m.supports_thinking,
589                    tool_input_format: LlmToolInputFormat::JsonSchema,
590                },
591                is_default: m.is_default,
592                is_default_fast: m.is_default_fast,
593            })
594            .collect())
595    }
596
597    fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
598        llm_get_credential("google-ai").is_some()
599    }
600
601    fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
602        Some(
603            "[Create an API key](https://aistudio.google.com/apikey) to use Google AI as your LLM provider.".to_string(),
604        )
605    }
606
607    fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
608        llm_delete_credential("google-ai")
609    }
610
611    fn llm_stream_completion_start(
612        &mut self,
613        _provider_id: &str,
614        model_id: &str,
615        request: &LlmCompletionRequest,
616    ) -> Result<String, String> {
617        let api_key = llm_get_credential("google-ai").ok_or_else(|| {
618            "No API key configured. Please add your Google AI API key in settings.".to_string()
619        })?;
620
621        let (google_request, real_model_id) = convert_request(model_id, request)?;
622
623        let body = serde_json::to_vec(&google_request)
624            .map_err(|e| format!("Failed to serialize request: {}", e))?;
625
626        let url = format!(
627            "https://generativelanguage.googleapis.com/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
628            real_model_id, api_key
629        );
630
631        let http_request = HttpRequest {
632            method: HttpMethod::Post,
633            url,
634            headers: vec![("Content-Type".to_string(), "application/json".to_string())],
635            body: Some(body),
636            redirect_policy: RedirectPolicy::FollowAll,
637        };
638
639        let response_stream = http_request
640            .fetch_stream()
641            .map_err(|e| format!("HTTP request failed: {}", e))?;
642
643        let stream_id = {
644            let mut id_counter = self.next_stream_id.lock().unwrap();
645            let id = format!("google-ai-stream-{}", *id_counter);
646            *id_counter += 1;
647            id
648        };
649
650        self.streams.lock().unwrap().insert(
651            stream_id.clone(),
652            StreamState {
653                response_stream: Some(response_stream),
654                buffer: String::new(),
655                started: false,
656                stop_reason: None,
657                wants_tool_use: false,
658            },
659        );
660
661        Ok(stream_id)
662    }
663
664    fn llm_stream_completion_next(
665        &mut self,
666        stream_id: &str,
667    ) -> Result<Option<LlmCompletionEvent>, String> {
668        let mut streams = self.streams.lock().unwrap();
669        let state = streams
670            .get_mut(stream_id)
671            .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
672
673        if !state.started {
674            state.started = true;
675            return Ok(Some(LlmCompletionEvent::Started));
676        }
677
678        let response_stream = state
679            .response_stream
680            .as_mut()
681            .ok_or_else(|| "Stream already closed".to_string())?;
682
683        loop {
684            if let Some(newline_pos) = state.buffer.find('\n') {
685                let line = state.buffer[..newline_pos].to_string();
686                state.buffer = state.buffer[newline_pos + 1..].to_string();
687
688                if let Some(response) = parse_stream_line(&line) {
689                    for candidate in response.candidates {
690                        if let Some(finish_reason) = &candidate.finish_reason {
691                            state.stop_reason = Some(match finish_reason.as_str() {
692                                "STOP" => {
693                                    if state.wants_tool_use {
694                                        LlmStopReason::ToolUse
695                                    } else {
696                                        LlmStopReason::EndTurn
697                                    }
698                                }
699                                "MAX_TOKENS" => LlmStopReason::MaxTokens,
700                                "SAFETY" => LlmStopReason::Refusal,
701                                _ => LlmStopReason::EndTurn,
702                            });
703                        }
704
705                        if let Some(content) = candidate.content {
706                            for part in content.parts {
707                                match part {
708                                    GooglePart::Text(text_part) => {
709                                        if !text_part.text.is_empty() {
710                                            return Ok(Some(LlmCompletionEvent::Text(
711                                                text_part.text,
712                                            )));
713                                        }
714                                    }
715                                    GooglePart::FunctionCall(fc_part) => {
716                                        state.wants_tool_use = true;
717                                        let next_tool_id =
718                                            TOOL_CALL_COUNTER.fetch_add(1, Ordering::SeqCst);
719                                        let id = format!(
720                                            "{}-{}",
721                                            fc_part.function_call.name, next_tool_id
722                                        );
723
724                                        let thought_signature =
725                                            fc_part.thought_signature.filter(|s| !s.is_empty());
726
727                                        return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
728                                            id,
729                                            name: fc_part.function_call.name,
730                                            input: fc_part.function_call.args.to_string(),
731                                            is_input_complete: true,
732                                            thought_signature,
733                                        })));
734                                    }
735                                    GooglePart::Thought(thought_part) => {
736                                        return Ok(Some(LlmCompletionEvent::Thinking(
737                                            LlmThinkingContent {
738                                                text: "(Encrypted thought)".to_string(),
739                                                signature: Some(thought_part.thought_signature),
740                                            },
741                                        )));
742                                    }
743                                    _ => {}
744                                }
745                            }
746                        }
747                    }
748
749                    if let Some(usage) = response.usage_metadata {
750                        return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
751                            input_tokens: usage.prompt_token_count,
752                            output_tokens: usage.candidates_token_count,
753                            cache_creation_input_tokens: None,
754                            cache_read_input_tokens: None,
755                        })));
756                    }
757                }
758
759                continue;
760            }
761
762            match response_stream.next_chunk() {
763                Ok(Some(chunk)) => {
764                    let text = String::from_utf8_lossy(&chunk);
765                    state.buffer.push_str(&text);
766                }
767                Ok(None) => {
768                    // Stream ended - check if we have a stop reason
769                    if let Some(stop_reason) = state.stop_reason.take() {
770                        return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
771                    }
772
773                    // No stop reason - this is unexpected. Check if buffer contains error info
774                    let mut error_msg = String::from("Stream ended unexpectedly.");
775
776                    // Try to parse remaining buffer as potential error response
777                    if !state.buffer.is_empty() {
778                        error_msg.push_str(&format!(
779                            "\nRemaining buffer: {}",
780                            &state.buffer[..state.buffer.len().min(1000)]
781                        ));
782                    }
783
784                    return Err(error_msg);
785                }
786                Err(e) => {
787                    return Err(format!("Stream error: {}", e));
788                }
789            }
790        }
791    }
792
793    fn llm_stream_completion_close(&mut self, stream_id: &str) {
794        self.streams.lock().unwrap().remove(stream_id);
795    }
796}
797
798zed::register_extension!(GoogleAiProvider);