1mod supported_countries;
2
3use anyhow::{anyhow, Context, Result};
4use futures::{
5 io::BufReader,
6 stream::{self, BoxStream},
7 AsyncBufReadExt, AsyncReadExt, Stream, StreamExt,
8};
9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
10use isahc::config::Configurable;
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use std::{
14 convert::TryFrom,
15 future::{self, Future},
16 pin::Pin,
17 time::Duration,
18};
19use strum::EnumIter;
20
21pub use supported_countries::*;
22
23pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
24
25fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
26 opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
27}
28
29#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
30#[serde(rename_all = "lowercase")]
31pub enum Role {
32 User,
33 Assistant,
34 System,
35 Tool,
36}
37
38impl TryFrom<String> for Role {
39 type Error = anyhow::Error;
40
41 fn try_from(value: String) -> Result<Self> {
42 match value.as_str() {
43 "user" => Ok(Self::User),
44 "assistant" => Ok(Self::Assistant),
45 "system" => Ok(Self::System),
46 "tool" => Ok(Self::Tool),
47 _ => Err(anyhow!("invalid role '{value}'")),
48 }
49 }
50}
51
52impl From<Role> for String {
53 fn from(val: Role) -> Self {
54 match val {
55 Role::User => "user".to_owned(),
56 Role::Assistant => "assistant".to_owned(),
57 Role::System => "system".to_owned(),
58 Role::Tool => "tool".to_owned(),
59 }
60 }
61}
62
63#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
64#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
65pub enum Model {
66 #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo")]
67 ThreePointFiveTurbo,
68 #[serde(rename = "gpt-4", alias = "gpt-4")]
69 Four,
70 #[serde(rename = "gpt-4-turbo", alias = "gpt-4-turbo")]
71 FourTurbo,
72 #[serde(rename = "gpt-4o", alias = "gpt-4o")]
73 #[default]
74 FourOmni,
75 #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini")]
76 FourOmniMini,
77 #[serde(rename = "o1-preview", alias = "o1-preview")]
78 O1Preview,
79 #[serde(rename = "o1-mini", alias = "o1-mini")]
80 O1Mini,
81
82 #[serde(rename = "custom")]
83 Custom {
84 name: String,
85 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
86 display_name: Option<String>,
87 max_tokens: usize,
88 max_output_tokens: Option<u32>,
89 max_completion_tokens: Option<u32>,
90 },
91}
92
93impl Model {
94 pub fn from_id(id: &str) -> Result<Self> {
95 match id {
96 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
97 "gpt-4" => Ok(Self::Four),
98 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
99 "gpt-4o" => Ok(Self::FourOmni),
100 "gpt-4o-mini" => Ok(Self::FourOmniMini),
101 "o1-preview" => Ok(Self::O1Preview),
102 "o1-mini" => Ok(Self::O1Mini),
103 _ => Err(anyhow!("invalid model id")),
104 }
105 }
106
107 pub fn id(&self) -> &str {
108 match self {
109 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
110 Self::Four => "gpt-4",
111 Self::FourTurbo => "gpt-4-turbo",
112 Self::FourOmni => "gpt-4o",
113 Self::FourOmniMini => "gpt-4o-mini",
114 Self::O1Preview => "o1-preview",
115 Self::O1Mini => "o1-mini",
116 Self::Custom { name, .. } => name,
117 }
118 }
119
120 pub fn display_name(&self) -> &str {
121 match self {
122 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
123 Self::Four => "gpt-4",
124 Self::FourTurbo => "gpt-4-turbo",
125 Self::FourOmni => "gpt-4o",
126 Self::FourOmniMini => "gpt-4o-mini",
127 Self::O1Preview => "o1-preview",
128 Self::O1Mini => "o1-mini",
129 Self::Custom {
130 name, display_name, ..
131 } => display_name.as_ref().unwrap_or(name),
132 }
133 }
134
135 pub fn max_token_count(&self) -> usize {
136 match self {
137 Self::ThreePointFiveTurbo => 16385,
138 Self::Four => 8192,
139 Self::FourTurbo => 128000,
140 Self::FourOmni => 128000,
141 Self::FourOmniMini => 128000,
142 Self::O1Preview => 128000,
143 Self::O1Mini => 128000,
144 Self::Custom { max_tokens, .. } => *max_tokens,
145 }
146 }
147
148 pub fn max_output_tokens(&self) -> Option<u32> {
149 match self {
150 Self::Custom {
151 max_output_tokens, ..
152 } => *max_output_tokens,
153 _ => None,
154 }
155 }
156}
157
158#[derive(Debug, Serialize, Deserialize)]
159pub struct Request {
160 pub model: String,
161 pub messages: Vec<RequestMessage>,
162 pub stream: bool,
163 #[serde(default, skip_serializing_if = "Option::is_none")]
164 pub max_tokens: Option<u32>,
165 #[serde(default, skip_serializing_if = "Vec::is_empty")]
166 pub stop: Vec<String>,
167 pub temperature: f32,
168 #[serde(default, skip_serializing_if = "Option::is_none")]
169 pub tool_choice: Option<ToolChoice>,
170 #[serde(default, skip_serializing_if = "Vec::is_empty")]
171 pub tools: Vec<ToolDefinition>,
172}
173
174#[derive(Debug, Serialize, Deserialize)]
175#[serde(untagged)]
176pub enum ToolChoice {
177 Auto,
178 Required,
179 None,
180 Other(ToolDefinition),
181}
182
183#[derive(Clone, Deserialize, Serialize, Debug)]
184#[serde(tag = "type", rename_all = "snake_case")]
185pub enum ToolDefinition {
186 #[allow(dead_code)]
187 Function { function: FunctionDefinition },
188}
189
190#[derive(Clone, Debug, Serialize, Deserialize)]
191pub struct FunctionDefinition {
192 pub name: String,
193 pub description: Option<String>,
194 pub parameters: Option<Value>,
195}
196
197#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
198#[serde(tag = "role", rename_all = "lowercase")]
199pub enum RequestMessage {
200 Assistant {
201 content: Option<String>,
202 #[serde(default, skip_serializing_if = "Vec::is_empty")]
203 tool_calls: Vec<ToolCall>,
204 },
205 User {
206 content: String,
207 },
208 System {
209 content: String,
210 },
211 Tool {
212 content: String,
213 tool_call_id: String,
214 },
215}
216
217#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
218pub struct ToolCall {
219 pub id: String,
220 #[serde(flatten)]
221 pub content: ToolCallContent,
222}
223
224#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
225#[serde(tag = "type", rename_all = "lowercase")]
226pub enum ToolCallContent {
227 Function { function: FunctionContent },
228}
229
230#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
231pub struct FunctionContent {
232 pub name: String,
233 pub arguments: String,
234}
235
236#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
237pub struct ResponseMessageDelta {
238 pub role: Option<Role>,
239 pub content: Option<String>,
240 #[serde(default, skip_serializing_if = "is_none_or_empty")]
241 pub tool_calls: Option<Vec<ToolCallChunk>>,
242}
243
244#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
245pub struct ToolCallChunk {
246 pub index: usize,
247 pub id: Option<String>,
248
249 // There is also an optional `type` field that would determine if a
250 // function is there. Sometimes this streams in with the `function` before
251 // it streams in the `type`
252 pub function: Option<FunctionChunk>,
253}
254
255#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
256pub struct FunctionChunk {
257 pub name: Option<String>,
258 pub arguments: Option<String>,
259}
260
261#[derive(Serialize, Deserialize, Debug)]
262pub struct Usage {
263 pub prompt_tokens: u32,
264 pub completion_tokens: u32,
265 pub total_tokens: u32,
266}
267
268#[derive(Serialize, Deserialize, Debug)]
269pub struct ChoiceDelta {
270 pub index: u32,
271 pub delta: ResponseMessageDelta,
272 pub finish_reason: Option<String>,
273}
274
275#[derive(Serialize, Deserialize, Debug)]
276#[serde(untagged)]
277pub enum ResponseStreamResult {
278 Ok(ResponseStreamEvent),
279 Err { error: String },
280}
281
282#[derive(Serialize, Deserialize, Debug)]
283pub struct ResponseStreamEvent {
284 pub created: u32,
285 pub model: String,
286 pub choices: Vec<ChoiceDelta>,
287 pub usage: Option<Usage>,
288}
289
290#[derive(Serialize, Deserialize, Debug)]
291pub struct Response {
292 pub id: String,
293 pub object: String,
294 pub created: u64,
295 pub model: String,
296 pub choices: Vec<Choice>,
297 pub usage: Usage,
298}
299
300#[derive(Serialize, Deserialize, Debug)]
301pub struct Choice {
302 pub index: u32,
303 pub message: RequestMessage,
304 pub finish_reason: Option<String>,
305}
306
307pub async fn complete(
308 client: &dyn HttpClient,
309 api_url: &str,
310 api_key: &str,
311 request: Request,
312 low_speed_timeout: Option<Duration>,
313) -> Result<Response> {
314 let uri = format!("{api_url}/chat/completions");
315 let mut request_builder = HttpRequest::builder()
316 .method(Method::POST)
317 .uri(uri)
318 .header("Content-Type", "application/json")
319 .header("Authorization", format!("Bearer {}", api_key));
320 if let Some(low_speed_timeout) = low_speed_timeout {
321 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
322 };
323
324 let mut request_body = request;
325 request_body.stream = false;
326
327 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
328 let mut response = client.send(request).await?;
329
330 if response.status().is_success() {
331 let mut body = String::new();
332 response.body_mut().read_to_string(&mut body).await?;
333 let response: Response = serde_json::from_str(&body)?;
334 Ok(response)
335 } else {
336 let mut body = String::new();
337 response.body_mut().read_to_string(&mut body).await?;
338
339 #[derive(Deserialize)]
340 struct OpenAiResponse {
341 error: OpenAiError,
342 }
343
344 #[derive(Deserialize)]
345 struct OpenAiError {
346 message: String,
347 }
348
349 match serde_json::from_str::<OpenAiResponse>(&body) {
350 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
351 "Failed to connect to OpenAI API: {}",
352 response.error.message,
353 )),
354
355 _ => Err(anyhow!(
356 "Failed to connect to OpenAI API: {} {}",
357 response.status(),
358 body,
359 )),
360 }
361 }
362}
363
364fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
365 ResponseStreamEvent {
366 created: response.created as u32,
367 model: response.model,
368 choices: response
369 .choices
370 .into_iter()
371 .map(|choice| ChoiceDelta {
372 index: choice.index,
373 delta: ResponseMessageDelta {
374 role: Some(match choice.message {
375 RequestMessage::Assistant { .. } => Role::Assistant,
376 RequestMessage::User { .. } => Role::User,
377 RequestMessage::System { .. } => Role::System,
378 RequestMessage::Tool { .. } => Role::Tool,
379 }),
380 content: match choice.message {
381 RequestMessage::Assistant { content, .. } => content,
382 RequestMessage::User { content } => Some(content),
383 RequestMessage::System { content } => Some(content),
384 RequestMessage::Tool { content, .. } => Some(content),
385 },
386 tool_calls: None,
387 },
388 finish_reason: choice.finish_reason,
389 })
390 .collect(),
391 usage: Some(response.usage),
392 }
393}
394
395pub async fn stream_completion(
396 client: &dyn HttpClient,
397 api_url: &str,
398 api_key: &str,
399 request: Request,
400 low_speed_timeout: Option<Duration>,
401) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
402 if request.model == "o1-preview" || request.model == "o1-mini" {
403 let response = complete(client, api_url, api_key, request, low_speed_timeout).await;
404 let response_stream_event = response.map(adapt_response_to_stream);
405 return Ok(stream::once(future::ready(response_stream_event)).boxed());
406 }
407
408 let uri = format!("{api_url}/chat/completions");
409 let mut request_builder = HttpRequest::builder()
410 .method(Method::POST)
411 .uri(uri)
412 .header("Content-Type", "application/json")
413 .header("Authorization", format!("Bearer {}", api_key));
414
415 if let Some(low_speed_timeout) = low_speed_timeout {
416 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
417 };
418
419 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
420 let mut response = client.send(request).await?;
421 if response.status().is_success() {
422 let reader = BufReader::new(response.into_body());
423 Ok(reader
424 .lines()
425 .filter_map(|line| async move {
426 match line {
427 Ok(line) => {
428 let line = line.strip_prefix("data: ")?;
429 if line == "[DONE]" {
430 None
431 } else {
432 match serde_json::from_str(line) {
433 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
434 Ok(ResponseStreamResult::Err { error }) => {
435 Some(Err(anyhow!(error)))
436 }
437 Err(error) => Some(Err(anyhow!(error))),
438 }
439 }
440 }
441 Err(error) => Some(Err(anyhow!(error))),
442 }
443 })
444 .boxed())
445 } else {
446 let mut body = String::new();
447 response.body_mut().read_to_string(&mut body).await?;
448
449 #[derive(Deserialize)]
450 struct OpenAiResponse {
451 error: OpenAiError,
452 }
453
454 #[derive(Deserialize)]
455 struct OpenAiError {
456 message: String,
457 }
458
459 match serde_json::from_str::<OpenAiResponse>(&body) {
460 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
461 "Failed to connect to OpenAI API: {}",
462 response.error.message,
463 )),
464
465 _ => Err(anyhow!(
466 "Failed to connect to OpenAI API: {} {}",
467 response.status(),
468 body,
469 )),
470 }
471 }
472}
473
474#[derive(Copy, Clone, Serialize, Deserialize)]
475pub enum OpenAiEmbeddingModel {
476 #[serde(rename = "text-embedding-3-small")]
477 TextEmbedding3Small,
478 #[serde(rename = "text-embedding-3-large")]
479 TextEmbedding3Large,
480}
481
482#[derive(Serialize)]
483struct OpenAiEmbeddingRequest<'a> {
484 model: OpenAiEmbeddingModel,
485 input: Vec<&'a str>,
486}
487
488#[derive(Deserialize)]
489pub struct OpenAiEmbeddingResponse {
490 pub data: Vec<OpenAiEmbedding>,
491}
492
493#[derive(Deserialize)]
494pub struct OpenAiEmbedding {
495 pub embedding: Vec<f32>,
496}
497
498pub fn embed<'a>(
499 client: &dyn HttpClient,
500 api_url: &str,
501 api_key: &str,
502 model: OpenAiEmbeddingModel,
503 texts: impl IntoIterator<Item = &'a str>,
504) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
505 let uri = format!("{api_url}/embeddings");
506
507 let request = OpenAiEmbeddingRequest {
508 model,
509 input: texts.into_iter().collect(),
510 };
511 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
512 let request = HttpRequest::builder()
513 .method(Method::POST)
514 .uri(uri)
515 .header("Content-Type", "application/json")
516 .header("Authorization", format!("Bearer {}", api_key))
517 .body(body)
518 .map(|request| client.send(request));
519
520 async move {
521 let mut response = request?.await?;
522 let mut body = String::new();
523 response.body_mut().read_to_string(&mut body).await?;
524
525 if response.status().is_success() {
526 let response: OpenAiEmbeddingResponse =
527 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
528 Ok(response)
529 } else {
530 Err(anyhow!(
531 "error during embedding, status: {:?}, body: {:?}",
532 response.status(),
533 body
534 ))
535 }
536 }
537}
538
539pub async fn extract_tool_args_from_events(
540 tool_name: String,
541 mut events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
542) -> Result<impl Send + Stream<Item = Result<String>>> {
543 let mut tool_use_index = None;
544 let mut first_chunk = None;
545 while let Some(event) = events.next().await {
546 let call = event?.choices.into_iter().find_map(|choice| {
547 choice.delta.tool_calls?.into_iter().find_map(|call| {
548 if call.function.as_ref()?.name.as_deref()? == tool_name {
549 Some(call)
550 } else {
551 None
552 }
553 })
554 });
555 if let Some(call) = call {
556 tool_use_index = Some(call.index);
557 first_chunk = call.function.and_then(|func| func.arguments);
558 break;
559 }
560 }
561
562 let Some(tool_use_index) = tool_use_index else {
563 return Err(anyhow!("tool not used"));
564 };
565
566 Ok(events.filter_map(move |event| {
567 let result = match event {
568 Err(error) => Some(Err(error)),
569 Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| {
570 choice.delta.tool_calls?.into_iter().find_map(|call| {
571 if call.index == tool_use_index {
572 let func = call.function?;
573 let mut arguments = func.arguments?;
574 if let Some(mut first_chunk) = first_chunk.take() {
575 first_chunk.push_str(&arguments);
576 arguments = first_chunk
577 }
578 Some(Ok(arguments))
579 } else {
580 None
581 }
582 })
583 }),
584 };
585
586 async move { result }
587 }))
588}
589
590pub fn extract_text_from_events(
591 response: impl Stream<Item = Result<ResponseStreamEvent>>,
592) -> impl Stream<Item = Result<String>> {
593 response.filter_map(|response| async move {
594 match response {
595 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
596 Err(error) => Some(Err(error)),
597 }
598 })
599}