1use std::sync::Arc;
2
3use anyhow::{anyhow, Result};
4use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
5use http::HttpClient;
6use serde::{Deserialize, Serialize};
7
8pub const API_URL: &str = "https://generativelanguage.googleapis.com";
9
10pub async fn stream_generate_content(
11 client: Arc<dyn HttpClient>,
12 api_url: &str,
13 api_key: &str,
14 model: &str,
15 request: GenerateContentRequest,
16) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
17 let uri = format!(
18 "{}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={}",
19 api_url, api_key
20 );
21
22 let request = serde_json::to_string(&request)?;
23 let mut response = client.post_json(&uri, request.into()).await?;
24 if response.status().is_success() {
25 let reader = BufReader::new(response.into_body());
26 Ok(reader
27 .lines()
28 .filter_map(|line| async move {
29 match line {
30 Ok(line) => {
31 if let Some(line) = line.strip_prefix("data: ") {
32 match serde_json::from_str(line) {
33 Ok(response) => Some(Ok(response)),
34 Err(error) => Some(Err(anyhow!(error))),
35 }
36 } else {
37 None
38 }
39 }
40 Err(error) => Some(Err(anyhow!(error))),
41 }
42 })
43 .boxed())
44 } else {
45 let mut text = String::new();
46 response.body_mut().read_to_string(&mut text).await?;
47 Err(anyhow!(
48 "error during streamGenerateContent, status code: {:?}, body: {}",
49 response.status(),
50 text
51 ))
52 }
53}
54
55pub async fn count_tokens<T: HttpClient>(
56 client: &T,
57 api_url: &str,
58 api_key: &str,
59 request: CountTokensRequest,
60) -> Result<CountTokensResponse> {
61 let uri = format!(
62 "{}/v1beta/models/gemini-pro:countTokens?key={}",
63 api_url, api_key
64 );
65 let request = serde_json::to_string(&request)?;
66 let mut response = client.post_json(&uri, request.into()).await?;
67 let mut text = String::new();
68 response.body_mut().read_to_string(&mut text).await?;
69 if response.status().is_success() {
70 Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
71 } else {
72 Err(anyhow!(
73 "error during countTokens, status code: {:?}, body: {}",
74 response.status(),
75 text
76 ))
77 }
78}
79
80#[derive(Debug, Serialize, Deserialize)]
81pub enum Task {
82 #[serde(rename = "generateContent")]
83 GenerateContent,
84 #[serde(rename = "streamGenerateContent")]
85 StreamGenerateContent,
86 #[serde(rename = "countTokens")]
87 CountTokens,
88 #[serde(rename = "embedContent")]
89 EmbedContent,
90 #[serde(rename = "batchEmbedContents")]
91 BatchEmbedContents,
92}
93
94#[derive(Debug, Serialize)]
95#[serde(rename_all = "camelCase")]
96pub struct GenerateContentRequest {
97 pub contents: Vec<Content>,
98 pub generation_config: Option<GenerationConfig>,
99 pub safety_settings: Option<Vec<SafetySetting>>,
100}
101
102#[derive(Debug, Deserialize)]
103#[serde(rename_all = "camelCase")]
104pub struct GenerateContentResponse {
105 pub candidates: Option<Vec<GenerateContentCandidate>>,
106 pub prompt_feedback: Option<PromptFeedback>,
107}
108
109#[derive(Debug, Deserialize)]
110#[serde(rename_all = "camelCase")]
111pub struct GenerateContentCandidate {
112 pub index: usize,
113 pub content: Content,
114 pub finish_reason: Option<String>,
115 pub finish_message: Option<String>,
116 pub safety_ratings: Option<Vec<SafetyRating>>,
117 pub citation_metadata: Option<CitationMetadata>,
118}
119
120#[derive(Debug, Serialize, Deserialize)]
121#[serde(rename_all = "camelCase")]
122pub struct Content {
123 pub parts: Vec<Part>,
124 pub role: Role,
125}
126
127#[derive(Debug, Deserialize, Serialize)]
128#[serde(rename_all = "camelCase")]
129pub enum Role {
130 User,
131 Model,
132}
133
134#[derive(Debug, Serialize, Deserialize)]
135#[serde(untagged)]
136pub enum Part {
137 TextPart(TextPart),
138 InlineDataPart(InlineDataPart),
139}
140
141#[derive(Debug, Serialize, Deserialize)]
142#[serde(rename_all = "camelCase")]
143pub struct TextPart {
144 pub text: String,
145}
146
147#[derive(Debug, Serialize, Deserialize)]
148#[serde(rename_all = "camelCase")]
149pub struct InlineDataPart {
150 pub inline_data: GenerativeContentBlob,
151}
152
153#[derive(Debug, Serialize, Deserialize)]
154#[serde(rename_all = "camelCase")]
155pub struct GenerativeContentBlob {
156 pub mime_type: String,
157 pub data: String,
158}
159
160#[derive(Debug, Deserialize)]
161#[serde(rename_all = "camelCase")]
162pub struct CitationSource {
163 pub start_index: Option<usize>,
164 pub end_index: Option<usize>,
165 pub uri: Option<String>,
166 pub license: Option<String>,
167}
168
169#[derive(Debug, Deserialize)]
170#[serde(rename_all = "camelCase")]
171pub struct CitationMetadata {
172 pub citation_sources: Vec<CitationSource>,
173}
174
175#[derive(Debug, Deserialize)]
176#[serde(rename_all = "camelCase")]
177pub struct PromptFeedback {
178 pub block_reason: Option<String>,
179 pub safety_ratings: Vec<SafetyRating>,
180 pub block_reason_message: Option<String>,
181}
182
183#[derive(Debug, Serialize)]
184#[serde(rename_all = "camelCase")]
185pub struct GenerationConfig {
186 pub candidate_count: Option<usize>,
187 pub stop_sequences: Option<Vec<String>>,
188 pub max_output_tokens: Option<usize>,
189 pub temperature: Option<f64>,
190 pub top_p: Option<f64>,
191 pub top_k: Option<usize>,
192}
193
194#[derive(Debug, Serialize)]
195#[serde(rename_all = "camelCase")]
196pub struct SafetySetting {
197 pub category: HarmCategory,
198 pub threshold: HarmBlockThreshold,
199}
200
201#[derive(Debug, Serialize, Deserialize)]
202pub enum HarmCategory {
203 #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
204 Unspecified,
205 #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
206 Derogatory,
207 #[serde(rename = "HARM_CATEGORY_TOXICITY")]
208 Toxicity,
209 #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
210 Violence,
211 #[serde(rename = "HARM_CATEGORY_SEXUAL")]
212 Sexual,
213 #[serde(rename = "HARM_CATEGORY_MEDICAL")]
214 Medical,
215 #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
216 Dangerous,
217 #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
218 Harassment,
219 #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
220 HateSpeech,
221 #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
222 SexuallyExplicit,
223 #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
224 DangerousContent,
225}
226
227#[derive(Debug, Serialize)]
228pub enum HarmBlockThreshold {
229 #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
230 Unspecified,
231 #[serde(rename = "BLOCK_LOW_AND_ABOVE")]
232 BlockLowAndAbove,
233 #[serde(rename = "BLOCK_MEDIUM_AND_ABOVE")]
234 BlockMediumAndAbove,
235 #[serde(rename = "BLOCK_ONLY_HIGH")]
236 BlockOnlyHigh,
237 #[serde(rename = "BLOCK_NONE")]
238 BlockNone,
239}
240
241#[derive(Debug, Deserialize)]
242#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
243pub enum HarmProbability {
244 #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
245 Unspecified,
246 Negligible,
247 Low,
248 Medium,
249 High,
250}
251
252#[derive(Debug, Deserialize)]
253#[serde(rename_all = "camelCase")]
254pub struct SafetyRating {
255 pub category: HarmCategory,
256 pub probability: HarmProbability,
257}
258
259#[derive(Debug, Serialize)]
260#[serde(rename_all = "camelCase")]
261pub struct CountTokensRequest {
262 pub contents: Vec<Content>,
263}
264
265#[derive(Debug, Deserialize)]
266#[serde(rename_all = "camelCase")]
267pub struct CountTokensResponse {
268 pub total_tokens: usize,
269}