1use std::collections::HashMap;
2
3use serde::{Deserialize, Deserializer, Serialize, Serializer};
4use zed_extension_api::{
5 self as zed, http_client::HttpMethod, http_client::HttpRequest, llm_get_env_var,
6 LlmCacheConfiguration, LlmCompletionEvent, LlmCompletionRequest, LlmMessageContent,
7 LlmMessageRole, LlmModelCapabilities, LlmModelInfo, LlmProviderInfo, LlmStopReason,
8 LlmThinkingContent, LlmTokenUsage, LlmToolInputFormat, LlmToolUse,
9};
10
11pub const API_URL: &str = "https://generativelanguage.googleapis.com";
12
13fn stream_generate_content(
14 model_id: &str,
15 request: &LlmCompletionRequest,
16 streams: &mut HashMap<String, StreamState>,
17 next_stream_id: &mut u64,
18) -> Result<String, String> {
19 let api_key = get_api_key().ok_or_else(|| "API key not configured".to_string())?;
20
21 let generate_content_request = build_generate_content_request(model_id, request)?;
22 validate_generate_content_request(&generate_content_request)?;
23
24 let uri = format!(
25 "{}/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
26 API_URL, model_id, api_key
27 );
28
29 let body = serde_json::to_vec(&generate_content_request)
30 .map_err(|e| format!("Failed to serialize request: {}", e))?;
31
32 let http_request = HttpRequest::builder()
33 .method(HttpMethod::Post)
34 .url(&uri)
35 .header("Content-Type", "application/json")
36 .body(body)
37 .build()?;
38
39 let response_stream = http_request.fetch_stream()?;
40
41 let stream_id = format!("stream-{}", *next_stream_id);
42 *next_stream_id += 1;
43
44 streams.insert(
45 stream_id.clone(),
46 StreamState {
47 response_stream,
48 buffer: String::new(),
49 usage: None,
50 },
51 );
52
53 Ok(stream_id)
54}
55
56fn count_tokens(model_id: &str, request: &LlmCompletionRequest) -> Result<u64, String> {
57 let api_key = get_api_key().ok_or_else(|| "API key not configured".to_string())?;
58
59 let generate_content_request = build_generate_content_request(model_id, request)?;
60 validate_generate_content_request(&generate_content_request)?;
61 let count_request = CountTokensRequest {
62 generate_content_request,
63 };
64
65 let uri = format!(
66 "{}/v1beta/models/{}:countTokens?key={}",
67 API_URL, model_id, api_key
68 );
69
70 let body = serde_json::to_vec(&count_request)
71 .map_err(|e| format!("Failed to serialize request: {}", e))?;
72
73 let http_request = HttpRequest::builder()
74 .method(HttpMethod::Post)
75 .url(&uri)
76 .header("Content-Type", "application/json")
77 .body(body)
78 .build()?;
79
80 let response = http_request.fetch()?;
81 let response_body: CountTokensResponse = serde_json::from_slice(&response.body)
82 .map_err(|e| format!("Failed to parse response: {}", e))?;
83
84 Ok(response_body.total_tokens)
85}
86
87fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<(), String> {
88 if request.model.is_empty() {
89 return Err("Model must be specified".to_string());
90 }
91
92 if request.contents.is_empty() {
93 return Err("Request must contain at least one content item".to_string());
94 }
95
96 if let Some(user_content) = request
97 .contents
98 .iter()
99 .find(|content| content.role == Role::User)
100 {
101 if user_content.parts.is_empty() {
102 return Err("User content must contain at least one part".to_string());
103 }
104 }
105
106 Ok(())
107}
108
109// Extension implementation
110
111const PROVIDER_ID: &str = "google-ai";
112const PROVIDER_NAME: &str = "Google AI";
113
114struct GoogleAiExtension {
115 streams: HashMap<String, StreamState>,
116 next_stream_id: u64,
117}
118
119struct StreamState {
120 response_stream: zed::http_client::HttpResponseStream,
121 buffer: String,
122 usage: Option<UsageMetadata>,
123}
124
125impl zed::Extension for GoogleAiExtension {
126 fn new() -> Self {
127 Self {
128 streams: HashMap::new(),
129 next_stream_id: 0,
130 }
131 }
132
133 fn llm_providers(&self) -> Vec<LlmProviderInfo> {
134 vec![LlmProviderInfo {
135 id: PROVIDER_ID.to_string(),
136 name: PROVIDER_NAME.to_string(),
137 icon: Some("icons/google-ai.svg".to_string()),
138 }]
139 }
140
141 fn llm_provider_models(&self, provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
142 if provider_id != PROVIDER_ID {
143 return Err(format!("Unknown provider: {}", provider_id));
144 }
145 Ok(get_models())
146 }
147
148 fn llm_provider_settings_markdown(&self, provider_id: &str) -> Option<String> {
149 if provider_id != PROVIDER_ID {
150 return None;
151 }
152
153 Some(
154 r#"## Google AI Setup
155
156To use Google AI models in Zed, you need a Gemini API key.
157
1581. Go to [Google AI Studio](https://aistudio.google.com/apikey)
1592. Create or select a project
1603. Generate an API key
1614. Set the `GEMINI_API_KEY` or `GOOGLE_AI_API_KEY` environment variable
162
163You can set this in your shell profile or use a `.envrc` file with [direnv](https://direnv.net/).
164"#
165 .to_string(),
166 )
167 }
168
169 fn llm_provider_is_authenticated(&self, provider_id: &str) -> bool {
170 if provider_id != PROVIDER_ID {
171 return false;
172 }
173 get_api_key().is_some()
174 }
175
176 fn llm_provider_reset_credentials(&mut self, provider_id: &str) -> Result<(), String> {
177 if provider_id != PROVIDER_ID {
178 return Err(format!("Unknown provider: {}", provider_id));
179 }
180 Ok(())
181 }
182
183 fn llm_count_tokens(
184 &self,
185 provider_id: &str,
186 model_id: &str,
187 request: &LlmCompletionRequest,
188 ) -> Result<u64, String> {
189 if provider_id != PROVIDER_ID {
190 return Err(format!("Unknown provider: {}", provider_id));
191 }
192 count_tokens(model_id, request)
193 }
194
195 fn llm_stream_completion_start(
196 &mut self,
197 provider_id: &str,
198 model_id: &str,
199 request: &LlmCompletionRequest,
200 ) -> Result<String, String> {
201 if provider_id != PROVIDER_ID {
202 return Err(format!("Unknown provider: {}", provider_id));
203 }
204 stream_generate_content(model_id, request, &mut self.streams, &mut self.next_stream_id)
205 }
206
207 fn llm_stream_completion_next(
208 &mut self,
209 stream_id: &str,
210 ) -> Result<Option<LlmCompletionEvent>, String> {
211 stream_generate_content_next(stream_id, &mut self.streams)
212 }
213
214 fn llm_stream_completion_close(&mut self, stream_id: &str) {
215 self.streams.remove(stream_id);
216 }
217
218 fn llm_cache_configuration(
219 &self,
220 provider_id: &str,
221 _model_id: &str,
222 ) -> Option<LlmCacheConfiguration> {
223 if provider_id != PROVIDER_ID {
224 return None;
225 }
226
227 Some(LlmCacheConfiguration {
228 max_cache_anchors: 1,
229 should_cache_tool_definitions: false,
230 min_total_token_count: 32768,
231 })
232 }
233}
234
235zed::register_extension!(GoogleAiExtension);
236
237// Helper functions
238
239fn get_api_key() -> Option<String> {
240 llm_get_env_var("GEMINI_API_KEY").or_else(|| llm_get_env_var("GOOGLE_AI_API_KEY"))
241}
242
243fn get_models() -> Vec<LlmModelInfo> {
244 vec![
245 LlmModelInfo {
246 id: "gemini-2.5-flash-lite".to_string(),
247 name: "Gemini 2.5 Flash-Lite".to_string(),
248 max_token_count: 1_048_576,
249 max_output_tokens: Some(65_536),
250 capabilities: LlmModelCapabilities {
251 supports_images: true,
252 supports_tools: true,
253 supports_tool_choice_auto: true,
254 supports_tool_choice_any: true,
255 supports_tool_choice_none: true,
256 supports_thinking: true,
257 tool_input_format: LlmToolInputFormat::JsonSchema,
258 },
259 is_default: false,
260 is_default_fast: true,
261 },
262 LlmModelInfo {
263 id: "gemini-2.5-flash".to_string(),
264 name: "Gemini 2.5 Flash".to_string(),
265 max_token_count: 1_048_576,
266 max_output_tokens: Some(65_536),
267 capabilities: LlmModelCapabilities {
268 supports_images: true,
269 supports_tools: true,
270 supports_tool_choice_auto: true,
271 supports_tool_choice_any: true,
272 supports_tool_choice_none: true,
273 supports_thinking: true,
274 tool_input_format: LlmToolInputFormat::JsonSchema,
275 },
276 is_default: true,
277 is_default_fast: false,
278 },
279 LlmModelInfo {
280 id: "gemini-2.5-pro".to_string(),
281 name: "Gemini 2.5 Pro".to_string(),
282 max_token_count: 1_048_576,
283 max_output_tokens: Some(65_536),
284 capabilities: LlmModelCapabilities {
285 supports_images: true,
286 supports_tools: true,
287 supports_tool_choice_auto: true,
288 supports_tool_choice_any: true,
289 supports_tool_choice_none: true,
290 supports_thinking: true,
291 tool_input_format: LlmToolInputFormat::JsonSchema,
292 },
293 is_default: false,
294 is_default_fast: false,
295 },
296 LlmModelInfo {
297 id: "gemini-3-pro-preview".to_string(),
298 name: "Gemini 3 Pro".to_string(),
299 max_token_count: 1_048_576,
300 max_output_tokens: Some(65_536),
301 capabilities: LlmModelCapabilities {
302 supports_images: true,
303 supports_tools: true,
304 supports_tool_choice_auto: true,
305 supports_tool_choice_any: true,
306 supports_tool_choice_none: true,
307 supports_thinking: true,
308 tool_input_format: LlmToolInputFormat::JsonSchema,
309 },
310 is_default: false,
311 is_default_fast: false,
312 },
313 LlmModelInfo {
314 id: "gemini-3-flash-preview".to_string(),
315 name: "Gemini 3 Flash".to_string(),
316 max_token_count: 1_048_576,
317 max_output_tokens: Some(65_536),
318 capabilities: LlmModelCapabilities {
319 supports_images: true,
320 supports_tools: true,
321 supports_tool_choice_auto: true,
322 supports_tool_choice_any: true,
323 supports_tool_choice_none: true,
324 supports_thinking: true,
325 tool_input_format: LlmToolInputFormat::JsonSchema,
326 },
327 is_default: false,
328 is_default_fast: false,
329 },
330 ]
331}
332
333fn stream_generate_content_next(
334 stream_id: &str,
335 streams: &mut HashMap<String, StreamState>,
336) -> Result<Option<LlmCompletionEvent>, String> {
337 let state = streams
338 .get_mut(stream_id)
339 .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
340
341 loop {
342 if let Some(newline_pos) = state.buffer.find('\n') {
343 let line = state.buffer[..newline_pos].to_string();
344 state.buffer = state.buffer[newline_pos + 1..].to_string();
345
346 if let Some(data) = line.strip_prefix("data: ") {
347 if data.trim().is_empty() {
348 continue;
349 }
350
351 let response: GenerateContentResponse = serde_json::from_str(data)
352 .map_err(|e| format!("Failed to parse SSE data: {} - {}", e, data))?;
353
354 if let Some(usage) = response.usage_metadata {
355 state.usage = Some(usage);
356 }
357
358 if let Some(candidates) = response.candidates {
359 for candidate in candidates {
360 for part in candidate.content.parts {
361 match part {
362 Part::TextPart(text_part) => {
363 return Ok(Some(LlmCompletionEvent::Text(text_part.text)));
364 }
365 Part::ThoughtPart(thought_part) => {
366 return Ok(Some(LlmCompletionEvent::Thinking(
367 LlmThinkingContent {
368 text: String::new(),
369 signature: Some(thought_part.thought_signature),
370 },
371 )));
372 }
373 Part::FunctionCallPart(fc_part) => {
374 return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
375 id: fc_part.function_call.name.clone(),
376 name: fc_part.function_call.name,
377 input: serde_json::to_string(&fc_part.function_call.args)
378 .unwrap_or_default(),
379 is_input_complete: true,
380 thought_signature: fc_part.thought_signature,
381 })));
382 }
383 _ => {}
384 }
385 }
386
387 if let Some(finish_reason) = candidate.finish_reason {
388 let stop_reason = match finish_reason.as_str() {
389 "STOP" => LlmStopReason::EndTurn,
390 "MAX_TOKENS" => LlmStopReason::MaxTokens,
391 "TOOL_USE" | "FUNCTION_CALL" => LlmStopReason::ToolUse,
392 "SAFETY" | "RECITATION" | "OTHER" => LlmStopReason::Refusal,
393 _ => LlmStopReason::EndTurn,
394 };
395
396 if let Some(usage) = state.usage.take() {
397 return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
398 input_tokens: usage.prompt_token_count.unwrap_or(0),
399 output_tokens: usage.candidates_token_count.unwrap_or(0),
400 cache_creation_input_tokens: None,
401 cache_read_input_tokens: usage.cached_content_token_count,
402 })));
403 }
404
405 return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
406 }
407 }
408 }
409 }
410
411 continue;
412 }
413
414 match state.response_stream.next_chunk() {
415 Ok(Some(chunk)) => {
416 let chunk_str = String::from_utf8_lossy(&chunk);
417 state.buffer.push_str(&chunk_str);
418 }
419 Ok(None) => {
420 streams.remove(stream_id);
421 return Ok(None);
422 }
423 Err(e) => {
424 streams.remove(stream_id);
425 return Err(e);
426 }
427 }
428 }
429}
430
431fn build_generate_content_request(
432 model_id: &str,
433 request: &LlmCompletionRequest,
434) -> Result<GenerateContentRequest, String> {
435 let mut contents: Vec<Content> = Vec::new();
436 let mut system_instruction: Option<SystemInstruction> = None;
437
438 for message in &request.messages {
439 match message.role {
440 LlmMessageRole::System => {
441 let parts = convert_content_to_parts(&message.content)?;
442 system_instruction = Some(SystemInstruction { parts });
443 }
444 LlmMessageRole::User | LlmMessageRole::Assistant => {
445 let role = match message.role {
446 LlmMessageRole::User => Role::User,
447 LlmMessageRole::Assistant => Role::Model,
448 _ => continue,
449 };
450 let parts = convert_content_to_parts(&message.content)?;
451 contents.push(Content { parts, role });
452 }
453 }
454 }
455
456 let tools = if !request.tools.is_empty() {
457 Some(vec![Tool {
458 function_declarations: request
459 .tools
460 .iter()
461 .map(|t| FunctionDeclaration {
462 name: t.name.clone(),
463 description: t.description.clone(),
464 parameters: serde_json::from_str(&t.input_schema).unwrap_or_default(),
465 })
466 .collect(),
467 }])
468 } else {
469 None
470 };
471
472 let tool_config = request.tool_choice.as_ref().map(|choice| {
473 let mode = match choice {
474 zed::LlmToolChoice::Auto => FunctionCallingMode::Auto,
475 zed::LlmToolChoice::Any => FunctionCallingMode::Any,
476 zed::LlmToolChoice::None => FunctionCallingMode::None,
477 };
478 ToolConfig {
479 function_calling_config: FunctionCallingConfig {
480 mode,
481 allowed_function_names: None,
482 },
483 }
484 });
485
486 let generation_config = Some(GenerationConfig {
487 candidate_count: Some(1),
488 stop_sequences: if request.stop_sequences.is_empty() {
489 None
490 } else {
491 Some(request.stop_sequences.clone())
492 },
493 max_output_tokens: request.max_tokens.map(|t| t as usize),
494 temperature: request.temperature.map(|t| t as f64),
495 top_p: None,
496 top_k: None,
497 thinking_config: if request.thinking_allowed {
498 Some(ThinkingConfig {
499 thinking_budget: 8192,
500 })
501 } else {
502 None
503 },
504 });
505
506 Ok(GenerateContentRequest {
507 model: ModelName {
508 model_id: model_id.to_string(),
509 },
510 contents,
511 system_instruction,
512 generation_config,
513 safety_settings: None,
514 tools,
515 tool_config,
516 })
517}
518
519fn convert_content_to_parts(content: &[LlmMessageContent]) -> Result<Vec<Part>, String> {
520 let mut parts = Vec::new();
521
522 for item in content {
523 match item {
524 LlmMessageContent::Text(text) => {
525 parts.push(Part::TextPart(TextPart { text: text.clone() }));
526 }
527 LlmMessageContent::Image(image) => {
528 parts.push(Part::InlineDataPart(InlineDataPart {
529 inline_data: GenerativeContentBlob {
530 mime_type: "image/png".to_string(),
531 data: image.source.clone(),
532 },
533 }));
534 }
535 LlmMessageContent::ToolUse(tool_use) => {
536 parts.push(Part::FunctionCallPart(FunctionCallPart {
537 function_call: FunctionCall {
538 name: tool_use.name.clone(),
539 args: serde_json::from_str(&tool_use.input).unwrap_or_default(),
540 },
541 thought_signature: tool_use.thought_signature.clone(),
542 }));
543 }
544 LlmMessageContent::ToolResult(tool_result) => {
545 let response_value = match &tool_result.content {
546 zed::LlmToolResultContent::Text(text) => {
547 serde_json::json!({ "result": text })
548 }
549 zed::LlmToolResultContent::Image(_) => {
550 serde_json::json!({ "error": "Image results not supported" })
551 }
552 };
553 parts.push(Part::FunctionResponsePart(FunctionResponsePart {
554 function_response: FunctionResponse {
555 name: tool_result.tool_name.clone(),
556 response: response_value,
557 },
558 }));
559 }
560 LlmMessageContent::Thinking(thinking) => {
561 if let Some(signature) = &thinking.signature {
562 parts.push(Part::ThoughtPart(ThoughtPart {
563 thought: true,
564 thought_signature: signature.clone(),
565 }));
566 }
567 }
568 LlmMessageContent::RedactedThinking(_) => {}
569 }
570 }
571
572 Ok(parts)
573}
574
575// Data structures for Google AI API
576
577#[derive(Debug, Serialize, Deserialize)]
578#[serde(rename_all = "camelCase")]
579pub struct GenerateContentRequest {
580 #[serde(default, skip_serializing_if = "ModelName::is_empty")]
581 pub model: ModelName,
582 pub contents: Vec<Content>,
583 #[serde(skip_serializing_if = "Option::is_none")]
584 pub system_instruction: Option<SystemInstruction>,
585 #[serde(skip_serializing_if = "Option::is_none")]
586 pub generation_config: Option<GenerationConfig>,
587 #[serde(skip_serializing_if = "Option::is_none")]
588 pub safety_settings: Option<Vec<SafetySetting>>,
589 #[serde(skip_serializing_if = "Option::is_none")]
590 pub tools: Option<Vec<Tool>>,
591 #[serde(skip_serializing_if = "Option::is_none")]
592 pub tool_config: Option<ToolConfig>,
593}
594
595#[derive(Debug, Serialize, Deserialize)]
596#[serde(rename_all = "camelCase")]
597pub struct GenerateContentResponse {
598 #[serde(skip_serializing_if = "Option::is_none")]
599 pub candidates: Option<Vec<GenerateContentCandidate>>,
600 #[serde(skip_serializing_if = "Option::is_none")]
601 pub prompt_feedback: Option<PromptFeedback>,
602 #[serde(skip_serializing_if = "Option::is_none")]
603 pub usage_metadata: Option<UsageMetadata>,
604}
605
606#[derive(Debug, Serialize, Deserialize)]
607#[serde(rename_all = "camelCase")]
608pub struct GenerateContentCandidate {
609 #[serde(skip_serializing_if = "Option::is_none")]
610 pub index: Option<usize>,
611 pub content: Content,
612 #[serde(skip_serializing_if = "Option::is_none")]
613 pub finish_reason: Option<String>,
614 #[serde(skip_serializing_if = "Option::is_none")]
615 pub finish_message: Option<String>,
616 #[serde(skip_serializing_if = "Option::is_none")]
617 pub safety_ratings: Option<Vec<SafetyRating>>,
618 #[serde(skip_serializing_if = "Option::is_none")]
619 pub citation_metadata: Option<CitationMetadata>,
620}
621
622#[derive(Debug, Serialize, Deserialize)]
623#[serde(rename_all = "camelCase")]
624pub struct Content {
625 #[serde(default)]
626 pub parts: Vec<Part>,
627 pub role: Role,
628}
629
630#[derive(Debug, Serialize, Deserialize)]
631#[serde(rename_all = "camelCase")]
632pub struct SystemInstruction {
633 pub parts: Vec<Part>,
634}
635
636#[derive(Debug, PartialEq, Deserialize, Serialize)]
637#[serde(rename_all = "camelCase")]
638pub enum Role {
639 User,
640 Model,
641}
642
643#[derive(Debug, Serialize, Deserialize)]
644#[serde(untagged)]
645pub enum Part {
646 TextPart(TextPart),
647 InlineDataPart(InlineDataPart),
648 FunctionCallPart(FunctionCallPart),
649 FunctionResponsePart(FunctionResponsePart),
650 ThoughtPart(ThoughtPart),
651}
652
653#[derive(Debug, Serialize, Deserialize)]
654#[serde(rename_all = "camelCase")]
655pub struct TextPart {
656 pub text: String,
657}
658
659#[derive(Debug, Serialize, Deserialize)]
660#[serde(rename_all = "camelCase")]
661pub struct InlineDataPart {
662 pub inline_data: GenerativeContentBlob,
663}
664
665#[derive(Debug, Serialize, Deserialize)]
666#[serde(rename_all = "camelCase")]
667pub struct GenerativeContentBlob {
668 pub mime_type: String,
669 pub data: String,
670}
671
672#[derive(Debug, Serialize, Deserialize)]
673#[serde(rename_all = "camelCase")]
674pub struct FunctionCallPart {
675 pub function_call: FunctionCall,
676 /// Thought signature returned by the model for function calls.
677 /// Only present on the first function call in parallel call scenarios.
678 #[serde(skip_serializing_if = "Option::is_none")]
679 pub thought_signature: Option<String>,
680}
681
682#[derive(Debug, Serialize, Deserialize)]
683#[serde(rename_all = "camelCase")]
684pub struct FunctionResponsePart {
685 pub function_response: FunctionResponse,
686}
687
688#[derive(Debug, Serialize, Deserialize)]
689#[serde(rename_all = "camelCase")]
690pub struct ThoughtPart {
691 pub thought: bool,
692 pub thought_signature: String,
693}
694
695#[derive(Debug, Serialize, Deserialize)]
696#[serde(rename_all = "camelCase")]
697pub struct CitationSource {
698 #[serde(skip_serializing_if = "Option::is_none")]
699 pub start_index: Option<usize>,
700 #[serde(skip_serializing_if = "Option::is_none")]
701 pub end_index: Option<usize>,
702 #[serde(skip_serializing_if = "Option::is_none")]
703 pub uri: Option<String>,
704 #[serde(skip_serializing_if = "Option::is_none")]
705 pub license: Option<String>,
706}
707
708#[derive(Debug, Serialize, Deserialize)]
709#[serde(rename_all = "camelCase")]
710pub struct CitationMetadata {
711 pub citation_sources: Vec<CitationSource>,
712}
713
714#[derive(Debug, Serialize, Deserialize)]
715#[serde(rename_all = "camelCase")]
716pub struct PromptFeedback {
717 #[serde(skip_serializing_if = "Option::is_none")]
718 pub block_reason: Option<String>,
719 pub safety_ratings: Option<Vec<SafetyRating>>,
720 #[serde(skip_serializing_if = "Option::is_none")]
721 pub block_reason_message: Option<String>,
722}
723
724#[derive(Debug, Serialize, Deserialize, Default)]
725#[serde(rename_all = "camelCase")]
726pub struct UsageMetadata {
727 #[serde(skip_serializing_if = "Option::is_none")]
728 pub prompt_token_count: Option<u64>,
729 #[serde(skip_serializing_if = "Option::is_none")]
730 pub cached_content_token_count: Option<u64>,
731 #[serde(skip_serializing_if = "Option::is_none")]
732 pub candidates_token_count: Option<u64>,
733 #[serde(skip_serializing_if = "Option::is_none")]
734 pub tool_use_prompt_token_count: Option<u64>,
735 #[serde(skip_serializing_if = "Option::is_none")]
736 pub thoughts_token_count: Option<u64>,
737 #[serde(skip_serializing_if = "Option::is_none")]
738 pub total_token_count: Option<u64>,
739}
740
741#[derive(Debug, Serialize, Deserialize)]
742#[serde(rename_all = "camelCase")]
743pub struct ThinkingConfig {
744 pub thinking_budget: u32,
745}
746
747#[derive(Debug, Deserialize, Serialize)]
748#[serde(rename_all = "camelCase")]
749pub struct GenerationConfig {
750 #[serde(skip_serializing_if = "Option::is_none")]
751 pub candidate_count: Option<usize>,
752 #[serde(skip_serializing_if = "Option::is_none")]
753 pub stop_sequences: Option<Vec<String>>,
754 #[serde(skip_serializing_if = "Option::is_none")]
755 pub max_output_tokens: Option<usize>,
756 #[serde(skip_serializing_if = "Option::is_none")]
757 pub temperature: Option<f64>,
758 #[serde(skip_serializing_if = "Option::is_none")]
759 pub top_p: Option<f64>,
760 #[serde(skip_serializing_if = "Option::is_none")]
761 pub top_k: Option<usize>,
762 #[serde(skip_serializing_if = "Option::is_none")]
763 pub thinking_config: Option<ThinkingConfig>,
764}
765
766#[derive(Debug, Serialize, Deserialize)]
767#[serde(rename_all = "camelCase")]
768pub struct SafetySetting {
769 pub category: HarmCategory,
770 pub threshold: HarmBlockThreshold,
771}
772
773#[derive(Debug, Serialize, Deserialize)]
774pub enum HarmCategory {
775 #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
776 Unspecified,
777 #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
778 Derogatory,
779 #[serde(rename = "HARM_CATEGORY_TOXICITY")]
780 Toxicity,
781 #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
782 Violence,
783 #[serde(rename = "HARM_CATEGORY_SEXUAL")]
784 Sexual,
785 #[serde(rename = "HARM_CATEGORY_MEDICAL")]
786 Medical,
787 #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
788 Dangerous,
789 #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
790 Harassment,
791 #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
792 HateSpeech,
793 #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
794 SexuallyExplicit,
795 #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
796 DangerousContent,
797}
798
799#[derive(Debug, Serialize, Deserialize)]
800#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
801pub enum HarmBlockThreshold {
802 #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
803 Unspecified,
804 BlockLowAndAbove,
805 BlockMediumAndAbove,
806 BlockOnlyHigh,
807 BlockNone,
808}
809
810#[derive(Debug, Serialize, Deserialize)]
811#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
812pub enum HarmProbability {
813 #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
814 Unspecified,
815 Negligible,
816 Low,
817 Medium,
818 High,
819}
820
821#[derive(Debug, Serialize, Deserialize)]
822#[serde(rename_all = "camelCase")]
823pub struct SafetyRating {
824 pub category: HarmCategory,
825 pub probability: HarmProbability,
826}
827
828#[derive(Debug, Serialize, Deserialize)]
829#[serde(rename_all = "camelCase")]
830pub struct CountTokensRequest {
831 pub generate_content_request: GenerateContentRequest,
832}
833
834#[derive(Debug, Serialize, Deserialize)]
835#[serde(rename_all = "camelCase")]
836pub struct CountTokensResponse {
837 pub total_tokens: u64,
838}
839
840#[derive(Debug, Serialize, Deserialize)]
841pub struct FunctionCall {
842 pub name: String,
843 pub args: serde_json::Value,
844}
845
846#[derive(Debug, Serialize, Deserialize)]
847pub struct FunctionResponse {
848 pub name: String,
849 pub response: serde_json::Value,
850}
851
852#[derive(Debug, Serialize, Deserialize)]
853#[serde(rename_all = "camelCase")]
854pub struct Tool {
855 pub function_declarations: Vec<FunctionDeclaration>,
856}
857
858#[derive(Debug, Serialize, Deserialize)]
859#[serde(rename_all = "camelCase")]
860pub struct ToolConfig {
861 pub function_calling_config: FunctionCallingConfig,
862}
863
864#[derive(Debug, Serialize, Deserialize)]
865#[serde(rename_all = "camelCase")]
866pub struct FunctionCallingConfig {
867 pub mode: FunctionCallingMode,
868 #[serde(skip_serializing_if = "Option::is_none")]
869 pub allowed_function_names: Option<Vec<String>>,
870}
871
872#[derive(Debug, Serialize, Deserialize)]
873#[serde(rename_all = "lowercase")]
874pub enum FunctionCallingMode {
875 Auto,
876 Any,
877 None,
878}
879
880#[derive(Debug, Serialize, Deserialize)]
881pub struct FunctionDeclaration {
882 pub name: String,
883 pub description: String,
884 pub parameters: serde_json::Value,
885}
886
887#[derive(Debug, Default)]
888pub struct ModelName {
889 pub model_id: String,
890}
891
892impl ModelName {
893 pub fn is_empty(&self) -> bool {
894 self.model_id.is_empty()
895 }
896}
897
898const MODEL_NAME_PREFIX: &str = "models/";
899
900impl Serialize for ModelName {
901 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
902 where
903 S: Serializer,
904 {
905 serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id))
906 }
907}
908
909impl<'de> Deserialize<'de> for ModelName {
910 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
911 where
912 D: Deserializer<'de>,
913 {
914 let string = String::deserialize(deserializer)?;
915 if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) {
916 Ok(Self {
917 model_id: id.to_string(),
918 })
919 } else {
920 Err(serde::de::Error::custom(format!(
921 "Expected model name to begin with {}, got: {}",
922 MODEL_NAME_PREFIX, string
923 )))
924 }
925 }
926}