1mod supported_countries;
2
3use anyhow::{anyhow, Result};
4use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
5use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
6use serde::{Deserialize, Serialize};
7
8pub use supported_countries::*;
9
10pub const API_URL: &str = "https://generativelanguage.googleapis.com";
11
12pub async fn stream_generate_content(
13 client: &dyn HttpClient,
14 api_url: &str,
15 api_key: &str,
16 mut request: GenerateContentRequest,
17) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
18 let uri = format!(
19 "{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}",
20 model = request.model
21 );
22 request.model.clear();
23
24 let request_builder = HttpRequest::builder()
25 .method(Method::POST)
26 .uri(uri)
27 .header("Content-Type", "application/json");
28
29 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
30 let mut response = client.send(request).await?;
31 if response.status().is_success() {
32 let reader = BufReader::new(response.into_body());
33 Ok(reader
34 .lines()
35 .filter_map(|line| async move {
36 match line {
37 Ok(line) => {
38 if let Some(line) = line.strip_prefix("data: ") {
39 match serde_json::from_str(line) {
40 Ok(response) => Some(Ok(response)),
41 Err(error) => Some(Err(anyhow!(error))),
42 }
43 } else {
44 None
45 }
46 }
47 Err(error) => Some(Err(anyhow!(error))),
48 }
49 })
50 .boxed())
51 } else {
52 let mut text = String::new();
53 response.body_mut().read_to_string(&mut text).await?;
54 Err(anyhow!(
55 "error during streamGenerateContent, status code: {:?}, body: {}",
56 response.status(),
57 text
58 ))
59 }
60}
61
62pub async fn count_tokens(
63 client: &dyn HttpClient,
64 api_url: &str,
65 api_key: &str,
66 request: CountTokensRequest,
67) -> Result<CountTokensResponse> {
68 let uri = format!(
69 "{}/v1beta/models/gemini-pro:countTokens?key={}",
70 api_url, api_key
71 );
72 let request = serde_json::to_string(&request)?;
73
74 let request_builder = HttpRequest::builder()
75 .method(Method::POST)
76 .uri(&uri)
77 .header("Content-Type", "application/json");
78
79 let http_request = request_builder.body(AsyncBody::from(request))?;
80 let mut response = client.send(http_request).await?;
81 let mut text = String::new();
82 response.body_mut().read_to_string(&mut text).await?;
83 if response.status().is_success() {
84 Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
85 } else {
86 Err(anyhow!(
87 "error during countTokens, status code: {:?}, body: {}",
88 response.status(),
89 text
90 ))
91 }
92}
93
94#[derive(Debug, Serialize, Deserialize)]
95pub enum Task {
96 #[serde(rename = "generateContent")]
97 GenerateContent,
98 #[serde(rename = "streamGenerateContent")]
99 StreamGenerateContent,
100 #[serde(rename = "countTokens")]
101 CountTokens,
102 #[serde(rename = "embedContent")]
103 EmbedContent,
104 #[serde(rename = "batchEmbedContents")]
105 BatchEmbedContents,
106}
107
108#[derive(Debug, Serialize, Deserialize)]
109#[serde(rename_all = "camelCase")]
110pub struct GenerateContentRequest {
111 #[serde(default, skip_serializing_if = "String::is_empty")]
112 pub model: String,
113 pub contents: Vec<Content>,
114 pub generation_config: Option<GenerationConfig>,
115 pub safety_settings: Option<Vec<SafetySetting>>,
116}
117
118#[derive(Debug, Serialize, Deserialize)]
119#[serde(rename_all = "camelCase")]
120pub struct GenerateContentResponse {
121 pub candidates: Option<Vec<GenerateContentCandidate>>,
122 pub prompt_feedback: Option<PromptFeedback>,
123}
124
125#[derive(Debug, Serialize, Deserialize)]
126#[serde(rename_all = "camelCase")]
127pub struct GenerateContentCandidate {
128 pub index: Option<usize>,
129 pub content: Content,
130 pub finish_reason: Option<String>,
131 pub finish_message: Option<String>,
132 pub safety_ratings: Option<Vec<SafetyRating>>,
133 pub citation_metadata: Option<CitationMetadata>,
134}
135
136#[derive(Debug, Serialize, Deserialize)]
137#[serde(rename_all = "camelCase")]
138pub struct Content {
139 pub parts: Vec<Part>,
140 pub role: Role,
141}
142
143#[derive(Debug, Deserialize, Serialize)]
144#[serde(rename_all = "camelCase")]
145pub enum Role {
146 User,
147 Model,
148}
149
150#[derive(Debug, Serialize, Deserialize)]
151#[serde(untagged)]
152pub enum Part {
153 TextPart(TextPart),
154 InlineDataPart(InlineDataPart),
155}
156
157#[derive(Debug, Serialize, Deserialize)]
158#[serde(rename_all = "camelCase")]
159pub struct TextPart {
160 pub text: String,
161}
162
163#[derive(Debug, Serialize, Deserialize)]
164#[serde(rename_all = "camelCase")]
165pub struct InlineDataPart {
166 pub inline_data: GenerativeContentBlob,
167}
168
169#[derive(Debug, Serialize, Deserialize)]
170#[serde(rename_all = "camelCase")]
171pub struct GenerativeContentBlob {
172 pub mime_type: String,
173 pub data: String,
174}
175
176#[derive(Debug, Serialize, Deserialize)]
177#[serde(rename_all = "camelCase")]
178pub struct CitationSource {
179 pub start_index: Option<usize>,
180 pub end_index: Option<usize>,
181 pub uri: Option<String>,
182 pub license: Option<String>,
183}
184
185#[derive(Debug, Serialize, Deserialize)]
186#[serde(rename_all = "camelCase")]
187pub struct CitationMetadata {
188 pub citation_sources: Vec<CitationSource>,
189}
190
191#[derive(Debug, Serialize, Deserialize)]
192#[serde(rename_all = "camelCase")]
193pub struct PromptFeedback {
194 pub block_reason: Option<String>,
195 pub safety_ratings: Vec<SafetyRating>,
196 pub block_reason_message: Option<String>,
197}
198
199#[derive(Debug, Deserialize, Serialize)]
200#[serde(rename_all = "camelCase")]
201pub struct GenerationConfig {
202 pub candidate_count: Option<usize>,
203 pub stop_sequences: Option<Vec<String>>,
204 pub max_output_tokens: Option<usize>,
205 pub temperature: Option<f64>,
206 pub top_p: Option<f64>,
207 pub top_k: Option<usize>,
208}
209
210#[derive(Debug, Serialize, Deserialize)]
211#[serde(rename_all = "camelCase")]
212pub struct SafetySetting {
213 pub category: HarmCategory,
214 pub threshold: HarmBlockThreshold,
215}
216
217#[derive(Debug, Serialize, Deserialize)]
218pub enum HarmCategory {
219 #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
220 Unspecified,
221 #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
222 Derogatory,
223 #[serde(rename = "HARM_CATEGORY_TOXICITY")]
224 Toxicity,
225 #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
226 Violence,
227 #[serde(rename = "HARM_CATEGORY_SEXUAL")]
228 Sexual,
229 #[serde(rename = "HARM_CATEGORY_MEDICAL")]
230 Medical,
231 #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
232 Dangerous,
233 #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
234 Harassment,
235 #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
236 HateSpeech,
237 #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
238 SexuallyExplicit,
239 #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
240 DangerousContent,
241}
242
243#[derive(Debug, Serialize, Deserialize)]
244pub enum HarmBlockThreshold {
245 #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
246 Unspecified,
247 #[serde(rename = "BLOCK_LOW_AND_ABOVE")]
248 BlockLowAndAbove,
249 #[serde(rename = "BLOCK_MEDIUM_AND_ABOVE")]
250 BlockMediumAndAbove,
251 #[serde(rename = "BLOCK_ONLY_HIGH")]
252 BlockOnlyHigh,
253 #[serde(rename = "BLOCK_NONE")]
254 BlockNone,
255}
256
257#[derive(Debug, Serialize, Deserialize)]
258#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
259pub enum HarmProbability {
260 #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
261 Unspecified,
262 Negligible,
263 Low,
264 Medium,
265 High,
266}
267
268#[derive(Debug, Serialize, Deserialize)]
269#[serde(rename_all = "camelCase")]
270pub struct SafetyRating {
271 pub category: HarmCategory,
272 pub probability: HarmProbability,
273}
274
275#[derive(Debug, Serialize, Deserialize)]
276#[serde(rename_all = "camelCase")]
277pub struct CountTokensRequest {
278 pub contents: Vec<Content>,
279}
280
281#[derive(Debug, Serialize, Deserialize)]
282#[serde(rename_all = "camelCase")]
283pub struct CountTokensResponse {
284 pub total_tokens: usize,
285}
286
287#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
288#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
289pub enum Model {
290 #[serde(rename = "gemini-1.5-pro")]
291 Gemini15Pro,
292 #[serde(rename = "gemini-1.5-flash")]
293 Gemini15Flash,
294 #[serde(rename = "custom")]
295 Custom {
296 name: String,
297 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
298 display_name: Option<String>,
299 max_tokens: usize,
300 },
301}
302
303impl Model {
304 pub fn id(&self) -> &str {
305 match self {
306 Model::Gemini15Pro => "gemini-1.5-pro",
307 Model::Gemini15Flash => "gemini-1.5-flash",
308 Model::Custom { name, .. } => name,
309 }
310 }
311
312 pub fn display_name(&self) -> &str {
313 match self {
314 Model::Gemini15Pro => "Gemini 1.5 Pro",
315 Model::Gemini15Flash => "Gemini 1.5 Flash",
316 Self::Custom {
317 name, display_name, ..
318 } => display_name.as_ref().unwrap_or(name),
319 }
320 }
321
322 pub fn max_token_count(&self) -> usize {
323 match self {
324 Model::Gemini15Pro => 2_000_000,
325 Model::Gemini15Flash => 1_000_000,
326 Model::Custom { max_tokens, .. } => *max_tokens,
327 }
328 }
329}
330
331impl std::fmt::Display for Model {
332 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
333 write!(f, "{}", self.id())
334 }
335}
336
337pub fn extract_text_from_events(
338 events: impl Stream<Item = Result<GenerateContentResponse>>,
339) -> impl Stream<Item = Result<String>> {
340 events.filter_map(|event| async move {
341 match event {
342 Ok(event) => event.candidates.and_then(|candidates| {
343 candidates.into_iter().next().and_then(|candidate| {
344 candidate.content.parts.into_iter().next().and_then(|part| {
345 if let Part::TextPart(TextPart { text }) = part {
346 Some(Ok(text))
347 } else {
348 None
349 }
350 })
351 })
352 }),
353 Err(error) => Some(Err(error)),
354 }
355 })
356}