1use anyhow::{anyhow, Result};
2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
3use serde::{Deserialize, Serialize};
4use util::http::HttpClient;
5
6pub const API_URL: &str = "https://generativelanguage.googleapis.com";
7
8pub async fn stream_generate_content<T: HttpClient>(
9 client: &T,
10 api_url: &str,
11 api_key: &str,
12 request: GenerateContentRequest,
13) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
14 let uri = format!(
15 "{}/v1beta/models/gemini-pro:streamGenerateContent?alt=sse&key={}",
16 api_url, api_key
17 );
18
19 let request = serde_json::to_string(&request)?;
20 let mut response = client.post_json(&uri, request.into()).await?;
21 if response.status().is_success() {
22 let reader = BufReader::new(response.into_body());
23 Ok(reader
24 .lines()
25 .filter_map(|line| async move {
26 match line {
27 Ok(line) => {
28 if let Some(line) = line.strip_prefix("data: ") {
29 match serde_json::from_str(line) {
30 Ok(response) => Some(Ok(response)),
31 Err(error) => Some(Err(anyhow!(error))),
32 }
33 } else {
34 None
35 }
36 }
37 Err(error) => Some(Err(anyhow!(error))),
38 }
39 })
40 .boxed())
41 } else {
42 let mut text = String::new();
43 response.body_mut().read_to_string(&mut text).await?;
44 Err(anyhow!(
45 "error during streamGenerateContent, status code: {:?}, body: {}",
46 response.status(),
47 text
48 ))
49 }
50}
51
52pub async fn count_tokens<T: HttpClient>(
53 client: &T,
54 api_url: &str,
55 api_key: &str,
56 request: CountTokensRequest,
57) -> Result<CountTokensResponse> {
58 let uri = format!(
59 "{}/v1beta/models/gemini-pro:countTokens?key={}",
60 api_url, api_key
61 );
62 let request = serde_json::to_string(&request)?;
63 let mut response = client.post_json(&uri, request.into()).await?;
64 let mut text = String::new();
65 response.body_mut().read_to_string(&mut text).await?;
66 if response.status().is_success() {
67 Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
68 } else {
69 Err(anyhow!(
70 "error during countTokens, status code: {:?}, body: {}",
71 response.status(),
72 text
73 ))
74 }
75}
76
77#[derive(Debug, Serialize, Deserialize)]
78pub enum Task {
79 #[serde(rename = "generateContent")]
80 GenerateContent,
81 #[serde(rename = "streamGenerateContent")]
82 StreamGenerateContent,
83 #[serde(rename = "countTokens")]
84 CountTokens,
85 #[serde(rename = "embedContent")]
86 EmbedContent,
87 #[serde(rename = "batchEmbedContents")]
88 BatchEmbedContents,
89}
90
91#[derive(Debug, Serialize)]
92#[serde(rename_all = "camelCase")]
93pub struct GenerateContentRequest {
94 pub contents: Vec<Content>,
95 pub generation_config: Option<GenerationConfig>,
96 pub safety_settings: Option<Vec<SafetySetting>>,
97}
98
99#[derive(Debug, Deserialize)]
100#[serde(rename_all = "camelCase")]
101pub struct GenerateContentResponse {
102 pub candidates: Option<Vec<GenerateContentCandidate>>,
103 pub prompt_feedback: Option<PromptFeedback>,
104}
105
106#[derive(Debug, Deserialize)]
107#[serde(rename_all = "camelCase")]
108pub struct GenerateContentCandidate {
109 pub index: usize,
110 pub content: Content,
111 pub finish_reason: Option<String>,
112 pub finish_message: Option<String>,
113 pub safety_ratings: Option<Vec<SafetyRating>>,
114 pub citation_metadata: Option<CitationMetadata>,
115}
116
117#[derive(Debug, Serialize, Deserialize)]
118#[serde(rename_all = "camelCase")]
119pub struct Content {
120 pub parts: Vec<Part>,
121 pub role: Role,
122}
123
124#[derive(Debug, Deserialize, Serialize)]
125#[serde(rename_all = "camelCase")]
126pub enum Role {
127 User,
128 Model,
129}
130
131#[derive(Debug, Serialize, Deserialize)]
132#[serde(untagged)]
133pub enum Part {
134 TextPart(TextPart),
135 InlineDataPart(InlineDataPart),
136}
137
138#[derive(Debug, Serialize, Deserialize)]
139#[serde(rename_all = "camelCase")]
140pub struct TextPart {
141 pub text: String,
142}
143
144#[derive(Debug, Serialize, Deserialize)]
145#[serde(rename_all = "camelCase")]
146pub struct InlineDataPart {
147 pub inline_data: GenerativeContentBlob,
148}
149
150#[derive(Debug, Serialize, Deserialize)]
151#[serde(rename_all = "camelCase")]
152pub struct GenerativeContentBlob {
153 pub mime_type: String,
154 pub data: String,
155}
156
157#[derive(Debug, Deserialize)]
158#[serde(rename_all = "camelCase")]
159pub struct CitationSource {
160 pub start_index: Option<usize>,
161 pub end_index: Option<usize>,
162 pub uri: Option<String>,
163 pub license: Option<String>,
164}
165
166#[derive(Debug, Deserialize)]
167#[serde(rename_all = "camelCase")]
168pub struct CitationMetadata {
169 pub citation_sources: Vec<CitationSource>,
170}
171
172#[derive(Debug, Deserialize)]
173#[serde(rename_all = "camelCase")]
174pub struct PromptFeedback {
175 pub block_reason: Option<String>,
176 pub safety_ratings: Vec<SafetyRating>,
177 pub block_reason_message: Option<String>,
178}
179
180#[derive(Debug, Serialize)]
181#[serde(rename_all = "camelCase")]
182pub struct GenerationConfig {
183 pub candidate_count: Option<usize>,
184 pub stop_sequences: Option<Vec<String>>,
185 pub max_output_tokens: Option<usize>,
186 pub temperature: Option<f64>,
187 pub top_p: Option<f64>,
188 pub top_k: Option<usize>,
189}
190
191#[derive(Debug, Serialize)]
192#[serde(rename_all = "camelCase")]
193pub struct SafetySetting {
194 pub category: HarmCategory,
195 pub threshold: HarmBlockThreshold,
196}
197
198#[derive(Debug, Serialize, Deserialize)]
199pub enum HarmCategory {
200 #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
201 Unspecified,
202 #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
203 Derogatory,
204 #[serde(rename = "HARM_CATEGORY_TOXICITY")]
205 Toxicity,
206 #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
207 Violence,
208 #[serde(rename = "HARM_CATEGORY_SEXUAL")]
209 Sexual,
210 #[serde(rename = "HARM_CATEGORY_MEDICAL")]
211 Medical,
212 #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
213 Dangerous,
214 #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
215 Harassment,
216 #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
217 HateSpeech,
218 #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
219 SexuallyExplicit,
220 #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
221 DangerousContent,
222}
223
224#[derive(Debug, Serialize)]
225pub enum HarmBlockThreshold {
226 #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
227 Unspecified,
228 #[serde(rename = "BLOCK_LOW_AND_ABOVE")]
229 BlockLowAndAbove,
230 #[serde(rename = "BLOCK_MEDIUM_AND_ABOVE")]
231 BlockMediumAndAbove,
232 #[serde(rename = "BLOCK_ONLY_HIGH")]
233 BlockOnlyHigh,
234 #[serde(rename = "BLOCK_NONE")]
235 BlockNone,
236}
237
238#[derive(Debug, Deserialize)]
239#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
240pub enum HarmProbability {
241 #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
242 Unspecified,
243 Negligible,
244 Low,
245 Medium,
246 High,
247}
248
249#[derive(Debug, Deserialize)]
250#[serde(rename_all = "camelCase")]
251pub struct SafetyRating {
252 pub category: HarmCategory,
253 pub probability: HarmProbability,
254}
255
256#[derive(Debug, Serialize)]
257#[serde(rename_all = "camelCase")]
258pub struct CountTokensRequest {
259 pub contents: Vec<Content>,
260}
261
262#[derive(Debug, Deserialize)]
263#[serde(rename_all = "camelCase")]
264pub struct CountTokensResponse {
265 pub total_tokens: usize,
266}