1use std::collections::HashMap;
2
3use serde::{Deserialize, Deserializer, Serialize, Serializer};
4use zed_extension_api::{
5 self as zed, http_client::HttpMethod, http_client::HttpRequest,
6 llm_get_env_var, llm_get_provider_settings,
7 LlmCacheConfiguration, LlmCompletionEvent, LlmCompletionRequest, LlmCustomModelConfig,
8 LlmMessageContent, LlmMessageRole, LlmModelCapabilities, LlmModelInfo, LlmProviderInfo,
9 LlmStopReason, LlmThinkingContent, LlmTokenUsage, LlmToolInputFormat, LlmToolUse,
10};
11
12pub const DEFAULT_API_URL: &str = "https://generativelanguage.googleapis.com";
13
14fn get_api_url() -> String {
15 llm_get_provider_settings(PROVIDER_ID)
16 .and_then(|s| s.api_url)
17 .unwrap_or_else(|| DEFAULT_API_URL.to_string())
18}
19
20fn get_custom_models() -> Vec<LlmCustomModelConfig> {
21 llm_get_provider_settings(PROVIDER_ID)
22 .map(|s| s.available_models)
23 .unwrap_or_default()
24}
25
26fn stream_generate_content(
27 model_id: &str,
28 request: &LlmCompletionRequest,
29 streams: &mut HashMap<String, StreamState>,
30 next_stream_id: &mut u64,
31) -> Result<String, String> {
32 let api_key = get_api_key().ok_or_else(|| "API key not configured".to_string())?;
33
34 let generate_content_request = build_generate_content_request(model_id, request)?;
35 validate_generate_content_request(&generate_content_request)?;
36
37 let api_url = get_api_url();
38 let uri = format!(
39 "{}/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
40 api_url, model_id, api_key
41 );
42
43 let body = serde_json::to_vec(&generate_content_request)
44 .map_err(|e| format!("Failed to serialize request: {}", e))?;
45
46 let http_request = HttpRequest::builder()
47 .method(HttpMethod::Post)
48 .url(&uri)
49 .header("Content-Type", "application/json")
50 .body(body)
51 .build()?;
52
53 let response_stream = http_request.fetch_stream()?;
54
55 let stream_id = format!("stream-{}", *next_stream_id);
56 *next_stream_id += 1;
57
58 streams.insert(
59 stream_id.clone(),
60 StreamState {
61 response_stream,
62 buffer: String::new(),
63 usage: None,
64 pending_events: Vec::new(),
65 wants_to_use_tool: false,
66 },
67 );
68
69 Ok(stream_id)
70}
71
72fn count_tokens(model_id: &str, request: &LlmCompletionRequest) -> Result<u64, String> {
73 let api_key = get_api_key().ok_or_else(|| "API key not configured".to_string())?;
74
75 let generate_content_request = build_generate_content_request(model_id, request)?;
76 validate_generate_content_request(&generate_content_request)?;
77 let count_request = CountTokensRequest {
78 generate_content_request,
79 };
80
81 let api_url = get_api_url();
82 let uri = format!(
83 "{}/v1beta/models/{}:countTokens?key={}",
84 api_url, model_id, api_key
85 );
86
87 let body = serde_json::to_vec(&count_request)
88 .map_err(|e| format!("Failed to serialize request: {}", e))?;
89
90 let http_request = HttpRequest::builder()
91 .method(HttpMethod::Post)
92 .url(&uri)
93 .header("Content-Type", "application/json")
94 .body(body)
95 .build()?;
96
97 let response = http_request.fetch()?;
98 let response_body: CountTokensResponse = serde_json::from_slice(&response.body)
99 .map_err(|e| format!("Failed to parse response: {}", e))?;
100
101 Ok(response_body.total_tokens)
102}
103
104fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<(), String> {
105 if request.model.is_empty() {
106 return Err("Model must be specified".to_string());
107 }
108
109 if request.contents.is_empty() {
110 return Err("Request must contain at least one content item".to_string());
111 }
112
113 if let Some(user_content) = request
114 .contents
115 .iter()
116 .find(|content| content.role == Role::User)
117 {
118 if user_content.parts.is_empty() {
119 return Err("User content must contain at least one part".to_string());
120 }
121 }
122
123 Ok(())
124}
125
126// Extension implementation
127
128const PROVIDER_ID: &str = "google-ai";
129const PROVIDER_NAME: &str = "Google AI";
130
131struct GoogleAiExtension {
132 streams: HashMap<String, StreamState>,
133 next_stream_id: u64,
134}
135
136struct StreamState {
137 response_stream: zed::http_client::HttpResponseStream,
138 buffer: String,
139 usage: Option<UsageMetadata>,
140 pending_events: Vec<LlmCompletionEvent>,
141 wants_to_use_tool: bool,
142}
143
144impl zed::Extension for GoogleAiExtension {
145 fn new() -> Self {
146 Self {
147 streams: HashMap::new(),
148 next_stream_id: 0,
149 }
150 }
151
152 fn llm_providers(&self) -> Vec<LlmProviderInfo> {
153 vec![LlmProviderInfo {
154 id: PROVIDER_ID.to_string(),
155 name: PROVIDER_NAME.to_string(),
156 icon: Some("icons/google-ai.svg".to_string()),
157 }]
158 }
159
160 fn llm_provider_models(&self, provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
161 if provider_id != PROVIDER_ID {
162 return Err(format!("Unknown provider: {}", provider_id));
163 }
164 Ok(get_models())
165 }
166
167 fn llm_provider_settings_markdown(&self, provider_id: &str) -> Option<String> {
168 if provider_id != PROVIDER_ID {
169 return None;
170 }
171
172 Some(
173 r#"## Google AI Setup
174
175To use Google AI models in Zed, you need a Gemini API key.
176
1771. Go to [Google AI Studio](https://aistudio.google.com/apikey)
1782. Create or select a project
1793. Generate an API key
1804. Set the `GEMINI_API_KEY` or `GOOGLE_AI_API_KEY` environment variable
181
182You can set this in your shell profile or use a `.envrc` file with [direnv](https://direnv.net/).
183"#
184 .to_string(),
185 )
186 }
187
188 fn llm_provider_is_authenticated(&self, provider_id: &str) -> bool {
189 if provider_id != PROVIDER_ID {
190 return false;
191 }
192 get_api_key().is_some()
193 }
194
195 fn llm_provider_reset_credentials(&mut self, provider_id: &str) -> Result<(), String> {
196 if provider_id != PROVIDER_ID {
197 return Err(format!("Unknown provider: {}", provider_id));
198 }
199 Ok(())
200 }
201
202 fn llm_count_tokens(
203 &self,
204 provider_id: &str,
205 model_id: &str,
206 request: &LlmCompletionRequest,
207 ) -> Result<u64, String> {
208 if provider_id != PROVIDER_ID {
209 return Err(format!("Unknown provider: {}", provider_id));
210 }
211 count_tokens(model_id, request)
212 }
213
214 fn llm_stream_completion_start(
215 &mut self,
216 provider_id: &str,
217 model_id: &str,
218 request: &LlmCompletionRequest,
219 ) -> Result<String, String> {
220 if provider_id != PROVIDER_ID {
221 return Err(format!("Unknown provider: {}", provider_id));
222 }
223 stream_generate_content(model_id, request, &mut self.streams, &mut self.next_stream_id)
224 }
225
226 fn llm_stream_completion_next(
227 &mut self,
228 stream_id: &str,
229 ) -> Result<Option<LlmCompletionEvent>, String> {
230 stream_generate_content_next(stream_id, &mut self.streams)
231 }
232
233 fn llm_stream_completion_close(&mut self, stream_id: &str) {
234 self.streams.remove(stream_id);
235 }
236
237 fn llm_cache_configuration(
238 &self,
239 provider_id: &str,
240 _model_id: &str,
241 ) -> Option<LlmCacheConfiguration> {
242 if provider_id != PROVIDER_ID {
243 return None;
244 }
245
246 Some(LlmCacheConfiguration {
247 max_cache_anchors: 1,
248 should_cache_tool_definitions: false,
249 min_total_token_count: 32768,
250 })
251 }
252}
253
254zed::register_extension!(GoogleAiExtension);
255
256// Helper functions
257
258fn get_api_key() -> Option<String> {
259 llm_get_env_var("GEMINI_API_KEY").or_else(|| llm_get_env_var("GOOGLE_AI_API_KEY"))
260}
261
262fn get_default_models() -> Vec<LlmModelInfo> {
263 vec![
264 LlmModelInfo {
265 id: "gemini-2.5-flash-lite".to_string(),
266 name: "Gemini 2.5 Flash-Lite".to_string(),
267 max_token_count: 1_048_576,
268 max_output_tokens: Some(65_536),
269 capabilities: LlmModelCapabilities {
270 supports_images: true,
271 supports_tools: true,
272 supports_tool_choice_auto: true,
273 supports_tool_choice_any: true,
274 supports_tool_choice_none: true,
275 supports_thinking: true,
276 tool_input_format: LlmToolInputFormat::JsonSchema,
277 },
278 is_default: false,
279 is_default_fast: true,
280 },
281 LlmModelInfo {
282 id: "gemini-2.5-flash".to_string(),
283 name: "Gemini 2.5 Flash".to_string(),
284 max_token_count: 1_048_576,
285 max_output_tokens: Some(65_536),
286 capabilities: LlmModelCapabilities {
287 supports_images: true,
288 supports_tools: true,
289 supports_tool_choice_auto: true,
290 supports_tool_choice_any: true,
291 supports_tool_choice_none: true,
292 supports_thinking: true,
293 tool_input_format: LlmToolInputFormat::JsonSchema,
294 },
295 is_default: true,
296 is_default_fast: false,
297 },
298 LlmModelInfo {
299 id: "gemini-2.5-pro".to_string(),
300 name: "Gemini 2.5 Pro".to_string(),
301 max_token_count: 1_048_576,
302 max_output_tokens: Some(65_536),
303 capabilities: LlmModelCapabilities {
304 supports_images: true,
305 supports_tools: true,
306 supports_tool_choice_auto: true,
307 supports_tool_choice_any: true,
308 supports_tool_choice_none: true,
309 supports_thinking: true,
310 tool_input_format: LlmToolInputFormat::JsonSchema,
311 },
312 is_default: false,
313 is_default_fast: false,
314 },
315 LlmModelInfo {
316 id: "gemini-3-pro-preview".to_string(),
317 name: "Gemini 3 Pro".to_string(),
318 max_token_count: 1_048_576,
319 max_output_tokens: Some(65_536),
320 capabilities: LlmModelCapabilities {
321 supports_images: true,
322 supports_tools: true,
323 supports_tool_choice_auto: true,
324 supports_tool_choice_any: true,
325 supports_tool_choice_none: true,
326 supports_thinking: true,
327 tool_input_format: LlmToolInputFormat::JsonSchema,
328 },
329 is_default: false,
330 is_default_fast: false,
331 },
332 LlmModelInfo {
333 id: "gemini-3-flash-preview".to_string(),
334 name: "Gemini 3 Flash".to_string(),
335 max_token_count: 1_048_576,
336 max_output_tokens: Some(65_536),
337 capabilities: LlmModelCapabilities {
338 supports_images: true,
339 supports_tools: true,
340 supports_tool_choice_auto: true,
341 supports_tool_choice_any: true,
342 supports_tool_choice_none: true,
343 supports_thinking: true,
344 tool_input_format: LlmToolInputFormat::JsonSchema,
345 },
346 is_default: false,
347 is_default_fast: false,
348 },
349 ]
350}
351
352/// Model aliases for backward compatibility with old model names.
353/// Maps old names to canonical model IDs.
354fn get_model_aliases() -> Vec<(&'static str, &'static str)> {
355 vec![
356 // Gemini 2.5 Flash-Lite aliases
357 ("gemini-2.5-flash-lite-preview-06-17", "gemini-2.5-flash-lite"),
358 ("gemini-2.0-flash-lite-preview", "gemini-2.5-flash-lite"),
359 // Gemini 2.5 Flash aliases
360 ("gemini-2.0-flash-thinking-exp", "gemini-2.5-flash"),
361 ("gemini-2.5-flash-preview-04-17", "gemini-2.5-flash"),
362 ("gemini-2.5-flash-preview-05-20", "gemini-2.5-flash"),
363 ("gemini-2.5-flash-preview-latest", "gemini-2.5-flash"),
364 ("gemini-2.0-flash", "gemini-2.5-flash"),
365 // Gemini 2.5 Pro aliases
366 ("gemini-2.0-pro-exp", "gemini-2.5-pro"),
367 ("gemini-2.5-pro-preview-latest", "gemini-2.5-pro"),
368 ("gemini-2.5-pro-exp-03-25", "gemini-2.5-pro"),
369 ("gemini-2.5-pro-preview-03-25", "gemini-2.5-pro"),
370 ("gemini-2.5-pro-preview-05-06", "gemini-2.5-pro"),
371 ("gemini-2.5-pro-preview-06-05", "gemini-2.5-pro"),
372 ]
373}
374
375fn get_models() -> Vec<LlmModelInfo> {
376 let mut models: HashMap<String, LlmModelInfo> = HashMap::new();
377
378 // Add default models
379 for model in get_default_models() {
380 models.insert(model.id.clone(), model);
381 }
382
383 // Add aliases as separate model entries (pointing to the same underlying model)
384 for (alias, canonical_id) in get_model_aliases() {
385 if let Some(canonical_model) = models.get(canonical_id) {
386 let mut alias_model = canonical_model.clone();
387 alias_model.id = alias.to_string();
388 alias_model.is_default = false;
389 alias_model.is_default_fast = false;
390 models.insert(alias.to_string(), alias_model);
391 }
392 }
393
394 // Add/override with custom models from settings
395 for custom_model in get_custom_models() {
396 let model = LlmModelInfo {
397 id: custom_model.name.clone(),
398 name: custom_model.display_name.unwrap_or(custom_model.name.clone()),
399 max_token_count: custom_model.max_tokens,
400 max_output_tokens: custom_model.max_output_tokens,
401 capabilities: LlmModelCapabilities {
402 supports_images: true,
403 supports_tools: true,
404 supports_tool_choice_auto: true,
405 supports_tool_choice_any: true,
406 supports_tool_choice_none: true,
407 supports_thinking: custom_model.thinking_budget.is_some(),
408 tool_input_format: LlmToolInputFormat::JsonSchema,
409 },
410 is_default: false,
411 is_default_fast: false,
412 };
413 models.insert(custom_model.name, model);
414 }
415
416 models.into_values().collect()
417}
418
419/// Get the thinking budget for a specific model from custom settings.
420fn get_model_thinking_budget(model_id: &str) -> Option<u32> {
421 get_custom_models()
422 .into_iter()
423 .find(|m| m.name == model_id)
424 .and_then(|m| m.thinking_budget)
425}
426
427fn stream_generate_content_next(
428 stream_id: &str,
429 streams: &mut HashMap<String, StreamState>,
430) -> Result<Option<LlmCompletionEvent>, String> {
431 let state = streams
432 .get_mut(stream_id)
433 .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
434
435 loop {
436 // Return any pending events first
437 if let Some(event) = state.pending_events.pop() {
438 return Ok(Some(event));
439 }
440
441 if let Some(newline_pos) = state.buffer.find('\n') {
442 let line = state.buffer[..newline_pos].to_string();
443 state.buffer = state.buffer[newline_pos + 1..].to_string();
444
445 if let Some(data) = line.strip_prefix("data: ") {
446 if data.trim().is_empty() {
447 continue;
448 }
449
450 let response: GenerateContentResponse = match serde_json::from_str(data) {
451 Ok(response) => response,
452 Err(parse_error) => {
453 // Try to parse as an API error response
454 if let Ok(api_error) = serde_json::from_str::<ApiErrorResponse>(data) {
455 let error_msg = api_error
456 .error
457 .message
458 .unwrap_or_else(|| "Unknown API error".to_string());
459 let status = api_error.error.status.unwrap_or_default();
460 let code = api_error.error.code.unwrap_or(0);
461 return Err(format!(
462 "Google AI API error ({}): {} [status: {}]",
463 code, error_msg, status
464 ));
465 }
466 // If it's not an error response, return the parse error
467 return Err(format!(
468 "Failed to parse SSE data: {} - {}",
469 parse_error, data
470 ));
471 }
472 };
473
474 // Handle prompt feedback (blocked prompts)
475 if let Some(ref prompt_feedback) = response.prompt_feedback {
476 if let Some(ref block_reason) = prompt_feedback.block_reason {
477 let _stop_reason = match block_reason.as_str() {
478 "SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT"
479 | "IMAGE_SAFETY" => LlmStopReason::Refusal,
480 _ => LlmStopReason::Refusal,
481 };
482 return Ok(Some(LlmCompletionEvent::Stop(LlmStopReason::Refusal)));
483 }
484 }
485
486 // Send usage updates immediately when received
487 if let Some(ref usage) = response.usage_metadata {
488 let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
489 let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
490 let input_tokens = prompt_tokens.saturating_sub(cached_tokens);
491 state.pending_events.push(LlmCompletionEvent::Usage(LlmTokenUsage {
492 input_tokens,
493 output_tokens: usage.candidates_token_count.unwrap_or(0),
494 cache_creation_input_tokens: None,
495 cache_read_input_tokens: Some(cached_tokens).filter(|&c| c > 0),
496 }));
497 state.usage = Some(usage.clone());
498 }
499
500 if let Some(candidates) = response.candidates {
501 for candidate in candidates {
502 for part in candidate.content.parts {
503 match part {
504 Part::TextPart(text_part) => {
505 return Ok(Some(LlmCompletionEvent::Text(text_part.text)));
506 }
507 Part::ThoughtPart(thought_part) => {
508 return Ok(Some(LlmCompletionEvent::Thinking(
509 LlmThinkingContent {
510 text: "(Encrypted thought)".to_string(),
511 signature: Some(thought_part.thought_signature),
512 },
513 )));
514 }
515 Part::FunctionCallPart(fc_part) => {
516 state.wants_to_use_tool = true;
517 // Normalize empty string signatures to None
518 let thought_signature =
519 fc_part.thought_signature.filter(|s| !s.is_empty());
520 return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
521 id: fc_part.function_call.name.clone(),
522 name: fc_part.function_call.name,
523 input: serde_json::to_string(&fc_part.function_call.args)
524 .unwrap_or_default(),
525 is_input_complete: true,
526 thought_signature,
527 })));
528 }
529 _ => {}
530 }
531 }
532
533 if let Some(finish_reason) = candidate.finish_reason {
534 // Even when Gemini wants to use a Tool, the API
535 // responds with `finish_reason: STOP`, so we check
536 // wants_to_use_tool to override
537 let stop_reason = if state.wants_to_use_tool {
538 LlmStopReason::ToolUse
539 } else {
540 match finish_reason.as_str() {
541 "STOP" => LlmStopReason::EndTurn,
542 "MAX_TOKENS" => LlmStopReason::MaxTokens,
543 "TOOL_USE" | "FUNCTION_CALL" => LlmStopReason::ToolUse,
544 "SAFETY" | "RECITATION" | "OTHER" => LlmStopReason::Refusal,
545 _ => LlmStopReason::EndTurn,
546 }
547 };
548
549 return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
550 }
551 }
552 }
553 }
554
555 continue;
556 }
557
558 // Check if the buffer contains a non-SSE error response (no "data: " prefix)
559 // This can happen when Google returns an immediate error without streaming
560 if !state.buffer.is_empty()
561 && !state.buffer.contains("data: ")
562 && state.buffer.contains("\"error\"")
563 {
564 // Try to parse the entire buffer as an error response
565 if let Ok(api_error) = serde_json::from_str::<ApiErrorResponse>(&state.buffer) {
566 let error_msg = api_error
567 .error
568 .message
569 .unwrap_or_else(|| "Unknown API error".to_string());
570 let status = api_error.error.status.unwrap_or_default();
571 let code = api_error.error.code.unwrap_or(0);
572 streams.remove(stream_id);
573 return Err(format!(
574 "Google AI API error ({}): {} [status: {}]",
575 code, error_msg, status
576 ));
577 }
578 }
579
580 match state.response_stream.next_chunk() {
581 Ok(Some(chunk)) => {
582 let chunk_str = String::from_utf8_lossy(&chunk);
583 state.buffer.push_str(&chunk_str);
584 }
585 Ok(None) => {
586 streams.remove(stream_id);
587 return Ok(None);
588 }
589 Err(e) => {
590 streams.remove(stream_id);
591 return Err(e);
592 }
593 }
594 }
595}
596
597fn build_generate_content_request(
598 model_id: &str,
599 request: &LlmCompletionRequest,
600) -> Result<GenerateContentRequest, String> {
601 let mut contents: Vec<Content> = Vec::new();
602 let mut system_instruction: Option<SystemInstruction> = None;
603
604 for message in &request.messages {
605 match message.role {
606 LlmMessageRole::System => {
607 let parts = convert_content_to_parts(&message.content)?;
608 system_instruction = Some(SystemInstruction { parts });
609 }
610 LlmMessageRole::User | LlmMessageRole::Assistant => {
611 let role = match message.role {
612 LlmMessageRole::User => Role::User,
613 LlmMessageRole::Assistant => Role::Model,
614 _ => continue,
615 };
616 let parts = convert_content_to_parts(&message.content)?;
617 contents.push(Content { parts, role });
618 }
619 }
620 }
621
622 let tools = if !request.tools.is_empty() {
623 Some(vec![Tool {
624 function_declarations: request
625 .tools
626 .iter()
627 .map(|t| FunctionDeclaration {
628 name: t.name.clone(),
629 description: t.description.clone(),
630 parameters: serde_json::from_str(&t.input_schema).unwrap_or_default(),
631 })
632 .collect(),
633 }])
634 } else {
635 None
636 };
637
638 let tool_config = request.tool_choice.as_ref().map(|choice| {
639 let mode = match choice {
640 zed::LlmToolChoice::Auto => FunctionCallingMode::Auto,
641 zed::LlmToolChoice::Any => FunctionCallingMode::Any,
642 zed::LlmToolChoice::None => FunctionCallingMode::None,
643 };
644 ToolConfig {
645 function_calling_config: FunctionCallingConfig {
646 mode,
647 allowed_function_names: None,
648 },
649 }
650 });
651
652 let generation_config = Some(GenerationConfig {
653 candidate_count: Some(1),
654 stop_sequences: if request.stop_sequences.is_empty() {
655 None
656 } else {
657 Some(request.stop_sequences.clone())
658 },
659 max_output_tokens: request.max_tokens.map(|t| t as usize),
660 temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
661 top_p: None,
662 top_k: None,
663 thinking_config: if request.thinking_allowed {
664 // Check if this model has a custom thinking budget configured
665 get_model_thinking_budget(model_id).map(|thinking_budget| ThinkingConfig {
666 thinking_budget,
667 })
668 } else {
669 None
670 },
671 });
672
673 Ok(GenerateContentRequest {
674 model: ModelName {
675 model_id: model_id.to_string(),
676 },
677 contents,
678 system_instruction,
679 generation_config,
680 safety_settings: None,
681 tools,
682 tool_config,
683 })
684}
685
686fn convert_content_to_parts(content: &[LlmMessageContent]) -> Result<Vec<Part>, String> {
687 let mut parts = Vec::new();
688
689 for item in content {
690 match item {
691 LlmMessageContent::Text(text) => {
692 parts.push(Part::TextPart(TextPart { text: text.clone() }));
693 }
694 LlmMessageContent::Image(image) => {
695 parts.push(Part::InlineDataPart(InlineDataPart {
696 inline_data: GenerativeContentBlob {
697 mime_type: "image/png".to_string(),
698 data: image.source.clone(),
699 },
700 }));
701 }
702 LlmMessageContent::ToolUse(tool_use) => {
703 // Normalize empty string signatures to None
704 let thought_signature = tool_use
705 .thought_signature
706 .clone()
707 .filter(|s| !s.is_empty());
708 parts.push(Part::FunctionCallPart(FunctionCallPart {
709 function_call: FunctionCall {
710 name: tool_use.name.clone(),
711 args: serde_json::from_str(&tool_use.input).unwrap_or_default(),
712 },
713 thought_signature,
714 }));
715 }
716 LlmMessageContent::ToolResult(tool_result) => {
717 match &tool_result.content {
718 zed::LlmToolResultContent::Text(text) => {
719 parts.push(Part::FunctionResponsePart(FunctionResponsePart {
720 function_response: FunctionResponse {
721 name: tool_result.tool_name.clone(),
722 response: serde_json::json!({ "output": text }),
723 },
724 }));
725 }
726 zed::LlmToolResultContent::Image(image) => {
727 // Send both the function response and the image inline
728 parts.push(Part::FunctionResponsePart(FunctionResponsePart {
729 function_response: FunctionResponse {
730 name: tool_result.tool_name.clone(),
731 response: serde_json::json!({ "output": "Tool responded with an image" }),
732 },
733 }));
734 parts.push(Part::InlineDataPart(InlineDataPart {
735 inline_data: GenerativeContentBlob {
736 mime_type: "image/png".to_string(),
737 data: image.source.clone(),
738 },
739 }));
740 }
741 }
742 }
743 LlmMessageContent::Thinking(thinking) => {
744 if let Some(signature) = &thinking.signature {
745 parts.push(Part::ThoughtPart(ThoughtPart {
746 thought: true,
747 thought_signature: signature.clone(),
748 }));
749 }
750 }
751 LlmMessageContent::RedactedThinking(_) => {}
752 }
753 }
754
755 Ok(parts)
756}
757
758// Data structures for Google AI API
759
760#[derive(Debug, Serialize, Deserialize)]
761#[serde(rename_all = "camelCase")]
762pub struct GenerateContentRequest {
763 #[serde(default, skip_serializing_if = "ModelName::is_empty")]
764 pub model: ModelName,
765 pub contents: Vec<Content>,
766 #[serde(skip_serializing_if = "Option::is_none")]
767 pub system_instruction: Option<SystemInstruction>,
768 #[serde(skip_serializing_if = "Option::is_none")]
769 pub generation_config: Option<GenerationConfig>,
770 #[serde(skip_serializing_if = "Option::is_none")]
771 pub safety_settings: Option<Vec<SafetySetting>>,
772 #[serde(skip_serializing_if = "Option::is_none")]
773 pub tools: Option<Vec<Tool>>,
774 #[serde(skip_serializing_if = "Option::is_none")]
775 pub tool_config: Option<ToolConfig>,
776}
777
778#[derive(Debug, Serialize, Deserialize)]
779#[serde(rename_all = "camelCase")]
780pub struct GenerateContentResponse {
781 #[serde(skip_serializing_if = "Option::is_none")]
782 pub candidates: Option<Vec<GenerateContentCandidate>>,
783 #[serde(skip_serializing_if = "Option::is_none")]
784 pub prompt_feedback: Option<PromptFeedback>,
785 #[serde(skip_serializing_if = "Option::is_none")]
786 pub usage_metadata: Option<UsageMetadata>,
787}
788
789#[derive(Debug, Serialize, Deserialize)]
790#[serde(rename_all = "camelCase")]
791pub struct GenerateContentCandidate {
792 #[serde(skip_serializing_if = "Option::is_none")]
793 pub index: Option<usize>,
794 pub content: Content,
795 #[serde(skip_serializing_if = "Option::is_none")]
796 pub finish_reason: Option<String>,
797 #[serde(skip_serializing_if = "Option::is_none")]
798 pub finish_message: Option<String>,
799 #[serde(skip_serializing_if = "Option::is_none")]
800 pub safety_ratings: Option<Vec<SafetyRating>>,
801 #[serde(skip_serializing_if = "Option::is_none")]
802 pub citation_metadata: Option<CitationMetadata>,
803}
804
805#[derive(Debug, Serialize, Deserialize)]
806#[serde(rename_all = "camelCase")]
807pub struct Content {
808 #[serde(default)]
809 pub parts: Vec<Part>,
810 pub role: Role,
811}
812
813#[derive(Debug, Serialize, Deserialize)]
814#[serde(rename_all = "camelCase")]
815pub struct SystemInstruction {
816 pub parts: Vec<Part>,
817}
818
819#[derive(Debug, PartialEq, Deserialize, Serialize)]
820#[serde(rename_all = "camelCase")]
821pub enum Role {
822 User,
823 Model,
824}
825
826#[derive(Debug, Serialize, Deserialize)]
827#[serde(untagged)]
828pub enum Part {
829 TextPart(TextPart),
830 InlineDataPart(InlineDataPart),
831 FunctionCallPart(FunctionCallPart),
832 FunctionResponsePart(FunctionResponsePart),
833 ThoughtPart(ThoughtPart),
834}
835
836#[derive(Debug, Serialize, Deserialize)]
837#[serde(rename_all = "camelCase")]
838pub struct TextPart {
839 pub text: String,
840}
841
842#[derive(Debug, Serialize, Deserialize)]
843#[serde(rename_all = "camelCase")]
844pub struct InlineDataPart {
845 pub inline_data: GenerativeContentBlob,
846}
847
848#[derive(Debug, Serialize, Deserialize)]
849#[serde(rename_all = "camelCase")]
850pub struct GenerativeContentBlob {
851 pub mime_type: String,
852 pub data: String,
853}
854
855#[derive(Debug, Serialize, Deserialize)]
856#[serde(rename_all = "camelCase")]
857pub struct FunctionCallPart {
858 pub function_call: FunctionCall,
859 /// Thought signature returned by the model for function calls.
860 /// Only present on the first function call in parallel call scenarios.
861 #[serde(skip_serializing_if = "Option::is_none")]
862 pub thought_signature: Option<String>,
863}
864
865#[derive(Debug, Serialize, Deserialize)]
866#[serde(rename_all = "camelCase")]
867pub struct FunctionResponsePart {
868 pub function_response: FunctionResponse,
869}
870
871#[derive(Debug, Serialize, Deserialize)]
872#[serde(rename_all = "camelCase")]
873pub struct ThoughtPart {
874 pub thought: bool,
875 pub thought_signature: String,
876}
877
878#[derive(Debug, Serialize, Deserialize)]
879#[serde(rename_all = "camelCase")]
880pub struct CitationSource {
881 #[serde(skip_serializing_if = "Option::is_none")]
882 pub start_index: Option<usize>,
883 #[serde(skip_serializing_if = "Option::is_none")]
884 pub end_index: Option<usize>,
885 #[serde(skip_serializing_if = "Option::is_none")]
886 pub uri: Option<String>,
887 #[serde(skip_serializing_if = "Option::is_none")]
888 pub license: Option<String>,
889}
890
891#[derive(Debug, Serialize, Deserialize)]
892#[serde(rename_all = "camelCase")]
893pub struct CitationMetadata {
894 pub citation_sources: Vec<CitationSource>,
895}
896
897#[derive(Debug, Serialize, Deserialize)]
898#[serde(rename_all = "camelCase")]
899pub struct PromptFeedback {
900 #[serde(skip_serializing_if = "Option::is_none")]
901 pub block_reason: Option<String>,
902 pub safety_ratings: Option<Vec<SafetyRating>>,
903 #[serde(skip_serializing_if = "Option::is_none")]
904 pub block_reason_message: Option<String>,
905}
906
907#[derive(Debug, Clone, Serialize, Deserialize, Default)]
908#[serde(rename_all = "camelCase")]
909pub struct UsageMetadata {
910 #[serde(skip_serializing_if = "Option::is_none")]
911 pub prompt_token_count: Option<u64>,
912 #[serde(skip_serializing_if = "Option::is_none")]
913 pub cached_content_token_count: Option<u64>,
914 #[serde(skip_serializing_if = "Option::is_none")]
915 pub candidates_token_count: Option<u64>,
916 #[serde(skip_serializing_if = "Option::is_none")]
917 pub tool_use_prompt_token_count: Option<u64>,
918 #[serde(skip_serializing_if = "Option::is_none")]
919 pub thoughts_token_count: Option<u64>,
920 #[serde(skip_serializing_if = "Option::is_none")]
921 pub total_token_count: Option<u64>,
922}
923
924#[derive(Debug, Serialize, Deserialize)]
925#[serde(rename_all = "camelCase")]
926pub struct ThinkingConfig {
927 pub thinking_budget: u32,
928}
929
930#[derive(Debug, Deserialize, Serialize)]
931#[serde(rename_all = "camelCase")]
932pub struct GenerationConfig {
933 #[serde(skip_serializing_if = "Option::is_none")]
934 pub candidate_count: Option<usize>,
935 #[serde(skip_serializing_if = "Option::is_none")]
936 pub stop_sequences: Option<Vec<String>>,
937 #[serde(skip_serializing_if = "Option::is_none")]
938 pub max_output_tokens: Option<usize>,
939 #[serde(skip_serializing_if = "Option::is_none")]
940 pub temperature: Option<f64>,
941 #[serde(skip_serializing_if = "Option::is_none")]
942 pub top_p: Option<f64>,
943 #[serde(skip_serializing_if = "Option::is_none")]
944 pub top_k: Option<usize>,
945 #[serde(skip_serializing_if = "Option::is_none")]
946 pub thinking_config: Option<ThinkingConfig>,
947}
948
949#[derive(Debug, Serialize, Deserialize)]
950#[serde(rename_all = "camelCase")]
951pub struct SafetySetting {
952 pub category: HarmCategory,
953 pub threshold: HarmBlockThreshold,
954}
955
956#[derive(Debug, Serialize, Deserialize)]
957pub enum HarmCategory {
958 #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
959 Unspecified,
960 #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
961 Derogatory,
962 #[serde(rename = "HARM_CATEGORY_TOXICITY")]
963 Toxicity,
964 #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
965 Violence,
966 #[serde(rename = "HARM_CATEGORY_SEXUAL")]
967 Sexual,
968 #[serde(rename = "HARM_CATEGORY_MEDICAL")]
969 Medical,
970 #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
971 Dangerous,
972 #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
973 Harassment,
974 #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
975 HateSpeech,
976 #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
977 SexuallyExplicit,
978 #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
979 DangerousContent,
980}
981
982#[derive(Debug, Serialize, Deserialize)]
983#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
984pub enum HarmBlockThreshold {
985 #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
986 Unspecified,
987 BlockLowAndAbove,
988 BlockMediumAndAbove,
989 BlockOnlyHigh,
990 BlockNone,
991}
992
993#[derive(Debug, Serialize, Deserialize)]
994#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
995pub enum HarmProbability {
996 #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
997 Unspecified,
998 Negligible,
999 Low,
1000 Medium,
1001 High,
1002}
1003
1004#[derive(Debug, Serialize, Deserialize)]
1005#[serde(rename_all = "camelCase")]
1006pub struct SafetyRating {
1007 pub category: HarmCategory,
1008 pub probability: HarmProbability,
1009}
1010
1011#[derive(Debug, Serialize, Deserialize)]
1012#[serde(rename_all = "camelCase")]
1013pub struct CountTokensRequest {
1014 pub generate_content_request: GenerateContentRequest,
1015}
1016
1017#[derive(Debug, Serialize, Deserialize)]
1018#[serde(rename_all = "camelCase")]
1019pub struct CountTokensResponse {
1020 pub total_tokens: u64,
1021}
1022
1023#[derive(Debug, Serialize, Deserialize)]
1024pub struct FunctionCall {
1025 pub name: String,
1026 pub args: serde_json::Value,
1027}
1028
1029#[derive(Debug, Serialize, Deserialize)]
1030pub struct FunctionResponse {
1031 pub name: String,
1032 pub response: serde_json::Value,
1033}
1034
1035#[derive(Debug, Serialize, Deserialize)]
1036#[serde(rename_all = "camelCase")]
1037pub struct Tool {
1038 pub function_declarations: Vec<FunctionDeclaration>,
1039}
1040
1041#[derive(Debug, Serialize, Deserialize)]
1042#[serde(rename_all = "camelCase")]
1043pub struct ToolConfig {
1044 pub function_calling_config: FunctionCallingConfig,
1045}
1046
1047#[derive(Debug, Serialize, Deserialize)]
1048#[serde(rename_all = "camelCase")]
1049pub struct FunctionCallingConfig {
1050 pub mode: FunctionCallingMode,
1051 #[serde(skip_serializing_if = "Option::is_none")]
1052 pub allowed_function_names: Option<Vec<String>>,
1053}
1054
1055#[derive(Debug, Serialize, Deserialize)]
1056#[serde(rename_all = "lowercase")]
1057pub enum FunctionCallingMode {
1058 Auto,
1059 Any,
1060 None,
1061}
1062
1063#[derive(Debug, Serialize, Deserialize)]
1064pub struct FunctionDeclaration {
1065 pub name: String,
1066 pub description: String,
1067 pub parameters: serde_json::Value,
1068}
1069
1070#[derive(Debug, Default)]
1071pub struct ModelName {
1072 pub model_id: String,
1073}
1074
1075impl ModelName {
1076 pub fn is_empty(&self) -> bool {
1077 self.model_id.is_empty()
1078 }
1079}
1080
1081const MODEL_NAME_PREFIX: &str = "models/";
1082
1083/// Google API error response structure
1084#[derive(Debug, Deserialize)]
1085pub struct ApiErrorResponse {
1086 pub error: ApiError,
1087}
1088
1089#[derive(Debug, Deserialize)]
1090pub struct ApiError {
1091 pub code: Option<u16>,
1092 pub message: Option<String>,
1093 pub status: Option<String>,
1094}
1095
1096impl Serialize for ModelName {
1097 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1098 where
1099 S: Serializer,
1100 {
1101 serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id))
1102 }
1103}
1104
1105impl<'de> Deserialize<'de> for ModelName {
1106 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1107 where
1108 D: Deserializer<'de>,
1109 {
1110 let string = String::deserialize(deserializer)?;
1111 if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) {
1112 Ok(Self {
1113 model_id: id.to_string(),
1114 })
1115 } else {
1116 Err(serde::de::Error::custom(format!(
1117 "Expected model name to begin with {}, got: {}",
1118 MODEL_NAME_PREFIX, string
1119 )))
1120 }
1121 }
1122}