1mod supported_countries;
2
3use anyhow::{Context as _, Result, anyhow};
4use futures::{
5 AsyncBufReadExt, AsyncReadExt, StreamExt,
6 io::BufReader,
7 stream::{self, BoxStream},
8};
9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::{
13 convert::TryFrom,
14 future::{self, Future},
15};
16use strum::EnumIter;
17
18pub use supported_countries::*;
19
20pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
21
22fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
23 opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
24}
25
26#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
27#[serde(rename_all = "lowercase")]
28pub enum Role {
29 User,
30 Assistant,
31 System,
32 Tool,
33}
34
35impl TryFrom<String> for Role {
36 type Error = anyhow::Error;
37
38 fn try_from(value: String) -> Result<Self> {
39 match value.as_str() {
40 "user" => Ok(Self::User),
41 "assistant" => Ok(Self::Assistant),
42 "system" => Ok(Self::System),
43 "tool" => Ok(Self::Tool),
44 _ => Err(anyhow!("invalid role '{value}'")),
45 }
46 }
47}
48
49impl From<Role> for String {
50 fn from(val: Role) -> Self {
51 match val {
52 Role::User => "user".to_owned(),
53 Role::Assistant => "assistant".to_owned(),
54 Role::System => "system".to_owned(),
55 Role::Tool => "tool".to_owned(),
56 }
57 }
58}
59
60#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
61#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
62pub enum Model {
63 #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo")]
64 ThreePointFiveTurbo,
65 #[serde(rename = "gpt-4", alias = "gpt-4")]
66 Four,
67 #[serde(rename = "gpt-4-turbo", alias = "gpt-4-turbo")]
68 FourTurbo,
69 #[serde(rename = "gpt-4o", alias = "gpt-4o")]
70 #[default]
71 FourOmni,
72 #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini")]
73 FourOmniMini,
74 #[serde(rename = "gpt-4.1", alias = "gpt-4.1")]
75 FourPointOne,
76 #[serde(rename = "gpt-4.1-mini", alias = "gpt-4.1-mini")]
77 FourPointOneMini,
78 #[serde(rename = "gpt-4.1-nano", alias = "gpt-4.1-nano")]
79 FourPointOneNano,
80 #[serde(rename = "o1", alias = "o1")]
81 O1,
82 #[serde(rename = "o1-preview", alias = "o1-preview")]
83 O1Preview,
84 #[serde(rename = "o1-mini", alias = "o1-mini")]
85 O1Mini,
86 #[serde(rename = "o3-mini", alias = "o3-mini")]
87 O3Mini,
88
89 #[serde(rename = "custom")]
90 Custom {
91 name: String,
92 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
93 display_name: Option<String>,
94 max_tokens: usize,
95 max_output_tokens: Option<u32>,
96 max_completion_tokens: Option<u32>,
97 },
98}
99
100impl Model {
101 pub fn from_id(id: &str) -> Result<Self> {
102 match id {
103 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
104 "gpt-4" => Ok(Self::Four),
105 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
106 "gpt-4o" => Ok(Self::FourOmni),
107 "gpt-4o-mini" => Ok(Self::FourOmniMini),
108 "gpt-4.1" => Ok(Self::FourPointOne),
109 "gpt-4.1-mini" => Ok(Self::FourPointOneMini),
110 "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
111 "o1" => Ok(Self::O1),
112 "o1-preview" => Ok(Self::O1Preview),
113 "o1-mini" => Ok(Self::O1Mini),
114 "o3-mini" => Ok(Self::O3Mini),
115 _ => Err(anyhow!("invalid model id")),
116 }
117 }
118
119 pub fn id(&self) -> &str {
120 match self {
121 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
122 Self::Four => "gpt-4",
123 Self::FourTurbo => "gpt-4-turbo",
124 Self::FourOmni => "gpt-4o",
125 Self::FourOmniMini => "gpt-4o-mini",
126 Self::FourPointOne => "gpt-4.1",
127 Self::FourPointOneMini => "gpt-4.1-mini",
128 Self::FourPointOneNano => "gpt-4.1-nano",
129 Self::O1 => "o1",
130 Self::O1Preview => "o1-preview",
131 Self::O1Mini => "o1-mini",
132 Self::O3Mini => "o3-mini",
133 Self::Custom { name, .. } => name,
134 }
135 }
136
137 pub fn display_name(&self) -> &str {
138 match self {
139 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
140 Self::Four => "gpt-4",
141 Self::FourTurbo => "gpt-4-turbo",
142 Self::FourOmni => "gpt-4o",
143 Self::FourOmniMini => "gpt-4o-mini",
144 Self::FourPointOne => "gpt-4.1",
145 Self::FourPointOneMini => "gpt-4.1-mini",
146 Self::FourPointOneNano => "gpt-4.1-nano",
147 Self::O1 => "o1",
148 Self::O1Preview => "o1-preview",
149 Self::O1Mini => "o1-mini",
150 Self::O3Mini => "o3-mini",
151 Self::Custom {
152 name, display_name, ..
153 } => display_name.as_ref().unwrap_or(name),
154 }
155 }
156
157 pub fn max_token_count(&self) -> usize {
158 match self {
159 Self::ThreePointFiveTurbo => 16_385,
160 Self::Four => 8_192,
161 Self::FourTurbo => 128_000,
162 Self::FourOmni => 128_000,
163 Self::FourOmniMini => 128_000,
164 Self::FourPointOne => 1_047_576,
165 Self::FourPointOneMini => 1_047_576,
166 Self::FourPointOneNano => 1_047_576,
167 Self::O1 => 200_000,
168 Self::O1Preview => 128_000,
169 Self::O1Mini => 128_000,
170 Self::O3Mini => 200_000,
171 Self::Custom { max_tokens, .. } => *max_tokens,
172 }
173 }
174
175 pub fn max_output_tokens(&self) -> Option<u32> {
176 match self {
177 Self::Custom {
178 max_output_tokens, ..
179 } => *max_output_tokens,
180 _ => None,
181 }
182 }
183
184 /// Returns whether the given model supports the `parallel_tool_calls` parameter.
185 ///
186 /// If the model does not support the parameter, do not pass it up, or the API will return an error.
187 pub fn supports_parallel_tool_calls(&self) -> bool {
188 match self {
189 Self::ThreePointFiveTurbo
190 | Self::Four
191 | Self::FourTurbo
192 | Self::FourOmni
193 | Self::FourOmniMini
194 | Self::FourPointOne
195 | Self::FourPointOneMini
196 | Self::FourPointOneNano
197 | Self::O1
198 | Self::O1Preview
199 | Self::O1Mini => true,
200 _ => false,
201 }
202 }
203}
204
205#[derive(Debug, Serialize, Deserialize)]
206pub struct Request {
207 pub model: String,
208 pub messages: Vec<RequestMessage>,
209 pub stream: bool,
210 #[serde(default, skip_serializing_if = "Option::is_none")]
211 pub max_tokens: Option<u32>,
212 #[serde(default, skip_serializing_if = "Vec::is_empty")]
213 pub stop: Vec<String>,
214 pub temperature: f32,
215 #[serde(default, skip_serializing_if = "Option::is_none")]
216 pub tool_choice: Option<ToolChoice>,
217 /// Whether to enable parallel function calling during tool use.
218 #[serde(default, skip_serializing_if = "Option::is_none")]
219 pub parallel_tool_calls: Option<bool>,
220 #[serde(default, skip_serializing_if = "Vec::is_empty")]
221 pub tools: Vec<ToolDefinition>,
222}
223
224#[derive(Debug, Serialize, Deserialize)]
225pub struct CompletionRequest {
226 pub model: String,
227 pub prompt: String,
228 pub max_tokens: u32,
229 pub temperature: f32,
230 #[serde(default, skip_serializing_if = "Option::is_none")]
231 pub prediction: Option<Prediction>,
232 #[serde(default, skip_serializing_if = "Option::is_none")]
233 pub rewrite_speculation: Option<bool>,
234}
235
236#[derive(Clone, Deserialize, Serialize, Debug)]
237#[serde(tag = "type", rename_all = "snake_case")]
238pub enum Prediction {
239 Content { content: String },
240}
241
242#[derive(Debug, Serialize, Deserialize)]
243#[serde(untagged)]
244pub enum ToolChoice {
245 Auto,
246 Required,
247 None,
248 Other(ToolDefinition),
249}
250
251#[derive(Clone, Deserialize, Serialize, Debug)]
252#[serde(tag = "type", rename_all = "snake_case")]
253pub enum ToolDefinition {
254 #[allow(dead_code)]
255 Function { function: FunctionDefinition },
256}
257
258#[derive(Clone, Debug, Serialize, Deserialize)]
259pub struct FunctionDefinition {
260 pub name: String,
261 pub description: Option<String>,
262 pub parameters: Option<Value>,
263}
264
265#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
266#[serde(tag = "role", rename_all = "lowercase")]
267pub enum RequestMessage {
268 Assistant {
269 content: Option<String>,
270 #[serde(default, skip_serializing_if = "Vec::is_empty")]
271 tool_calls: Vec<ToolCall>,
272 },
273 User {
274 content: String,
275 },
276 System {
277 content: String,
278 },
279 Tool {
280 content: String,
281 tool_call_id: String,
282 },
283}
284
285#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
286pub struct ToolCall {
287 pub id: String,
288 #[serde(flatten)]
289 pub content: ToolCallContent,
290}
291
292#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
293#[serde(tag = "type", rename_all = "lowercase")]
294pub enum ToolCallContent {
295 Function { function: FunctionContent },
296}
297
298#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
299pub struct FunctionContent {
300 pub name: String,
301 pub arguments: String,
302}
303
304#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
305pub struct ResponseMessageDelta {
306 pub role: Option<Role>,
307 pub content: Option<String>,
308 #[serde(default, skip_serializing_if = "is_none_or_empty")]
309 pub tool_calls: Option<Vec<ToolCallChunk>>,
310}
311
312#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
313pub struct ToolCallChunk {
314 pub index: usize,
315 pub id: Option<String>,
316
317 // There is also an optional `type` field that would determine if a
318 // function is there. Sometimes this streams in with the `function` before
319 // it streams in the `type`
320 pub function: Option<FunctionChunk>,
321}
322
323#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
324pub struct FunctionChunk {
325 pub name: Option<String>,
326 pub arguments: Option<String>,
327}
328
329#[derive(Serialize, Deserialize, Debug)]
330pub struct Usage {
331 pub prompt_tokens: u32,
332 pub completion_tokens: u32,
333 pub total_tokens: u32,
334}
335
336#[derive(Serialize, Deserialize, Debug)]
337pub struct ChoiceDelta {
338 pub index: u32,
339 pub delta: ResponseMessageDelta,
340 pub finish_reason: Option<String>,
341}
342
343#[derive(Serialize, Deserialize, Debug)]
344#[serde(untagged)]
345pub enum ResponseStreamResult {
346 Ok(ResponseStreamEvent),
347 Err { error: String },
348}
349
350#[derive(Serialize, Deserialize, Debug)]
351pub struct ResponseStreamEvent {
352 pub created: u32,
353 pub model: String,
354 pub choices: Vec<ChoiceDelta>,
355 pub usage: Option<Usage>,
356}
357
358#[derive(Serialize, Deserialize, Debug)]
359pub struct CompletionResponse {
360 pub id: String,
361 pub object: String,
362 pub created: u64,
363 pub model: String,
364 pub choices: Vec<CompletionChoice>,
365 pub usage: Usage,
366}
367
368#[derive(Serialize, Deserialize, Debug)]
369pub struct CompletionChoice {
370 pub text: String,
371}
372
373#[derive(Serialize, Deserialize, Debug)]
374pub struct Response {
375 pub id: String,
376 pub object: String,
377 pub created: u64,
378 pub model: String,
379 pub choices: Vec<Choice>,
380 pub usage: Usage,
381}
382
383#[derive(Serialize, Deserialize, Debug)]
384pub struct Choice {
385 pub index: u32,
386 pub message: RequestMessage,
387 pub finish_reason: Option<String>,
388}
389
390pub async fn complete(
391 client: &dyn HttpClient,
392 api_url: &str,
393 api_key: &str,
394 request: Request,
395) -> Result<Response> {
396 let uri = format!("{api_url}/chat/completions");
397 let request_builder = HttpRequest::builder()
398 .method(Method::POST)
399 .uri(uri)
400 .header("Content-Type", "application/json")
401 .header("Authorization", format!("Bearer {}", api_key));
402
403 let mut request_body = request;
404 request_body.stream = false;
405
406 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
407 let mut response = client.send(request).await?;
408
409 if response.status().is_success() {
410 let mut body = String::new();
411 response.body_mut().read_to_string(&mut body).await?;
412 let response: Response = serde_json::from_str(&body)?;
413 Ok(response)
414 } else {
415 let mut body = String::new();
416 response.body_mut().read_to_string(&mut body).await?;
417
418 #[derive(Deserialize)]
419 struct OpenAiResponse {
420 error: OpenAiError,
421 }
422
423 #[derive(Deserialize)]
424 struct OpenAiError {
425 message: String,
426 }
427
428 match serde_json::from_str::<OpenAiResponse>(&body) {
429 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
430 "Failed to connect to OpenAI API: {}",
431 response.error.message,
432 )),
433
434 _ => Err(anyhow!(
435 "Failed to connect to OpenAI API: {} {}",
436 response.status(),
437 body,
438 )),
439 }
440 }
441}
442
443pub async fn complete_text(
444 client: &dyn HttpClient,
445 api_url: &str,
446 api_key: &str,
447 request: CompletionRequest,
448) -> Result<CompletionResponse> {
449 let uri = format!("{api_url}/completions");
450 let request_builder = HttpRequest::builder()
451 .method(Method::POST)
452 .uri(uri)
453 .header("Content-Type", "application/json")
454 .header("Authorization", format!("Bearer {}", api_key));
455
456 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
457 let mut response = client.send(request).await?;
458
459 if response.status().is_success() {
460 let mut body = String::new();
461 response.body_mut().read_to_string(&mut body).await?;
462 let response = serde_json::from_str(&body)?;
463 Ok(response)
464 } else {
465 let mut body = String::new();
466 response.body_mut().read_to_string(&mut body).await?;
467
468 #[derive(Deserialize)]
469 struct OpenAiResponse {
470 error: OpenAiError,
471 }
472
473 #[derive(Deserialize)]
474 struct OpenAiError {
475 message: String,
476 }
477
478 match serde_json::from_str::<OpenAiResponse>(&body) {
479 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
480 "Failed to connect to OpenAI API: {}",
481 response.error.message,
482 )),
483
484 _ => Err(anyhow!(
485 "Failed to connect to OpenAI API: {} {}",
486 response.status(),
487 body,
488 )),
489 }
490 }
491}
492
493fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
494 ResponseStreamEvent {
495 created: response.created as u32,
496 model: response.model,
497 choices: response
498 .choices
499 .into_iter()
500 .map(|choice| ChoiceDelta {
501 index: choice.index,
502 delta: ResponseMessageDelta {
503 role: Some(match choice.message {
504 RequestMessage::Assistant { .. } => Role::Assistant,
505 RequestMessage::User { .. } => Role::User,
506 RequestMessage::System { .. } => Role::System,
507 RequestMessage::Tool { .. } => Role::Tool,
508 }),
509 content: match choice.message {
510 RequestMessage::Assistant { content, .. } => content,
511 RequestMessage::User { content } => Some(content),
512 RequestMessage::System { content } => Some(content),
513 RequestMessage::Tool { content, .. } => Some(content),
514 },
515 tool_calls: None,
516 },
517 finish_reason: choice.finish_reason,
518 })
519 .collect(),
520 usage: Some(response.usage),
521 }
522}
523
524pub async fn stream_completion(
525 client: &dyn HttpClient,
526 api_url: &str,
527 api_key: &str,
528 request: Request,
529) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
530 if request.model.starts_with("o1") {
531 let response = complete(client, api_url, api_key, request).await;
532 let response_stream_event = response.map(adapt_response_to_stream);
533 return Ok(stream::once(future::ready(response_stream_event)).boxed());
534 }
535
536 let uri = format!("{api_url}/chat/completions");
537 let request_builder = HttpRequest::builder()
538 .method(Method::POST)
539 .uri(uri)
540 .header("Content-Type", "application/json")
541 .header("Authorization", format!("Bearer {}", api_key));
542
543 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
544 let mut response = client.send(request).await?;
545 if response.status().is_success() {
546 let reader = BufReader::new(response.into_body());
547 Ok(reader
548 .lines()
549 .filter_map(|line| async move {
550 match line {
551 Ok(line) => {
552 let line = line.strip_prefix("data: ")?;
553 if line == "[DONE]" {
554 None
555 } else {
556 match serde_json::from_str(line) {
557 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
558 Ok(ResponseStreamResult::Err { error }) => {
559 Some(Err(anyhow!(error)))
560 }
561 Err(error) => Some(Err(anyhow!(error))),
562 }
563 }
564 }
565 Err(error) => Some(Err(anyhow!(error))),
566 }
567 })
568 .boxed())
569 } else {
570 let mut body = String::new();
571 response.body_mut().read_to_string(&mut body).await?;
572
573 #[derive(Deserialize)]
574 struct OpenAiResponse {
575 error: OpenAiError,
576 }
577
578 #[derive(Deserialize)]
579 struct OpenAiError {
580 message: String,
581 }
582
583 match serde_json::from_str::<OpenAiResponse>(&body) {
584 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
585 "Failed to connect to OpenAI API: {}",
586 response.error.message,
587 )),
588
589 _ => Err(anyhow!(
590 "Failed to connect to OpenAI API: {} {}",
591 response.status(),
592 body,
593 )),
594 }
595 }
596}
597
598#[derive(Copy, Clone, Serialize, Deserialize)]
599pub enum OpenAiEmbeddingModel {
600 #[serde(rename = "text-embedding-3-small")]
601 TextEmbedding3Small,
602 #[serde(rename = "text-embedding-3-large")]
603 TextEmbedding3Large,
604}
605
606#[derive(Serialize)]
607struct OpenAiEmbeddingRequest<'a> {
608 model: OpenAiEmbeddingModel,
609 input: Vec<&'a str>,
610}
611
612#[derive(Deserialize)]
613pub struct OpenAiEmbeddingResponse {
614 pub data: Vec<OpenAiEmbedding>,
615}
616
617#[derive(Deserialize)]
618pub struct OpenAiEmbedding {
619 pub embedding: Vec<f32>,
620}
621
622pub fn embed<'a>(
623 client: &dyn HttpClient,
624 api_url: &str,
625 api_key: &str,
626 model: OpenAiEmbeddingModel,
627 texts: impl IntoIterator<Item = &'a str>,
628) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
629 let uri = format!("{api_url}/embeddings");
630
631 let request = OpenAiEmbeddingRequest {
632 model,
633 input: texts.into_iter().collect(),
634 };
635 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
636 let request = HttpRequest::builder()
637 .method(Method::POST)
638 .uri(uri)
639 .header("Content-Type", "application/json")
640 .header("Authorization", format!("Bearer {}", api_key))
641 .body(body)
642 .map(|request| client.send(request));
643
644 async move {
645 let mut response = request?.await?;
646 let mut body = String::new();
647 response.body_mut().read_to_string(&mut body).await?;
648
649 if response.status().is_success() {
650 let response: OpenAiEmbeddingResponse =
651 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
652 Ok(response)
653 } else {
654 Err(anyhow!(
655 "error during embedding, status: {:?}, body: {:?}",
656 response.status(),
657 body
658 ))
659 }
660 }
661}