1use anyhow::{anyhow, Result};
2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
3use http_client::HttpClient;
4use serde::{Deserialize, Serialize};
5
6pub const API_URL: &str = "https://generativelanguage.googleapis.com";
7
8pub async fn stream_generate_content(
9 client: &dyn HttpClient,
10 api_url: &str,
11 api_key: &str,
12 mut request: GenerateContentRequest,
13) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
14 let uri = format!(
15 "{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}",
16 model = request.model
17 );
18 request.model.clear();
19
20 let request = serde_json::to_string(&request)?;
21 let mut response = client.post_json(&uri, request.into()).await?;
22 if response.status().is_success() {
23 let reader = BufReader::new(response.into_body());
24 Ok(reader
25 .lines()
26 .filter_map(|line| async move {
27 match line {
28 Ok(line) => {
29 if let Some(line) = line.strip_prefix("data: ") {
30 match serde_json::from_str(line) {
31 Ok(response) => Some(Ok(response)),
32 Err(error) => Some(Err(anyhow!(error))),
33 }
34 } else {
35 None
36 }
37 }
38 Err(error) => Some(Err(anyhow!(error))),
39 }
40 })
41 .boxed())
42 } else {
43 let mut text = String::new();
44 response.body_mut().read_to_string(&mut text).await?;
45 Err(anyhow!(
46 "error during streamGenerateContent, status code: {:?}, body: {}",
47 response.status(),
48 text
49 ))
50 }
51}
52
53pub async fn count_tokens(
54 client: &dyn HttpClient,
55 api_url: &str,
56 api_key: &str,
57 request: CountTokensRequest,
58) -> Result<CountTokensResponse> {
59 let uri = format!(
60 "{}/v1beta/models/gemini-pro:countTokens?key={}",
61 api_url, api_key
62 );
63 let request = serde_json::to_string(&request)?;
64 let mut response = client.post_json(&uri, request.into()).await?;
65 let mut text = String::new();
66 response.body_mut().read_to_string(&mut text).await?;
67 if response.status().is_success() {
68 Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
69 } else {
70 Err(anyhow!(
71 "error during countTokens, status code: {:?}, body: {}",
72 response.status(),
73 text
74 ))
75 }
76}
77
78#[derive(Debug, Serialize, Deserialize)]
79pub enum Task {
80 #[serde(rename = "generateContent")]
81 GenerateContent,
82 #[serde(rename = "streamGenerateContent")]
83 StreamGenerateContent,
84 #[serde(rename = "countTokens")]
85 CountTokens,
86 #[serde(rename = "embedContent")]
87 EmbedContent,
88 #[serde(rename = "batchEmbedContents")]
89 BatchEmbedContents,
90}
91
92#[derive(Debug, Serialize, Deserialize)]
93#[serde(rename_all = "camelCase")]
94pub struct GenerateContentRequest {
95 #[serde(default, skip_serializing_if = "String::is_empty")]
96 pub model: String,
97 pub contents: Vec<Content>,
98 pub generation_config: Option<GenerationConfig>,
99 pub safety_settings: Option<Vec<SafetySetting>>,
100}
101
102#[derive(Debug, Serialize, Deserialize)]
103#[serde(rename_all = "camelCase")]
104pub struct GenerateContentResponse {
105 pub candidates: Option<Vec<GenerateContentCandidate>>,
106 pub prompt_feedback: Option<PromptFeedback>,
107}
108
109#[derive(Debug, Serialize, Deserialize)]
110#[serde(rename_all = "camelCase")]
111pub struct GenerateContentCandidate {
112 pub index: usize,
113 pub content: Content,
114 pub finish_reason: Option<String>,
115 pub finish_message: Option<String>,
116 pub safety_ratings: Option<Vec<SafetyRating>>,
117 pub citation_metadata: Option<CitationMetadata>,
118}
119
120#[derive(Debug, Serialize, Deserialize)]
121#[serde(rename_all = "camelCase")]
122pub struct Content {
123 pub parts: Vec<Part>,
124 pub role: Role,
125}
126
127#[derive(Debug, Deserialize, Serialize)]
128#[serde(rename_all = "camelCase")]
129pub enum Role {
130 User,
131 Model,
132}
133
134#[derive(Debug, Serialize, Deserialize)]
135#[serde(untagged)]
136pub enum Part {
137 TextPart(TextPart),
138 InlineDataPart(InlineDataPart),
139}
140
141#[derive(Debug, Serialize, Deserialize)]
142#[serde(rename_all = "camelCase")]
143pub struct TextPart {
144 pub text: String,
145}
146
147#[derive(Debug, Serialize, Deserialize)]
148#[serde(rename_all = "camelCase")]
149pub struct InlineDataPart {
150 pub inline_data: GenerativeContentBlob,
151}
152
153#[derive(Debug, Serialize, Deserialize)]
154#[serde(rename_all = "camelCase")]
155pub struct GenerativeContentBlob {
156 pub mime_type: String,
157 pub data: String,
158}
159
160#[derive(Debug, Serialize, Deserialize)]
161#[serde(rename_all = "camelCase")]
162pub struct CitationSource {
163 pub start_index: Option<usize>,
164 pub end_index: Option<usize>,
165 pub uri: Option<String>,
166 pub license: Option<String>,
167}
168
169#[derive(Debug, Serialize, Deserialize)]
170#[serde(rename_all = "camelCase")]
171pub struct CitationMetadata {
172 pub citation_sources: Vec<CitationSource>,
173}
174
175#[derive(Debug, Serialize, Deserialize)]
176#[serde(rename_all = "camelCase")]
177pub struct PromptFeedback {
178 pub block_reason: Option<String>,
179 pub safety_ratings: Vec<SafetyRating>,
180 pub block_reason_message: Option<String>,
181}
182
183#[derive(Debug, Deserialize, Serialize)]
184#[serde(rename_all = "camelCase")]
185pub struct GenerationConfig {
186 pub candidate_count: Option<usize>,
187 pub stop_sequences: Option<Vec<String>>,
188 pub max_output_tokens: Option<usize>,
189 pub temperature: Option<f64>,
190 pub top_p: Option<f64>,
191 pub top_k: Option<usize>,
192}
193
194#[derive(Debug, Serialize, Deserialize)]
195#[serde(rename_all = "camelCase")]
196pub struct SafetySetting {
197 pub category: HarmCategory,
198 pub threshold: HarmBlockThreshold,
199}
200
201#[derive(Debug, Serialize, Deserialize)]
202pub enum HarmCategory {
203 #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
204 Unspecified,
205 #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
206 Derogatory,
207 #[serde(rename = "HARM_CATEGORY_TOXICITY")]
208 Toxicity,
209 #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
210 Violence,
211 #[serde(rename = "HARM_CATEGORY_SEXUAL")]
212 Sexual,
213 #[serde(rename = "HARM_CATEGORY_MEDICAL")]
214 Medical,
215 #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
216 Dangerous,
217 #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
218 Harassment,
219 #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
220 HateSpeech,
221 #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
222 SexuallyExplicit,
223 #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
224 DangerousContent,
225}
226
227#[derive(Debug, Serialize, Deserialize)]
228pub enum HarmBlockThreshold {
229 #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
230 Unspecified,
231 #[serde(rename = "BLOCK_LOW_AND_ABOVE")]
232 BlockLowAndAbove,
233 #[serde(rename = "BLOCK_MEDIUM_AND_ABOVE")]
234 BlockMediumAndAbove,
235 #[serde(rename = "BLOCK_ONLY_HIGH")]
236 BlockOnlyHigh,
237 #[serde(rename = "BLOCK_NONE")]
238 BlockNone,
239}
240
241#[derive(Debug, Serialize, Deserialize)]
242#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
243pub enum HarmProbability {
244 #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
245 Unspecified,
246 Negligible,
247 Low,
248 Medium,
249 High,
250}
251
252#[derive(Debug, Serialize, Deserialize)]
253#[serde(rename_all = "camelCase")]
254pub struct SafetyRating {
255 pub category: HarmCategory,
256 pub probability: HarmProbability,
257}
258
259#[derive(Debug, Serialize, Deserialize)]
260#[serde(rename_all = "camelCase")]
261pub struct CountTokensRequest {
262 pub contents: Vec<Content>,
263}
264
265#[derive(Debug, Serialize, Deserialize)]
266#[serde(rename_all = "camelCase")]
267pub struct CountTokensResponse {
268 pub total_tokens: usize,
269}
270
271#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
272#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
273pub enum Model {
274 #[serde(rename = "gemini-1.5-pro")]
275 Gemini15Pro,
276 #[serde(rename = "gemini-1.5-flash")]
277 Gemini15Flash,
278 #[serde(rename = "custom")]
279 Custom { name: String, max_tokens: usize },
280}
281
282impl Model {
283 pub fn id(&self) -> &str {
284 match self {
285 Model::Gemini15Pro => "gemini-1.5-pro",
286 Model::Gemini15Flash => "gemini-1.5-flash",
287 Model::Custom { name, .. } => name,
288 }
289 }
290
291 pub fn display_name(&self) -> &str {
292 match self {
293 Model::Gemini15Pro => "Gemini 1.5 Pro",
294 Model::Gemini15Flash => "Gemini 1.5 Flash",
295 Model::Custom { name, .. } => name,
296 }
297 }
298
299 pub fn max_token_count(&self) -> usize {
300 match self {
301 Model::Gemini15Pro => 2_000_000,
302 Model::Gemini15Flash => 1_000_000,
303 Model::Custom { max_tokens, .. } => *max_tokens,
304 }
305 }
306}
307
308impl std::fmt::Display for Model {
309 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310 write!(f, "{}", self.id())
311 }
312}
313
314pub fn extract_text_from_events(
315 events: impl Stream<Item = Result<GenerateContentResponse>>,
316) -> impl Stream<Item = Result<String>> {
317 events.filter_map(|event| async move {
318 match event {
319 Ok(event) => event.candidates.and_then(|candidates| {
320 candidates.into_iter().next().and_then(|candidate| {
321 candidate.content.parts.into_iter().next().and_then(|part| {
322 if let Part::TextPart(TextPart { text }) = part {
323 Some(Ok(text))
324 } else {
325 None
326 }
327 })
328 })
329 }),
330 Err(error) => Some(Err(error)),
331 }
332 })
333}