1use std::collections::HashMap;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::sync::Mutex;
4
5use serde::{Deserialize, Serialize};
6use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
7use zed_extension_api::{self as zed, *};
8
9static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
10
11struct GoogleAiProvider {
12 streams: Mutex<HashMap<String, StreamState>>,
13 next_stream_id: Mutex<u64>,
14}
15
16struct StreamState {
17 response_stream: Option<HttpResponseStream>,
18 buffer: String,
19 started: bool,
20 stop_reason: Option<LlmStopReason>,
21 wants_tool_use: bool,
22}
23
24struct ModelDefinition {
25 real_id: &'static str,
26 display_name: &'static str,
27 max_tokens: u64,
28 max_output_tokens: Option<u64>,
29 supports_images: bool,
30 supports_thinking: bool,
31 is_default: bool,
32 is_default_fast: bool,
33}
34
35const MODELS: &[ModelDefinition] = &[
36 ModelDefinition {
37 real_id: "gemini-2.5-flash-lite",
38 display_name: "Gemini 2.5 Flash-Lite",
39 max_tokens: 1_048_576,
40 max_output_tokens: Some(65_536),
41 supports_images: true,
42 supports_thinking: true,
43 is_default: false,
44 is_default_fast: true,
45 },
46 ModelDefinition {
47 real_id: "gemini-2.5-flash",
48 display_name: "Gemini 2.5 Flash",
49 max_tokens: 1_048_576,
50 max_output_tokens: Some(65_536),
51 supports_images: true,
52 supports_thinking: true,
53 is_default: true,
54 is_default_fast: false,
55 },
56 ModelDefinition {
57 real_id: "gemini-2.5-pro",
58 display_name: "Gemini 2.5 Pro",
59 max_tokens: 1_048_576,
60 max_output_tokens: Some(65_536),
61 supports_images: true,
62 supports_thinking: true,
63 is_default: false,
64 is_default_fast: false,
65 },
66 ModelDefinition {
67 real_id: "gemini-3-pro-preview",
68 display_name: "Gemini 3 Pro",
69 max_tokens: 1_048_576,
70 max_output_tokens: Some(65_536),
71 supports_images: true,
72 supports_thinking: true,
73 is_default: false,
74 is_default_fast: false,
75 },
76];
77
78fn get_real_model_id(display_name: &str) -> Option<&'static str> {
79 MODELS
80 .iter()
81 .find(|m| m.display_name == display_name)
82 .map(|m| m.real_id)
83}
84
85fn get_model_supports_thinking(display_name: &str) -> bool {
86 MODELS
87 .iter()
88 .find(|m| m.display_name == display_name)
89 .map(|m| m.supports_thinking)
90 .unwrap_or(false)
91}
92
93/// Adapts a JSON schema to be compatible with Google's API subset.
94/// Google only supports a specific subset of JSON Schema fields.
95/// See: https://ai.google.dev/api/caching#Schema
96fn adapt_schema_for_google(json: &mut serde_json::Value) {
97 adapt_schema_for_google_impl(json, true);
98}
99
100fn adapt_schema_for_google_impl(json: &mut serde_json::Value, is_schema: bool) {
101 if let serde_json::Value::Object(obj) = json {
102 // Google's Schema only supports these fields:
103 // type, format, title, description, nullable, enum, maxItems, minItems,
104 // properties, required, minProperties, maxProperties, minLength, maxLength,
105 // pattern, example, anyOf, propertyOrdering, default, items, minimum, maximum
106 const ALLOWED_KEYS: &[&str] = &[
107 "type",
108 "format",
109 "title",
110 "description",
111 "nullable",
112 "enum",
113 "maxItems",
114 "minItems",
115 "properties",
116 "required",
117 "minProperties",
118 "maxProperties",
119 "minLength",
120 "maxLength",
121 "pattern",
122 "example",
123 "anyOf",
124 "propertyOrdering",
125 "default",
126 "items",
127 "minimum",
128 "maximum",
129 ];
130
131 // Convert oneOf to anyOf before filtering keys
132 if let Some(one_of) = obj.remove("oneOf") {
133 obj.insert("anyOf".to_string(), one_of);
134 }
135
136 // If type is an array (e.g., ["string", "null"]), take just the first type
137 if let Some(type_field) = obj.get_mut("type") {
138 if let serde_json::Value::Array(types) = type_field {
139 if let Some(first_type) = types.first().cloned() {
140 *type_field = first_type;
141 }
142 }
143 }
144
145 // Only filter keys if this is a schema object, not a properties map
146 if is_schema {
147 obj.retain(|key, _| ALLOWED_KEYS.contains(&key.as_str()));
148 }
149
150 // Recursively process nested values
151 // "properties" contains a map of property names -> schemas
152 // "items" and "anyOf" contain schemas directly
153 for (key, value) in obj.iter_mut() {
154 if key == "properties" {
155 // properties is a map of property_name -> schema
156 if let serde_json::Value::Object(props) = value {
157 for (_, prop_schema) in props.iter_mut() {
158 adapt_schema_for_google_impl(prop_schema, true);
159 }
160 }
161 } else if key == "items" {
162 // items is a schema
163 adapt_schema_for_google_impl(value, true);
164 } else if key == "anyOf" {
165 // anyOf is an array of schemas
166 if let serde_json::Value::Array(arr) = value {
167 for item in arr.iter_mut() {
168 adapt_schema_for_google_impl(item, true);
169 }
170 }
171 }
172 }
173 } else if let serde_json::Value::Array(arr) = json {
174 for item in arr.iter_mut() {
175 adapt_schema_for_google_impl(item, true);
176 }
177 }
178}
179
180#[derive(Serialize)]
181#[serde(rename_all = "camelCase")]
182struct GoogleRequest {
183 contents: Vec<GoogleContent>,
184 #[serde(skip_serializing_if = "Option::is_none")]
185 system_instruction: Option<GoogleSystemInstruction>,
186 #[serde(skip_serializing_if = "Option::is_none")]
187 generation_config: Option<GoogleGenerationConfig>,
188 #[serde(skip_serializing_if = "Option::is_none")]
189 tools: Option<Vec<GoogleTool>>,
190 #[serde(skip_serializing_if = "Option::is_none")]
191 tool_config: Option<GoogleToolConfig>,
192}
193
194#[derive(Serialize)]
195#[serde(rename_all = "camelCase")]
196struct GoogleSystemInstruction {
197 parts: Vec<GooglePart>,
198}
199
200#[derive(Serialize, Deserialize, Debug, Clone)]
201#[serde(rename_all = "camelCase")]
202struct GoogleContent {
203 parts: Vec<GooglePart>,
204 #[serde(skip_serializing_if = "Option::is_none")]
205 role: Option<String>,
206}
207
208#[derive(Serialize, Deserialize, Debug, Clone)]
209#[serde(untagged)]
210enum GooglePart {
211 Text(GoogleTextPart),
212 InlineData(GoogleInlineDataPart),
213 FunctionCall(GoogleFunctionCallPart),
214 FunctionResponse(GoogleFunctionResponsePart),
215 Thought(GoogleThoughtPart),
216}
217
218#[derive(Serialize, Deserialize, Debug, Clone)]
219#[serde(rename_all = "camelCase")]
220struct GoogleTextPart {
221 text: String,
222}
223
224#[derive(Serialize, Deserialize, Debug, Clone)]
225#[serde(rename_all = "camelCase")]
226struct GoogleInlineDataPart {
227 inline_data: GoogleBlob,
228}
229
230#[derive(Serialize, Deserialize, Debug, Clone)]
231#[serde(rename_all = "camelCase")]
232struct GoogleBlob {
233 mime_type: String,
234 data: String,
235}
236
237#[derive(Serialize, Deserialize, Debug, Clone)]
238#[serde(rename_all = "camelCase")]
239struct GoogleFunctionCallPart {
240 function_call: GoogleFunctionCall,
241 #[serde(skip_serializing_if = "Option::is_none")]
242 thought_signature: Option<String>,
243}
244
245#[derive(Serialize, Deserialize, Debug, Clone)]
246#[serde(rename_all = "camelCase")]
247struct GoogleFunctionCall {
248 name: String,
249 args: serde_json::Value,
250}
251
252#[derive(Serialize, Deserialize, Debug, Clone)]
253#[serde(rename_all = "camelCase")]
254struct GoogleFunctionResponsePart {
255 function_response: GoogleFunctionResponse,
256}
257
258#[derive(Serialize, Deserialize, Debug, Clone)]
259#[serde(rename_all = "camelCase")]
260struct GoogleFunctionResponse {
261 name: String,
262 response: serde_json::Value,
263}
264
265#[derive(Serialize, Deserialize, Debug, Clone)]
266#[serde(rename_all = "camelCase")]
267struct GoogleThoughtPart {
268 thought: bool,
269 thought_signature: String,
270}
271
272#[derive(Serialize)]
273#[serde(rename_all = "camelCase")]
274struct GoogleGenerationConfig {
275 #[serde(skip_serializing_if = "Option::is_none")]
276 candidate_count: Option<usize>,
277 #[serde(skip_serializing_if = "Option::is_none")]
278 stop_sequences: Option<Vec<String>>,
279 #[serde(skip_serializing_if = "Option::is_none")]
280 max_output_tokens: Option<usize>,
281 #[serde(skip_serializing_if = "Option::is_none")]
282 temperature: Option<f64>,
283 #[serde(skip_serializing_if = "Option::is_none")]
284 thinking_config: Option<GoogleThinkingConfig>,
285}
286
287#[derive(Serialize)]
288#[serde(rename_all = "camelCase")]
289struct GoogleThinkingConfig {
290 thinking_budget: u32,
291}
292
293#[derive(Serialize)]
294#[serde(rename_all = "camelCase")]
295struct GoogleTool {
296 function_declarations: Vec<GoogleFunctionDeclaration>,
297}
298
299#[derive(Serialize)]
300#[serde(rename_all = "camelCase")]
301struct GoogleFunctionDeclaration {
302 name: String,
303 description: String,
304 parameters: serde_json::Value,
305}
306
307#[derive(Serialize)]
308#[serde(rename_all = "camelCase")]
309struct GoogleToolConfig {
310 function_calling_config: GoogleFunctionCallingConfig,
311}
312
313#[derive(Serialize)]
314#[serde(rename_all = "camelCase")]
315struct GoogleFunctionCallingConfig {
316 mode: String,
317 #[serde(skip_serializing_if = "Option::is_none")]
318 allowed_function_names: Option<Vec<String>>,
319}
320
321#[derive(Deserialize, Debug)]
322#[serde(rename_all = "camelCase")]
323struct GoogleStreamResponse {
324 #[serde(default)]
325 candidates: Vec<GoogleCandidate>,
326 #[serde(default)]
327 usage_metadata: Option<GoogleUsageMetadata>,
328}
329
330#[derive(Deserialize, Debug)]
331#[serde(rename_all = "camelCase")]
332struct GoogleCandidate {
333 #[serde(default)]
334 content: Option<GoogleContent>,
335 #[serde(default)]
336 finish_reason: Option<String>,
337}
338
339#[derive(Deserialize, Debug)]
340#[serde(rename_all = "camelCase")]
341struct GoogleUsageMetadata {
342 #[serde(default)]
343 prompt_token_count: u64,
344 #[serde(default)]
345 candidates_token_count: u64,
346}
347
348fn convert_request(
349 model_id: &str,
350 request: &LlmCompletionRequest,
351) -> Result<(GoogleRequest, String), String> {
352 let real_model_id =
353 get_real_model_id(model_id).ok_or_else(|| format!("Unknown model: {}", model_id))?;
354
355 let supports_thinking = get_model_supports_thinking(model_id);
356
357 let mut contents: Vec<GoogleContent> = Vec::new();
358 let mut system_parts: Vec<GooglePart> = Vec::new();
359
360 for msg in &request.messages {
361 match msg.role {
362 LlmMessageRole::System => {
363 for content in &msg.content {
364 if let LlmMessageContent::Text(text) = content {
365 if !text.is_empty() {
366 system_parts
367 .push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
368 }
369 }
370 }
371 }
372 LlmMessageRole::User => {
373 let mut parts: Vec<GooglePart> = Vec::new();
374
375 for content in &msg.content {
376 match content {
377 LlmMessageContent::Text(text) => {
378 if !text.is_empty() {
379 parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
380 }
381 }
382 LlmMessageContent::Image(img) => {
383 parts.push(GooglePart::InlineData(GoogleInlineDataPart {
384 inline_data: GoogleBlob {
385 mime_type: "image/png".to_string(),
386 data: img.source.clone(),
387 },
388 }));
389 }
390 LlmMessageContent::ToolResult(result) => {
391 let response_value = match &result.content {
392 LlmToolResultContent::Text(t) => {
393 serde_json::json!({ "output": t })
394 }
395 LlmToolResultContent::Image(_) => {
396 serde_json::json!({ "output": "Tool responded with an image" })
397 }
398 };
399 parts.push(GooglePart::FunctionResponse(GoogleFunctionResponsePart {
400 function_response: GoogleFunctionResponse {
401 name: result.tool_name.clone(),
402 response: response_value,
403 },
404 }));
405 }
406 _ => {}
407 }
408 }
409
410 if !parts.is_empty() {
411 contents.push(GoogleContent {
412 parts,
413 role: Some("user".to_string()),
414 });
415 }
416 }
417 LlmMessageRole::Assistant => {
418 let mut parts: Vec<GooglePart> = Vec::new();
419
420 for content in &msg.content {
421 match content {
422 LlmMessageContent::Text(text) => {
423 if !text.is_empty() {
424 parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
425 }
426 }
427 LlmMessageContent::ToolUse(tool_use) => {
428 let thought_signature =
429 tool_use.thought_signature.clone().filter(|s| !s.is_empty());
430
431 let args: serde_json::Value =
432 serde_json::from_str(&tool_use.input).unwrap_or_default();
433
434 parts.push(GooglePart::FunctionCall(GoogleFunctionCallPart {
435 function_call: GoogleFunctionCall {
436 name: tool_use.name.clone(),
437 args,
438 },
439 thought_signature,
440 }));
441 }
442 LlmMessageContent::Thinking(thinking) => {
443 if let Some(ref signature) = thinking.signature {
444 if !signature.is_empty() {
445 parts.push(GooglePart::Thought(GoogleThoughtPart {
446 thought: true,
447 thought_signature: signature.clone(),
448 }));
449 }
450 }
451 }
452 _ => {}
453 }
454 }
455
456 if !parts.is_empty() {
457 contents.push(GoogleContent {
458 parts,
459 role: Some("model".to_string()),
460 });
461 }
462 }
463 }
464 }
465
466 let system_instruction = if system_parts.is_empty() {
467 None
468 } else {
469 Some(GoogleSystemInstruction {
470 parts: system_parts,
471 })
472 };
473
474 let tools: Option<Vec<GoogleTool>> = if request.tools.is_empty() {
475 None
476 } else {
477 let declarations: Vec<GoogleFunctionDeclaration> = request
478 .tools
479 .iter()
480 .map(|t| {
481 let mut parameters: serde_json::Value = serde_json::from_str(&t.input_schema)
482 .unwrap_or(serde_json::Value::Object(Default::default()));
483 adapt_schema_for_google(&mut parameters);
484 GoogleFunctionDeclaration {
485 name: t.name.clone(),
486 description: t.description.clone(),
487 parameters,
488 }
489 })
490 .collect();
491 Some(vec![GoogleTool {
492 function_declarations: declarations,
493 }])
494 };
495
496 let tool_config = request.tool_choice.as_ref().map(|tc| {
497 let mode = match tc {
498 LlmToolChoice::Auto => "AUTO",
499 LlmToolChoice::Any => "ANY",
500 LlmToolChoice::None => "NONE",
501 };
502 GoogleToolConfig {
503 function_calling_config: GoogleFunctionCallingConfig {
504 mode: mode.to_string(),
505 allowed_function_names: None,
506 },
507 }
508 });
509
510 let thinking_config = if supports_thinking && request.thinking_allowed {
511 Some(GoogleThinkingConfig {
512 thinking_budget: 8192,
513 })
514 } else {
515 None
516 };
517
518 let generation_config = Some(GoogleGenerationConfig {
519 candidate_count: Some(1),
520 stop_sequences: if request.stop_sequences.is_empty() {
521 None
522 } else {
523 Some(request.stop_sequences.clone())
524 },
525 max_output_tokens: None,
526 temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
527 thinking_config,
528 });
529
530 Ok((
531 GoogleRequest {
532 contents,
533 system_instruction,
534 generation_config,
535 tools,
536 tool_config,
537 },
538 real_model_id.to_string(),
539 ))
540}
541
542fn parse_stream_line(line: &str) -> Option<GoogleStreamResponse> {
543 let trimmed = line.trim();
544 if trimmed.is_empty() || trimmed == "[" || trimmed == "]" || trimmed == "," {
545 return None;
546 }
547
548 let json_str = trimmed.strip_prefix("data: ").unwrap_or(trimmed);
549 let json_str = json_str.trim_start_matches(',').trim();
550
551 if json_str.is_empty() {
552 return None;
553 }
554
555 serde_json::from_str(json_str).ok()
556}
557
558impl zed::Extension for GoogleAiProvider {
559 fn new() -> Self {
560 Self {
561 streams: Mutex::new(HashMap::new()),
562 next_stream_id: Mutex::new(0),
563 }
564 }
565
566 fn llm_providers(&self) -> Vec<LlmProviderInfo> {
567 vec![LlmProviderInfo {
568 id: "google-ai".into(),
569 name: "Google AI".into(),
570 icon: Some("icons/google-ai.svg".into()),
571 }]
572 }
573
574 fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
575 Ok(MODELS
576 .iter()
577 .map(|m| LlmModelInfo {
578 id: m.display_name.to_string(),
579 name: m.display_name.to_string(),
580 max_token_count: m.max_tokens,
581 max_output_tokens: m.max_output_tokens,
582 capabilities: LlmModelCapabilities {
583 supports_images: m.supports_images,
584 supports_tools: true,
585 supports_tool_choice_auto: true,
586 supports_tool_choice_any: true,
587 supports_tool_choice_none: true,
588 supports_thinking: m.supports_thinking,
589 tool_input_format: LlmToolInputFormat::JsonSchema,
590 },
591 is_default: m.is_default,
592 is_default_fast: m.is_default_fast,
593 })
594 .collect())
595 }
596
597 fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
598 llm_get_credential("google-ai").is_some()
599 }
600
601 fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
602 Some(
603 r#"# Google AI Setup
604
605Welcome to **Google AI**! This extension provides access to Google Gemini models.
606
607## Configuration
608
609Enter your Google AI API key below. You can get your API key at [aistudio.google.com/apikey](https://aistudio.google.com/apikey).
610
611## Available Models
612
613| Display Name | Real Model | Context | Output |
614|--------------|------------|---------|--------|
615| Gemini 2.5 Flash-Lite | gemini-2.5-flash-lite | 1M | 65K |
616| Gemini 2.5 Flash | gemini-2.5-flash | 1M | 65K |
617| Gemini 2.5 Pro | gemini-2.5-pro | 1M | 65K |
618| Gemini 3 Pro | gemini-3-pro-preview | 1M | 65K |
619
620## Features
621
622- ✅ Full streaming support
623- ✅ Tool/function calling with thought signatures
624- ✅ Vision (image inputs)
625- ✅ Extended thinking support
626- ✅ All Gemini models
627
628## Pricing
629
630Uses your Google AI API credits. See [Google AI pricing](https://ai.google.dev/pricing) for details.
631"#
632 .to_string(),
633 )
634 }
635
636 fn llm_provider_authenticate(&mut self, _provider_id: &str) -> Result<(), String> {
637 let provided = llm_request_credential(
638 "google-ai",
639 LlmCredentialType::ApiKey,
640 "Google AI API Key",
641 "AIza...",
642 )?;
643 if provided {
644 Ok(())
645 } else {
646 Err("Authentication cancelled".to_string())
647 }
648 }
649
650 fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
651 llm_delete_credential("google-ai")
652 }
653
654 fn llm_stream_completion_start(
655 &mut self,
656 _provider_id: &str,
657 model_id: &str,
658 request: &LlmCompletionRequest,
659 ) -> Result<String, String> {
660 let api_key = llm_get_credential("google-ai").ok_or_else(|| {
661 "No API key configured. Please add your Google AI API key in settings.".to_string()
662 })?;
663
664 let (google_request, real_model_id) = convert_request(model_id, request)?;
665
666 let body = serde_json::to_vec(&google_request)
667 .map_err(|e| format!("Failed to serialize request: {}", e))?;
668
669 let url = format!(
670 "https://generativelanguage.googleapis.com/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
671 real_model_id, api_key
672 );
673
674 let http_request = HttpRequest {
675 method: HttpMethod::Post,
676 url,
677 headers: vec![("Content-Type".to_string(), "application/json".to_string())],
678 body: Some(body),
679 redirect_policy: RedirectPolicy::FollowAll,
680 };
681
682 let response_stream = http_request
683 .fetch_stream()
684 .map_err(|e| format!("HTTP request failed: {}", e))?;
685
686 let stream_id = {
687 let mut id_counter = self.next_stream_id.lock().unwrap();
688 let id = format!("google-ai-stream-{}", *id_counter);
689 *id_counter += 1;
690 id
691 };
692
693 self.streams.lock().unwrap().insert(
694 stream_id.clone(),
695 StreamState {
696 response_stream: Some(response_stream),
697 buffer: String::new(),
698 started: false,
699 stop_reason: None,
700 wants_tool_use: false,
701 },
702 );
703
704 Ok(stream_id)
705 }
706
707 fn llm_stream_completion_next(
708 &mut self,
709 stream_id: &str,
710 ) -> Result<Option<LlmCompletionEvent>, String> {
711 let mut streams = self.streams.lock().unwrap();
712 let state = streams
713 .get_mut(stream_id)
714 .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
715
716 if !state.started {
717 state.started = true;
718 return Ok(Some(LlmCompletionEvent::Started));
719 }
720
721 let response_stream = state
722 .response_stream
723 .as_mut()
724 .ok_or_else(|| "Stream already closed".to_string())?;
725
726 loop {
727 if let Some(newline_pos) = state.buffer.find('\n') {
728 let line = state.buffer[..newline_pos].to_string();
729 state.buffer = state.buffer[newline_pos + 1..].to_string();
730
731 if let Some(response) = parse_stream_line(&line) {
732 for candidate in response.candidates {
733 if let Some(finish_reason) = &candidate.finish_reason {
734 state.stop_reason = Some(match finish_reason.as_str() {
735 "STOP" => {
736 if state.wants_tool_use {
737 LlmStopReason::ToolUse
738 } else {
739 LlmStopReason::EndTurn
740 }
741 }
742 "MAX_TOKENS" => LlmStopReason::MaxTokens,
743 "SAFETY" => LlmStopReason::Refusal,
744 _ => LlmStopReason::EndTurn,
745 });
746 }
747
748 if let Some(content) = candidate.content {
749 for part in content.parts {
750 match part {
751 GooglePart::Text(text_part) => {
752 if !text_part.text.is_empty() {
753 return Ok(Some(LlmCompletionEvent::Text(
754 text_part.text,
755 )));
756 }
757 }
758 GooglePart::FunctionCall(fc_part) => {
759 state.wants_tool_use = true;
760 let next_tool_id =
761 TOOL_CALL_COUNTER.fetch_add(1, Ordering::SeqCst);
762 let id = format!(
763 "{}-{}",
764 fc_part.function_call.name, next_tool_id
765 );
766
767 let thought_signature =
768 fc_part.thought_signature.filter(|s| !s.is_empty());
769
770 return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
771 id,
772 name: fc_part.function_call.name,
773 input: fc_part.function_call.args.to_string(),
774 thought_signature,
775 })));
776 }
777 GooglePart::Thought(thought_part) => {
778 return Ok(Some(LlmCompletionEvent::Thinking(
779 LlmThinkingContent {
780 text: "(Encrypted thought)".to_string(),
781 signature: Some(thought_part.thought_signature),
782 },
783 )));
784 }
785 _ => {}
786 }
787 }
788 }
789 }
790
791 if let Some(usage) = response.usage_metadata {
792 return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
793 input_tokens: usage.prompt_token_count,
794 output_tokens: usage.candidates_token_count,
795 cache_creation_input_tokens: None,
796 cache_read_input_tokens: None,
797 })));
798 }
799 }
800
801 continue;
802 }
803
804 match response_stream.next_chunk() {
805 Ok(Some(chunk)) => {
806 let text = String::from_utf8_lossy(&chunk);
807 state.buffer.push_str(&text);
808 }
809 Ok(None) => {
810 // Stream ended - check if we have a stop reason
811 if let Some(stop_reason) = state.stop_reason.take() {
812 return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
813 }
814
815 // No stop reason - this is unexpected. Check if buffer contains error info
816 let mut error_msg = String::from("Stream ended unexpectedly.");
817
818 // Try to parse remaining buffer as potential error response
819 if !state.buffer.is_empty() {
820 error_msg.push_str(&format!(
821 "\nRemaining buffer: {}",
822 &state.buffer[..state.buffer.len().min(1000)]
823 ));
824 }
825
826 return Err(error_msg);
827 }
828 Err(e) => {
829 return Err(format!("Stream error: {}", e));
830 }
831 }
832 }
833 }
834
835 fn llm_stream_completion_close(&mut self, stream_id: &str) {
836 self.streams.lock().unwrap().remove(stream_id);
837 }
838}
839
840zed::register_extension!(GoogleAiProvider);