open_ai.rs

  1mod supported_countries;
  2
  3use anyhow::{Context as _, Result, anyhow};
  4use futures::{
  5    AsyncBufReadExt, AsyncReadExt, StreamExt,
  6    io::BufReader,
  7    stream::{self, BoxStream},
  8};
  9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 10use serde::{Deserialize, Serialize};
 11use serde_json::Value;
 12use std::{
 13    convert::TryFrom,
 14    future::{self, Future},
 15};
 16use strum::EnumIter;
 17
 18pub use supported_countries::*;
 19
 20pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 21
 22fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 23    opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
 24}
 25
 26#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 27#[serde(rename_all = "lowercase")]
 28pub enum Role {
 29    User,
 30    Assistant,
 31    System,
 32    Tool,
 33}
 34
 35impl TryFrom<String> for Role {
 36    type Error = anyhow::Error;
 37
 38    fn try_from(value: String) -> Result<Self> {
 39        match value.as_str() {
 40            "user" => Ok(Self::User),
 41            "assistant" => Ok(Self::Assistant),
 42            "system" => Ok(Self::System),
 43            "tool" => Ok(Self::Tool),
 44            _ => Err(anyhow!("invalid role '{value}'")),
 45        }
 46    }
 47}
 48
 49impl From<Role> for String {
 50    fn from(val: Role) -> Self {
 51        match val {
 52            Role::User => "user".to_owned(),
 53            Role::Assistant => "assistant".to_owned(),
 54            Role::System => "system".to_owned(),
 55            Role::Tool => "tool".to_owned(),
 56        }
 57    }
 58}
 59
 60#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 61#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 62pub enum Model {
 63    #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo")]
 64    ThreePointFiveTurbo,
 65    #[serde(rename = "gpt-4", alias = "gpt-4")]
 66    Four,
 67    #[serde(rename = "gpt-4-turbo", alias = "gpt-4-turbo")]
 68    FourTurbo,
 69    #[serde(rename = "gpt-4o", alias = "gpt-4o")]
 70    #[default]
 71    FourOmni,
 72    #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini")]
 73    FourOmniMini,
 74    #[serde(rename = "gpt-4.1", alias = "gpt-4.1")]
 75    FourPointOne,
 76    #[serde(rename = "gpt-4.1-mini", alias = "gpt-4.1-mini")]
 77    FourPointOneMini,
 78    #[serde(rename = "gpt-4.1-nano", alias = "gpt-4.1-nano")]
 79    FourPointOneNano,
 80    #[serde(rename = "o1", alias = "o1")]
 81    O1,
 82    #[serde(rename = "o1-preview", alias = "o1-preview")]
 83    O1Preview,
 84    #[serde(rename = "o1-mini", alias = "o1-mini")]
 85    O1Mini,
 86    #[serde(rename = "o3-mini", alias = "o3-mini")]
 87    O3Mini,
 88    #[serde(rename = "o3", alias = "o3")]
 89    O3,
 90    #[serde(rename = "o4-mini", alias = "o4-mini")]
 91    O4Mini,
 92
 93    #[serde(rename = "custom")]
 94    Custom {
 95        name: String,
 96        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 97        display_name: Option<String>,
 98        max_tokens: usize,
 99        max_output_tokens: Option<u32>,
100        max_completion_tokens: Option<u32>,
101    },
102}
103
104impl Model {
105    pub fn default_fast() -> Self {
106        Self::FourPointOneMini
107    }
108
109    pub fn from_id(id: &str) -> Result<Self> {
110        match id {
111            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
112            "gpt-4" => Ok(Self::Four),
113            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
114            "gpt-4o" => Ok(Self::FourOmni),
115            "gpt-4o-mini" => Ok(Self::FourOmniMini),
116            "gpt-4.1" => Ok(Self::FourPointOne),
117            "gpt-4.1-mini" => Ok(Self::FourPointOneMini),
118            "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
119            "o1" => Ok(Self::O1),
120            "o1-preview" => Ok(Self::O1Preview),
121            "o1-mini" => Ok(Self::O1Mini),
122            "o3-mini" => Ok(Self::O3Mini),
123            "o3" => Ok(Self::O3),
124            "o4-mini" => Ok(Self::O4Mini),
125            _ => Err(anyhow!("invalid model id")),
126        }
127    }
128
129    pub fn id(&self) -> &str {
130        match self {
131            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
132            Self::Four => "gpt-4",
133            Self::FourTurbo => "gpt-4-turbo",
134            Self::FourOmni => "gpt-4o",
135            Self::FourOmniMini => "gpt-4o-mini",
136            Self::FourPointOne => "gpt-4.1",
137            Self::FourPointOneMini => "gpt-4.1-mini",
138            Self::FourPointOneNano => "gpt-4.1-nano",
139            Self::O1 => "o1",
140            Self::O1Preview => "o1-preview",
141            Self::O1Mini => "o1-mini",
142            Self::O3Mini => "o3-mini",
143            Self::O3 => "o3",
144            Self::O4Mini => "o4-mini",
145            Self::Custom { name, .. } => name,
146        }
147    }
148
149    pub fn display_name(&self) -> &str {
150        match self {
151            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
152            Self::Four => "gpt-4",
153            Self::FourTurbo => "gpt-4-turbo",
154            Self::FourOmni => "gpt-4o",
155            Self::FourOmniMini => "gpt-4o-mini",
156            Self::FourPointOne => "gpt-4.1",
157            Self::FourPointOneMini => "gpt-4.1-mini",
158            Self::FourPointOneNano => "gpt-4.1-nano",
159            Self::O1 => "o1",
160            Self::O1Preview => "o1-preview",
161            Self::O1Mini => "o1-mini",
162            Self::O3Mini => "o3-mini",
163            Self::O3 => "o3",
164            Self::O4Mini => "o4-mini",
165            Self::Custom {
166                name, display_name, ..
167            } => display_name.as_ref().unwrap_or(name),
168        }
169    }
170
171    pub fn max_token_count(&self) -> usize {
172        match self {
173            Self::ThreePointFiveTurbo => 16_385,
174            Self::Four => 8_192,
175            Self::FourTurbo => 128_000,
176            Self::FourOmni => 128_000,
177            Self::FourOmniMini => 128_000,
178            Self::FourPointOne => 1_047_576,
179            Self::FourPointOneMini => 1_047_576,
180            Self::FourPointOneNano => 1_047_576,
181            Self::O1 => 200_000,
182            Self::O1Preview => 128_000,
183            Self::O1Mini => 128_000,
184            Self::O3Mini => 200_000,
185            Self::O3 => 200_000,
186            Self::O4Mini => 200_000,
187            Self::Custom { max_tokens, .. } => *max_tokens,
188        }
189    }
190
191    pub fn max_output_tokens(&self) -> Option<u32> {
192        match self {
193            Self::Custom {
194                max_output_tokens, ..
195            } => *max_output_tokens,
196            _ => None,
197        }
198    }
199
200    /// Returns whether the given model supports the `parallel_tool_calls` parameter.
201    ///
202    /// If the model does not support the parameter, do not pass it up, or the API will return an error.
203    pub fn supports_parallel_tool_calls(&self) -> bool {
204        match self {
205            Self::ThreePointFiveTurbo
206            | Self::Four
207            | Self::FourTurbo
208            | Self::FourOmni
209            | Self::FourOmniMini
210            | Self::FourPointOne
211            | Self::FourPointOneMini
212            | Self::FourPointOneNano
213            | Self::O1
214            | Self::O1Preview
215            | Self::O1Mini => true,
216            _ => false,
217        }
218    }
219}
220
221#[derive(Debug, Serialize, Deserialize)]
222pub struct Request {
223    pub model: String,
224    pub messages: Vec<RequestMessage>,
225    pub stream: bool,
226    #[serde(default, skip_serializing_if = "Option::is_none")]
227    pub max_tokens: Option<u32>,
228    #[serde(default, skip_serializing_if = "Vec::is_empty")]
229    pub stop: Vec<String>,
230    pub temperature: f32,
231    #[serde(default, skip_serializing_if = "Option::is_none")]
232    pub tool_choice: Option<ToolChoice>,
233    /// Whether to enable parallel function calling during tool use.
234    #[serde(default, skip_serializing_if = "Option::is_none")]
235    pub parallel_tool_calls: Option<bool>,
236    #[serde(default, skip_serializing_if = "Vec::is_empty")]
237    pub tools: Vec<ToolDefinition>,
238}
239
240#[derive(Debug, Serialize, Deserialize)]
241pub struct CompletionRequest {
242    pub model: String,
243    pub prompt: String,
244    pub max_tokens: u32,
245    pub temperature: f32,
246    #[serde(default, skip_serializing_if = "Option::is_none")]
247    pub prediction: Option<Prediction>,
248    #[serde(default, skip_serializing_if = "Option::is_none")]
249    pub rewrite_speculation: Option<bool>,
250}
251
252#[derive(Clone, Deserialize, Serialize, Debug)]
253#[serde(tag = "type", rename_all = "snake_case")]
254pub enum Prediction {
255    Content { content: String },
256}
257
258#[derive(Debug, Serialize, Deserialize)]
259#[serde(untagged)]
260pub enum ToolChoice {
261    Auto,
262    Required,
263    None,
264    Other(ToolDefinition),
265}
266
267#[derive(Clone, Deserialize, Serialize, Debug)]
268#[serde(tag = "type", rename_all = "snake_case")]
269pub enum ToolDefinition {
270    #[allow(dead_code)]
271    Function { function: FunctionDefinition },
272}
273
274#[derive(Clone, Debug, Serialize, Deserialize)]
275pub struct FunctionDefinition {
276    pub name: String,
277    pub description: Option<String>,
278    pub parameters: Option<Value>,
279}
280
281#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
282#[serde(tag = "role", rename_all = "lowercase")]
283pub enum RequestMessage {
284    Assistant {
285        content: Option<String>,
286        #[serde(default, skip_serializing_if = "Vec::is_empty")]
287        tool_calls: Vec<ToolCall>,
288    },
289    User {
290        content: String,
291    },
292    System {
293        content: String,
294    },
295    Tool {
296        content: String,
297        tool_call_id: String,
298    },
299}
300
301#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
302pub struct ToolCall {
303    pub id: String,
304    #[serde(flatten)]
305    pub content: ToolCallContent,
306}
307
308#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
309#[serde(tag = "type", rename_all = "lowercase")]
310pub enum ToolCallContent {
311    Function { function: FunctionContent },
312}
313
314#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
315pub struct FunctionContent {
316    pub name: String,
317    pub arguments: String,
318}
319
320#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
321pub struct ResponseMessageDelta {
322    pub role: Option<Role>,
323    pub content: Option<String>,
324    #[serde(default, skip_serializing_if = "is_none_or_empty")]
325    pub tool_calls: Option<Vec<ToolCallChunk>>,
326}
327
328#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
329pub struct ToolCallChunk {
330    pub index: usize,
331    pub id: Option<String>,
332
333    // There is also an optional `type` field that would determine if a
334    // function is there. Sometimes this streams in with the `function` before
335    // it streams in the `type`
336    pub function: Option<FunctionChunk>,
337}
338
339#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
340pub struct FunctionChunk {
341    pub name: Option<String>,
342    pub arguments: Option<String>,
343}
344
345#[derive(Serialize, Deserialize, Debug)]
346pub struct Usage {
347    pub prompt_tokens: u32,
348    pub completion_tokens: u32,
349    pub total_tokens: u32,
350}
351
352#[derive(Serialize, Deserialize, Debug)]
353pub struct ChoiceDelta {
354    pub index: u32,
355    pub delta: ResponseMessageDelta,
356    pub finish_reason: Option<String>,
357}
358
359#[derive(Serialize, Deserialize, Debug)]
360#[serde(untagged)]
361pub enum ResponseStreamResult {
362    Ok(ResponseStreamEvent),
363    Err { error: String },
364}
365
366#[derive(Serialize, Deserialize, Debug)]
367pub struct ResponseStreamEvent {
368    pub created: u32,
369    pub model: String,
370    pub choices: Vec<ChoiceDelta>,
371    pub usage: Option<Usage>,
372}
373
374#[derive(Serialize, Deserialize, Debug)]
375pub struct CompletionResponse {
376    pub id: String,
377    pub object: String,
378    pub created: u64,
379    pub model: String,
380    pub choices: Vec<CompletionChoice>,
381    pub usage: Usage,
382}
383
384#[derive(Serialize, Deserialize, Debug)]
385pub struct CompletionChoice {
386    pub text: String,
387}
388
389#[derive(Serialize, Deserialize, Debug)]
390pub struct Response {
391    pub id: String,
392    pub object: String,
393    pub created: u64,
394    pub model: String,
395    pub choices: Vec<Choice>,
396    pub usage: Usage,
397}
398
399#[derive(Serialize, Deserialize, Debug)]
400pub struct Choice {
401    pub index: u32,
402    pub message: RequestMessage,
403    pub finish_reason: Option<String>,
404}
405
406pub async fn complete(
407    client: &dyn HttpClient,
408    api_url: &str,
409    api_key: &str,
410    request: Request,
411) -> Result<Response> {
412    let uri = format!("{api_url}/chat/completions");
413    let request_builder = HttpRequest::builder()
414        .method(Method::POST)
415        .uri(uri)
416        .header("Content-Type", "application/json")
417        .header("Authorization", format!("Bearer {}", api_key));
418
419    let mut request_body = request;
420    request_body.stream = false;
421
422    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
423    let mut response = client.send(request).await?;
424
425    if response.status().is_success() {
426        let mut body = String::new();
427        response.body_mut().read_to_string(&mut body).await?;
428        let response: Response = serde_json::from_str(&body)?;
429        Ok(response)
430    } else {
431        let mut body = String::new();
432        response.body_mut().read_to_string(&mut body).await?;
433
434        #[derive(Deserialize)]
435        struct OpenAiResponse {
436            error: OpenAiError,
437        }
438
439        #[derive(Deserialize)]
440        struct OpenAiError {
441            message: String,
442        }
443
444        match serde_json::from_str::<OpenAiResponse>(&body) {
445            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
446                "Failed to connect to OpenAI API: {}",
447                response.error.message,
448            )),
449
450            _ => Err(anyhow!(
451                "Failed to connect to OpenAI API: {} {}",
452                response.status(),
453                body,
454            )),
455        }
456    }
457}
458
459pub async fn complete_text(
460    client: &dyn HttpClient,
461    api_url: &str,
462    api_key: &str,
463    request: CompletionRequest,
464) -> Result<CompletionResponse> {
465    let uri = format!("{api_url}/completions");
466    let request_builder = HttpRequest::builder()
467        .method(Method::POST)
468        .uri(uri)
469        .header("Content-Type", "application/json")
470        .header("Authorization", format!("Bearer {}", api_key));
471
472    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
473    let mut response = client.send(request).await?;
474
475    if response.status().is_success() {
476        let mut body = String::new();
477        response.body_mut().read_to_string(&mut body).await?;
478        let response = serde_json::from_str(&body)?;
479        Ok(response)
480    } else {
481        let mut body = String::new();
482        response.body_mut().read_to_string(&mut body).await?;
483
484        #[derive(Deserialize)]
485        struct OpenAiResponse {
486            error: OpenAiError,
487        }
488
489        #[derive(Deserialize)]
490        struct OpenAiError {
491            message: String,
492        }
493
494        match serde_json::from_str::<OpenAiResponse>(&body) {
495            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
496                "Failed to connect to OpenAI API: {}",
497                response.error.message,
498            )),
499
500            _ => Err(anyhow!(
501                "Failed to connect to OpenAI API: {} {}",
502                response.status(),
503                body,
504            )),
505        }
506    }
507}
508
509fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
510    ResponseStreamEvent {
511        created: response.created as u32,
512        model: response.model,
513        choices: response
514            .choices
515            .into_iter()
516            .map(|choice| ChoiceDelta {
517                index: choice.index,
518                delta: ResponseMessageDelta {
519                    role: Some(match choice.message {
520                        RequestMessage::Assistant { .. } => Role::Assistant,
521                        RequestMessage::User { .. } => Role::User,
522                        RequestMessage::System { .. } => Role::System,
523                        RequestMessage::Tool { .. } => Role::Tool,
524                    }),
525                    content: match choice.message {
526                        RequestMessage::Assistant { content, .. } => content,
527                        RequestMessage::User { content } => Some(content),
528                        RequestMessage::System { content } => Some(content),
529                        RequestMessage::Tool { content, .. } => Some(content),
530                    },
531                    tool_calls: None,
532                },
533                finish_reason: choice.finish_reason,
534            })
535            .collect(),
536        usage: Some(response.usage),
537    }
538}
539
540pub async fn stream_completion(
541    client: &dyn HttpClient,
542    api_url: &str,
543    api_key: &str,
544    request: Request,
545) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
546    if request.model.starts_with("o1") {
547        let response = complete(client, api_url, api_key, request).await;
548        let response_stream_event = response.map(adapt_response_to_stream);
549        return Ok(stream::once(future::ready(response_stream_event)).boxed());
550    }
551
552    let uri = format!("{api_url}/chat/completions");
553    let request_builder = HttpRequest::builder()
554        .method(Method::POST)
555        .uri(uri)
556        .header("Content-Type", "application/json")
557        .header("Authorization", format!("Bearer {}", api_key));
558
559    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
560    let mut response = client.send(request).await?;
561    if response.status().is_success() {
562        let reader = BufReader::new(response.into_body());
563        Ok(reader
564            .lines()
565            .filter_map(|line| async move {
566                match line {
567                    Ok(line) => {
568                        let line = line.strip_prefix("data: ")?;
569                        if line == "[DONE]" {
570                            None
571                        } else {
572                            match serde_json::from_str(line) {
573                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
574                                Ok(ResponseStreamResult::Err { error }) => {
575                                    Some(Err(anyhow!(error)))
576                                }
577                                Err(error) => Some(Err(anyhow!(error))),
578                            }
579                        }
580                    }
581                    Err(error) => Some(Err(anyhow!(error))),
582                }
583            })
584            .boxed())
585    } else {
586        let mut body = String::new();
587        response.body_mut().read_to_string(&mut body).await?;
588
589        #[derive(Deserialize)]
590        struct OpenAiResponse {
591            error: OpenAiError,
592        }
593
594        #[derive(Deserialize)]
595        struct OpenAiError {
596            message: String,
597        }
598
599        match serde_json::from_str::<OpenAiResponse>(&body) {
600            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
601                "Failed to connect to OpenAI API: {}",
602                response.error.message,
603            )),
604
605            _ => Err(anyhow!(
606                "Failed to connect to OpenAI API: {} {}",
607                response.status(),
608                body,
609            )),
610        }
611    }
612}
613
614#[derive(Copy, Clone, Serialize, Deserialize)]
615pub enum OpenAiEmbeddingModel {
616    #[serde(rename = "text-embedding-3-small")]
617    TextEmbedding3Small,
618    #[serde(rename = "text-embedding-3-large")]
619    TextEmbedding3Large,
620}
621
622#[derive(Serialize)]
623struct OpenAiEmbeddingRequest<'a> {
624    model: OpenAiEmbeddingModel,
625    input: Vec<&'a str>,
626}
627
628#[derive(Deserialize)]
629pub struct OpenAiEmbeddingResponse {
630    pub data: Vec<OpenAiEmbedding>,
631}
632
633#[derive(Deserialize)]
634pub struct OpenAiEmbedding {
635    pub embedding: Vec<f32>,
636}
637
638pub fn embed<'a>(
639    client: &dyn HttpClient,
640    api_url: &str,
641    api_key: &str,
642    model: OpenAiEmbeddingModel,
643    texts: impl IntoIterator<Item = &'a str>,
644) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
645    let uri = format!("{api_url}/embeddings");
646
647    let request = OpenAiEmbeddingRequest {
648        model,
649        input: texts.into_iter().collect(),
650    };
651    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
652    let request = HttpRequest::builder()
653        .method(Method::POST)
654        .uri(uri)
655        .header("Content-Type", "application/json")
656        .header("Authorization", format!("Bearer {}", api_key))
657        .body(body)
658        .map(|request| client.send(request));
659
660    async move {
661        let mut response = request?.await?;
662        let mut body = String::new();
663        response.body_mut().read_to_string(&mut body).await?;
664
665        if response.status().is_success() {
666            let response: OpenAiEmbeddingResponse =
667                serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
668            Ok(response)
669        } else {
670            Err(anyhow!(
671                "error during embedding, status: {:?}, body: {:?}",
672                response.status(),
673                body
674            ))
675        }
676    }
677}