1mod supported_countries;
2
3use anyhow::{Context as _, Result, anyhow};
4use futures::{
5 AsyncBufReadExt, AsyncReadExt, StreamExt,
6 io::BufReader,
7 stream::{self, BoxStream},
8};
9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::{
13 convert::TryFrom,
14 future::{self, Future},
15};
16use strum::EnumIter;
17
18pub use supported_countries::*;
19
20pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
21
22fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
23 opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
24}
25
26#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
27#[serde(rename_all = "lowercase")]
28pub enum Role {
29 User,
30 Assistant,
31 System,
32 Tool,
33}
34
35impl TryFrom<String> for Role {
36 type Error = anyhow::Error;
37
38 fn try_from(value: String) -> Result<Self> {
39 match value.as_str() {
40 "user" => Ok(Self::User),
41 "assistant" => Ok(Self::Assistant),
42 "system" => Ok(Self::System),
43 "tool" => Ok(Self::Tool),
44 _ => Err(anyhow!("invalid role '{value}'")),
45 }
46 }
47}
48
49impl From<Role> for String {
50 fn from(val: Role) -> Self {
51 match val {
52 Role::User => "user".to_owned(),
53 Role::Assistant => "assistant".to_owned(),
54 Role::System => "system".to_owned(),
55 Role::Tool => "tool".to_owned(),
56 }
57 }
58}
59
60#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
61#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
62pub enum Model {
63 #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo")]
64 ThreePointFiveTurbo,
65 #[serde(rename = "gpt-4", alias = "gpt-4")]
66 Four,
67 #[serde(rename = "gpt-4-turbo", alias = "gpt-4-turbo")]
68 FourTurbo,
69 #[serde(rename = "gpt-4o", alias = "gpt-4o")]
70 #[default]
71 FourOmni,
72 #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini")]
73 FourOmniMini,
74 #[serde(rename = "o1", alias = "o1")]
75 O1,
76 #[serde(rename = "o1-preview", alias = "o1-preview")]
77 O1Preview,
78 #[serde(rename = "o1-mini", alias = "o1-mini")]
79 O1Mini,
80 #[serde(rename = "o3-mini", alias = "o3-mini")]
81 O3Mini,
82
83 #[serde(rename = "custom")]
84 Custom {
85 name: String,
86 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
87 display_name: Option<String>,
88 max_tokens: usize,
89 max_output_tokens: Option<u32>,
90 max_completion_tokens: Option<u32>,
91 },
92}
93
94impl Model {
95 pub fn from_id(id: &str) -> Result<Self> {
96 match id {
97 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
98 "gpt-4" => Ok(Self::Four),
99 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
100 "gpt-4o" => Ok(Self::FourOmni),
101 "gpt-4o-mini" => Ok(Self::FourOmniMini),
102 "o1" => Ok(Self::O1),
103 "o1-preview" => Ok(Self::O1Preview),
104 "o1-mini" => Ok(Self::O1Mini),
105 "o3-mini" => Ok(Self::O3Mini),
106 _ => Err(anyhow!("invalid model id")),
107 }
108 }
109
110 pub fn id(&self) -> &str {
111 match self {
112 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
113 Self::Four => "gpt-4",
114 Self::FourTurbo => "gpt-4-turbo",
115 Self::FourOmni => "gpt-4o",
116 Self::FourOmniMini => "gpt-4o-mini",
117 Self::O1 => "o1",
118 Self::O1Preview => "o1-preview",
119 Self::O1Mini => "o1-mini",
120 Self::O3Mini => "o3-mini",
121 Self::Custom { name, .. } => name,
122 }
123 }
124
125 pub fn display_name(&self) -> &str {
126 match self {
127 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
128 Self::Four => "gpt-4",
129 Self::FourTurbo => "gpt-4-turbo",
130 Self::FourOmni => "gpt-4o",
131 Self::FourOmniMini => "gpt-4o-mini",
132 Self::O1 => "o1",
133 Self::O1Preview => "o1-preview",
134 Self::O1Mini => "o1-mini",
135 Self::O3Mini => "o3-mini",
136 Self::Custom {
137 name, display_name, ..
138 } => display_name.as_ref().unwrap_or(name),
139 }
140 }
141
142 pub fn max_token_count(&self) -> usize {
143 match self {
144 Self::ThreePointFiveTurbo => 16_385,
145 Self::Four => 8_192,
146 Self::FourTurbo => 128_000,
147 Self::FourOmni => 128_000,
148 Self::FourOmniMini => 128_000,
149 Self::O1 => 200_000,
150 Self::O1Preview => 128_000,
151 Self::O1Mini => 128_000,
152 Self::O3Mini => 200_000,
153 Self::Custom { max_tokens, .. } => *max_tokens,
154 }
155 }
156
157 pub fn max_output_tokens(&self) -> Option<u32> {
158 match self {
159 Self::Custom {
160 max_output_tokens, ..
161 } => *max_output_tokens,
162 _ => None,
163 }
164 }
165
166 /// Returns whether the given model supports the `parallel_tool_calls` parameter.
167 ///
168 /// If the model does not support the parameter, do not pass it up, or the API will return an error.
169 pub fn supports_parallel_tool_calls(&self) -> bool {
170 match self {
171 Self::ThreePointFiveTurbo
172 | Self::Four
173 | Self::FourTurbo
174 | Self::FourOmni
175 | Self::FourOmniMini
176 | Self::O1
177 | Self::O1Preview
178 | Self::O1Mini => true,
179 _ => false,
180 }
181 }
182}
183
184#[derive(Debug, Serialize, Deserialize)]
185pub struct Request {
186 pub model: String,
187 pub messages: Vec<RequestMessage>,
188 pub stream: bool,
189 #[serde(default, skip_serializing_if = "Option::is_none")]
190 pub max_tokens: Option<u32>,
191 #[serde(default, skip_serializing_if = "Vec::is_empty")]
192 pub stop: Vec<String>,
193 pub temperature: f32,
194 #[serde(default, skip_serializing_if = "Option::is_none")]
195 pub tool_choice: Option<ToolChoice>,
196 /// Whether to enable parallel function calling during tool use.
197 #[serde(default, skip_serializing_if = "Option::is_none")]
198 pub parallel_tool_calls: Option<bool>,
199 #[serde(default, skip_serializing_if = "Vec::is_empty")]
200 pub tools: Vec<ToolDefinition>,
201}
202
203#[derive(Debug, Serialize, Deserialize)]
204pub struct CompletionRequest {
205 pub model: String,
206 pub prompt: String,
207 pub max_tokens: u32,
208 pub temperature: f32,
209 #[serde(default, skip_serializing_if = "Option::is_none")]
210 pub prediction: Option<Prediction>,
211 #[serde(default, skip_serializing_if = "Option::is_none")]
212 pub rewrite_speculation: Option<bool>,
213}
214
215#[derive(Clone, Deserialize, Serialize, Debug)]
216#[serde(tag = "type", rename_all = "snake_case")]
217pub enum Prediction {
218 Content { content: String },
219}
220
221#[derive(Debug, Serialize, Deserialize)]
222#[serde(untagged)]
223pub enum ToolChoice {
224 Auto,
225 Required,
226 None,
227 Other(ToolDefinition),
228}
229
230#[derive(Clone, Deserialize, Serialize, Debug)]
231#[serde(tag = "type", rename_all = "snake_case")]
232pub enum ToolDefinition {
233 #[allow(dead_code)]
234 Function { function: FunctionDefinition },
235}
236
237#[derive(Clone, Debug, Serialize, Deserialize)]
238pub struct FunctionDefinition {
239 pub name: String,
240 pub description: Option<String>,
241 pub parameters: Option<Value>,
242}
243
244#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
245#[serde(tag = "role", rename_all = "lowercase")]
246pub enum RequestMessage {
247 Assistant {
248 content: Option<String>,
249 #[serde(default, skip_serializing_if = "Vec::is_empty")]
250 tool_calls: Vec<ToolCall>,
251 },
252 User {
253 content: String,
254 },
255 System {
256 content: String,
257 },
258 Tool {
259 content: String,
260 tool_call_id: String,
261 },
262}
263
264#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
265pub struct ToolCall {
266 pub id: String,
267 #[serde(flatten)]
268 pub content: ToolCallContent,
269}
270
271#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
272#[serde(tag = "type", rename_all = "lowercase")]
273pub enum ToolCallContent {
274 Function { function: FunctionContent },
275}
276
277#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
278pub struct FunctionContent {
279 pub name: String,
280 pub arguments: String,
281}
282
283#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
284pub struct ResponseMessageDelta {
285 pub role: Option<Role>,
286 pub content: Option<String>,
287 #[serde(default, skip_serializing_if = "is_none_or_empty")]
288 pub tool_calls: Option<Vec<ToolCallChunk>>,
289}
290
291#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
292pub struct ToolCallChunk {
293 pub index: usize,
294 pub id: Option<String>,
295
296 // There is also an optional `type` field that would determine if a
297 // function is there. Sometimes this streams in with the `function` before
298 // it streams in the `type`
299 pub function: Option<FunctionChunk>,
300}
301
302#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
303pub struct FunctionChunk {
304 pub name: Option<String>,
305 pub arguments: Option<String>,
306}
307
308#[derive(Serialize, Deserialize, Debug)]
309pub struct Usage {
310 pub prompt_tokens: u32,
311 pub completion_tokens: u32,
312 pub total_tokens: u32,
313}
314
315#[derive(Serialize, Deserialize, Debug)]
316pub struct ChoiceDelta {
317 pub index: u32,
318 pub delta: ResponseMessageDelta,
319 pub finish_reason: Option<String>,
320}
321
322#[derive(Serialize, Deserialize, Debug)]
323#[serde(untagged)]
324pub enum ResponseStreamResult {
325 Ok(ResponseStreamEvent),
326 Err { error: String },
327}
328
329#[derive(Serialize, Deserialize, Debug)]
330pub struct ResponseStreamEvent {
331 pub created: u32,
332 pub model: String,
333 pub choices: Vec<ChoiceDelta>,
334 pub usage: Option<Usage>,
335}
336
337#[derive(Serialize, Deserialize, Debug)]
338pub struct CompletionResponse {
339 pub id: String,
340 pub object: String,
341 pub created: u64,
342 pub model: String,
343 pub choices: Vec<CompletionChoice>,
344 pub usage: Usage,
345}
346
347#[derive(Serialize, Deserialize, Debug)]
348pub struct CompletionChoice {
349 pub text: String,
350}
351
352#[derive(Serialize, Deserialize, Debug)]
353pub struct Response {
354 pub id: String,
355 pub object: String,
356 pub created: u64,
357 pub model: String,
358 pub choices: Vec<Choice>,
359 pub usage: Usage,
360}
361
362#[derive(Serialize, Deserialize, Debug)]
363pub struct Choice {
364 pub index: u32,
365 pub message: RequestMessage,
366 pub finish_reason: Option<String>,
367}
368
369pub async fn complete(
370 client: &dyn HttpClient,
371 api_url: &str,
372 api_key: &str,
373 request: Request,
374) -> Result<Response> {
375 let uri = format!("{api_url}/chat/completions");
376 let request_builder = HttpRequest::builder()
377 .method(Method::POST)
378 .uri(uri)
379 .header("Content-Type", "application/json")
380 .header("Authorization", format!("Bearer {}", api_key));
381
382 let mut request_body = request;
383 request_body.stream = false;
384
385 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
386 let mut response = client.send(request).await?;
387
388 if response.status().is_success() {
389 let mut body = String::new();
390 response.body_mut().read_to_string(&mut body).await?;
391 let response: Response = serde_json::from_str(&body)?;
392 Ok(response)
393 } else {
394 let mut body = String::new();
395 response.body_mut().read_to_string(&mut body).await?;
396
397 #[derive(Deserialize)]
398 struct OpenAiResponse {
399 error: OpenAiError,
400 }
401
402 #[derive(Deserialize)]
403 struct OpenAiError {
404 message: String,
405 }
406
407 match serde_json::from_str::<OpenAiResponse>(&body) {
408 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
409 "Failed to connect to OpenAI API: {}",
410 response.error.message,
411 )),
412
413 _ => Err(anyhow!(
414 "Failed to connect to OpenAI API: {} {}",
415 response.status(),
416 body,
417 )),
418 }
419 }
420}
421
422pub async fn complete_text(
423 client: &dyn HttpClient,
424 api_url: &str,
425 api_key: &str,
426 request: CompletionRequest,
427) -> Result<CompletionResponse> {
428 let uri = format!("{api_url}/completions");
429 let request_builder = HttpRequest::builder()
430 .method(Method::POST)
431 .uri(uri)
432 .header("Content-Type", "application/json")
433 .header("Authorization", format!("Bearer {}", api_key));
434
435 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
436 let mut response = client.send(request).await?;
437
438 if response.status().is_success() {
439 let mut body = String::new();
440 response.body_mut().read_to_string(&mut body).await?;
441 let response = serde_json::from_str(&body)?;
442 Ok(response)
443 } else {
444 let mut body = String::new();
445 response.body_mut().read_to_string(&mut body).await?;
446
447 #[derive(Deserialize)]
448 struct OpenAiResponse {
449 error: OpenAiError,
450 }
451
452 #[derive(Deserialize)]
453 struct OpenAiError {
454 message: String,
455 }
456
457 match serde_json::from_str::<OpenAiResponse>(&body) {
458 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
459 "Failed to connect to OpenAI API: {}",
460 response.error.message,
461 )),
462
463 _ => Err(anyhow!(
464 "Failed to connect to OpenAI API: {} {}",
465 response.status(),
466 body,
467 )),
468 }
469 }
470}
471
472fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
473 ResponseStreamEvent {
474 created: response.created as u32,
475 model: response.model,
476 choices: response
477 .choices
478 .into_iter()
479 .map(|choice| ChoiceDelta {
480 index: choice.index,
481 delta: ResponseMessageDelta {
482 role: Some(match choice.message {
483 RequestMessage::Assistant { .. } => Role::Assistant,
484 RequestMessage::User { .. } => Role::User,
485 RequestMessage::System { .. } => Role::System,
486 RequestMessage::Tool { .. } => Role::Tool,
487 }),
488 content: match choice.message {
489 RequestMessage::Assistant { content, .. } => content,
490 RequestMessage::User { content } => Some(content),
491 RequestMessage::System { content } => Some(content),
492 RequestMessage::Tool { content, .. } => Some(content),
493 },
494 tool_calls: None,
495 },
496 finish_reason: choice.finish_reason,
497 })
498 .collect(),
499 usage: Some(response.usage),
500 }
501}
502
503pub async fn stream_completion(
504 client: &dyn HttpClient,
505 api_url: &str,
506 api_key: &str,
507 request: Request,
508) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
509 if request.model.starts_with("o1") {
510 let response = complete(client, api_url, api_key, request).await;
511 let response_stream_event = response.map(adapt_response_to_stream);
512 return Ok(stream::once(future::ready(response_stream_event)).boxed());
513 }
514
515 let uri = format!("{api_url}/chat/completions");
516 let request_builder = HttpRequest::builder()
517 .method(Method::POST)
518 .uri(uri)
519 .header("Content-Type", "application/json")
520 .header("Authorization", format!("Bearer {}", api_key));
521
522 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
523 let mut response = client.send(request).await?;
524 if response.status().is_success() {
525 let reader = BufReader::new(response.into_body());
526 Ok(reader
527 .lines()
528 .filter_map(|line| async move {
529 match line {
530 Ok(line) => {
531 let line = line.strip_prefix("data: ")?;
532 if line == "[DONE]" {
533 None
534 } else {
535 match serde_json::from_str(line) {
536 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
537 Ok(ResponseStreamResult::Err { error }) => {
538 Some(Err(anyhow!(error)))
539 }
540 Err(error) => Some(Err(anyhow!(error))),
541 }
542 }
543 }
544 Err(error) => Some(Err(anyhow!(error))),
545 }
546 })
547 .boxed())
548 } else {
549 let mut body = String::new();
550 response.body_mut().read_to_string(&mut body).await?;
551
552 #[derive(Deserialize)]
553 struct OpenAiResponse {
554 error: OpenAiError,
555 }
556
557 #[derive(Deserialize)]
558 struct OpenAiError {
559 message: String,
560 }
561
562 match serde_json::from_str::<OpenAiResponse>(&body) {
563 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
564 "Failed to connect to OpenAI API: {}",
565 response.error.message,
566 )),
567
568 _ => Err(anyhow!(
569 "Failed to connect to OpenAI API: {} {}",
570 response.status(),
571 body,
572 )),
573 }
574 }
575}
576
577#[derive(Copy, Clone, Serialize, Deserialize)]
578pub enum OpenAiEmbeddingModel {
579 #[serde(rename = "text-embedding-3-small")]
580 TextEmbedding3Small,
581 #[serde(rename = "text-embedding-3-large")]
582 TextEmbedding3Large,
583}
584
585#[derive(Serialize)]
586struct OpenAiEmbeddingRequest<'a> {
587 model: OpenAiEmbeddingModel,
588 input: Vec<&'a str>,
589}
590
591#[derive(Deserialize)]
592pub struct OpenAiEmbeddingResponse {
593 pub data: Vec<OpenAiEmbedding>,
594}
595
596#[derive(Deserialize)]
597pub struct OpenAiEmbedding {
598 pub embedding: Vec<f32>,
599}
600
601pub fn embed<'a>(
602 client: &dyn HttpClient,
603 api_url: &str,
604 api_key: &str,
605 model: OpenAiEmbeddingModel,
606 texts: impl IntoIterator<Item = &'a str>,
607) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
608 let uri = format!("{api_url}/embeddings");
609
610 let request = OpenAiEmbeddingRequest {
611 model,
612 input: texts.into_iter().collect(),
613 };
614 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
615 let request = HttpRequest::builder()
616 .method(Method::POST)
617 .uri(uri)
618 .header("Content-Type", "application/json")
619 .header("Authorization", format!("Bearer {}", api_key))
620 .body(body)
621 .map(|request| client.send(request));
622
623 async move {
624 let mut response = request?.await?;
625 let mut body = String::new();
626 response.body_mut().read_to_string(&mut body).await?;
627
628 if response.status().is_success() {
629 let response: OpenAiEmbeddingResponse =
630 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
631 Ok(response)
632 } else {
633 Err(anyhow!(
634 "error during embedding, status: {:?}, body: {:?}",
635 response.status(),
636 body
637 ))
638 }
639 }
640}