open_ai.rs

  1mod supported_countries;
  2
  3use anyhow::{Context as _, Result, anyhow};
  4use futures::{
  5    AsyncBufReadExt, AsyncReadExt, StreamExt,
  6    io::BufReader,
  7    stream::{self, BoxStream},
  8};
  9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 10use serde::{Deserialize, Serialize};
 11use serde_json::Value;
 12use std::{
 13    convert::TryFrom,
 14    future::{self, Future},
 15};
 16use strum::EnumIter;
 17
 18pub use supported_countries::*;
 19
 20pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 21
 22fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 23    opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
 24}
 25
 26#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 27#[serde(rename_all = "lowercase")]
 28pub enum Role {
 29    User,
 30    Assistant,
 31    System,
 32    Tool,
 33}
 34
 35impl TryFrom<String> for Role {
 36    type Error = anyhow::Error;
 37
 38    fn try_from(value: String) -> Result<Self> {
 39        match value.as_str() {
 40            "user" => Ok(Self::User),
 41            "assistant" => Ok(Self::Assistant),
 42            "system" => Ok(Self::System),
 43            "tool" => Ok(Self::Tool),
 44            _ => Err(anyhow!("invalid role '{value}'")),
 45        }
 46    }
 47}
 48
 49impl From<Role> for String {
 50    fn from(val: Role) -> Self {
 51        match val {
 52            Role::User => "user".to_owned(),
 53            Role::Assistant => "assistant".to_owned(),
 54            Role::System => "system".to_owned(),
 55            Role::Tool => "tool".to_owned(),
 56        }
 57    }
 58}
 59
 60#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 61#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 62pub enum Model {
 63    #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo")]
 64    ThreePointFiveTurbo,
 65    #[serde(rename = "gpt-4", alias = "gpt-4")]
 66    Four,
 67    #[serde(rename = "gpt-4-turbo", alias = "gpt-4-turbo")]
 68    FourTurbo,
 69    #[serde(rename = "gpt-4o", alias = "gpt-4o")]
 70    #[default]
 71    FourOmni,
 72    #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini")]
 73    FourOmniMini,
 74    #[serde(rename = "gpt-4.1", alias = "gpt-4.1")]
 75    FourPointOne,
 76    #[serde(rename = "gpt-4.1-mini", alias = "gpt-4.1-mini")]
 77    FourPointOneMini,
 78    #[serde(rename = "gpt-4.1-nano", alias = "gpt-4.1-nano")]
 79    FourPointOneNano,
 80    #[serde(rename = "o1", alias = "o1")]
 81    O1,
 82    #[serde(rename = "o1-preview", alias = "o1-preview")]
 83    O1Preview,
 84    #[serde(rename = "o1-mini", alias = "o1-mini")]
 85    O1Mini,
 86    #[serde(rename = "o3-mini", alias = "o3-mini")]
 87    O3Mini,
 88
 89    #[serde(rename = "custom")]
 90    Custom {
 91        name: String,
 92        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 93        display_name: Option<String>,
 94        max_tokens: usize,
 95        max_output_tokens: Option<u32>,
 96        max_completion_tokens: Option<u32>,
 97    },
 98}
 99
