1mod supported_countries;
2
3use anyhow::{anyhow, Context, Result};
4use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
5use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
6use isahc::config::Configurable;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::{convert::TryFrom, future::Future, pin::Pin, time::Duration};
10use strum::EnumIter;
11
12pub use supported_countries::*;
13
14pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
15
16fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
17 opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
18}
19
20#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
21#[serde(rename_all = "lowercase")]
22pub enum Role {
23 User,
24 Assistant,
25 System,
26 Tool,
27}
28
29impl TryFrom<String> for Role {
30 type Error = anyhow::Error;
31
32 fn try_from(value: String) -> Result<Self> {
33 match value.as_str() {
34 "user" => Ok(Self::User),
35 "assistant" => Ok(Self::Assistant),
36 "system" => Ok(Self::System),
37 "tool" => Ok(Self::Tool),
38 _ => Err(anyhow!("invalid role '{value}'")),
39 }
40 }
41}
42
43impl From<Role> for String {
44 fn from(val: Role) -> Self {
45 match val {
46 Role::User => "user".to_owned(),
47 Role::Assistant => "assistant".to_owned(),
48 Role::System => "system".to_owned(),
49 Role::Tool => "tool".to_owned(),
50 }
51 }
52}
53
54#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
55#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
56pub enum Model {
57 #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
58 ThreePointFiveTurbo,
59 #[serde(rename = "gpt-4", alias = "gpt-4-0613")]
60 Four,
61 #[serde(rename = "gpt-4-turbo-preview", alias = "gpt-4-1106-preview")]
62 FourTurbo,
63 #[serde(rename = "gpt-4o", alias = "gpt-4o-2024-05-13")]
64 #[default]
65 FourOmni,
66 #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini-2024-07-18")]
67 FourOmniMini,
68 #[serde(rename = "custom")]
69 Custom { name: String, max_tokens: usize },
70}
71
72impl Model {
73 pub fn from_id(id: &str) -> Result<Self> {
74 match id {
75 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
76 "gpt-4" => Ok(Self::Four),
77 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
78 "gpt-4o" => Ok(Self::FourOmni),
79 "gpt-4o-mini" => Ok(Self::FourOmniMini),
80 _ => Err(anyhow!("invalid model id")),
81 }
82 }
83
84 pub fn id(&self) -> &str {
85 match self {
86 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
87 Self::Four => "gpt-4",
88 Self::FourTurbo => "gpt-4-turbo-preview",
89 Self::FourOmni => "gpt-4o",
90 Self::FourOmniMini => "gpt-4o-mini",
91 Self::Custom { name, .. } => name,
92 }
93 }
94
95 pub fn display_name(&self) -> &str {
96 match self {
97 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
98 Self::Four => "gpt-4",
99 Self::FourTurbo => "gpt-4-turbo",
100 Self::FourOmni => "gpt-4o",
101 Self::FourOmniMini => "gpt-4o-mini",
102 Self::Custom { name, .. } => name,
103 }
104 }
105
106 pub fn max_token_count(&self) -> usize {
107 match self {
108 Self::ThreePointFiveTurbo => 4096,
109 Self::Four => 8192,
110 Self::FourTurbo => 128000,
111 Self::FourOmni => 128000,
112 Self::FourOmniMini => 128000,
113 Self::Custom { max_tokens, .. } => *max_tokens,
114 }
115 }
116}
117
118#[derive(Debug, Serialize, Deserialize)]
119pub struct Request {
120 pub model: String,
121 pub messages: Vec<RequestMessage>,
122 pub stream: bool,
123 #[serde(default, skip_serializing_if = "Option::is_none")]
124 pub max_tokens: Option<usize>,
125 pub stop: Vec<String>,
126 pub temperature: f32,
127 #[serde(default, skip_serializing_if = "Option::is_none")]
128 pub tool_choice: Option<ToolChoice>,
129 #[serde(default, skip_serializing_if = "Vec::is_empty")]
130 pub tools: Vec<ToolDefinition>,
131}
132
133#[derive(Debug, Serialize, Deserialize)]
134#[serde(untagged)]
135pub enum ToolChoice {
136 Auto,
137 Required,
138 None,
139 Other(ToolDefinition),
140}
141
142#[derive(Clone, Deserialize, Serialize, Debug)]
143#[serde(tag = "type", rename_all = "snake_case")]
144pub enum ToolDefinition {
145 #[allow(dead_code)]
146 Function { function: FunctionDefinition },
147}
148
149#[derive(Clone, Debug, Serialize, Deserialize)]
150pub struct FunctionDefinition {
151 pub name: String,
152 pub description: Option<String>,
153 pub parameters: Option<Value>,
154}
155
156#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
157#[serde(tag = "role", rename_all = "lowercase")]
158pub enum RequestMessage {
159 Assistant {
160 content: Option<String>,
161 #[serde(default, skip_serializing_if = "Vec::is_empty")]
162 tool_calls: Vec<ToolCall>,
163 },
164 User {
165 content: String,
166 },
167 System {
168 content: String,
169 },
170 Tool {
171 content: String,
172 tool_call_id: String,
173 },
174}
175
176#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
177pub struct ToolCall {
178 pub id: String,
179 #[serde(flatten)]
180 pub content: ToolCallContent,
181}
182
183#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
184#[serde(tag = "type", rename_all = "lowercase")]
185pub enum ToolCallContent {
186 Function { function: FunctionContent },
187}
188
189#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
190pub struct FunctionContent {
191 pub name: String,
192 pub arguments: String,
193}
194
195#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
196pub struct ResponseMessageDelta {
197 pub role: Option<Role>,
198 pub content: Option<String>,
199 #[serde(default, skip_serializing_if = "is_none_or_empty")]
200 pub tool_calls: Option<Vec<ToolCallChunk>>,
201}
202
203#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
204pub struct ToolCallChunk {
205 pub index: usize,
206 pub id: Option<String>,
207
208 // There is also an optional `type` field that would determine if a
209 // function is there. Sometimes this streams in with the `function` before
210 // it streams in the `type`
211 pub function: Option<FunctionChunk>,
212}
213
214#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
215pub struct FunctionChunk {
216 pub name: Option<String>,
217 pub arguments: Option<String>,
218}
219
220#[derive(Serialize, Deserialize, Debug)]
221pub struct Usage {
222 pub prompt_tokens: u32,
223 pub completion_tokens: u32,
224 pub total_tokens: u32,
225}
226
227#[derive(Serialize, Deserialize, Debug)]
228pub struct ChoiceDelta {
229 pub index: u32,
230 pub delta: ResponseMessageDelta,
231 pub finish_reason: Option<String>,
232}
233
234#[derive(Serialize, Deserialize, Debug)]
235#[serde(untagged)]
236pub enum ResponseStreamResult {
237 Ok(ResponseStreamEvent),
238 Err { error: String },
239}
240
241#[derive(Serialize, Deserialize, Debug)]
242pub struct ResponseStreamEvent {
243 pub created: u32,
244 pub model: String,
245 pub choices: Vec<ChoiceDelta>,
246 pub usage: Option<Usage>,
247}
248
249pub async fn stream_completion(
250 client: &dyn HttpClient,
251 api_url: &str,
252 api_key: &str,
253 request: Request,
254 low_speed_timeout: Option<Duration>,
255) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
256 let uri = format!("{api_url}/chat/completions");
257 let mut request_builder = HttpRequest::builder()
258 .method(Method::POST)
259 .uri(uri)
260 .header("Content-Type", "application/json")
261 .header("Authorization", format!("Bearer {}", api_key));
262
263 if let Some(low_speed_timeout) = low_speed_timeout {
264 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
265 };
266
267 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
268 let mut response = client.send(request).await?;
269 if response.status().is_success() {
270 let reader = BufReader::new(response.into_body());
271 Ok(reader
272 .lines()
273 .filter_map(|line| async move {
274 match line {
275 Ok(line) => {
276 let line = line.strip_prefix("data: ")?;
277 if line == "[DONE]" {
278 None
279 } else {
280 match serde_json::from_str(line) {
281 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
282 Ok(ResponseStreamResult::Err { error }) => {
283 Some(Err(anyhow!(error)))
284 }
285 Err(error) => Some(Err(anyhow!(error))),
286 }
287 }
288 }
289 Err(error) => Some(Err(anyhow!(error))),
290 }
291 })
292 .boxed())
293 } else {
294 let mut body = String::new();
295 response.body_mut().read_to_string(&mut body).await?;
296
297 #[derive(Deserialize)]
298 struct OpenAiResponse {
299 error: OpenAiError,
300 }
301
302 #[derive(Deserialize)]
303 struct OpenAiError {
304 message: String,
305 }
306
307 match serde_json::from_str::<OpenAiResponse>(&body) {
308 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
309 "Failed to connect to OpenAI API: {}",
310 response.error.message,
311 )),
312
313 _ => Err(anyhow!(
314 "Failed to connect to OpenAI API: {} {}",
315 response.status(),
316 body,
317 )),
318 }
319 }
320}
321
322#[derive(Copy, Clone, Serialize, Deserialize)]
323pub enum OpenAiEmbeddingModel {
324 #[serde(rename = "text-embedding-3-small")]
325 TextEmbedding3Small,
326 #[serde(rename = "text-embedding-3-large")]
327 TextEmbedding3Large,
328}
329
330#[derive(Serialize)]
331struct OpenAiEmbeddingRequest<'a> {
332 model: OpenAiEmbeddingModel,
333 input: Vec<&'a str>,
334}
335
336#[derive(Deserialize)]
337pub struct OpenAiEmbeddingResponse {
338 pub data: Vec<OpenAiEmbedding>,
339}
340
341#[derive(Deserialize)]
342pub struct OpenAiEmbedding {
343 pub embedding: Vec<f32>,
344}
345
346pub fn embed<'a>(
347 client: &dyn HttpClient,
348 api_url: &str,
349 api_key: &str,
350 model: OpenAiEmbeddingModel,
351 texts: impl IntoIterator<Item = &'a str>,
352) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
353 let uri = format!("{api_url}/embeddings");
354
355 let request = OpenAiEmbeddingRequest {
356 model,
357 input: texts.into_iter().collect(),
358 };
359 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
360 let request = HttpRequest::builder()
361 .method(Method::POST)
362 .uri(uri)
363 .header("Content-Type", "application/json")
364 .header("Authorization", format!("Bearer {}", api_key))
365 .body(body)
366 .map(|request| client.send(request));
367
368 async move {
369 let mut response = request?.await?;
370 let mut body = String::new();
371 response.body_mut().read_to_string(&mut body).await?;
372
373 if response.status().is_success() {
374 let response: OpenAiEmbeddingResponse =
375 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
376 Ok(response)
377 } else {
378 Err(anyhow!(
379 "error during embedding, status: {:?}, body: {:?}",
380 response.status(),
381 body
382 ))
383 }
384 }
385}
386
387pub async fn extract_tool_args_from_events(
388 tool_name: String,
389 mut events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
390) -> Result<impl Send + Stream<Item = Result<String>>> {
391 let mut tool_use_index = None;
392 let mut first_chunk = None;
393 while let Some(event) = events.next().await {
394 let call = event?.choices.into_iter().find_map(|choice| {
395 choice.delta.tool_calls?.into_iter().find_map(|call| {
396 if call.function.as_ref()?.name.as_deref()? == tool_name {
397 Some(call)
398 } else {
399 None
400 }
401 })
402 });
403 if let Some(call) = call {
404 tool_use_index = Some(call.index);
405 first_chunk = call.function.and_then(|func| func.arguments);
406 break;
407 }
408 }
409
410 let Some(tool_use_index) = tool_use_index else {
411 return Err(anyhow!("tool not used"));
412 };
413
414 Ok(events.filter_map(move |event| {
415 let result = match event {
416 Err(error) => Some(Err(error)),
417 Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| {
418 choice.delta.tool_calls?.into_iter().find_map(|call| {
419 if call.index == tool_use_index {
420 let func = call.function?;
421 let mut arguments = func.arguments?;
422 if let Some(mut first_chunk) = first_chunk.take() {
423 first_chunk.push_str(&arguments);
424 arguments = first_chunk
425 }
426 Some(Ok(arguments))
427 } else {
428 None
429 }
430 })
431 }),
432 };
433
434 async move { result }
435 }))
436}
437
438pub fn extract_text_from_events(
439 response: impl Stream<Item = Result<ResponseStreamEvent>>,
440) -> impl Stream<Item = Result<String>> {
441 response.filter_map(|response| async move {
442 match response {
443 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
444 Err(error) => Some(Err(error)),
445 }
446 })
447}