1mod supported_countries;
2
3use anyhow::{Context as _, Result, anyhow};
4use futures::{
5 AsyncBufReadExt, AsyncReadExt, StreamExt,
6 io::BufReader,
7 stream::{self, BoxStream},
8};
9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::{
13 convert::TryFrom,
14 future::{self, Future},
15};
16use strum::EnumIter;
17
18pub use supported_countries::*;
19
20pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
21
22fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
23 opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
24}
25
26#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
27#[serde(rename_all = "lowercase")]
28pub enum Role {
29 User,
30 Assistant,
31 System,
32 Tool,
33}
34
35impl TryFrom<String> for Role {
36 type Error = anyhow::Error;
37
38 fn try_from(value: String) -> Result<Self> {
39 match value.as_str() {
40 "user" => Ok(Self::User),
41 "assistant" => Ok(Self::Assistant),
42 "system" => Ok(Self::System),
43 "tool" => Ok(Self::Tool),
44 _ => Err(anyhow!("invalid role '{value}'")),
45 }
46 }
47}
48
49impl From<Role> for String {
50 fn from(val: Role) -> Self {
51 match val {
52 Role::User => "user".to_owned(),
53 Role::Assistant => "assistant".to_owned(),
54 Role::System => "system".to_owned(),
55 Role::Tool => "tool".to_owned(),
56 }
57 }
58}
59
60#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
61#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
62pub enum Model {
63 #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo")]
64 ThreePointFiveTurbo,
65 #[serde(rename = "gpt-4", alias = "gpt-4")]
66 Four,
67 #[serde(rename = "gpt-4-turbo", alias = "gpt-4-turbo")]
68 FourTurbo,
69 #[serde(rename = "gpt-4o", alias = "gpt-4o")]
70 #[default]
71 FourOmni,
72 #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini")]
73 FourOmniMini,
74 #[serde(rename = "gpt-4.1", alias = "gpt-4.1")]
75 FourPointOne,
76 #[serde(rename = "gpt-4.1-mini", alias = "gpt-4.1-mini")]
77 FourPointOneMini,
78 #[serde(rename = "gpt-4.1-nano", alias = "gpt-4.1-nano")]
79 FourPointOneNano,
80 #[serde(rename = "o1", alias = "o1")]
81 O1,
82 #[serde(rename = "o1-preview", alias = "o1-preview")]
83 O1Preview,
84 #[serde(rename = "o1-mini", alias = "o1-mini")]
85 O1Mini,
86 #[serde(rename = "o3-mini", alias = "o3-mini")]
87 O3Mini,
88 #[serde(rename = "o3", alias = "o3")]
89 O3,
90 #[serde(rename = "o4-mini", alias = "o4-mini")]
91 O4Mini,
92
93 #[serde(rename = "custom")]
94 Custom {
95 name: String,
96 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
97 display_name: Option<String>,
98 max_tokens: usize,
99 max_output_tokens: Option<u32>,
100 max_completion_tokens: Option<u32>,
101 },
102}
103
104impl Model {
105 pub fn default_fast() -> Self {
106 Self::FourPointOneMini
107 }
108
109 pub fn from_id(id: &str) -> Result<Self> {
110 match id {
111 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
112 "gpt-4" => Ok(Self::Four),
113 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
114 "gpt-4o" => Ok(Self::FourOmni),
115 "gpt-4o-mini" => Ok(Self::FourOmniMini),
116 "gpt-4.1" => Ok(Self::FourPointOne),
117 "gpt-4.1-mini" => Ok(Self::FourPointOneMini),
118 "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
119 "o1" => Ok(Self::O1),
120 "o1-preview" => Ok(Self::O1Preview),
121 "o1-mini" => Ok(Self::O1Mini),
122 "o3-mini" => Ok(Self::O3Mini),
123 "o3" => Ok(Self::O3),
124 "o4-mini" => Ok(Self::O4Mini),
125 _ => Err(anyhow!("invalid model id")),
126 }
127 }
128
129 pub fn id(&self) -> &str {
130 match self {
131 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
132 Self::Four => "gpt-4",
133 Self::FourTurbo => "gpt-4-turbo",
134 Self::FourOmni => "gpt-4o",
135 Self::FourOmniMini => "gpt-4o-mini",
136 Self::FourPointOne => "gpt-4.1",
137 Self::FourPointOneMini => "gpt-4.1-mini",
138 Self::FourPointOneNano => "gpt-4.1-nano",
139 Self::O1 => "o1",
140 Self::O1Preview => "o1-preview",
141 Self::O1Mini => "o1-mini",
142 Self::O3Mini => "o3-mini",
143 Self::O3 => "o3",
144 Self::O4Mini => "o4-mini",
145 Self::Custom { name, .. } => name,
146 }
147 }
148
149 pub fn display_name(&self) -> &str {
150 match self {
151 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
152 Self::Four => "gpt-4",
153 Self::FourTurbo => "gpt-4-turbo",
154 Self::FourOmni => "gpt-4o",
155 Self::FourOmniMini => "gpt-4o-mini",
156 Self::FourPointOne => "gpt-4.1",
157 Self::FourPointOneMini => "gpt-4.1-mini",
158 Self::FourPointOneNano => "gpt-4.1-nano",
159 Self::O1 => "o1",
160 Self::O1Preview => "o1-preview",
161 Self::O1Mini => "o1-mini",
162 Self::O3Mini => "o3-mini",
163 Self::O3 => "o3",
164 Self::O4Mini => "o4-mini",
165 Self::Custom {
166 name, display_name, ..
167 } => display_name.as_ref().unwrap_or(name),
168 }
169 }
170
171 pub fn max_token_count(&self) -> usize {
172 match self {
173 Self::ThreePointFiveTurbo => 16_385,
174 Self::Four => 8_192,
175 Self::FourTurbo => 128_000,
176 Self::FourOmni => 128_000,
177 Self::FourOmniMini => 128_000,
178 Self::FourPointOne => 1_047_576,
179 Self::FourPointOneMini => 1_047_576,
180 Self::FourPointOneNano => 1_047_576,
181 Self::O1 => 200_000,
182 Self::O1Preview => 128_000,
183 Self::O1Mini => 128_000,
184 Self::O3Mini => 200_000,
185 Self::O3 => 200_000,
186 Self::O4Mini => 200_000,
187 Self::Custom { max_tokens, .. } => *max_tokens,
188 }
189 }
190
191 pub fn max_output_tokens(&self) -> Option<u32> {
192 match self {
193 Self::Custom {
194 max_output_tokens, ..
195 } => *max_output_tokens,
196 _ => None,
197 }
198 }
199
200 /// Returns whether the given model supports the `parallel_tool_calls` parameter.
201 ///
202 /// If the model does not support the parameter, do not pass it up, or the API will return an error.
203 pub fn supports_parallel_tool_calls(&self) -> bool {
204 match self {
205 Self::ThreePointFiveTurbo
206 | Self::Four
207 | Self::FourTurbo
208 | Self::FourOmni
209 | Self::FourOmniMini
210 | Self::FourPointOne
211 | Self::FourPointOneMini
212 | Self::FourPointOneNano
213 | Self::O1
214 | Self::O1Preview
215 | Self::O1Mini => true,
216 _ => false,
217 }
218 }
219}
220
221#[derive(Debug, Serialize, Deserialize)]
222pub struct Request {
223 pub model: String,
224 pub messages: Vec<RequestMessage>,
225 pub stream: bool,
226 #[serde(default, skip_serializing_if = "Option::is_none")]
227 pub max_tokens: Option<u32>,
228 #[serde(default, skip_serializing_if = "Vec::is_empty")]
229 pub stop: Vec<String>,
230 pub temperature: f32,
231 #[serde(default, skip_serializing_if = "Option::is_none")]
232 pub tool_choice: Option<ToolChoice>,
233 /// Whether to enable parallel function calling during tool use.
234 #[serde(default, skip_serializing_if = "Option::is_none")]
235 pub parallel_tool_calls: Option<bool>,
236 #[serde(default, skip_serializing_if = "Vec::is_empty")]
237 pub tools: Vec<ToolDefinition>,
238}
239
240#[derive(Debug, Serialize, Deserialize)]
241pub struct CompletionRequest {
242 pub model: String,
243 pub prompt: String,
244 pub max_tokens: u32,
245 pub temperature: f32,
246 #[serde(default, skip_serializing_if = "Option::is_none")]
247 pub prediction: Option<Prediction>,
248 #[serde(default, skip_serializing_if = "Option::is_none")]
249 pub rewrite_speculation: Option<bool>,
250}
251
252#[derive(Clone, Deserialize, Serialize, Debug)]
253#[serde(tag = "type", rename_all = "snake_case")]
254pub enum Prediction {
255 Content { content: String },
256}
257
258#[derive(Debug, Serialize, Deserialize)]
259#[serde(untagged)]
260pub enum ToolChoice {
261 Auto,
262 Required,
263 None,
264 Other(ToolDefinition),
265}
266
267#[derive(Clone, Deserialize, Serialize, Debug)]
268#[serde(tag = "type", rename_all = "snake_case")]
269pub enum ToolDefinition {
270 #[allow(dead_code)]
271 Function { function: FunctionDefinition },
272}
273
274#[derive(Clone, Debug, Serialize, Deserialize)]
275pub struct FunctionDefinition {
276 pub name: String,
277 pub description: Option<String>,
278 pub parameters: Option<Value>,
279}
280
281#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
282#[serde(tag = "role", rename_all = "lowercase")]
283pub enum RequestMessage {
284 Assistant {
285 content: Option<String>,
286 #[serde(default, skip_serializing_if = "Vec::is_empty")]
287 tool_calls: Vec<ToolCall>,
288 },
289 User {
290 content: String,
291 },
292 System {
293 content: String,
294 },
295 Tool {
296 content: String,
297 tool_call_id: String,
298 },
299}
300
301#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
302pub struct ToolCall {
303 pub id: String,
304 #[serde(flatten)]
305 pub content: ToolCallContent,
306}
307
308#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
309#[serde(tag = "type", rename_all = "lowercase")]
310pub enum ToolCallContent {
311 Function { function: FunctionContent },
312}
313
314#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
315pub struct FunctionContent {
316 pub name: String,
317 pub arguments: String,
318}
319
320#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
321pub struct ResponseMessageDelta {
322 pub role: Option<Role>,
323 pub content: Option<String>,
324 #[serde(default, skip_serializing_if = "is_none_or_empty")]
325 pub tool_calls: Option<Vec<ToolCallChunk>>,
326}
327
328#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
329pub struct ToolCallChunk {
330 pub index: usize,
331 pub id: Option<String>,
332
333 // There is also an optional `type` field that would determine if a
334 // function is there. Sometimes this streams in with the `function` before
335 // it streams in the `type`
336 pub function: Option<FunctionChunk>,
337}
338
339#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
340pub struct FunctionChunk {
341 pub name: Option<String>,
342 pub arguments: Option<String>,
343}
344
345#[derive(Serialize, Deserialize, Debug)]
346pub struct Usage {
347 pub prompt_tokens: u32,
348 pub completion_tokens: u32,
349 pub total_tokens: u32,
350}
351
352#[derive(Serialize, Deserialize, Debug)]
353pub struct ChoiceDelta {
354 pub index: u32,
355 pub delta: ResponseMessageDelta,
356 pub finish_reason: Option<String>,
357}
358
359#[derive(Serialize, Deserialize, Debug)]
360#[serde(untagged)]
361pub enum ResponseStreamResult {
362 Ok(ResponseStreamEvent),
363 Err { error: String },
364}
365
366#[derive(Serialize, Deserialize, Debug)]
367pub struct ResponseStreamEvent {
368 pub created: u32,
369 pub model: String,
370 pub choices: Vec<ChoiceDelta>,
371 pub usage: Option<Usage>,
372}
373
374#[derive(Serialize, Deserialize, Debug)]
375pub struct CompletionResponse {
376 pub id: String,
377 pub object: String,
378 pub created: u64,
379 pub model: String,
380 pub choices: Vec<CompletionChoice>,
381 pub usage: Usage,
382}
383
384#[derive(Serialize, Deserialize, Debug)]
385pub struct CompletionChoice {
386 pub text: String,
387}
388
389#[derive(Serialize, Deserialize, Debug)]
390pub struct Response {
391 pub id: String,
392 pub object: String,
393 pub created: u64,
394 pub model: String,
395 pub choices: Vec<Choice>,
396 pub usage: Usage,
397}
398
399#[derive(Serialize, Deserialize, Debug)]
400pub struct Choice {
401 pub index: u32,
402 pub message: RequestMessage,
403 pub finish_reason: Option<String>,
404}
405
406pub async fn complete(
407 client: &dyn HttpClient,
408 api_url: &str,
409 api_key: &str,
410 request: Request,
411) -> Result<Response> {
412 let uri = format!("{api_url}/chat/completions");
413 let request_builder = HttpRequest::builder()
414 .method(Method::POST)
415 .uri(uri)
416 .header("Content-Type", "application/json")
417 .header("Authorization", format!("Bearer {}", api_key));
418
419 let mut request_body = request;
420 request_body.stream = false;
421
422 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
423 let mut response = client.send(request).await?;
424
425 if response.status().is_success() {
426 let mut body = String::new();
427 response.body_mut().read_to_string(&mut body).await?;
428 let response: Response = serde_json::from_str(&body)?;
429 Ok(response)
430 } else {
431 let mut body = String::new();
432 response.body_mut().read_to_string(&mut body).await?;
433
434 #[derive(Deserialize)]
435 struct OpenAiResponse {
436 error: OpenAiError,
437 }
438
439 #[derive(Deserialize)]
440 struct OpenAiError {
441 message: String,
442 }
443
444 match serde_json::from_str::<OpenAiResponse>(&body) {
445 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
446 "Failed to connect to OpenAI API: {}",
447 response.error.message,
448 )),
449
450 _ => Err(anyhow!(
451 "Failed to connect to OpenAI API: {} {}",
452 response.status(),
453 body,
454 )),
455 }
456 }
457}
458
459pub async fn complete_text(
460 client: &dyn HttpClient,
461 api_url: &str,
462 api_key: &str,
463 request: CompletionRequest,
464) -> Result<CompletionResponse> {
465 let uri = format!("{api_url}/completions");
466 let request_builder = HttpRequest::builder()
467 .method(Method::POST)
468 .uri(uri)
469 .header("Content-Type", "application/json")
470 .header("Authorization", format!("Bearer {}", api_key));
471
472 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
473 let mut response = client.send(request).await?;
474
475 if response.status().is_success() {
476 let mut body = String::new();
477 response.body_mut().read_to_string(&mut body).await?;
478 let response = serde_json::from_str(&body)?;
479 Ok(response)
480 } else {
481 let mut body = String::new();
482 response.body_mut().read_to_string(&mut body).await?;
483
484 #[derive(Deserialize)]
485 struct OpenAiResponse {
486 error: OpenAiError,
487 }
488
489 #[derive(Deserialize)]
490 struct OpenAiError {
491 message: String,
492 }
493
494 match serde_json::from_str::<OpenAiResponse>(&body) {
495 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
496 "Failed to connect to OpenAI API: {}",
497 response.error.message,
498 )),
499
500 _ => Err(anyhow!(
501 "Failed to connect to OpenAI API: {} {}",
502 response.status(),
503 body,
504 )),
505 }
506 }
507}
508
509fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
510 ResponseStreamEvent {
511 created: response.created as u32,
512 model: response.model,
513 choices: response
514 .choices
515 .into_iter()
516 .map(|choice| ChoiceDelta {
517 index: choice.index,
518 delta: ResponseMessageDelta {
519 role: Some(match choice.message {
520 RequestMessage::Assistant { .. } => Role::Assistant,
521 RequestMessage::User { .. } => Role::User,
522 RequestMessage::System { .. } => Role::System,
523 RequestMessage::Tool { .. } => Role::Tool,
524 }),
525 content: match choice.message {
526 RequestMessage::Assistant { content, .. } => content,
527 RequestMessage::User { content } => Some(content),
528 RequestMessage::System { content } => Some(content),
529 RequestMessage::Tool { content, .. } => Some(content),
530 },
531 tool_calls: None,
532 },
533 finish_reason: choice.finish_reason,
534 })
535 .collect(),
536 usage: Some(response.usage),
537 }
538}
539
540pub async fn stream_completion(
541 client: &dyn HttpClient,
542 api_url: &str,
543 api_key: &str,
544 request: Request,
545) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
546 if request.model.starts_with("o1") {
547 let response = complete(client, api_url, api_key, request).await;
548 let response_stream_event = response.map(adapt_response_to_stream);
549 return Ok(stream::once(future::ready(response_stream_event)).boxed());
550 }
551
552 let uri = format!("{api_url}/chat/completions");
553 let request_builder = HttpRequest::builder()
554 .method(Method::POST)
555 .uri(uri)
556 .header("Content-Type", "application/json")
557 .header("Authorization", format!("Bearer {}", api_key));
558
559 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
560 let mut response = client.send(request).await?;
561 if response.status().is_success() {
562 let reader = BufReader::new(response.into_body());
563 Ok(reader
564 .lines()
565 .filter_map(|line| async move {
566 match line {
567 Ok(line) => {
568 let line = line.strip_prefix("data: ")?;
569 if line == "[DONE]" {
570 None
571 } else {
572 match serde_json::from_str(line) {
573 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
574 Ok(ResponseStreamResult::Err { error }) => {
575 Some(Err(anyhow!(error)))
576 }
577 Err(error) => Some(Err(anyhow!(error))),
578 }
579 }
580 }
581 Err(error) => Some(Err(anyhow!(error))),
582 }
583 })
584 .boxed())
585 } else {
586 let mut body = String::new();
587 response.body_mut().read_to_string(&mut body).await?;
588
589 #[derive(Deserialize)]
590 struct OpenAiResponse {
591 error: OpenAiError,
592 }
593
594 #[derive(Deserialize)]
595 struct OpenAiError {
596 message: String,
597 }
598
599 match serde_json::from_str::<OpenAiResponse>(&body) {
600 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
601 "Failed to connect to OpenAI API: {}",
602 response.error.message,
603 )),
604
605 _ => Err(anyhow!(
606 "Failed to connect to OpenAI API: {} {}",
607 response.status(),
608 body,
609 )),
610 }
611 }
612}
613
614#[derive(Copy, Clone, Serialize, Deserialize)]
615pub enum OpenAiEmbeddingModel {
616 #[serde(rename = "text-embedding-3-small")]
617 TextEmbedding3Small,
618 #[serde(rename = "text-embedding-3-large")]
619 TextEmbedding3Large,
620}
621
622#[derive(Serialize)]
623struct OpenAiEmbeddingRequest<'a> {
624 model: OpenAiEmbeddingModel,
625 input: Vec<&'a str>,
626}
627
628#[derive(Deserialize)]
629pub struct OpenAiEmbeddingResponse {
630 pub data: Vec<OpenAiEmbedding>,
631}
632
633#[derive(Deserialize)]
634pub struct OpenAiEmbedding {
635 pub embedding: Vec<f32>,
636}
637
638pub fn embed<'a>(
639 client: &dyn HttpClient,
640 api_url: &str,
641 api_key: &str,
642 model: OpenAiEmbeddingModel,
643 texts: impl IntoIterator<Item = &'a str>,
644) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
645 let uri = format!("{api_url}/embeddings");
646
647 let request = OpenAiEmbeddingRequest {
648 model,
649 input: texts.into_iter().collect(),
650 };
651 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
652 let request = HttpRequest::builder()
653 .method(Method::POST)
654 .uri(uri)
655 .header("Content-Type", "application/json")
656 .header("Authorization", format!("Bearer {}", api_key))
657 .body(body)
658 .map(|request| client.send(request));
659
660 async move {
661 let mut response = request?.await?;
662 let mut body = String::new();
663 response.body_mut().read_to_string(&mut body).await?;
664
665 if response.status().is_success() {
666 let response: OpenAiEmbeddingResponse =
667 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
668 Ok(response)
669 } else {
670 Err(anyhow!(
671 "error during embedding, status: {:?}, body: {:?}",
672 response.status(),
673 body
674 ))
675 }
676 }
677}