1mod supported_countries;
2
3use anyhow::{anyhow, Context, Result};
4use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
5use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
6use isahc::config::Configurable;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::{convert::TryFrom, future::Future, pin::Pin, time::Duration};
10use strum::EnumIter;
11
12pub use supported_countries::*;
13
14pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
15
16fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
17 opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
18}
19
20#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
21#[serde(rename_all = "lowercase")]
22pub enum Role {
23 User,
24 Assistant,
25 System,
26 Tool,
27}
28
29impl TryFrom<String> for Role {
30 type Error = anyhow::Error;
31
32 fn try_from(value: String) -> Result<Self> {
33 match value.as_str() {
34 "user" => Ok(Self::User),
35 "assistant" => Ok(Self::Assistant),
36 "system" => Ok(Self::System),
37 "tool" => Ok(Self::Tool),
38 _ => Err(anyhow!("invalid role '{value}'")),
39 }
40 }
41}
42
43impl From<Role> for String {
44 fn from(val: Role) -> Self {
45 match val {
46 Role::User => "user".to_owned(),
47 Role::Assistant => "assistant".to_owned(),
48 Role::System => "system".to_owned(),
49 Role::Tool => "tool".to_owned(),
50 }
51 }
52}
53
54#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
55#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
56pub enum Model {
57 #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
58 ThreePointFiveTurbo,
59 #[serde(rename = "gpt-4", alias = "gpt-4-0613")]
60 Four,
61 #[serde(rename = "gpt-4-turbo-preview", alias = "gpt-4-1106-preview")]
62 FourTurbo,
63 #[serde(rename = "gpt-4o", alias = "gpt-4o-2024-05-13")]
64 #[default]
65 FourOmni,
66 #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini-2024-07-18")]
67 FourOmniMini,
68 #[serde(rename = "custom")]
69 Custom {
70 name: String,
71 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
72 display_name: Option<String>,
73 max_tokens: usize,
74 max_output_tokens: Option<u32>,
75 },
76}
77
78impl Model {
79 pub fn from_id(id: &str) -> Result<Self> {
80 match id {
81 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
82 "gpt-4" => Ok(Self::Four),
83 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
84 "gpt-4o" => Ok(Self::FourOmni),
85 "gpt-4o-mini" => Ok(Self::FourOmniMini),
86 _ => Err(anyhow!("invalid model id")),
87 }
88 }
89
90 pub fn id(&self) -> &str {
91 match self {
92 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
93 Self::Four => "gpt-4",
94 Self::FourTurbo => "gpt-4-turbo-preview",
95 Self::FourOmni => "gpt-4o",
96 Self::FourOmniMini => "gpt-4o-mini",
97 Self::Custom { name, .. } => name,
98 }
99 }
100
101 pub fn display_name(&self) -> &str {
102 match self {
103 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
104 Self::Four => "gpt-4",
105 Self::FourTurbo => "gpt-4-turbo",
106 Self::FourOmni => "gpt-4o",
107 Self::FourOmniMini => "gpt-4o-mini",
108 Self::Custom {
109 name, display_name, ..
110 } => display_name.as_ref().unwrap_or(name),
111 }
112 }
113
114 pub fn max_token_count(&self) -> usize {
115 match self {
116 Self::ThreePointFiveTurbo => 4096,
117 Self::Four => 8192,
118 Self::FourTurbo => 128000,
119 Self::FourOmni => 128000,
120 Self::FourOmniMini => 128000,
121 Self::Custom { max_tokens, .. } => *max_tokens,
122 }
123 }
124
125 pub fn max_output_tokens(&self) -> Option<u32> {
126 match self {
127 Self::Custom {
128 max_output_tokens, ..
129 } => *max_output_tokens,
130 _ => None,
131 }
132 }
133}
134
135#[derive(Debug, Serialize, Deserialize)]
136pub struct Request {
137 pub model: String,
138 pub messages: Vec<RequestMessage>,
139 pub stream: bool,
140 #[serde(default, skip_serializing_if = "Option::is_none")]
141 pub max_tokens: Option<u32>,
142 pub stop: Vec<String>,
143 pub temperature: f32,
144 #[serde(default, skip_serializing_if = "Option::is_none")]
145 pub tool_choice: Option<ToolChoice>,
146 #[serde(default, skip_serializing_if = "Vec::is_empty")]
147 pub tools: Vec<ToolDefinition>,
148}
149
150#[derive(Debug, Serialize, Deserialize)]
151#[serde(untagged)]
152pub enum ToolChoice {
153 Auto,
154 Required,
155 None,
156 Other(ToolDefinition),
157}
158
159#[derive(Clone, Deserialize, Serialize, Debug)]
160#[serde(tag = "type", rename_all = "snake_case")]
161pub enum ToolDefinition {
162 #[allow(dead_code)]
163 Function { function: FunctionDefinition },
164}
165
166#[derive(Clone, Debug, Serialize, Deserialize)]
167pub struct FunctionDefinition {
168 pub name: String,
169 pub description: Option<String>,
170 pub parameters: Option<Value>,
171}
172
173#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
174#[serde(tag = "role", rename_all = "lowercase")]
175pub enum RequestMessage {
176 Assistant {
177 content: Option<String>,
178 #[serde(default, skip_serializing_if = "Vec::is_empty")]
179 tool_calls: Vec<ToolCall>,
180 },
181 User {
182 content: String,
183 },
184 System {
185 content: String,
186 },
187 Tool {
188 content: String,
189 tool_call_id: String,
190 },
191}
192
193#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
194pub struct ToolCall {
195 pub id: String,
196 #[serde(flatten)]
197 pub content: ToolCallContent,
198}
199
200#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
201#[serde(tag = "type", rename_all = "lowercase")]
202pub enum ToolCallContent {
203 Function { function: FunctionContent },
204}
205
206#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
207pub struct FunctionContent {
208 pub name: String,
209 pub arguments: String,
210}
211
212#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
213pub struct ResponseMessageDelta {
214 pub role: Option<Role>,
215 pub content: Option<String>,
216 #[serde(default, skip_serializing_if = "is_none_or_empty")]
217 pub tool_calls: Option<Vec<ToolCallChunk>>,
218}
219
220#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
221pub struct ToolCallChunk {
222 pub index: usize,
223 pub id: Option<String>,
224
225 // There is also an optional `type` field that would determine if a
226 // function is there. Sometimes this streams in with the `function` before
227 // it streams in the `type`
228 pub function: Option<FunctionChunk>,
229}
230
231#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
232pub struct FunctionChunk {
233 pub name: Option<String>,
234 pub arguments: Option<String>,
235}
236
237#[derive(Serialize, Deserialize, Debug)]
238pub struct Usage {
239 pub prompt_tokens: u32,
240 pub completion_tokens: u32,
241 pub total_tokens: u32,
242}
243
244#[derive(Serialize, Deserialize, Debug)]
245pub struct ChoiceDelta {
246 pub index: u32,
247 pub delta: ResponseMessageDelta,
248 pub finish_reason: Option<String>,
249}
250
251#[derive(Serialize, Deserialize, Debug)]
252#[serde(untagged)]
253pub enum ResponseStreamResult {
254 Ok(ResponseStreamEvent),
255 Err { error: String },
256}
257
258#[derive(Serialize, Deserialize, Debug)]
259pub struct ResponseStreamEvent {
260 pub created: u32,
261 pub model: String,
262 pub choices: Vec<ChoiceDelta>,
263 pub usage: Option<Usage>,
264}
265
266pub async fn stream_completion(
267 client: &dyn HttpClient,
268 api_url: &str,
269 api_key: &str,
270 request: Request,
271 low_speed_timeout: Option<Duration>,
272) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
273 let uri = format!("{api_url}/chat/completions");
274 let mut request_builder = HttpRequest::builder()
275 .method(Method::POST)
276 .uri(uri)
277 .header("Content-Type", "application/json")
278 .header("Authorization", format!("Bearer {}", api_key));
279
280 if let Some(low_speed_timeout) = low_speed_timeout {
281 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
282 };
283
284 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
285 let mut response = client.send(request).await?;
286 if response.status().is_success() {
287 let reader = BufReader::new(response.into_body());
288 Ok(reader
289 .lines()
290 .filter_map(|line| async move {
291 match line {
292 Ok(line) => {
293 let line = line.strip_prefix("data: ")?;
294 if line == "[DONE]" {
295 None
296 } else {
297 match serde_json::from_str(line) {
298 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
299 Ok(ResponseStreamResult::Err { error }) => {
300 Some(Err(anyhow!(error)))
301 }
302 Err(error) => Some(Err(anyhow!(error))),
303 }
304 }
305 }
306 Err(error) => Some(Err(anyhow!(error))),
307 }
308 })
309 .boxed())
310 } else {
311 let mut body = String::new();
312 response.body_mut().read_to_string(&mut body).await?;
313
314 #[derive(Deserialize)]
315 struct OpenAiResponse {
316 error: OpenAiError,
317 }
318
319 #[derive(Deserialize)]
320 struct OpenAiError {
321 message: String,
322 }
323
324 match serde_json::from_str::<OpenAiResponse>(&body) {
325 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
326 "Failed to connect to OpenAI API: {}",
327 response.error.message,
328 )),
329
330 _ => Err(anyhow!(
331 "Failed to connect to OpenAI API: {} {}",
332 response.status(),
333 body,
334 )),
335 }
336 }
337}
338
339#[derive(Copy, Clone, Serialize, Deserialize)]
340pub enum OpenAiEmbeddingModel {
341 #[serde(rename = "text-embedding-3-small")]
342 TextEmbedding3Small,
343 #[serde(rename = "text-embedding-3-large")]
344 TextEmbedding3Large,
345}
346
347#[derive(Serialize)]
348struct OpenAiEmbeddingRequest<'a> {
349 model: OpenAiEmbeddingModel,
350 input: Vec<&'a str>,
351}
352
353#[derive(Deserialize)]
354pub struct OpenAiEmbeddingResponse {
355 pub data: Vec<OpenAiEmbedding>,
356}
357
358#[derive(Deserialize)]
359pub struct OpenAiEmbedding {
360 pub embedding: Vec<f32>,
361}
362
363pub fn embed<'a>(
364 client: &dyn HttpClient,
365 api_url: &str,
366 api_key: &str,
367 model: OpenAiEmbeddingModel,
368 texts: impl IntoIterator<Item = &'a str>,
369) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
370 let uri = format!("{api_url}/embeddings");
371
372 let request = OpenAiEmbeddingRequest {
373 model,
374 input: texts.into_iter().collect(),
375 };
376 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
377 let request = HttpRequest::builder()
378 .method(Method::POST)
379 .uri(uri)
380 .header("Content-Type", "application/json")
381 .header("Authorization", format!("Bearer {}", api_key))
382 .body(body)
383 .map(|request| client.send(request));
384
385 async move {
386 let mut response = request?.await?;
387 let mut body = String::new();
388 response.body_mut().read_to_string(&mut body).await?;
389
390 if response.status().is_success() {
391 let response: OpenAiEmbeddingResponse =
392 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
393 Ok(response)
394 } else {
395 Err(anyhow!(
396 "error during embedding, status: {:?}, body: {:?}",
397 response.status(),
398 body
399 ))
400 }
401 }
402}
403
404pub async fn extract_tool_args_from_events(
405 tool_name: String,
406 mut events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
407) -> Result<impl Send + Stream<Item = Result<String>>> {
408 let mut tool_use_index = None;
409 let mut first_chunk = None;
410 while let Some(event) = events.next().await {
411 let call = event?.choices.into_iter().find_map(|choice| {
412 choice.delta.tool_calls?.into_iter().find_map(|call| {
413 if call.function.as_ref()?.name.as_deref()? == tool_name {
414 Some(call)
415 } else {
416 None
417 }
418 })
419 });
420 if let Some(call) = call {
421 tool_use_index = Some(call.index);
422 first_chunk = call.function.and_then(|func| func.arguments);
423 break;
424 }
425 }
426
427 let Some(tool_use_index) = tool_use_index else {
428 return Err(anyhow!("tool not used"));
429 };
430
431 Ok(events.filter_map(move |event| {
432 let result = match event {
433 Err(error) => Some(Err(error)),
434 Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| {
435 choice.delta.tool_calls?.into_iter().find_map(|call| {
436 if call.index == tool_use_index {
437 let func = call.function?;
438 let mut arguments = func.arguments?;
439 if let Some(mut first_chunk) = first_chunk.take() {
440 first_chunk.push_str(&arguments);
441 arguments = first_chunk
442 }
443 Some(Ok(arguments))
444 } else {
445 None
446 }
447 })
448 }),
449 };
450
451 async move { result }
452 }))
453}
454
455pub fn extract_text_from_events(
456 response: impl Stream<Item = Result<ResponseStreamEvent>>,
457) -> impl Stream<Item = Result<String>> {
458 response.filter_map(|response| async move {
459 match response {
460 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
461 Err(error) => Some(Err(error)),
462 }
463 })
464}