1mod supported_countries;
2
3use anyhow::{anyhow, Result};
4use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
5use http_client::HttpClient;
6use serde::{Deserialize, Serialize};
7
8pub use supported_countries::*;
9
10pub const API_URL: &str = "https://generativelanguage.googleapis.com";
11
12pub async fn stream_generate_content(
13 client: &dyn HttpClient,
14 api_url: &str,
15 api_key: &str,
16 mut request: GenerateContentRequest,
17) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
18 let uri = format!(
19 "{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}",
20 model = request.model
21 );
22 request.model.clear();
23
24 let request = serde_json::to_string(&request)?;
25 let mut response = client.post_json(&uri, request.into()).await?;
26 if response.status().is_success() {
27 let reader = BufReader::new(response.into_body());
28 Ok(reader
29 .lines()
30 .filter_map(|line| async move {
31 match line {
32 Ok(line) => {
33 if let Some(line) = line.strip_prefix("data: ") {
34 match serde_json::from_str(line) {
35 Ok(response) => Some(Ok(response)),
36 Err(error) => Some(Err(anyhow!(error))),
37 }
38 } else {
39 None
40 }
41 }
42 Err(error) => Some(Err(anyhow!(error))),
43 }
44 })
45 .boxed())
46 } else {
47 let mut text = String::new();
48 response.body_mut().read_to_string(&mut text).await?;
49 Err(anyhow!(
50 "error during streamGenerateContent, status code: {:?}, body: {}",
51 response.status(),
52 text
53 ))
54 }
55}
56
57pub async fn count_tokens(
58 client: &dyn HttpClient,
59 api_url: &str,
60 api_key: &str,
61 request: CountTokensRequest,
62) -> Result<CountTokensResponse> {
63 let uri = format!(
64 "{}/v1beta/models/gemini-pro:countTokens?key={}",
65 api_url, api_key
66 );
67 let request = serde_json::to_string(&request)?;
68 let mut response = client.post_json(&uri, request.into()).await?;
69 let mut text = String::new();
70 response.body_mut().read_to_string(&mut text).await?;
71 if response.status().is_success() {
72 Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
73 } else {
74 Err(anyhow!(
75 "error during countTokens, status code: {:?}, body: {}",
76 response.status(),
77 text
78 ))
79 }
80}
81
82#[derive(Debug, Serialize, Deserialize)]
83pub enum Task {
84 #[serde(rename = "generateContent")]
85 GenerateContent,
86 #[serde(rename = "streamGenerateContent")]
87 StreamGenerateContent,
88 #[serde(rename = "countTokens")]
89 CountTokens,
90 #[serde(rename = "embedContent")]
91 EmbedContent,
92 #[serde(rename = "batchEmbedContents")]
93 BatchEmbedContents,
94}
95
96#[derive(Debug, Serialize, Deserialize)]
97#[serde(rename_all = "camelCase")]
98pub struct GenerateContentRequest {
99 #[serde(default, skip_serializing_if = "String::is_empty")]
100 pub model: String,
101 pub contents: Vec<Content>,
102 pub generation_config: Option<GenerationConfig>,
103 pub safety_settings: Option<Vec<SafetySetting>>,
104}
105
106#[derive(Debug, Serialize, Deserialize)]
107#[serde(rename_all = "camelCase")]
108pub struct GenerateContentResponse {
109 pub candidates: Option<Vec<GenerateContentCandidate>>,
110 pub prompt_feedback: Option<PromptFeedback>,
111}
112
113#[derive(Debug, Serialize, Deserialize)]
114#[serde(rename_all = "camelCase")]
115pub struct GenerateContentCandidate {
116 pub index: usize,
117 pub content: Content,
118 pub finish_reason: Option<String>,
119 pub finish_message: Option<String>,
120 pub safety_ratings: Option<Vec<SafetyRating>>,
121 pub citation_metadata: Option<CitationMetadata>,
122}
123
124#[derive(Debug, Serialize, Deserialize)]
125#[serde(rename_all = "camelCase")]
126pub struct Content {
127 pub parts: Vec<Part>,
128 pub role: Role,
129}
130
131#[derive(Debug, Deserialize, Serialize)]
132#[serde(rename_all = "camelCase")]
133pub enum Role {
134 User,
135 Model,
136}
137
138#[derive(Debug, Serialize, Deserialize)]
139#[serde(untagged)]
140pub enum Part {
141 TextPart(TextPart),
142 InlineDataPart(InlineDataPart),
143}
144
145#[derive(Debug, Serialize, Deserialize)]
146#[serde(rename_all = "camelCase")]
147pub struct TextPart {
148 pub text: String,
149}
150
151#[derive(Debug, Serialize, Deserialize)]
152#[serde(rename_all = "camelCase")]
153pub struct InlineDataPart {
154 pub inline_data: GenerativeContentBlob,
155}
156
157#[derive(Debug, Serialize, Deserialize)]
158#[serde(rename_all = "camelCase")]
159pub struct GenerativeContentBlob {
160 pub mime_type: String,
161 pub data: String,
162}
163
164#[derive(Debug, Serialize, Deserialize)]
165#[serde(rename_all = "camelCase")]
166pub struct CitationSource {
167 pub start_index: Option<usize>,
168 pub end_index: Option<usize>,
169 pub uri: Option<String>,
170 pub license: Option<String>,
171}
172
173#[derive(Debug, Serialize, Deserialize)]
174#[serde(rename_all = "camelCase")]
175pub struct CitationMetadata {
176 pub citation_sources: Vec<CitationSource>,
177}
178
179#[derive(Debug, Serialize, Deserialize)]
180#[serde(rename_all = "camelCase")]
181pub struct PromptFeedback {
182 pub block_reason: Option<String>,
183 pub safety_ratings: Vec<SafetyRating>,
184 pub block_reason_message: Option<String>,
185}
186
187#[derive(Debug, Deserialize, Serialize)]
188#[serde(rename_all = "camelCase")]
189pub struct GenerationConfig {
190 pub candidate_count: Option<usize>,
191 pub stop_sequences: Option<Vec<String>>,
192 pub max_output_tokens: Option<usize>,
193 pub temperature: Option<f64>,
194 pub top_p: Option<f64>,
195 pub top_k: Option<usize>,
196}
197
198#[derive(Debug, Serialize, Deserialize)]
199#[serde(rename_all = "camelCase")]
200pub struct SafetySetting {
201 pub category: HarmCategory,
202 pub threshold: HarmBlockThreshold,
203}
204
205#[derive(Debug, Serialize, Deserialize)]
206pub enum HarmCategory {
207 #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
208 Unspecified,
209 #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
210 Derogatory,
211 #[serde(rename = "HARM_CATEGORY_TOXICITY")]
212 Toxicity,
213 #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
214 Violence,
215 #[serde(rename = "HARM_CATEGORY_SEXUAL")]
216 Sexual,
217 #[serde(rename = "HARM_CATEGORY_MEDICAL")]
218 Medical,
219 #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
220 Dangerous,
221 #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
222 Harassment,
223 #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
224 HateSpeech,
225 #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
226 SexuallyExplicit,
227 #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
228 DangerousContent,
229}
230
231#[derive(Debug, Serialize, Deserialize)]
232pub enum HarmBlockThreshold {
233 #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
234 Unspecified,
235 #[serde(rename = "BLOCK_LOW_AND_ABOVE")]
236 BlockLowAndAbove,
237 #[serde(rename = "BLOCK_MEDIUM_AND_ABOVE")]
238 BlockMediumAndAbove,
239 #[serde(rename = "BLOCK_ONLY_HIGH")]
240 BlockOnlyHigh,
241 #[serde(rename = "BLOCK_NONE")]
242 BlockNone,
243}
244
245#[derive(Debug, Serialize, Deserialize)]
246#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
247pub enum HarmProbability {
248 #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
249 Unspecified,
250 Negligible,
251 Low,
252 Medium,
253 High,
254}
255
256#[derive(Debug, Serialize, Deserialize)]
257#[serde(rename_all = "camelCase")]
258pub struct SafetyRating {
259 pub category: HarmCategory,
260 pub probability: HarmProbability,
261}
262
263#[derive(Debug, Serialize, Deserialize)]
264#[serde(rename_all = "camelCase")]
265pub struct CountTokensRequest {
266 pub contents: Vec<Content>,
267}
268
269#[derive(Debug, Serialize, Deserialize)]
270#[serde(rename_all = "camelCase")]
271pub struct CountTokensResponse {
272 pub total_tokens: usize,
273}
274
275#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
276#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
277pub enum Model {
278 #[serde(rename = "gemini-1.5-pro")]
279 Gemini15Pro,
280 #[serde(rename = "gemini-1.5-flash")]
281 Gemini15Flash,
282 #[serde(rename = "custom")]
283 Custom { name: String, max_tokens: usize },
284}
285
286impl Model {
287 pub fn id(&self) -> &str {
288 match self {
289 Model::Gemini15Pro => "gemini-1.5-pro",
290 Model::Gemini15Flash => "gemini-1.5-flash",
291 Model::Custom { name, .. } => name,
292 }
293 }
294
295 pub fn display_name(&self) -> &str {
296 match self {
297 Model::Gemini15Pro => "Gemini 1.5 Pro",
298 Model::Gemini15Flash => "Gemini 1.5 Flash",
299 Model::Custom { name, .. } => name,
300 }
301 }
302
303 pub fn max_token_count(&self) -> usize {
304 match self {
305 Model::Gemini15Pro => 2_000_000,
306 Model::Gemini15Flash => 1_000_000,
307 Model::Custom { max_tokens, .. } => *max_tokens,
308 }
309 }
310}
311
312impl std::fmt::Display for Model {
313 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314 write!(f, "{}", self.id())
315 }
316}
317
318pub fn extract_text_from_events(
319 events: impl Stream<Item = Result<GenerateContentResponse>>,
320) -> impl Stream<Item = Result<String>> {
321 events.filter_map(|event| async move {
322 match event {
323 Ok(event) => event.candidates.and_then(|candidates| {
324 candidates.into_iter().next().and_then(|candidate| {
325 candidate.content.parts.into_iter().next().and_then(|part| {
326 if let Part::TextPart(TextPart { text }) = part {
327 Some(Ok(text))
328 } else {
329 None
330 }
331 })
332 })
333 }),
334 Err(error) => Some(Err(error)),
335 }
336 })
337}