1use std::sync::Arc;
2
3use anyhow::{anyhow, Result};
4use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
5use http::HttpClient;
6use serde::{Deserialize, Serialize};
7
8pub const API_URL: &str = "https://generativelanguage.googleapis.com";
9
10pub async fn stream_generate_content(
11 client: Arc<dyn HttpClient>,
12 api_url: &str,
13 api_key: &str,
14 request: GenerateContentRequest,
15) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
16 let uri = format!(
17 "{}/v1beta/models/gemini-pro:streamGenerateContent?alt=sse&key={}",
18 api_url, api_key
19 );
20
21 let request = serde_json::to_string(&request)?;
22 let mut response = client.post_json(&uri, request.into()).await?;
23 if response.status().is_success() {
24 let reader = BufReader::new(response.into_body());
25 Ok(reader
26 .lines()
27 .filter_map(|line| async move {
28 match line {
29 Ok(line) => {
30 if let Some(line) = line.strip_prefix("data: ") {
31 match serde_json::from_str(line) {
32 Ok(response) => Some(Ok(response)),
33 Err(error) => Some(Err(anyhow!(error))),
34 }
35 } else {
36 None
37 }
38 }
39 Err(error) => Some(Err(anyhow!(error))),
40 }
41 })
42 .boxed())
43 } else {
44 let mut text = String::new();
45 response.body_mut().read_to_string(&mut text).await?;
46 Err(anyhow!(
47 "error during streamGenerateContent, status code: {:?}, body: {}",
48 response.status(),
49 text
50 ))
51 }
52}
53
54pub async fn count_tokens<T: HttpClient>(
55 client: &T,
56 api_url: &str,
57 api_key: &str,
58 request: CountTokensRequest,
59) -> Result<CountTokensResponse> {
60 let uri = format!(
61 "{}/v1beta/models/gemini-pro:countTokens?key={}",
62 api_url, api_key
63 );
64 let request = serde_json::to_string(&request)?;
65 let mut response = client.post_json(&uri, request.into()).await?;
66 let mut text = String::new();
67 response.body_mut().read_to_string(&mut text).await?;
68 if response.status().is_success() {
69 Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
70 } else {
71 Err(anyhow!(
72 "error during countTokens, status code: {:?}, body: {}",
73 response.status(),
74 text
75 ))
76 }
77}
78
79#[derive(Debug, Serialize, Deserialize)]
80pub enum Task {
81 #[serde(rename = "generateContent")]
82 GenerateContent,
83 #[serde(rename = "streamGenerateContent")]
84 StreamGenerateContent,
85 #[serde(rename = "countTokens")]
86 CountTokens,
87 #[serde(rename = "embedContent")]
88 EmbedContent,
89 #[serde(rename = "batchEmbedContents")]
90 BatchEmbedContents,
91}
92
93#[derive(Debug, Serialize)]
94#[serde(rename_all = "camelCase")]
95pub struct GenerateContentRequest {
96 pub contents: Vec<Content>,
97 pub generation_config: Option<GenerationConfig>,
98 pub safety_settings: Option<Vec<SafetySetting>>,
99}
100
101#[derive(Debug, Deserialize)]
102#[serde(rename_all = "camelCase")]
103pub struct GenerateContentResponse {
104 pub candidates: Option<Vec<GenerateContentCandidate>>,
105 pub prompt_feedback: Option<PromptFeedback>,
106}
107
108#[derive(Debug, Deserialize)]
109#[serde(rename_all = "camelCase")]
110pub struct GenerateContentCandidate {
111 pub index: usize,
112 pub content: Content,
113 pub finish_reason: Option<String>,
114 pub finish_message: Option<String>,
115 pub safety_ratings: Option<Vec<SafetyRating>>,
116 pub citation_metadata: Option<CitationMetadata>,
117}
118
119#[derive(Debug, Serialize, Deserialize)]
120#[serde(rename_all = "camelCase")]
121pub struct Content {
122 pub parts: Vec<Part>,
123 pub role: Role,
124}
125
126#[derive(Debug, Deserialize, Serialize)]
127#[serde(rename_all = "camelCase")]
128pub enum Role {
129 User,
130 Model,
131}
132
133#[derive(Debug, Serialize, Deserialize)]
134#[serde(untagged)]
135pub enum Part {
136 TextPart(TextPart),
137 InlineDataPart(InlineDataPart),
138}
139
140#[derive(Debug, Serialize, Deserialize)]
141#[serde(rename_all = "camelCase")]
142pub struct TextPart {
143 pub text: String,
144}
145
146#[derive(Debug, Serialize, Deserialize)]
147#[serde(rename_all = "camelCase")]
148pub struct InlineDataPart {
149 pub inline_data: GenerativeContentBlob,
150}
151
152#[derive(Debug, Serialize, Deserialize)]
153#[serde(rename_all = "camelCase")]
154pub struct GenerativeContentBlob {
155 pub mime_type: String,
156 pub data: String,
157}
158
159#[derive(Debug, Deserialize)]
160#[serde(rename_all = "camelCase")]
161pub struct CitationSource {
162 pub start_index: Option<usize>,
163 pub end_index: Option<usize>,
164 pub uri: Option<String>,
165 pub license: Option<String>,
166}
167
168#[derive(Debug, Deserialize)]
169#[serde(rename_all = "camelCase")]
170pub struct CitationMetadata {
171 pub citation_sources: Vec<CitationSource>,
172}
173
174#[derive(Debug, Deserialize)]
175#[serde(rename_all = "camelCase")]
176pub struct PromptFeedback {
177 pub block_reason: Option<String>,
178 pub safety_ratings: Vec<SafetyRating>,
179 pub block_reason_message: Option<String>,
180}
181
182#[derive(Debug, Serialize)]
183#[serde(rename_all = "camelCase")]
184pub struct GenerationConfig {
185 pub candidate_count: Option<usize>,
186 pub stop_sequences: Option<Vec<String>>,
187 pub max_output_tokens: Option<usize>,
188 pub temperature: Option<f64>,
189 pub top_p: Option<f64>,
190 pub top_k: Option<usize>,
191}
192
193#[derive(Debug, Serialize)]
194#[serde(rename_all = "camelCase")]
195pub struct SafetySetting {
196 pub category: HarmCategory,
197 pub threshold: HarmBlockThreshold,
198}
199
200#[derive(Debug, Serialize, Deserialize)]
201pub enum HarmCategory {
202 #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
203 Unspecified,
204 #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
205 Derogatory,
206 #[serde(rename = "HARM_CATEGORY_TOXICITY")]
207 Toxicity,
208 #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
209 Violence,
210 #[serde(rename = "HARM_CATEGORY_SEXUAL")]
211 Sexual,
212 #[serde(rename = "HARM_CATEGORY_MEDICAL")]
213 Medical,
214 #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
215 Dangerous,
216 #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
217 Harassment,
218 #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
219 HateSpeech,
220 #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
221 SexuallyExplicit,
222 #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
223 DangerousContent,
224}
225
226#[derive(Debug, Serialize)]
227pub enum HarmBlockThreshold {
228 #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
229 Unspecified,
230 #[serde(rename = "BLOCK_LOW_AND_ABOVE")]
231 BlockLowAndAbove,
232 #[serde(rename = "BLOCK_MEDIUM_AND_ABOVE")]
233 BlockMediumAndAbove,
234 #[serde(rename = "BLOCK_ONLY_HIGH")]
235 BlockOnlyHigh,
236 #[serde(rename = "BLOCK_NONE")]
237 BlockNone,
238}
239
240#[derive(Debug, Deserialize)]
241#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
242pub enum HarmProbability {
243 #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
244 Unspecified,
245 Negligible,
246 Low,
247 Medium,
248 High,
249}
250
251#[derive(Debug, Deserialize)]
252#[serde(rename_all = "camelCase")]
253pub struct SafetyRating {
254 pub category: HarmCategory,
255 pub probability: HarmProbability,
256}
257
258#[derive(Debug, Serialize)]
259#[serde(rename_all = "camelCase")]
260pub struct CountTokensRequest {
261 pub contents: Vec<Content>,
262}
263
264#[derive(Debug, Deserialize)]
265#[serde(rename_all = "camelCase")]
266pub struct CountTokensResponse {
267 pub total_tokens: usize,
268}