1mod supported_countries;
2
3use anyhow::{anyhow, Context, Result};
4use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
5use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
6use isahc::config::Configurable;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::{convert::TryFrom, future::Future, pin::Pin, time::Duration};
10use strum::EnumIter;
11
12pub use supported_countries::*;
13
14pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
15
16fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
17 opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
18}
19
20#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
21#[serde(rename_all = "lowercase")]
22pub enum Role {
23 User,
24 Assistant,
25 System,
26 Tool,
27}
28
29impl TryFrom<String> for Role {
30 type Error = anyhow::Error;
31
32 fn try_from(value: String) -> Result<Self> {
33 match value.as_str() {
34 "user" => Ok(Self::User),
35 "assistant" => Ok(Self::Assistant),
36 "system" => Ok(Self::System),
37 "tool" => Ok(Self::Tool),
38 _ => Err(anyhow!("invalid role '{value}'")),
39 }
40 }
41}
42
43impl From<Role> for String {
44 fn from(val: Role) -> Self {
45 match val {
46 Role::User => "user".to_owned(),
47 Role::Assistant => "assistant".to_owned(),
48 Role::System => "system".to_owned(),
49 Role::Tool => "tool".to_owned(),
50 }
51 }
52}
53
54#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
55#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
56pub enum Model {
57 #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
58 ThreePointFiveTurbo,
59 #[serde(rename = "gpt-4", alias = "gpt-4-0613")]
60 Four,
61 #[serde(rename = "gpt-4-turbo-preview", alias = "gpt-4-1106-preview")]
62 FourTurbo,
63 #[serde(rename = "gpt-4o", alias = "gpt-4o-2024-05-13")]
64 #[default]
65 FourOmni,
66 #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini-2024-07-18")]
67 FourOmniMini,
68 #[serde(rename = "custom")]
69 Custom {
70 name: String,
71 max_tokens: usize,
72 max_output_tokens: Option<u32>,
73 },
74}
75
76impl Model {
77 pub fn from_id(id: &str) -> Result<Self> {
78 match id {
79 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
80 "gpt-4" => Ok(Self::Four),
81 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
82 "gpt-4o" => Ok(Self::FourOmni),
83 "gpt-4o-mini" => Ok(Self::FourOmniMini),
84 _ => Err(anyhow!("invalid model id")),
85 }
86 }
87
88 pub fn id(&self) -> &str {
89 match self {
90 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
91 Self::Four => "gpt-4",
92 Self::FourTurbo => "gpt-4-turbo-preview",
93 Self::FourOmni => "gpt-4o",
94 Self::FourOmniMini => "gpt-4o-mini",
95 Self::Custom { name, .. } => name,
96 }
97 }
98
99 pub fn display_name(&self) -> &str {
100 match self {
101 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
102 Self::Four => "gpt-4",
103 Self::FourTurbo => "gpt-4-turbo",
104 Self::FourOmni => "gpt-4o",
105 Self::FourOmniMini => "gpt-4o-mini",
106 Self::Custom { name, .. } => name,
107 }
108 }
109
110 pub fn max_token_count(&self) -> usize {
111 match self {
112 Self::ThreePointFiveTurbo => 4096,
113 Self::Four => 8192,
114 Self::FourTurbo => 128000,
115 Self::FourOmni => 128000,
116 Self::FourOmniMini => 128000,
117 Self::Custom { max_tokens, .. } => *max_tokens,
118 }
119 }
120
121 pub fn max_output_tokens(&self) -> Option<u32> {
122 match self {
123 Self::Custom {
124 max_output_tokens, ..
125 } => *max_output_tokens,
126 _ => None,
127 }
128 }
129}
130
131#[derive(Debug, Serialize, Deserialize)]
132pub struct Request {
133 pub model: String,
134 pub messages: Vec<RequestMessage>,
135 pub stream: bool,
136 #[serde(default, skip_serializing_if = "Option::is_none")]
137 pub max_tokens: Option<u32>,
138 pub stop: Vec<String>,
139 pub temperature: f32,
140 #[serde(default, skip_serializing_if = "Option::is_none")]
141 pub tool_choice: Option<ToolChoice>,
142 #[serde(default, skip_serializing_if = "Vec::is_empty")]
143 pub tools: Vec<ToolDefinition>,
144}
145
146#[derive(Debug, Serialize, Deserialize)]
147#[serde(untagged)]
148pub enum ToolChoice {
149 Auto,
150 Required,
151 None,
152 Other(ToolDefinition),
153}
154
155#[derive(Clone, Deserialize, Serialize, Debug)]
156#[serde(tag = "type", rename_all = "snake_case")]
157pub enum ToolDefinition {
158 #[allow(dead_code)]
159 Function { function: FunctionDefinition },
160}
161
162#[derive(Clone, Debug, Serialize, Deserialize)]
163pub struct FunctionDefinition {
164 pub name: String,
165 pub description: Option<String>,
166 pub parameters: Option<Value>,
167}
168
169#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
170#[serde(tag = "role", rename_all = "lowercase")]
171pub enum RequestMessage {
172 Assistant {
173 content: Option<String>,
174 #[serde(default, skip_serializing_if = "Vec::is_empty")]
175 tool_calls: Vec<ToolCall>,
176 },
177 User {
178 content: String,
179 },
180 System {
181 content: String,
182 },
183 Tool {
184 content: String,
185 tool_call_id: String,
186 },
187}
188
189#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
190pub struct ToolCall {
191 pub id: String,
192 #[serde(flatten)]
193 pub content: ToolCallContent,
194}
195
196#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
197#[serde(tag = "type", rename_all = "lowercase")]
198pub enum ToolCallContent {
199 Function { function: FunctionContent },
200}
201
202#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
203pub struct FunctionContent {
204 pub name: String,
205 pub arguments: String,
206}
207
208#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
209pub struct ResponseMessageDelta {
210 pub role: Option<Role>,
211 pub content: Option<String>,
212 #[serde(default, skip_serializing_if = "is_none_or_empty")]
213 pub tool_calls: Option<Vec<ToolCallChunk>>,
214}
215
216#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
217pub struct ToolCallChunk {
218 pub index: usize,
219 pub id: Option<String>,
220
221 // There is also an optional `type` field that would determine if a
222 // function is there. Sometimes this streams in with the `function` before
223 // it streams in the `type`
224 pub function: Option<FunctionChunk>,
225}
226
227#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
228pub struct FunctionChunk {
229 pub name: Option<String>,
230 pub arguments: Option<String>,
231}
232
233#[derive(Serialize, Deserialize, Debug)]
234pub struct Usage {
235 pub prompt_tokens: u32,
236 pub completion_tokens: u32,
237 pub total_tokens: u32,
238}
239
240#[derive(Serialize, Deserialize, Debug)]
241pub struct ChoiceDelta {
242 pub index: u32,
243 pub delta: ResponseMessageDelta,
244 pub finish_reason: Option<String>,
245}
246
247#[derive(Serialize, Deserialize, Debug)]
248#[serde(untagged)]
249pub enum ResponseStreamResult {
250 Ok(ResponseStreamEvent),
251 Err { error: String },
252}
253
254#[derive(Serialize, Deserialize, Debug)]
255pub struct ResponseStreamEvent {
256 pub created: u32,
257 pub model: String,
258 pub choices: Vec<ChoiceDelta>,
259 pub usage: Option<Usage>,
260}
261
262pub async fn stream_completion(
263 client: &dyn HttpClient,
264 api_url: &str,
265 api_key: &str,
266 request: Request,
267 low_speed_timeout: Option<Duration>,
268) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
269 let uri = format!("{api_url}/chat/completions");
270 let mut request_builder = HttpRequest::builder()
271 .method(Method::POST)
272 .uri(uri)
273 .header("Content-Type", "application/json")
274 .header("Authorization", format!("Bearer {}", api_key));
275
276 if let Some(low_speed_timeout) = low_speed_timeout {
277 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
278 };
279
280 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
281 let mut response = client.send(request).await?;
282 if response.status().is_success() {
283 let reader = BufReader::new(response.into_body());
284 Ok(reader
285 .lines()
286 .filter_map(|line| async move {
287 match line {
288 Ok(line) => {
289 let line = line.strip_prefix("data: ")?;
290 if line == "[DONE]" {
291 None
292 } else {
293 match serde_json::from_str(line) {
294 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
295 Ok(ResponseStreamResult::Err { error }) => {
296 Some(Err(anyhow!(error)))
297 }
298 Err(error) => Some(Err(anyhow!(error))),
299 }
300 }
301 }
302 Err(error) => Some(Err(anyhow!(error))),
303 }
304 })
305 .boxed())
306 } else {
307 let mut body = String::new();
308 response.body_mut().read_to_string(&mut body).await?;
309
310 #[derive(Deserialize)]
311 struct OpenAiResponse {
312 error: OpenAiError,
313 }
314
315 #[derive(Deserialize)]
316 struct OpenAiError {
317 message: String,
318 }
319
320 match serde_json::from_str::<OpenAiResponse>(&body) {
321 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
322 "Failed to connect to OpenAI API: {}",
323 response.error.message,
324 )),
325
326 _ => Err(anyhow!(
327 "Failed to connect to OpenAI API: {} {}",
328 response.status(),
329 body,
330 )),
331 }
332 }
333}
334
335#[derive(Copy, Clone, Serialize, Deserialize)]
336pub enum OpenAiEmbeddingModel {
337 #[serde(rename = "text-embedding-3-small")]
338 TextEmbedding3Small,
339 #[serde(rename = "text-embedding-3-large")]
340 TextEmbedding3Large,
341}
342
343#[derive(Serialize)]
344struct OpenAiEmbeddingRequest<'a> {
345 model: OpenAiEmbeddingModel,
346 input: Vec<&'a str>,
347}
348
349#[derive(Deserialize)]
350pub struct OpenAiEmbeddingResponse {
351 pub data: Vec<OpenAiEmbedding>,
352}
353
354#[derive(Deserialize)]
355pub struct OpenAiEmbedding {
356 pub embedding: Vec<f32>,
357}
358
359pub fn embed<'a>(
360 client: &dyn HttpClient,
361 api_url: &str,
362 api_key: &str,
363 model: OpenAiEmbeddingModel,
364 texts: impl IntoIterator<Item = &'a str>,
365) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
366 let uri = format!("{api_url}/embeddings");
367
368 let request = OpenAiEmbeddingRequest {
369 model,
370 input: texts.into_iter().collect(),
371 };
372 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
373 let request = HttpRequest::builder()
374 .method(Method::POST)
375 .uri(uri)
376 .header("Content-Type", "application/json")
377 .header("Authorization", format!("Bearer {}", api_key))
378 .body(body)
379 .map(|request| client.send(request));
380
381 async move {
382 let mut response = request?.await?;
383 let mut body = String::new();
384 response.body_mut().read_to_string(&mut body).await?;
385
386 if response.status().is_success() {
387 let response: OpenAiEmbeddingResponse =
388 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
389 Ok(response)
390 } else {
391 Err(anyhow!(
392 "error during embedding, status: {:?}, body: {:?}",
393 response.status(),
394 body
395 ))
396 }
397 }
398}
399
400pub async fn extract_tool_args_from_events(
401 tool_name: String,
402 mut events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
403) -> Result<impl Send + Stream<Item = Result<String>>> {
404 let mut tool_use_index = None;
405 let mut first_chunk = None;
406 while let Some(event) = events.next().await {
407 let call = event?.choices.into_iter().find_map(|choice| {
408 choice.delta.tool_calls?.into_iter().find_map(|call| {
409 if call.function.as_ref()?.name.as_deref()? == tool_name {
410 Some(call)
411 } else {
412 None
413 }
414 })
415 });
416 if let Some(call) = call {
417 tool_use_index = Some(call.index);
418 first_chunk = call.function.and_then(|func| func.arguments);
419 break;
420 }
421 }
422
423 let Some(tool_use_index) = tool_use_index else {
424 return Err(anyhow!("tool not used"));
425 };
426
427 Ok(events.filter_map(move |event| {
428 let result = match event {
429 Err(error) => Some(Err(error)),
430 Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| {
431 choice.delta.tool_calls?.into_iter().find_map(|call| {
432 if call.index == tool_use_index {
433 let func = call.function?;
434 let mut arguments = func.arguments?;
435 if let Some(mut first_chunk) = first_chunk.take() {
436 first_chunk.push_str(&arguments);
437 arguments = first_chunk
438 }
439 Some(Ok(arguments))
440 } else {
441 None
442 }
443 })
444 }),
445 };
446
447 async move { result }
448 }))
449}
450
451pub fn extract_text_from_events(
452 response: impl Stream<Item = Result<ResponseStreamEvent>>,
453) -> impl Stream<Item = Result<String>> {
454 response.filter_map(|response| async move {
455 match response {
456 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
457 Err(error) => Some(Err(error)),
458 }
459 })
460}