100impl Model {
101    pub fn from_id(id: &str) -> Result<Self> {
102        match id {
103            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
104            "gpt-4" => Ok(Self::Four),
105            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
106            "gpt-4o" => Ok(Self::FourOmni),
107            "gpt-4o-mini" => Ok(Self::FourOmniMini),
108            "gpt-4.1" => Ok(Self::FourPointOne),
109            "gpt-4.1-mini" => Ok(Self::FourPointOneMini),
110            "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
111            "o1" => Ok(Self::O1),
112            "o1-preview" => Ok(Self::O1Preview),
113            "o1-mini" => Ok(Self::O1Mini),
114            "o3-mini" => Ok(Self::O3Mini),
115            _ => Err(anyhow!("invalid model id")),
116        }
117    }
118
119    pub fn id(&self) -> &str {
120        match self {
121            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
122            Self::Four => "gpt-4",
123            Self::FourTurbo => "gpt-4-turbo",
124            Self::FourOmni => "gpt-4o",
125            Self::FourOmniMini => "gpt-4o-mini",
126            Self::FourPointOne => "gpt-4.1",
127            Self::FourPointOneMini => "gpt-4.1-mini",
128            Self::FourPointOneNano => "gpt-4.1-nano",
129            Self::O1 => "o1",
130            Self::O1Preview => "o1-preview",
131            Self::O1Mini => "o1-mini",
132            Self::O3Mini => "o3-mini",
133            Self::Custom { name, .. } => name,
134        }
135    }
136
137    pub fn display_name(&self) -> &str {
138        match self {
139            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
140            Self::Four => "gpt-4",
141            Self::FourTurbo => "gpt-4-turbo",
142            Self::FourOmni => "gpt-4o",
143            Self::FourOmniMini => "gpt-4o-mini",
144            Self::FourPointOne => "gpt-4.1",
145            Self::FourPointOneMini => "gpt-4.1-mini",
146            Self::FourPointOneNano => "gpt-4.1-nano",
147            Self::O1 => "o1",
148            Self::O1Preview => "o1-preview",
149            Self::O1Mini => "o1-mini",
150            Self::O3Mini => "o3-mini",
151            Self::Custom {
152                name, display_name, ..
153            } => display_name.as_ref().unwrap_or(name),
154        }
155    }
156
157    pub fn max_token_count(&self) -> usize {
158        match self {
159            Self::ThreePointFiveTurbo => 16_385,
160            Self::Four => 8_192,
161            Self::FourTurbo => 128_000,
162            Self::FourOmni => 128_000,
163            Self::FourOmniMini => 128_000,
164            Self::FourPointOne => 1_047_576,
165            Self::FourPointOneMini => 1_047_576,
166            Self::FourPointOneNano => 1_047_576,
167            Self::O1 => 200_000,
168            Self::O1Preview => 128_000,
169            Self::O1Mini => 128_000,
170            Self::O3Mini => 200_000,
171            Self::Custom { max_tokens, .. } => *max_tokens,
172        }
173    }
174
175    pub fn max_output_tokens(&self) -> Option<u32> {
176        match self {
177            Self::Custom {
178                max_output_tokens, ..
179            } => *max_output_tokens,
180            _ => None,
181        }
182    }
183
184    /// Returns whether the given model supports the `parallel_tool_calls` parameter.
185    ///
186    /// If the model does not support the parameter, do not pass it up, or the API will return an error.
187    pub fn supports_parallel_tool_calls(&self) -> bool {
188        match self {
189            Self::ThreePointFiveTurbo
190            | Self::Four
191            | Self::FourTurbo
192            | Self::FourOmni
193            | Self::FourOmniMini
194            | Self::FourPointOne
195            | Self::FourPointOneMini
196            | Self::FourPointOneNano
197            | Self::O1
198            | Self::O1Preview
199            | Self::O1Mini => true,
200            _ => false,
201        }
202    }
203}
204
205#[derive(Debug, Serialize, Deserialize)]
206pub struct Request {
207    pub model: String,
208    pub messages: Vec<RequestMessage>,
209    pub stream: bool,
210    #[serde(default, skip_serializing_if = "Option::is_none")]
211    pub max_tokens: Option<u32>,
212    #[serde(default, skip_serializing_if = "Vec::is_empty")]
213    pub stop: Vec<String>,
214    pub temperature: f32,
215    #[serde(default, skip_serializing_if = "Option::is_none")]
216    pub tool_choice: Option<ToolChoice>,
217    /// Whether to enable parallel function calling during tool use.
218    #[serde(default, skip_serializing_if = "Option::is_none")]
219    pub parallel_tool_calls: Option<bool>,
220    #[serde(default, skip_serializing_if = "Vec::is_empty")]
221    pub tools: Vec<ToolDefinition>,
222}
223
224#[derive(Debug, Serialize, Deserialize)]
225pub struct CompletionRequest {
226    pub model: String,
227    pub prompt: String,
228    pub max_tokens: u32,
229    pub temperature: f32,
230    #[serde(default, skip_serializing_if = "Option::is_none")]
231    pub prediction: Option<Prediction>,
232    #[serde(default, skip_serializing_if = "Option::is_none")]
233    pub rewrite_speculation: Option<bool>,
234}
235
236#[derive(Clone, Deserialize, Serialize, Debug)]
237#[serde(tag = "type", rename_all = "snake_case")]
238pub enum Prediction {
239    Content { content: String },
240}
241
242#[derive(Debug, Serialize, Deserialize)]
243#[serde(untagged)]
244pub enum ToolChoice {
245    Auto,
246    Required,
247    None,
248    Other(ToolDefinition),
249}
250
251#[derive(Clone, Deserialize, Serialize, Debug)]
252#[serde(tag = "type", rename_all = "snake_case")]
253pub enum ToolDefinition {
254    #[allow(dead_code)]
255    Function { function: FunctionDefinition },
256}
257
258#[derive(Clone, Debug, Serialize, Deserialize)]
259pub struct FunctionDefinition {
260    pub name: String,
261    pub description: Option<String>,
262    pub parameters: Option<Value>,
263}
264
265#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
266#[serde(tag = "role", rename_all = "lowercase")]
267pub enum RequestMessage {
268    Assistant {
269        content: Option<String>,
270        #[serde(default, skip_serializing_if = "Vec::is_empty")]
271        tool_calls: Vec<ToolCall>,
272    },
273    User {
274        content: String,
275    },
276    System {
277        content: String,
278    },
279    Tool {
280        content: String,
281        tool_call_id: String,
282    },
283}
284
285#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
286pub struct ToolCall {
287    pub id: String,
288    #[serde(flatten)]
289    pub content: ToolCallContent,
290}
291
292#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
293#[serde(tag = "type", rename_all = "lowercase")]
294pub enum ToolCallContent {
295    Function { function: FunctionContent },
296}
297
298#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
299pub struct FunctionContent {
300    pub name: String,
301    pub arguments: String,
302}
303
304#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
305pub struct ResponseMessageDelta {
306    pub role: Option<Role>,
307    pub content: Option<String>,
308    #[serde(default, skip_serializing_if = "is_none_or_empty")]
309    pub tool_calls: Option<Vec<ToolCallChunk>>,
310}
311
312#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
313pub struct ToolCallChunk {
314    pub index: usize,
315    pub id: Option<String>,
316
317    // There is also an optional `type` field that would determine if a
318    // function is there. Sometimes this streams in with the `function` before
319    // it streams in the `type`
320    pub function: Option<FunctionChunk>,
321}
322
323#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
324pub struct FunctionChunk {
325    pub name: Option<String>,
326    pub arguments: Option<String>,
327}
328
329#[derive(Serialize, Deserialize, Debug)]
330pub struct Usage {
331    pub prompt_tokens: u32,
332    pub completion_tokens: u32,
333    pub total_tokens: u32,
334}
335
336#[derive(Serialize, Deserialize, Debug)]
337pub struct ChoiceDelta {
338    pub index: u32,
339    pub delta: ResponseMessageDelta,
340    pub finish_reason: Option<String>,
341}
342
343#[derive(Serialize, Deserialize, Debug)]
344#[serde(untagged)]
345pub enum ResponseStreamResult {
346    Ok(ResponseStreamEvent),
347    Err { error: String },
348}
349
350#[derive(Serialize, Deserialize, Debug)]
351pub struct ResponseStreamEvent {
352    pub created: u32,
353    pub model: String,
354    pub choices: Vec<ChoiceDelta>,
355    pub usage: Option<Usage>,
356}
357
358#[derive(Serialize, Deserialize, Debug)]
359pub struct CompletionResponse {
360    pub id: String,
361    pub object: String,
362    pub created: u64,
363    pub model: String,
364    pub choices: Vec<CompletionChoice>,
365    pub usage: Usage,
366}
367
368#[derive(Serialize, Deserialize, Debug)]
369pub struct CompletionChoice {
370    pub text: String,
371}
372
373#[derive(Serialize, Deserialize, Debug)]
374pub struct Response {
375    pub id: String,
376    pub object: String,
377    pub created: u64,
378    pub model: String,
379    pub choices: Vec<Choice>,
380    pub usage: Usage,
381}
382
383#[derive(Serialize, Deserialize, Debug)]
384pub struct Choice {
385    pub index: u32,
386    pub message: RequestMessage,
387    pub finish_reason: Option<String>,
388}
389
390pub async fn complete(
391    client: &dyn HttpClient,
392    api_url: &str,
393    api_key: &str,
394    request: Request,
395) -> Result<Response> {
396    let uri = format!("{api_url}/chat/completions");
397    let request_builder = HttpRequest::builder()
398        .method(Method::POST)
399        .uri(uri)
400        .header("Content-Type", "application/json")
401        .header("Authorization", format!("Bearer {}", api_key));
402
403    let mut request_body = request;
404    request_body.stream = false;
405
406    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
407    let mut response = client.send(request).await?;
408
409    if response.status().is_success() {
410        let mut body = String::new();
411        response.body_mut().read_to_string(&mut body).await?;
412        let response: Response = serde_json::from_str(&body)?;
413        Ok(response)
414    } else {
415        let mut body = String::new();
416        response.body_mut().read_to_string(&mut body).await?;
417
418        #[derive(Deserialize)]
419        struct OpenAiResponse {
420            error: OpenAiError,
421        }
422
423        #[derive(Deserialize)]
424        struct OpenAiError {
425            message: String,
426        }
427
428        match serde_json::from_str::<OpenAiResponse>(&body) {
429            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
430                "Failed to connect to OpenAI API: {}",
431                response.error.message,
432            )),
433
434            _ => Err(anyhow!(
435                "Failed to connect to OpenAI API: {} {}",
436                response.status(),
437                body,
438            )),
439        }
440    }
441}
442
443pub async fn complete_text(
444    client: &dyn HttpClient,
445    api_url: &str,
446    api_key: &str,
447    request: CompletionRequest,
448) -> Result<CompletionResponse> {
449    let uri = format!("{api_url}/completions");
450    let request_builder = HttpRequest::builder()
451        .method(Method::POST)
452        .uri(uri)
453        .header("Content-Type", "application/json")
454        .header("Authorization", format!("Bearer {}", api_key));
455
456    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
457    let mut response = client.send(request).await?;
458
459    if response.status().is_success() {
460        let mut body = String::new();
461        response.body_mut().read_to_string(&mut body).await?;
462        let response = serde_json::from_str(&body)?;
463        Ok(response)
464    } else {
465        let mut body = String::new();
466        response.body_mut().read_to_string(&mut body).await?;
467
468        #[derive(Deserialize)]
469        struct OpenAiResponse {
470            error: OpenAiError,
471        }
472
473        #[derive(Deserialize)]
474        struct OpenAiError {
475            message: String,
476        }
477
478        match serde_json::from_str::<OpenAiResponse>(&body) {
479            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
480                "Failed to connect to OpenAI API: {}",
481                response.error.message,
482            )),
483
484            _ => Err(anyhow!(
485                "Failed to connect to OpenAI API: {} {}",
486                response.status(),
487                body,
488            )),
489        }
490    }
491}
492
493fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
494    ResponseStreamEvent {
495        created: response.created as u32,
496        model: response.model,
497        choices: response
498            .choices
499            .into_iter()
500            .map(|choice| ChoiceDelta {
501                index: choice.index,
502                delta: ResponseMessageDelta {
503                    role: Some(match choice.message {
504                        RequestMessage::Assistant { .. } => Role::Assistant,
505                        RequestMessage::User { .. } => Role::User,
506                        RequestMessage::System { .. } => Role::System,
507                        RequestMessage::Tool { .. } => Role::Tool,
508                    }),
509                    content: match choice.message {
510                        RequestMessage::Assistant { content, .. } => content,
511                        RequestMessage::User { content } => Some(content),
512                        RequestMessage::System { content } => Some(content),
513                        RequestMessage::Tool { content, .. } => Some(content),
514                    },
515                    tool_calls: None,
516                },
517                finish_reason: choice.finish_reason,
518            })
519            .collect(),
520        usage: Some(response.usage),
521    }
522}
523
524pub async fn stream_completion(
525    client: &dyn HttpClient,
526    api_url: &str,
527    api_key: &str,
528    request: Request,
529) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
530    if request.model.starts_with("o1") {
531        let response = complete(client, api_url, api_key, request).await;
532        let response_stream_event = response.map(adapt_response_to_stream);
533        return Ok(stream::once(future::ready(response_stream_event)).boxed());
534    }
535
536    let uri = format!("{api_url}/chat/completions");
537    let request_builder = HttpRequest::builder()
538        .method(Method::POST)
539        .uri(uri)
540        .header("Content-Type", "application/json")
541        .header("Authorization", format!("Bearer {}", api_key));
542
543    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
544    let mut response = client.send(request).await?;
545    if response.status().is_success() {
546        let reader = BufReader::new(response.into_body());
547        Ok(reader
548            .lines()
549            .filter_map(|line| async move {
550                match line {
551                    Ok(line) => {
552                        let line = line.strip_prefix("data: ")?;
553                        if line == "[DONE]" {
554                            None
555                        } else {
556                            match serde_json::from_str(line) {
557                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
558                                Ok(ResponseStreamResult::Err { error }) => {
559                                    Some(Err(anyhow!(error)))
560                                }
561                                Err(error) => Some(Err(anyhow!(error))),
562                            }
563                        }
564                    }
565                    Err(error) => Some(Err(anyhow!(error))),
566                }
567            })
568            .boxed())
569    } else {
570        let mut body = String::new();
571        response.body_mut().read_to_string(&mut body).await?;
572
573        #[derive(Deserialize)]
574        struct OpenAiResponse {
575            error: OpenAiError,
576        }
577
578        #[derive(Deserialize)]
579        struct OpenAiError {
580            message: String,
581        }
582
583        match serde_json::from_str::<OpenAiResponse>(&body) {
584            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
585                "Failed to connect to OpenAI API: {}",
586                response.error.message,
587            )),
588
589            _ => Err(anyhow!(
590                "Failed to connect to OpenAI API: {} {}",
591                response.status(),
592                body,
593            )),
594        }
595    }
596}
597
598#[derive(Copy, Clone, Serialize, Deserialize)]
599pub enum OpenAiEmbeddingModel {
600    #[serde(rename = "text-embedding-3-small")]
601    TextEmbedding3Small,
602    #[serde(rename = "text-embedding-3-large")]
603    TextEmbedding3Large,
604}
605
606#[derive(Serialize)]
607struct OpenAiEmbeddingRequest<'a> {
608    model: OpenAiEmbeddingModel,
609    input: Vec<&'a str>,
610}
611
612#[derive(Deserialize)]
613pub struct OpenAiEmbeddingResponse {
614    pub data: Vec<OpenAiEmbedding>,
615}
616
617#[derive(Deserialize)]
618pub struct OpenAiEmbedding {
619    pub embedding: Vec<f32>,
620}
621
622pub fn embed<'a>(
623    client: &dyn HttpClient,
624    api_url: &str,
625    api_key: &str,
626    model: OpenAiEmbeddingModel,
627    texts: impl IntoIterator<Item = &'a str>,
628) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
629    let uri = format!("{api_url}/embeddings");
630
631    let request = OpenAiEmbeddingRequest {
632        model,
633        input: texts.into_iter().collect(),
634    };
635    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
636    let request = HttpRequest::builder()
637        .method(Method::POST)
638        .uri(uri)
639        .header("Content-Type", "application/json")
640        .header("Authorization", format!("Bearer {}", api_key))
641        .body(body)
642        .map(|request| client.send(request));
643
644    async move {
645        let mut response = request?.await?;
646        let mut body = String::new();
647        response.body_mut().read_to_string(&mut body).await?;
648
649        if response.status().is_success() {
650            let response: OpenAiEmbeddingResponse =
651                serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
652            Ok(response)
653        } else {
654            Err(anyhow!(
655                "error during embedding, status: {:?}, body: {:?}",
656                response.status(),
657                body
658            ))
659        }
660    }
661}