1use std::collections::HashMap;
2use std::sync::atomic::{AtomicU64, Ordering};
3
4static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
5
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7use zed_extension_api::{
8 self as zed, http_client::HttpMethod, http_client::HttpRequest,
9 llm_get_env_var, llm_get_provider_settings,
10 LlmCacheConfiguration, LlmCompletionEvent, LlmCompletionRequest, LlmCustomModelConfig,
11 LlmMessageContent, LlmMessageRole, LlmModelCapabilities, LlmModelInfo, LlmProviderInfo,
12 LlmStopReason, LlmThinkingContent, LlmTokenUsage, LlmToolInputFormat, LlmToolUse,
13};
14
15pub const DEFAULT_API_URL: &str = "https://generativelanguage.googleapis.com";
16
17fn get_api_url() -> String {
18 llm_get_provider_settings(PROVIDER_ID)
19 .and_then(|s| s.api_url)
20 .unwrap_or_else(|| DEFAULT_API_URL.to_string())
21}
22
23fn get_custom_models() -> Vec<LlmCustomModelConfig> {
24 llm_get_provider_settings(PROVIDER_ID)
25 .map(|s| s.available_models)
26 .unwrap_or_default()
27}
28
29fn stream_generate_content(
30 model_id: &str,
31 request: &LlmCompletionRequest,
32 streams: &mut HashMap<String, StreamState>,
33 next_stream_id: &mut u64,
34) -> Result<String, String> {
35 let api_key = get_api_key().ok_or_else(|| "API key not configured".to_string())?;
36
37 let generate_content_request = build_generate_content_request(model_id, request)?;
38 validate_generate_content_request(&generate_content_request)?;
39
40 let api_url = get_api_url();
41 let uri = format!(
42 "{}/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
43 api_url, model_id, api_key
44 );
45
46 let body = serde_json::to_vec(&generate_content_request)
47 .map_err(|e| format!("Failed to serialize request: {}", e))?;
48
49 let http_request = HttpRequest::builder()
50 .method(HttpMethod::Post)
51 .url(&uri)
52 .header("Content-Type", "application/json")
53 .body(body)
54 .build()?;
55
56 let response_stream = http_request.fetch_stream()?;
57
58 let stream_id = format!("stream-{}", *next_stream_id);
59 *next_stream_id += 1;
60
61 streams.insert(
62 stream_id.clone(),
63 StreamState {
64 response_stream,
65 buffer: String::new(),
66 usage: None,
67 pending_events: Vec::new(),
68 wants_to_use_tool: false,
69 },
70 );
71
72 Ok(stream_id)
73}
74
75fn count_tokens(model_id: &str, request: &LlmCompletionRequest) -> Result<u64, String> {
76 let api_key = get_api_key().ok_or_else(|| "API key not configured".to_string())?;
77
78 let generate_content_request = build_generate_content_request(model_id, request)?;
79 validate_generate_content_request(&generate_content_request)?;
80 let count_request = CountTokensRequest {
81 generate_content_request,
82 };
83
84 let api_url = get_api_url();
85 let uri = format!(
86 "{}/v1beta/models/{}:countTokens?key={}",
87 api_url, model_id, api_key
88 );
89
90 let body = serde_json::to_vec(&count_request)
91 .map_err(|e| format!("Failed to serialize request: {}", e))?;
92
93 let http_request = HttpRequest::builder()
94 .method(HttpMethod::Post)
95 .url(&uri)
96 .header("Content-Type", "application/json")
97 .body(body)
98 .build()?;
99
100 let response = http_request.fetch()?;
101 let response_body: CountTokensResponse = serde_json::from_slice(&response.body)
102 .map_err(|e| format!("Failed to parse response: {}", e))?;
103
104 Ok(response_body.total_tokens)
105}
106
107fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<(), String> {
108 if request.model.is_empty() {
109 return Err("Model must be specified".to_string());
110 }
111
112 if request.contents.is_empty() {
113 return Err("Request must contain at least one content item".to_string());
114 }
115
116 if let Some(user_content) = request
117 .contents
118 .iter()
119 .find(|content| content.role == Role::User)
120 {
121 if user_content.parts.is_empty() {
122 return Err("User content must contain at least one part".to_string());
123 }
124 }
125
126 Ok(())
127}
128
129// Extension implementation
130
131const PROVIDER_ID: &str = "google";
132const PROVIDER_NAME: &str = "Google AI";
133
134struct GoogleAiExtension {
135 streams: HashMap<String, StreamState>,
136 next_stream_id: u64,
137}
138
139struct StreamState {
140 response_stream: zed::http_client::HttpResponseStream,
141 buffer: String,
142 usage: Option<UsageMetadata>,
143 pending_events: Vec<LlmCompletionEvent>,
144 wants_to_use_tool: bool,
145}
146
147impl zed::Extension for GoogleAiExtension {
148 fn new() -> Self {
149 Self {
150 streams: HashMap::new(),
151 next_stream_id: 0,
152 }
153 }
154
155 fn llm_providers(&self) -> Vec<LlmProviderInfo> {
156 vec![LlmProviderInfo {
157 id: PROVIDER_ID.to_string(),
158 name: PROVIDER_NAME.to_string(),
159 icon: Some("icons/google-ai.svg".to_string()),
160 }]
161 }
162
163 fn llm_provider_models(&self, provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
164 if provider_id != PROVIDER_ID {
165 return Err(format!("Unknown provider: {}", provider_id));
166 }
167 Ok(get_models())
168 }
169
170 fn llm_provider_settings_markdown(&self, provider_id: &str) -> Option<String> {
171 if provider_id != PROVIDER_ID {
172 return None;
173 }
174
175 Some(
176 r#"## Google AI Setup
177
178To use Google AI models in Zed, you need a Gemini API key.
179
1801. Go to [Google AI Studio](https://aistudio.google.com/apikey)
1812. Create or select a project
1823. Generate an API key
1834. Set the `GEMINI_API_KEY` or `GOOGLE_AI_API_KEY` environment variable
184
185You can set this in your shell profile or use a `.envrc` file with [direnv](https://direnv.net/).
186"#
187 .to_string(),
188 )
189 }
190
191 fn llm_provider_is_authenticated(&self, provider_id: &str) -> bool {
192 if provider_id != PROVIDER_ID {
193 return false;
194 }
195 get_api_key().is_some()
196 }
197
198 fn llm_provider_reset_credentials(&mut self, provider_id: &str) -> Result<(), String> {
199 if provider_id != PROVIDER_ID {
200 return Err(format!("Unknown provider: {}", provider_id));
201 }
202 Ok(())
203 }
204
205 fn llm_count_tokens(
206 &self,
207 provider_id: &str,
208 model_id: &str,
209 request: &LlmCompletionRequest,
210 ) -> Result<u64, String> {
211 if provider_id != PROVIDER_ID {
212 return Err(format!("Unknown provider: {}", provider_id));
213 }
214 count_tokens(model_id, request)
215 }
216
217 fn llm_stream_completion_start(
218 &mut self,
219 provider_id: &str,
220 model_id: &str,
221 request: &LlmCompletionRequest,
222 ) -> Result<String, String> {
223 if provider_id != PROVIDER_ID {
224 return Err(format!("Unknown provider: {}", provider_id));
225 }
226 stream_generate_content(model_id, request, &mut self.streams, &mut self.next_stream_id)
227 }
228
229 fn llm_stream_completion_next(
230 &mut self,
231 stream_id: &str,
232 ) -> Result<Option<LlmCompletionEvent>, String> {
233 stream_generate_content_next(stream_id, &mut self.streams)
234 }
235
236 fn llm_stream_completion_close(&mut self, stream_id: &str) {
237 self.streams.remove(stream_id);
238 }
239
240 fn llm_cache_configuration(
241 &self,
242 provider_id: &str,
243 _model_id: &str,
244 ) -> Option<LlmCacheConfiguration> {
245 if provider_id != PROVIDER_ID {
246 return None;
247 }
248
249 Some(LlmCacheConfiguration {
250 max_cache_anchors: 1,
251 should_cache_tool_definitions: false,
252 min_total_token_count: 32768,
253 })
254 }
255}
256
257zed::register_extension!(GoogleAiExtension);
258
259// Helper functions
260
261fn get_api_key() -> Option<String> {
262 llm_get_env_var("GEMINI_API_KEY").or_else(|| llm_get_env_var("GOOGLE_AI_API_KEY"))
263}
264
265fn get_default_models() -> Vec<LlmModelInfo> {
266 vec![
267 LlmModelInfo {
268 id: "gemini-2.5-flash-lite".to_string(),
269 name: "Gemini 2.5 Flash-Lite".to_string(),
270 max_token_count: 1_048_576,
271 max_output_tokens: Some(65_536),
272 capabilities: LlmModelCapabilities {
273 supports_images: true,
274 supports_tools: true,
275 supports_tool_choice_auto: true,
276 supports_tool_choice_any: true,
277 supports_tool_choice_none: true,
278 supports_thinking: true,
279 tool_input_format: LlmToolInputFormat::JsonSchemaSubset,
280 },
281 is_default: false,
282 is_default_fast: true,
283 },
284 LlmModelInfo {
285 id: "gemini-2.5-flash".to_string(),
286 name: "Gemini 2.5 Flash".to_string(),
287 max_token_count: 1_048_576,
288 max_output_tokens: Some(65_536),
289 capabilities: LlmModelCapabilities {
290 supports_images: true,
291 supports_tools: true,
292 supports_tool_choice_auto: true,
293 supports_tool_choice_any: true,
294 supports_tool_choice_none: true,
295 supports_thinking: true,
296 tool_input_format: LlmToolInputFormat::JsonSchemaSubset,
297 },
298 is_default: true,
299 is_default_fast: false,
300 },
301 LlmModelInfo {
302 id: "gemini-2.5-pro".to_string(),
303 name: "Gemini 2.5 Pro".to_string(),
304 max_token_count: 1_048_576,
305 max_output_tokens: Some(65_536),
306 capabilities: LlmModelCapabilities {
307 supports_images: true,
308 supports_tools: true,
309 supports_tool_choice_auto: true,
310 supports_tool_choice_any: true,
311 supports_tool_choice_none: true,
312 supports_thinking: true,
313 tool_input_format: LlmToolInputFormat::JsonSchemaSubset,
314 },
315 is_default: false,
316 is_default_fast: false,
317 },
318 LlmModelInfo {
319 id: "gemini-3-pro-preview".to_string(),
320 name: "Gemini 3 Pro".to_string(),
321 max_token_count: 1_048_576,
322 max_output_tokens: Some(65_536),
323 capabilities: LlmModelCapabilities {
324 supports_images: true,
325 supports_tools: true,
326 supports_tool_choice_auto: true,
327 supports_tool_choice_any: true,
328 supports_tool_choice_none: true,
329 supports_thinking: true,
330 tool_input_format: LlmToolInputFormat::JsonSchemaSubset,
331 },
332 is_default: false,
333 is_default_fast: false,
334 },
335 LlmModelInfo {
336 id: "gemini-3-flash-preview".to_string(),
337 name: "Gemini 3 Flash".to_string(),
338 max_token_count: 1_048_576,
339 max_output_tokens: Some(65_536),
340 capabilities: LlmModelCapabilities {
341 supports_images: true,
342 supports_tools: true,
343 supports_tool_choice_auto: true,
344 supports_tool_choice_any: true,
345 supports_tool_choice_none: true,
346 supports_thinking: false,
347 tool_input_format: LlmToolInputFormat::JsonSchemaSubset,
348 },
349 is_default: false,
350 is_default_fast: false,
351 },
352 ]
353}
354
355/// Model aliases for backward compatibility with old model names.
356/// Maps old names to canonical model IDs.
357fn get_model_aliases() -> Vec<(&'static str, &'static str)> {
358 vec![
359 // Gemini 2.5 Flash-Lite aliases
360 ("gemini-2.5-flash-lite-preview-06-17", "gemini-2.5-flash-lite"),
361 ("gemini-2.0-flash-lite-preview", "gemini-2.5-flash-lite"),
362 // Gemini 2.5 Flash aliases
363 ("gemini-2.0-flash-thinking-exp", "gemini-2.5-flash"),
364 ("gemini-2.5-flash-preview-04-17", "gemini-2.5-flash"),
365 ("gemini-2.5-flash-preview-05-20", "gemini-2.5-flash"),
366 ("gemini-2.5-flash-preview-latest", "gemini-2.5-flash"),
367 ("gemini-2.0-flash", "gemini-2.5-flash"),
368 // Gemini 2.5 Pro aliases
369 ("gemini-2.0-pro-exp", "gemini-2.5-pro"),
370 ("gemini-2.5-pro-preview-latest", "gemini-2.5-pro"),
371 ("gemini-2.5-pro-exp-03-25", "gemini-2.5-pro"),
372 ("gemini-2.5-pro-preview-03-25", "gemini-2.5-pro"),
373 ("gemini-2.5-pro-preview-05-06", "gemini-2.5-pro"),
374 ("gemini-2.5-pro-preview-06-05", "gemini-2.5-pro"),
375 ]
376}
377
378fn get_models() -> Vec<LlmModelInfo> {
379 let mut models: HashMap<String, LlmModelInfo> = HashMap::new();
380
381 // Add default models
382 for model in get_default_models() {
383 models.insert(model.id.clone(), model);
384 }
385
386 // Add aliases as separate model entries (pointing to the same underlying model)
387 for (alias, canonical_id) in get_model_aliases() {
388 if let Some(canonical_model) = models.get(canonical_id) {
389 let mut alias_model = canonical_model.clone();
390 alias_model.id = alias.to_string();
391 alias_model.is_default = false;
392 alias_model.is_default_fast = false;
393 models.insert(alias.to_string(), alias_model);
394 }
395 }
396
397 // Add/override with custom models from settings
398 for custom_model in get_custom_models() {
399 let model = LlmModelInfo {
400 id: custom_model.name.clone(),
401 name: custom_model.display_name.unwrap_or(custom_model.name.clone()),
402 max_token_count: custom_model.max_tokens,
403 max_output_tokens: custom_model.max_output_tokens,
404 capabilities: LlmModelCapabilities {
405 supports_images: true,
406 supports_tools: true,
407 supports_tool_choice_auto: true,
408 supports_tool_choice_any: true,
409 supports_tool_choice_none: true,
410 supports_thinking: custom_model.thinking_budget.is_some(),
411 tool_input_format: LlmToolInputFormat::JsonSchemaSubset,
412 },
413 is_default: false,
414 is_default_fast: false,
415 };
416 models.insert(custom_model.name, model);
417 }
418
419 models.into_values().collect()
420}
421
422/// Get the thinking budget for a specific model from custom settings.
423fn get_model_thinking_budget(model_id: &str) -> Option<u32> {
424 get_custom_models()
425 .into_iter()
426 .find(|m| m.name == model_id)
427 .and_then(|m| m.thinking_budget)
428}
429
430fn stream_generate_content_next(
431 stream_id: &str,
432 streams: &mut HashMap<String, StreamState>,
433) -> Result<Option<LlmCompletionEvent>, String> {
434 let state = streams
435 .get_mut(stream_id)
436 .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
437
438 loop {
439 // Return any pending events first
440 if let Some(event) = state.pending_events.pop() {
441 return Ok(Some(event));
442 }
443
444 if let Some(newline_pos) = state.buffer.find('\n') {
445 let line = state.buffer[..newline_pos].to_string();
446 state.buffer = state.buffer[newline_pos + 1..].to_string();
447
448 if let Some(data) = line.strip_prefix("data: ") {
449 if data.trim().is_empty() {
450 continue;
451 }
452
453 let response: GenerateContentResponse = match serde_json::from_str(data) {
454 Ok(response) => response,
455 Err(parse_error) => {
456 // Try to parse as an API error response
457 if let Ok(api_error) = serde_json::from_str::<ApiErrorResponse>(data) {
458 let error_msg = api_error
459 .error
460 .message
461 .unwrap_or_else(|| "Unknown API error".to_string());
462 let status = api_error.error.status.unwrap_or_default();
463 let code = api_error.error.code.unwrap_or(0);
464 return Err(format!(
465 "Google AI API error ({}): {} [status: {}]",
466 code, error_msg, status
467 ));
468 }
469 // If it's not an error response, return the parse error
470 return Err(format!(
471 "Failed to parse SSE data: {} - {}",
472 parse_error, data
473 ));
474 }
475 };
476
477 // Handle prompt feedback (blocked prompts)
478 if let Some(ref prompt_feedback) = response.prompt_feedback {
479 if let Some(ref block_reason) = prompt_feedback.block_reason {
480 let _stop_reason = match block_reason.as_str() {
481 "SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT"
482 | "IMAGE_SAFETY" => LlmStopReason::Refusal,
483 _ => LlmStopReason::Refusal,
484 };
485 return Ok(Some(LlmCompletionEvent::Stop(LlmStopReason::Refusal)));
486 }
487 }
488
489 // Send usage updates immediately when received
490 if let Some(ref usage) = response.usage_metadata {
491 let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
492 let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
493 let input_tokens = prompt_tokens.saturating_sub(cached_tokens);
494 state.pending_events.push(LlmCompletionEvent::Usage(LlmTokenUsage {
495 input_tokens,
496 output_tokens: usage.candidates_token_count.unwrap_or(0),
497 cache_creation_input_tokens: None,
498 cache_read_input_tokens: Some(cached_tokens).filter(|&c| c > 0),
499 }));
500 state.usage = Some(usage.clone());
501 }
502
503 if let Some(candidates) = response.candidates {
504 for candidate in candidates {
505 for part in candidate.content.parts {
506 match part {
507 Part::TextPart(text_part) => {
508 return Ok(Some(LlmCompletionEvent::Text(text_part.text)));
509 }
510 Part::ThoughtPart(thought_part) => {
511 return Ok(Some(LlmCompletionEvent::Thinking(
512 LlmThinkingContent {
513 text: "(Encrypted thought)".to_string(),
514 signature: Some(thought_part.thought_signature),
515 },
516 )));
517 }
518 Part::FunctionCallPart(fc_part) => {
519 state.wants_to_use_tool = true;
520 // Normalize empty string signatures to None
521 let thought_signature =
522 fc_part.thought_signature.filter(|s| !s.is_empty());
523 // Generate unique tool use ID like hardcoded implementation
524 let next_tool_id = TOOL_CALL_COUNTER.fetch_add(1, Ordering::SeqCst);
525 let tool_use_id = format!("{}-{}", fc_part.function_call.name, next_tool_id);
526 return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
527 id: tool_use_id,
528 name: fc_part.function_call.name,
529 input: serde_json::to_string(&fc_part.function_call.args)
530 .unwrap_or_default(),
531 is_input_complete: true,
532 thought_signature,
533 })));
534 }
535 _ => {}
536 }
537 }
538
539 if let Some(finish_reason) = candidate.finish_reason {
540 // Even when Gemini wants to use a Tool, the API
541 // responds with `finish_reason: STOP`, so we check
542 // wants_to_use_tool to override
543 let stop_reason = if state.wants_to_use_tool {
544 LlmStopReason::ToolUse
545 } else {
546 match finish_reason.as_str() {
547 "STOP" => LlmStopReason::EndTurn,
548 "MAX_TOKENS" => LlmStopReason::MaxTokens,
549 "TOOL_USE" | "FUNCTION_CALL" => LlmStopReason::ToolUse,
550 "SAFETY" | "RECITATION" | "OTHER" => LlmStopReason::Refusal,
551 _ => LlmStopReason::EndTurn,
552 }
553 };
554
555 return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
556 }
557 }
558 }
559 }
560
561 continue;
562 }
563
564 // Check if the buffer contains a non-SSE error response (no "data: " prefix)
565 // This can happen when Google returns an immediate error without streaming
566 if !state.buffer.is_empty()
567 && !state.buffer.contains("data: ")
568 && state.buffer.contains("\"error\"")
569 {
570 // Try to parse the entire buffer as an error response
571 if let Ok(api_error) = serde_json::from_str::<ApiErrorResponse>(&state.buffer) {
572 let error_msg = api_error
573 .error
574 .message
575 .unwrap_or_else(|| "Unknown API error".to_string());
576 let status = api_error.error.status.unwrap_or_default();
577 let code = api_error.error.code.unwrap_or(0);
578 streams.remove(stream_id);
579 return Err(format!(
580 "Google AI API error ({}): {} [status: {}]",
581 code, error_msg, status
582 ));
583 }
584 }
585
586 match state.response_stream.next_chunk() {
587 Ok(Some(chunk)) => {
588 let chunk_str = String::from_utf8_lossy(&chunk);
589 state.buffer.push_str(&chunk_str);
590 }
591 Ok(None) => {
592 streams.remove(stream_id);
593 return Ok(None);
594 }
595 Err(e) => {
596 streams.remove(stream_id);
597 return Err(e);
598 }
599 }
600 }
601}
602
603fn build_generate_content_request(
604 model_id: &str,
605 request: &LlmCompletionRequest,
606) -> Result<GenerateContentRequest, String> {
607 let mut contents: Vec<Content> = Vec::new();
608 let mut system_instruction: Option<SystemInstruction> = None;
609
610 for message in &request.messages {
611 match message.role {
612 LlmMessageRole::System => {
613 let parts = convert_content_to_parts(&message.content)?;
614 system_instruction = Some(SystemInstruction { parts });
615 }
616 LlmMessageRole::User | LlmMessageRole::Assistant => {
617 let role = match message.role {
618 LlmMessageRole::User => Role::User,
619 LlmMessageRole::Assistant => Role::Model,
620 _ => continue,
621 };
622 let parts = convert_content_to_parts(&message.content)?;
623 contents.push(Content { parts, role });
624 }
625 }
626 }
627
628 let tools = if !request.tools.is_empty() {
629 Some(vec![Tool {
630 function_declarations: request
631 .tools
632 .iter()
633 .map(|t| FunctionDeclaration {
634 name: t.name.clone(),
635 description: t.description.clone(),
636 parameters: serde_json::from_str(&t.input_schema).unwrap_or_default(),
637 })
638 .collect(),
639 }])
640 } else {
641 None
642 };
643
644 let tool_config = request.tool_choice.as_ref().map(|choice| {
645 let mode = match choice {
646 zed::LlmToolChoice::Auto => FunctionCallingMode::Auto,
647 zed::LlmToolChoice::Any => FunctionCallingMode::Any,
648 zed::LlmToolChoice::None => FunctionCallingMode::None,
649 };
650 ToolConfig {
651 function_calling_config: FunctionCallingConfig {
652 mode,
653 allowed_function_names: None,
654 },
655 }
656 });
657
658 let generation_config = Some(GenerationConfig {
659 candidate_count: Some(1),
660 stop_sequences: if request.stop_sequences.is_empty() {
661 None
662 } else {
663 Some(request.stop_sequences.clone())
664 },
665 max_output_tokens: request.max_tokens.map(|t| t as usize),
666 temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
667 top_p: None,
668 top_k: None,
669 thinking_config: if request.thinking_allowed {
670 // Check if this model has a custom thinking budget configured
671 get_model_thinking_budget(model_id).map(|thinking_budget| ThinkingConfig {
672 thinking_budget,
673 })
674 } else {
675 None
676 },
677 });
678
679 Ok(GenerateContentRequest {
680 model: ModelName {
681 model_id: model_id.to_string(),
682 },
683 contents,
684 system_instruction,
685 generation_config,
686 safety_settings: None,
687 tools,
688 tool_config,
689 })
690}
691
692fn convert_content_to_parts(content: &[LlmMessageContent]) -> Result<Vec<Part>, String> {
693 let mut parts = Vec::new();
694
695 for item in content {
696 match item {
697 LlmMessageContent::Text(text) => {
698 parts.push(Part::TextPart(TextPart { text: text.clone() }));
699 }
700 LlmMessageContent::Image(image) => {
701 parts.push(Part::InlineDataPart(InlineDataPart {
702 inline_data: GenerativeContentBlob {
703 mime_type: "image/png".to_string(),
704 data: image.source.clone(),
705 },
706 }));
707 }
708 LlmMessageContent::ToolUse(tool_use) => {
709 // Normalize empty string signatures to None
710 let thought_signature = tool_use
711 .thought_signature
712 .clone()
713 .filter(|s| !s.is_empty());
714 parts.push(Part::FunctionCallPart(FunctionCallPart {
715 function_call: FunctionCall {
716 name: tool_use.name.clone(),
717 args: serde_json::from_str(&tool_use.input).unwrap_or_default(),
718 },
719 thought_signature,
720 }));
721 }
722 LlmMessageContent::ToolResult(tool_result) => {
723 match &tool_result.content {
724 zed::LlmToolResultContent::Text(text) => {
725 parts.push(Part::FunctionResponsePart(FunctionResponsePart {
726 function_response: FunctionResponse {
727 name: tool_result.tool_name.clone(),
728 response: serde_json::json!({ "output": text }),
729 },
730 }));
731 }
732 zed::LlmToolResultContent::Image(image) => {
733 // Send both the function response and the image inline
734 parts.push(Part::FunctionResponsePart(FunctionResponsePart {
735 function_response: FunctionResponse {
736 name: tool_result.tool_name.clone(),
737 response: serde_json::json!({ "output": "Tool responded with an image" }),
738 },
739 }));
740 parts.push(Part::InlineDataPart(InlineDataPart {
741 inline_data: GenerativeContentBlob {
742 mime_type: "image/png".to_string(),
743 data: image.source.clone(),
744 },
745 }));
746 }
747 }
748 }
749 LlmMessageContent::Thinking(thinking) => {
750 if let Some(signature) = &thinking.signature {
751 parts.push(Part::ThoughtPart(ThoughtPart {
752 thought: true,
753 thought_signature: signature.clone(),
754 }));
755 }
756 }
757 LlmMessageContent::RedactedThinking(_) => {}
758 }
759 }
760
761 Ok(parts)
762}
763
764// Data structures for Google AI API
765
766#[derive(Debug, Serialize, Deserialize)]
767#[serde(rename_all = "camelCase")]
768pub struct GenerateContentRequest {
769 #[serde(default, skip_serializing_if = "ModelName::is_empty")]
770 pub model: ModelName,
771 pub contents: Vec<Content>,
772 #[serde(skip_serializing_if = "Option::is_none")]
773 pub system_instruction: Option<SystemInstruction>,
774 #[serde(skip_serializing_if = "Option::is_none")]
775 pub generation_config: Option<GenerationConfig>,
776 #[serde(skip_serializing_if = "Option::is_none")]
777 pub safety_settings: Option<Vec<SafetySetting>>,
778 #[serde(skip_serializing_if = "Option::is_none")]
779 pub tools: Option<Vec<Tool>>,
780 #[serde(skip_serializing_if = "Option::is_none")]
781 pub tool_config: Option<ToolConfig>,
782}
783
784#[derive(Debug, Serialize, Deserialize)]
785#[serde(rename_all = "camelCase")]
786pub struct GenerateContentResponse {
787 #[serde(skip_serializing_if = "Option::is_none")]
788 pub candidates: Option<Vec<GenerateContentCandidate>>,
789 #[serde(skip_serializing_if = "Option::is_none")]
790 pub prompt_feedback: Option<PromptFeedback>,
791 #[serde(skip_serializing_if = "Option::is_none")]
792 pub usage_metadata: Option<UsageMetadata>,
793}
794
795#[derive(Debug, Serialize, Deserialize)]
796#[serde(rename_all = "camelCase")]
797pub struct GenerateContentCandidate {
798 #[serde(skip_serializing_if = "Option::is_none")]
799 pub index: Option<usize>,
800 pub content: Content,
801 #[serde(skip_serializing_if = "Option::is_none")]
802 pub finish_reason: Option<String>,
803 #[serde(skip_serializing_if = "Option::is_none")]
804 pub finish_message: Option<String>,
805 #[serde(skip_serializing_if = "Option::is_none")]
806 pub safety_ratings: Option<Vec<SafetyRating>>,
807 #[serde(skip_serializing_if = "Option::is_none")]
808 pub citation_metadata: Option<CitationMetadata>,
809}
810
811#[derive(Debug, Serialize, Deserialize)]
812#[serde(rename_all = "camelCase")]
813pub struct Content {
814 #[serde(default)]
815 pub parts: Vec<Part>,
816 pub role: Role,
817}
818
819#[derive(Debug, Serialize, Deserialize)]
820#[serde(rename_all = "camelCase")]
821pub struct SystemInstruction {
822 pub parts: Vec<Part>,
823}
824
825#[derive(Debug, PartialEq, Deserialize, Serialize)]
826#[serde(rename_all = "camelCase")]
827pub enum Role {
828 User,
829 Model,
830}
831
832#[derive(Debug, Serialize, Deserialize)]
833#[serde(untagged)]
834pub enum Part {
835 TextPart(TextPart),
836 InlineDataPart(InlineDataPart),
837 FunctionCallPart(FunctionCallPart),
838 FunctionResponsePart(FunctionResponsePart),
839 ThoughtPart(ThoughtPart),
840}
841
842#[derive(Debug, Serialize, Deserialize)]
843#[serde(rename_all = "camelCase")]
844pub struct TextPart {
845 pub text: String,
846}
847
848#[derive(Debug, Serialize, Deserialize)]
849#[serde(rename_all = "camelCase")]
850pub struct InlineDataPart {
851 pub inline_data: GenerativeContentBlob,
852}
853
854#[derive(Debug, Serialize, Deserialize)]
855#[serde(rename_all = "camelCase")]
856pub struct GenerativeContentBlob {
857 pub mime_type: String,
858 pub data: String,
859}
860
861#[derive(Debug, Serialize, Deserialize)]
862#[serde(rename_all = "camelCase")]
863pub struct FunctionCallPart {
864 pub function_call: FunctionCall,
865 /// Thought signature returned by the model for function calls.
866 /// Only present on the first function call in parallel call scenarios.
867 #[serde(skip_serializing_if = "Option::is_none")]
868 pub thought_signature: Option<String>,
869}
870
871#[derive(Debug, Serialize, Deserialize)]
872#[serde(rename_all = "camelCase")]
873pub struct FunctionResponsePart {
874 pub function_response: FunctionResponse,
875}
876
877#[derive(Debug, Serialize, Deserialize)]
878#[serde(rename_all = "camelCase")]
879pub struct ThoughtPart {
880 pub thought: bool,
881 pub thought_signature: String,
882}
883
884#[derive(Debug, Serialize, Deserialize)]
885#[serde(rename_all = "camelCase")]
886pub struct CitationSource {
887 #[serde(skip_serializing_if = "Option::is_none")]
888 pub start_index: Option<usize>,
889 #[serde(skip_serializing_if = "Option::is_none")]
890 pub end_index: Option<usize>,
891 #[serde(skip_serializing_if = "Option::is_none")]
892 pub uri: Option<String>,
893 #[serde(skip_serializing_if = "Option::is_none")]
894 pub license: Option<String>,
895}
896
897#[derive(Debug, Serialize, Deserialize)]
898#[serde(rename_all = "camelCase")]
899pub struct CitationMetadata {
900 pub citation_sources: Vec<CitationSource>,
901}
902
903#[derive(Debug, Serialize, Deserialize)]
904#[serde(rename_all = "camelCase")]
905pub struct PromptFeedback {
906 #[serde(skip_serializing_if = "Option::is_none")]
907 pub block_reason: Option<String>,
908 pub safety_ratings: Option<Vec<SafetyRating>>,
909 #[serde(skip_serializing_if = "Option::is_none")]
910 pub block_reason_message: Option<String>,
911}
912
913#[derive(Debug, Clone, Serialize, Deserialize, Default)]
914#[serde(rename_all = "camelCase")]
915pub struct UsageMetadata {
916 #[serde(skip_serializing_if = "Option::is_none")]
917 pub prompt_token_count: Option<u64>,
918 #[serde(skip_serializing_if = "Option::is_none")]
919 pub cached_content_token_count: Option<u64>,
920 #[serde(skip_serializing_if = "Option::is_none")]
921 pub candidates_token_count: Option<u64>,
922 #[serde(skip_serializing_if = "Option::is_none")]
923 pub tool_use_prompt_token_count: Option<u64>,
924 #[serde(skip_serializing_if = "Option::is_none")]
925 pub thoughts_token_count: Option<u64>,
926 #[serde(skip_serializing_if = "Option::is_none")]
927 pub total_token_count: Option<u64>,
928}
929
930#[derive(Debug, Serialize, Deserialize)]
931#[serde(rename_all = "camelCase")]
932pub struct ThinkingConfig {
933 pub thinking_budget: u32,
934}
935
936#[derive(Debug, Deserialize, Serialize)]
937#[serde(rename_all = "camelCase")]
938pub struct GenerationConfig {
939 #[serde(skip_serializing_if = "Option::is_none")]
940 pub candidate_count: Option<usize>,
941 #[serde(skip_serializing_if = "Option::is_none")]
942 pub stop_sequences: Option<Vec<String>>,
943 #[serde(skip_serializing_if = "Option::is_none")]
944 pub max_output_tokens: Option<usize>,
945 #[serde(skip_serializing_if = "Option::is_none")]
946 pub temperature: Option<f64>,
947 #[serde(skip_serializing_if = "Option::is_none")]
948 pub top_p: Option<f64>,
949 #[serde(skip_serializing_if = "Option::is_none")]
950 pub top_k: Option<usize>,
951 #[serde(skip_serializing_if = "Option::is_none")]
952 pub thinking_config: Option<ThinkingConfig>,
953}
954
955#[derive(Debug, Serialize, Deserialize)]
956#[serde(rename_all = "camelCase")]
957pub struct SafetySetting {
958 pub category: HarmCategory,
959 pub threshold: HarmBlockThreshold,
960}
961
962#[derive(Debug, Serialize, Deserialize)]
963pub enum HarmCategory {
964 #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
965 Unspecified,
966 #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
967 Derogatory,
968 #[serde(rename = "HARM_CATEGORY_TOXICITY")]
969 Toxicity,
970 #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
971 Violence,
972 #[serde(rename = "HARM_CATEGORY_SEXUAL")]
973 Sexual,
974 #[serde(rename = "HARM_CATEGORY_MEDICAL")]
975 Medical,
976 #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
977 Dangerous,
978 #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
979 Harassment,
980 #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
981 HateSpeech,
982 #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
983 SexuallyExplicit,
984 #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
985 DangerousContent,
986}
987
988#[derive(Debug, Serialize, Deserialize)]
989#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
990pub enum HarmBlockThreshold {
991 #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
992 Unspecified,
993 BlockLowAndAbove,
994 BlockMediumAndAbove,
995 BlockOnlyHigh,
996 BlockNone,
997}
998
999#[derive(Debug, Serialize, Deserialize)]
1000#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1001pub enum HarmProbability {
1002 #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
1003 Unspecified,
1004 Negligible,
1005 Low,
1006 Medium,
1007 High,
1008}
1009
1010#[derive(Debug, Serialize, Deserialize)]
1011#[serde(rename_all = "camelCase")]
1012pub struct SafetyRating {
1013 pub category: HarmCategory,
1014 pub probability: HarmProbability,
1015}
1016
1017#[derive(Debug, Serialize, Deserialize)]
1018#[serde(rename_all = "camelCase")]
1019pub struct CountTokensRequest {
1020 pub generate_content_request: GenerateContentRequest,
1021}
1022
1023#[derive(Debug, Serialize, Deserialize)]
1024#[serde(rename_all = "camelCase")]
1025pub struct CountTokensResponse {
1026 pub total_tokens: u64,
1027}
1028
1029#[derive(Debug, Serialize, Deserialize)]
1030pub struct FunctionCall {
1031 pub name: String,
1032 pub args: serde_json::Value,
1033}
1034
1035#[derive(Debug, Serialize, Deserialize)]
1036pub struct FunctionResponse {
1037 pub name: String,
1038 pub response: serde_json::Value,
1039}
1040
1041#[derive(Debug, Serialize, Deserialize)]
1042#[serde(rename_all = "camelCase")]
1043pub struct Tool {
1044 pub function_declarations: Vec<FunctionDeclaration>,
1045}
1046
1047#[derive(Debug, Serialize, Deserialize)]
1048#[serde(rename_all = "camelCase")]
1049pub struct ToolConfig {
1050 pub function_calling_config: FunctionCallingConfig,
1051}
1052
1053#[derive(Debug, Serialize, Deserialize)]
1054#[serde(rename_all = "camelCase")]
1055pub struct FunctionCallingConfig {
1056 pub mode: FunctionCallingMode,
1057 #[serde(skip_serializing_if = "Option::is_none")]
1058 pub allowed_function_names: Option<Vec<String>>,
1059}
1060
1061#[derive(Debug, Serialize, Deserialize)]
1062#[serde(rename_all = "lowercase")]
1063pub enum FunctionCallingMode {
1064 Auto,
1065 Any,
1066 None,
1067}
1068
1069#[derive(Debug, Serialize, Deserialize)]
1070pub struct FunctionDeclaration {
1071 pub name: String,
1072 pub description: String,
1073 pub parameters: serde_json::Value,
1074}
1075
1076#[derive(Debug, Default)]
1077pub struct ModelName {
1078 pub model_id: String,
1079}
1080
1081impl ModelName {
1082 pub fn is_empty(&self) -> bool {
1083 self.model_id.is_empty()
1084 }
1085}
1086
1087const MODEL_NAME_PREFIX: &str = "models/";
1088
1089/// Google API error response structure
1090#[derive(Debug, Deserialize)]
1091pub struct ApiErrorResponse {
1092 pub error: ApiError,
1093}
1094
1095#[derive(Debug, Deserialize)]
1096pub struct ApiError {
1097 pub code: Option<u16>,
1098 pub message: Option<String>,
1099 pub status: Option<String>,
1100}
1101
1102impl Serialize for ModelName {
1103 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1104 where
1105 S: Serializer,
1106 {
1107 serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id))
1108 }
1109}
1110
1111impl<'de> Deserialize<'de> for ModelName {
1112 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1113 where
1114 D: Deserializer<'de>,
1115 {
1116 let string = String::deserialize(deserializer)?;
1117 if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) {
1118 Ok(Self {
1119 model_id: id.to_string(),
1120 })
1121 } else {
1122 Err(serde::de::Error::custom(format!(
1123 "Expected model name to begin with {}, got: {}",
1124 MODEL_NAME_PREFIX, string
1125 )))
1126 }
1127 }
1128}