1use std::collections::HashMap;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::sync::Mutex;
4
5use serde::{Deserialize, Serialize};
6use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
7use zed_extension_api::{self as zed, *};
8
9static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
10
11struct GoogleAiProvider {
12 streams: Mutex<HashMap<String, StreamState>>,
13 next_stream_id: Mutex<u64>,
14}
15
16struct StreamState {
17 response_stream: Option<HttpResponseStream>,
18 buffer: String,
19 started: bool,
20 stop_reason: Option<LlmStopReason>,
21 wants_tool_use: bool,
22}
23
24struct ModelDefinition {
25 real_id: &'static str,
26 display_name: &'static str,
27 max_tokens: u64,
28 max_output_tokens: Option<u64>,
29 supports_images: bool,
30 supports_thinking: bool,
31 is_default: bool,
32 is_default_fast: bool,
33}
34
35const MODELS: &[ModelDefinition] = &[
36 ModelDefinition {
37 real_id: "gemini-2.5-flash-lite",
38 display_name: "Gemini 2.5 Flash-Lite",
39 max_tokens: 1_048_576,
40 max_output_tokens: Some(65_536),
41 supports_images: true,
42 supports_thinking: true,
43 is_default: false,
44 is_default_fast: true,
45 },
46 ModelDefinition {
47 real_id: "gemini-2.5-flash",
48 display_name: "Gemini 2.5 Flash",
49 max_tokens: 1_048_576,
50 max_output_tokens: Some(65_536),
51 supports_images: true,
52 supports_thinking: true,
53 is_default: true,
54 is_default_fast: false,
55 },
56 ModelDefinition {
57 real_id: "gemini-2.5-pro",
58 display_name: "Gemini 2.5 Pro",
59 max_tokens: 1_048_576,
60 max_output_tokens: Some(65_536),
61 supports_images: true,
62 supports_thinking: true,
63 is_default: false,
64 is_default_fast: false,
65 },
66 ModelDefinition {
67 real_id: "gemini-3-pro-preview",
68 display_name: "Gemini 3 Pro",
69 max_tokens: 1_048_576,
70 max_output_tokens: Some(65_536),
71 supports_images: true,
72 supports_thinking: true,
73 is_default: false,
74 is_default_fast: false,
75 },
76];
77
78fn get_real_model_id(display_name: &str) -> Option<&'static str> {
79 MODELS
80 .iter()
81 .find(|m| m.display_name == display_name)
82 .map(|m| m.real_id)
83}
84
85fn get_model_supports_thinking(display_name: &str) -> bool {
86 MODELS
87 .iter()
88 .find(|m| m.display_name == display_name)
89 .map(|m| m.supports_thinking)
90 .unwrap_or(false)
91}
92
93/// Adapts a JSON schema to be compatible with Google's API subset.
94/// Google only supports a specific subset of JSON Schema fields.
95/// See: https://ai.google.dev/api/caching#Schema
96fn adapt_schema_for_google(json: &mut serde_json::Value) {
97 adapt_schema_for_google_impl(json, true);
98}
99
100fn adapt_schema_for_google_impl(json: &mut serde_json::Value, is_schema: bool) {
101 if let serde_json::Value::Object(obj) = json {
102 // Google's Schema only supports these fields:
103 // type, format, title, description, nullable, enum, maxItems, minItems,
104 // properties, required, minProperties, maxProperties, minLength, maxLength,
105 // pattern, example, anyOf, propertyOrdering, default, items, minimum, maximum
106 const ALLOWED_KEYS: &[&str] = &[
107 "type",
108 "format",
109 "title",
110 "description",
111 "nullable",
112 "enum",
113 "maxItems",
114 "minItems",
115 "properties",
116 "required",
117 "minProperties",
118 "maxProperties",
119 "minLength",
120 "maxLength",
121 "pattern",
122 "example",
123 "anyOf",
124 "propertyOrdering",
125 "default",
126 "items",
127 "minimum",
128 "maximum",
129 ];
130
131 // Convert oneOf to anyOf before filtering keys
132 if let Some(one_of) = obj.remove("oneOf") {
133 obj.insert("anyOf".to_string(), one_of);
134 }
135
136 // If type is an array (e.g., ["string", "null"]), take just the first type
137 if let Some(type_field) = obj.get_mut("type") {
138 if let serde_json::Value::Array(types) = type_field {
139 if let Some(first_type) = types.first().cloned() {
140 *type_field = first_type;
141 }
142 }
143 }
144
145 // Only filter keys if this is a schema object, not a properties map
146 if is_schema {
147 obj.retain(|key, _| ALLOWED_KEYS.contains(&key.as_str()));
148 }
149
150 // Recursively process nested values
151 // "properties" contains a map of property names -> schemas
152 // "items" and "anyOf" contain schemas directly
153 for (key, value) in obj.iter_mut() {
154 if key == "properties" {
155 // properties is a map of property_name -> schema
156 if let serde_json::Value::Object(props) = value {
157 for (_, prop_schema) in props.iter_mut() {
158 adapt_schema_for_google_impl(prop_schema, true);
159 }
160 }
161 } else if key == "items" {
162 // items is a schema
163 adapt_schema_for_google_impl(value, true);
164 } else if key == "anyOf" {
165 // anyOf is an array of schemas
166 if let serde_json::Value::Array(arr) = value {
167 for item in arr.iter_mut() {
168 adapt_schema_for_google_impl(item, true);
169 }
170 }
171 }
172 }
173 } else if let serde_json::Value::Array(arr) = json {
174 for item in arr.iter_mut() {
175 adapt_schema_for_google_impl(item, true);
176 }
177 }
178}
179
180#[derive(Serialize)]
181#[serde(rename_all = "camelCase")]
182struct GoogleRequest {
183 contents: Vec<GoogleContent>,
184 #[serde(skip_serializing_if = "Option::is_none")]
185 system_instruction: Option<GoogleSystemInstruction>,
186 #[serde(skip_serializing_if = "Option::is_none")]
187 generation_config: Option<GoogleGenerationConfig>,
188 #[serde(skip_serializing_if = "Option::is_none")]
189 tools: Option<Vec<GoogleTool>>,
190 #[serde(skip_serializing_if = "Option::is_none")]
191 tool_config: Option<GoogleToolConfig>,
192}
193
194#[derive(Serialize)]
195#[serde(rename_all = "camelCase")]
196struct GoogleSystemInstruction {
197 parts: Vec<GooglePart>,
198}
199
200#[derive(Serialize, Deserialize, Debug, Clone)]
201#[serde(rename_all = "camelCase")]
202struct GoogleContent {
203 parts: Vec<GooglePart>,
204 #[serde(skip_serializing_if = "Option::is_none")]
205 role: Option<String>,
206}
207
208#[derive(Serialize, Deserialize, Debug, Clone)]
209#[serde(untagged)]
210enum GooglePart {
211 Text(GoogleTextPart),
212 InlineData(GoogleInlineDataPart),
213 FunctionCall(GoogleFunctionCallPart),
214 FunctionResponse(GoogleFunctionResponsePart),
215 Thought(GoogleThoughtPart),
216}
217
218#[derive(Serialize, Deserialize, Debug, Clone)]
219#[serde(rename_all = "camelCase")]
220struct GoogleTextPart {
221 text: String,
222}
223
224#[derive(Serialize, Deserialize, Debug, Clone)]
225#[serde(rename_all = "camelCase")]
226struct GoogleInlineDataPart {
227 inline_data: GoogleBlob,
228}
229
230#[derive(Serialize, Deserialize, Debug, Clone)]
231#[serde(rename_all = "camelCase")]
232struct GoogleBlob {
233 mime_type: String,
234 data: String,
235}
236
237#[derive(Serialize, Deserialize, Debug, Clone)]
238#[serde(rename_all = "camelCase")]
239struct GoogleFunctionCallPart {
240 function_call: GoogleFunctionCall,
241 #[serde(skip_serializing_if = "Option::is_none")]
242 thought_signature: Option<String>,
243}
244
245#[derive(Serialize, Deserialize, Debug, Clone)]
246#[serde(rename_all = "camelCase")]
247struct GoogleFunctionCall {
248 name: String,
249 args: serde_json::Value,
250}
251
252#[derive(Serialize, Deserialize, Debug, Clone)]
253#[serde(rename_all = "camelCase")]
254struct GoogleFunctionResponsePart {
255 function_response: GoogleFunctionResponse,
256}
257
258#[derive(Serialize, Deserialize, Debug, Clone)]
259#[serde(rename_all = "camelCase")]
260struct GoogleFunctionResponse {
261 name: String,
262 response: serde_json::Value,
263}
264
265#[derive(Serialize, Deserialize, Debug, Clone)]
266#[serde(rename_all = "camelCase")]
267struct GoogleThoughtPart {
268 thought: bool,
269 thought_signature: String,
270}
271
272#[derive(Serialize)]
273#[serde(rename_all = "camelCase")]
274struct GoogleGenerationConfig {
275 #[serde(skip_serializing_if = "Option::is_none")]
276 candidate_count: Option<usize>,
277 #[serde(skip_serializing_if = "Option::is_none")]
278 stop_sequences: Option<Vec<String>>,
279 #[serde(skip_serializing_if = "Option::is_none")]
280 max_output_tokens: Option<usize>,
281 #[serde(skip_serializing_if = "Option::is_none")]
282 temperature: Option<f64>,
283 #[serde(skip_serializing_if = "Option::is_none")]
284 thinking_config: Option<GoogleThinkingConfig>,
285}
286
287#[derive(Serialize)]
288#[serde(rename_all = "camelCase")]
289struct GoogleThinkingConfig {
290 thinking_budget: u32,
291}
292
293#[derive(Serialize)]
294#[serde(rename_all = "camelCase")]
295struct GoogleTool {
296 function_declarations: Vec<GoogleFunctionDeclaration>,
297}
298
299#[derive(Serialize)]
300#[serde(rename_all = "camelCase")]
301struct GoogleFunctionDeclaration {
302 name: String,
303 description: String,
304 parameters: serde_json::Value,
305}
306
307#[derive(Serialize)]
308#[serde(rename_all = "camelCase")]
309struct GoogleToolConfig {
310 function_calling_config: GoogleFunctionCallingConfig,
311}
312
313#[derive(Serialize)]
314#[serde(rename_all = "camelCase")]
315struct GoogleFunctionCallingConfig {
316 mode: String,
317 #[serde(skip_serializing_if = "Option::is_none")]
318 allowed_function_names: Option<Vec<String>>,
319}
320
321#[derive(Deserialize, Debug)]
322#[serde(rename_all = "camelCase")]
323struct GoogleStreamResponse {
324 #[serde(default)]
325 candidates: Vec<GoogleCandidate>,
326 #[serde(default)]
327 usage_metadata: Option<GoogleUsageMetadata>,
328}
329
330#[derive(Deserialize, Debug)]
331#[serde(rename_all = "camelCase")]
332struct GoogleCandidate {
333 #[serde(default)]
334 content: Option<GoogleContent>,
335 #[serde(default)]
336 finish_reason: Option<String>,
337}
338
339#[derive(Deserialize, Debug)]
340#[serde(rename_all = "camelCase")]
341struct GoogleUsageMetadata {
342 #[serde(default)]
343 prompt_token_count: u64,
344 #[serde(default)]
345 candidates_token_count: u64,
346}
347
348fn convert_request(
349 model_id: &str,
350 request: &LlmCompletionRequest,
351) -> Result<(GoogleRequest, String), String> {
352 let real_model_id =
353 get_real_model_id(model_id).ok_or_else(|| format!("Unknown model: {}", model_id))?;
354
355 let supports_thinking = get_model_supports_thinking(model_id);
356
357 let mut contents: Vec<GoogleContent> = Vec::new();
358 let mut system_parts: Vec<GooglePart> = Vec::new();
359
360 for msg in &request.messages {
361 match msg.role {
362 LlmMessageRole::System => {
363 for content in &msg.content {
364 if let LlmMessageContent::Text(text) = content {
365 if !text.is_empty() {
366 system_parts
367 .push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
368 }
369 }
370 }
371 }
372 LlmMessageRole::User => {
373 let mut parts: Vec<GooglePart> = Vec::new();
374
375 for content in &msg.content {
376 match content {
377 LlmMessageContent::Text(text) => {
378 if !text.is_empty() {
379 parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
380 }
381 }
382 LlmMessageContent::Image(img) => {
383 parts.push(GooglePart::InlineData(GoogleInlineDataPart {
384 inline_data: GoogleBlob {
385 mime_type: "image/png".to_string(),
386 data: img.source.clone(),
387 },
388 }));
389 }
390 LlmMessageContent::ToolResult(result) => {
391 let response_value = match &result.content {
392 LlmToolResultContent::Text(t) => {
393 serde_json::json!({ "output": t })
394 }
395 LlmToolResultContent::Image(_) => {
396 serde_json::json!({ "output": "Tool responded with an image" })
397 }
398 };
399 parts.push(GooglePart::FunctionResponse(GoogleFunctionResponsePart {
400 function_response: GoogleFunctionResponse {
401 name: result.tool_name.clone(),
402 response: response_value,
403 },
404 }));
405 }
406 _ => {}
407 }
408 }
409
410 if !parts.is_empty() {
411 contents.push(GoogleContent {
412 parts,
413 role: Some("user".to_string()),
414 });
415 }
416 }
417 LlmMessageRole::Assistant => {
418 let mut parts: Vec<GooglePart> = Vec::new();
419
420 for content in &msg.content {
421 match content {
422 LlmMessageContent::Text(text) => {
423 if !text.is_empty() {
424 parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
425 }
426 }
427 LlmMessageContent::ToolUse(tool_use) => {
428 let thought_signature =
429 tool_use.thought_signature.clone().filter(|s| !s.is_empty());
430
431 let args: serde_json::Value =
432 serde_json::from_str(&tool_use.input).unwrap_or_default();
433
434 parts.push(GooglePart::FunctionCall(GoogleFunctionCallPart {
435 function_call: GoogleFunctionCall {
436 name: tool_use.name.clone(),
437 args,
438 },
439 thought_signature,
440 }));
441 }
442 LlmMessageContent::Thinking(thinking) => {
443 if let Some(ref signature) = thinking.signature {
444 if !signature.is_empty() {
445 parts.push(GooglePart::Thought(GoogleThoughtPart {
446 thought: true,
447 thought_signature: signature.clone(),
448 }));
449 }
450 }
451 }
452 _ => {}
453 }
454 }
455
456 if !parts.is_empty() {
457 contents.push(GoogleContent {
458 parts,
459 role: Some("model".to_string()),
460 });
461 }
462 }
463 }
464 }
465
466 let system_instruction = if system_parts.is_empty() {
467 None
468 } else {
469 Some(GoogleSystemInstruction {
470 parts: system_parts,
471 })
472 };
473
474 let tools: Option<Vec<GoogleTool>> = if request.tools.is_empty() {
475 None
476 } else {
477 let declarations: Vec<GoogleFunctionDeclaration> = request
478 .tools
479 .iter()
480 .map(|t| {
481 let mut parameters: serde_json::Value = serde_json::from_str(&t.input_schema)
482 .unwrap_or(serde_json::Value::Object(Default::default()));
483 adapt_schema_for_google(&mut parameters);
484 GoogleFunctionDeclaration {
485 name: t.name.clone(),
486 description: t.description.clone(),
487 parameters,
488 }
489 })
490 .collect();
491 Some(vec![GoogleTool {
492 function_declarations: declarations,
493 }])
494 };
495
496 let tool_config = request.tool_choice.as_ref().map(|tc| {
497 let mode = match tc {
498 LlmToolChoice::Auto => "AUTO",
499 LlmToolChoice::Any => "ANY",
500 LlmToolChoice::None => "NONE",
501 };
502 GoogleToolConfig {
503 function_calling_config: GoogleFunctionCallingConfig {
504 mode: mode.to_string(),
505 allowed_function_names: None,
506 },
507 }
508 });
509
510 let thinking_config = if supports_thinking && request.thinking_allowed {
511 Some(GoogleThinkingConfig {
512 thinking_budget: 8192,
513 })
514 } else {
515 None
516 };
517
518 let generation_config = Some(GoogleGenerationConfig {
519 candidate_count: Some(1),
520 stop_sequences: if request.stop_sequences.is_empty() {
521 None
522 } else {
523 Some(request.stop_sequences.clone())
524 },
525 max_output_tokens: None,
526 temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
527 thinking_config,
528 });
529
530 Ok((
531 GoogleRequest {
532 contents,
533 system_instruction,
534 generation_config,
535 tools,
536 tool_config,
537 },
538 real_model_id.to_string(),
539 ))
540}
541
542fn parse_stream_line(line: &str) -> Option<GoogleStreamResponse> {
543 let trimmed = line.trim();
544 if trimmed.is_empty() || trimmed == "[" || trimmed == "]" || trimmed == "," {
545 return None;
546 }
547
548 let json_str = trimmed.strip_prefix("data: ").unwrap_or(trimmed);
549 let json_str = json_str.trim_start_matches(',').trim();
550
551 if json_str.is_empty() {
552 return None;
553 }
554
555 serde_json::from_str(json_str).ok()
556}
557
558impl zed::Extension for GoogleAiProvider {
559 fn new() -> Self {
560 Self {
561 streams: Mutex::new(HashMap::new()),
562 next_stream_id: Mutex::new(0),
563 }
564 }
565
566 fn llm_providers(&self) -> Vec<LlmProviderInfo> {
567 vec![LlmProviderInfo {
568 id: "google-ai".into(),
569 name: "Google AI".into(),
570 icon: Some("icons/google-ai.svg".into()),
571 }]
572 }
573
574 fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
575 Ok(MODELS
576 .iter()
577 .map(|m| LlmModelInfo {
578 id: m.display_name.to_string(),
579 name: m.display_name.to_string(),
580 max_token_count: m.max_tokens,
581 max_output_tokens: m.max_output_tokens,
582 capabilities: LlmModelCapabilities {
583 supports_images: m.supports_images,
584 supports_tools: true,
585 supports_tool_choice_auto: true,
586 supports_tool_choice_any: true,
587 supports_tool_choice_none: true,
588 supports_thinking: m.supports_thinking,
589 tool_input_format: LlmToolInputFormat::JsonSchema,
590 },
591 is_default: m.is_default,
592 is_default_fast: m.is_default_fast,
593 })
594 .collect())
595 }
596
597 fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
598 llm_get_credential("google-ai").is_some()
599 }
600
601 fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
602 Some(
603 r#"# Google AI Setup
604
605Welcome to **Google AI**! This extension provides access to Google Gemini models.
606
607## Configuration
608
609Enter your Google AI API key below. You can get your API key at [aistudio.google.com/apikey](https://aistudio.google.com/apikey).
610
611## Available Models
612
613| Display Name | Real Model | Context | Output |
614|--------------|------------|---------|--------|
615| Gemini 2.5 Flash-Lite | gemini-2.5-flash-lite | 1M | 65K |
616| Gemini 2.5 Flash | gemini-2.5-flash | 1M | 65K |
617| Gemini 2.5 Pro | gemini-2.5-pro | 1M | 65K |
618| Gemini 3 Pro | gemini-3-pro-preview | 1M | 65K |
619
620## Features
621
622- ✅ Full streaming support
623- ✅ Tool/function calling with thought signatures
624- ✅ Vision (image inputs)
625- ✅ Extended thinking support
626- ✅ All Gemini models
627
628## Pricing
629
630Uses your Google AI API credits. See [Google AI pricing](https://ai.google.dev/pricing) for details.
631"#
632 .to_string(),
633 )
634 }
635
636 fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
637 llm_delete_credential("google-ai")
638 }
639
640 fn llm_stream_completion_start(
641 &mut self,
642 _provider_id: &str,
643 model_id: &str,
644 request: &LlmCompletionRequest,
645 ) -> Result<String, String> {
646 let api_key = llm_get_credential("google-ai").ok_or_else(|| {
647 "No API key configured. Please add your Google AI API key in settings.".to_string()
648 })?;
649
650 let (google_request, real_model_id) = convert_request(model_id, request)?;
651
652 let body = serde_json::to_vec(&google_request)
653 .map_err(|e| format!("Failed to serialize request: {}", e))?;
654
655 let url = format!(
656 "https://generativelanguage.googleapis.com/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
657 real_model_id, api_key
658 );
659
660 let http_request = HttpRequest {
661 method: HttpMethod::Post,
662 url,
663 headers: vec![("Content-Type".to_string(), "application/json".to_string())],
664 body: Some(body),
665 redirect_policy: RedirectPolicy::FollowAll,
666 };
667
668 let response_stream = http_request
669 .fetch_stream()
670 .map_err(|e| format!("HTTP request failed: {}", e))?;
671
672 let stream_id = {
673 let mut id_counter = self.next_stream_id.lock().unwrap();
674 let id = format!("google-ai-stream-{}", *id_counter);
675 *id_counter += 1;
676 id
677 };
678
679 self.streams.lock().unwrap().insert(
680 stream_id.clone(),
681 StreamState {
682 response_stream: Some(response_stream),
683 buffer: String::new(),
684 started: false,
685 stop_reason: None,
686 wants_tool_use: false,
687 },
688 );
689
690 Ok(stream_id)
691 }
692
693 fn llm_stream_completion_next(
694 &mut self,
695 stream_id: &str,
696 ) -> Result<Option<LlmCompletionEvent>, String> {
697 let mut streams = self.streams.lock().unwrap();
698 let state = streams
699 .get_mut(stream_id)
700 .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
701
702 if !state.started {
703 state.started = true;
704 return Ok(Some(LlmCompletionEvent::Started));
705 }
706
707 let response_stream = state
708 .response_stream
709 .as_mut()
710 .ok_or_else(|| "Stream already closed".to_string())?;
711
712 loop {
713 if let Some(newline_pos) = state.buffer.find('\n') {
714 let line = state.buffer[..newline_pos].to_string();
715 state.buffer = state.buffer[newline_pos + 1..].to_string();
716
717 if let Some(response) = parse_stream_line(&line) {
718 for candidate in response.candidates {
719 if let Some(finish_reason) = &candidate.finish_reason {
720 state.stop_reason = Some(match finish_reason.as_str() {
721 "STOP" => {
722 if state.wants_tool_use {
723 LlmStopReason::ToolUse
724 } else {
725 LlmStopReason::EndTurn
726 }
727 }
728 "MAX_TOKENS" => LlmStopReason::MaxTokens,
729 "SAFETY" => LlmStopReason::Refusal,
730 _ => LlmStopReason::EndTurn,
731 });
732 }
733
734 if let Some(content) = candidate.content {
735 for part in content.parts {
736 match part {
737 GooglePart::Text(text_part) => {
738 if !text_part.text.is_empty() {
739 return Ok(Some(LlmCompletionEvent::Text(
740 text_part.text,
741 )));
742 }
743 }
744 GooglePart::FunctionCall(fc_part) => {
745 state.wants_tool_use = true;
746 let next_tool_id =
747 TOOL_CALL_COUNTER.fetch_add(1, Ordering::SeqCst);
748 let id = format!(
749 "{}-{}",
750 fc_part.function_call.name, next_tool_id
751 );
752
753 let thought_signature =
754 fc_part.thought_signature.filter(|s| !s.is_empty());
755
756 return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
757 id,
758 name: fc_part.function_call.name,
759 input: fc_part.function_call.args.to_string(),
760 thought_signature,
761 })));
762 }
763 GooglePart::Thought(thought_part) => {
764 return Ok(Some(LlmCompletionEvent::Thinking(
765 LlmThinkingContent {
766 text: "(Encrypted thought)".to_string(),
767 signature: Some(thought_part.thought_signature),
768 },
769 )));
770 }
771 _ => {}
772 }
773 }
774 }
775 }
776
777 if let Some(usage) = response.usage_metadata {
778 return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
779 input_tokens: usage.prompt_token_count,
780 output_tokens: usage.candidates_token_count,
781 cache_creation_input_tokens: None,
782 cache_read_input_tokens: None,
783 })));
784 }
785 }
786
787 continue;
788 }
789
790 match response_stream.next_chunk() {
791 Ok(Some(chunk)) => {
792 let text = String::from_utf8_lossy(&chunk);
793 state.buffer.push_str(&text);
794 }
795 Ok(None) => {
796 // Stream ended - check if we have a stop reason
797 if let Some(stop_reason) = state.stop_reason.take() {
798 return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
799 }
800
801 // No stop reason - this is unexpected. Check if buffer contains error info
802 let mut error_msg = String::from("Stream ended unexpectedly.");
803
804 // Try to parse remaining buffer as potential error response
805 if !state.buffer.is_empty() {
806 error_msg.push_str(&format!(
807 "\nRemaining buffer: {}",
808 &state.buffer[..state.buffer.len().min(1000)]
809 ));
810 }
811
812 return Err(error_msg);
813 }
814 Err(e) => {
815 return Err(format!("Stream error: {}", e));
816 }
817 }
818 }
819 }
820
821 fn llm_stream_completion_close(&mut self, stream_id: &str) {
822 self.streams.lock().unwrap().remove(stream_id);
823 }
824}
825
826zed::register_extension!(GoogleAiProvider);