open_ai.rs

  1pub mod batches;
  2pub mod responses;
  3
  4use anyhow::{Context as _, Result, anyhow};
  5use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  6use http_client::{
  7    AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode,
  8    http::{HeaderMap, HeaderValue},
  9};
 10use serde::{Deserialize, Serialize};
 11use serde_json::Value;
 12pub use settings::OpenAiReasoningEffort as ReasoningEffort;
 13use std::{convert::TryFrom, future::Future};
 14use strum::EnumIter;
 15use thiserror::Error;
 16
 17pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 18
 19fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 20    opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
 21}
 22
 23#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 24#[serde(rename_all = "lowercase")]
 25pub enum Role {
 26    User,
 27    Assistant,
 28    System,
 29    Tool,
 30}
 31
 32impl TryFrom<String> for Role {
 33    type Error = anyhow::Error;
 34
 35    fn try_from(value: String) -> Result<Self> {
 36        match value.as_str() {
 37            "user" => Ok(Self::User),
 38            "assistant" => Ok(Self::Assistant),
 39            "system" => Ok(Self::System),
 40            "tool" => Ok(Self::Tool),
 41            _ => anyhow::bail!("invalid role '{value}'"),
 42        }
 43    }
 44}
 45
 46impl From<Role> for String {
 47    fn from(val: Role) -> Self {
 48        match val {
 49            Role::User => "user".to_owned(),
 50            Role::Assistant => "assistant".to_owned(),
 51            Role::System => "system".to_owned(),
 52            Role::Tool => "tool".to_owned(),
 53        }
 54    }
 55}
 56
 57#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 58#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 59pub enum Model {
 60    #[serde(rename = "gpt-3.5-turbo")]
 61    ThreePointFiveTurbo,
 62    #[serde(rename = "gpt-4")]
 63    Four,
 64    #[serde(rename = "gpt-4-turbo")]
 65    FourTurbo,
 66    #[serde(rename = "gpt-4o-mini")]
 67    FourOmniMini,
 68    #[serde(rename = "gpt-4.1-nano")]
 69    FourPointOneNano,
 70    #[serde(rename = "o1")]
 71    O1,
 72    #[serde(rename = "o3-mini")]
 73    O3Mini,
 74    #[serde(rename = "o3")]
 75    O3,
 76    #[serde(rename = "gpt-5")]
 77    Five,
 78    #[serde(rename = "gpt-5-codex")]
 79    FiveCodex,
 80    #[serde(rename = "gpt-5-mini")]
 81    #[default]
 82    FiveMini,
 83    #[serde(rename = "gpt-5-nano")]
 84    FiveNano,
 85    #[serde(rename = "gpt-5.1")]
 86    FivePointOne,
 87    #[serde(rename = "gpt-5.2")]
 88    FivePointTwo,
 89    #[serde(rename = "gpt-5.2-codex")]
 90    FivePointTwoCodex,
 91    #[serde(rename = "gpt-5.3-codex")]
 92    FivePointThreeCodex,
 93    #[serde(rename = "custom")]
 94    Custom {
 95        name: String,
 96        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 97        display_name: Option<String>,
 98        max_tokens: u64,
 99        max_output_tokens: Option<u64>,
100        max_completion_tokens: Option<u64>,
101        reasoning_effort: Option<ReasoningEffort>,
102        #[serde(default = "default_supports_chat_completions")]
103        supports_chat_completions: bool,
104    },
105}
106
107const fn default_supports_chat_completions() -> bool {
108    true
109}
110
111impl Model {
112    pub fn default_fast() -> Self {
113        Self::FiveMini
114    }
115
116    pub fn from_id(id: &str) -> Result<Self> {
117        match id {
118            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
119            "gpt-4" => Ok(Self::Four),
120            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
121            "gpt-4o-mini" => Ok(Self::FourOmniMini),
122            "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
123            "o1" => Ok(Self::O1),
124            "o3-mini" => Ok(Self::O3Mini),
125            "o3" => Ok(Self::O3),
126            "gpt-5" => Ok(Self::Five),
127            "gpt-5-codex" => Ok(Self::FiveCodex),
128            "gpt-5-mini" => Ok(Self::FiveMini),
129            "gpt-5-nano" => Ok(Self::FiveNano),
130            "gpt-5.1" => Ok(Self::FivePointOne),
131            "gpt-5.2" => Ok(Self::FivePointTwo),
132            "gpt-5.2-codex" => Ok(Self::FivePointTwoCodex),
133            "gpt-5.3-codex" => Ok(Self::FivePointThreeCodex),
134            invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
135        }
136    }
137
138    pub fn id(&self) -> &str {
139        match self {
140            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
141            Self::Four => "gpt-4",
142            Self::FourTurbo => "gpt-4-turbo",
143            Self::FourOmniMini => "gpt-4o-mini",
144            Self::FourPointOneNano => "gpt-4.1-nano",
145            Self::O1 => "o1",
146            Self::O3Mini => "o3-mini",
147            Self::O3 => "o3",
148            Self::Five => "gpt-5",
149            Self::FiveCodex => "gpt-5-codex",
150            Self::FiveMini => "gpt-5-mini",
151            Self::FiveNano => "gpt-5-nano",
152            Self::FivePointOne => "gpt-5.1",
153            Self::FivePointTwo => "gpt-5.2",
154            Self::FivePointTwoCodex => "gpt-5.2-codex",
155            Self::FivePointThreeCodex => "gpt-5.3-codex",
156            Self::Custom { name, .. } => name,
157        }
158    }
159
160    pub fn display_name(&self) -> &str {
161        match self {
162            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
163            Self::Four => "gpt-4",
164            Self::FourTurbo => "gpt-4-turbo",
165            Self::FourOmniMini => "gpt-4o-mini",
166            Self::FourPointOneNano => "gpt-4.1-nano",
167            Self::O1 => "o1",
168            Self::O3Mini => "o3-mini",
169            Self::O3 => "o3",
170            Self::Five => "gpt-5",
171            Self::FiveCodex => "gpt-5-codex",
172            Self::FiveMini => "gpt-5-mini",
173            Self::FiveNano => "gpt-5-nano",
174            Self::FivePointOne => "gpt-5.1",
175            Self::FivePointTwo => "gpt-5.2",
176            Self::FivePointTwoCodex => "gpt-5.2-codex",
177            Self::FivePointThreeCodex => "gpt-5.3-codex",
178            Self::Custom { display_name, .. } => display_name.as_deref().unwrap_or(&self.id()),
179        }
180    }
181
182    pub fn max_token_count(&self) -> u64 {
183        match self {
184            Self::ThreePointFiveTurbo => 16_385,
185            Self::Four => 8_192,
186            Self::FourTurbo => 128_000,
187            Self::FourOmniMini => 128_000,
188            Self::FourPointOneNano => 1_047_576,
189            Self::O1 => 200_000,
190            Self::O3Mini => 200_000,
191            Self::O3 => 200_000,
192            Self::Five => 272_000,
193            Self::FiveCodex => 272_000,
194            Self::FiveMini => 272_000,
195            Self::FiveNano => 272_000,
196            Self::FivePointOne => 400_000,
197            Self::FivePointTwo => 400_000,
198            Self::FivePointTwoCodex => 400_000,
199            Self::FivePointThreeCodex => 400_000,
200            Self::Custom { max_tokens, .. } => *max_tokens,
201        }
202    }
203
204    pub fn max_output_tokens(&self) -> Option<u64> {
205        match self {
206            Self::Custom {
207                max_output_tokens, ..
208            } => *max_output_tokens,
209            Self::ThreePointFiveTurbo => Some(4_096),
210            Self::Four => Some(8_192),
211            Self::FourTurbo => Some(4_096),
212            Self::FourOmniMini => Some(16_384),
213            Self::FourPointOneNano => Some(32_768),
214            Self::O1 => Some(100_000),
215            Self::O3Mini => Some(100_000),
216            Self::O3 => Some(100_000),
217            Self::Five => Some(128_000),
218            Self::FiveCodex => Some(128_000),
219            Self::FiveMini => Some(128_000),
220            Self::FiveNano => Some(128_000),
221            Self::FivePointOne => Some(128_000),
222            Self::FivePointTwo => Some(128_000),
223            Self::FivePointTwoCodex => Some(128_000),
224            Self::FivePointThreeCodex => Some(128_000),
225        }
226    }
227
228    pub fn reasoning_effort(&self) -> Option<ReasoningEffort> {
229        match self {
230            Self::Custom {
231                reasoning_effort, ..
232            } => reasoning_effort.to_owned(),
233            Self::FivePointThreeCodex => Some(ReasoningEffort::Medium),
234            _ => None,
235        }
236    }
237
238    pub fn supports_chat_completions(&self) -> bool {
239        match self {
240            Self::Custom {
241                supports_chat_completions,
242                ..
243            } => *supports_chat_completions,
244            Self::FiveCodex | Self::FivePointTwoCodex | Self::FivePointThreeCodex => false,
245            _ => true,
246        }
247    }
248
249    /// Returns whether the given model supports the `parallel_tool_calls` parameter.
250    ///
251    /// If the model does not support the parameter, do not pass it up, or the API will return an error.
252    pub fn supports_parallel_tool_calls(&self) -> bool {
253        match self {
254            Self::ThreePointFiveTurbo
255            | Self::Four
256            | Self::FourTurbo
257            | Self::FourOmniMini
258            | Self::FourPointOneNano
259            | Self::Five
260            | Self::FiveCodex
261            | Self::FiveMini
262            | Self::FivePointOne
263            | Self::FivePointTwo
264            | Self::FivePointTwoCodex
265            | Self::FivePointThreeCodex
266            | Self::FiveNano => true,
267            Self::O1 | Self::O3 | Self::O3Mini | Model::Custom { .. } => false,
268        }
269    }
270
271    /// Returns whether the given model supports the `prompt_cache_key` parameter.
272    ///
273    /// If the model does not support the parameter, do not pass it up.
274    pub fn supports_prompt_cache_key(&self) -> bool {
275        true
276    }
277}
278
279#[derive(Debug, Serialize, Deserialize)]
280pub struct Request {
281    pub model: String,
282    pub messages: Vec<RequestMessage>,
283    pub stream: bool,
284    #[serde(default, skip_serializing_if = "Option::is_none")]
285    pub max_completion_tokens: Option<u64>,
286    #[serde(default, skip_serializing_if = "Vec::is_empty")]
287    pub stop: Vec<String>,
288    #[serde(default, skip_serializing_if = "Option::is_none")]
289    pub temperature: Option<f32>,
290    #[serde(default, skip_serializing_if = "Option::is_none")]
291    pub tool_choice: Option<ToolChoice>,
292    /// Whether to enable parallel function calling during tool use.
293    #[serde(default, skip_serializing_if = "Option::is_none")]
294    pub parallel_tool_calls: Option<bool>,
295    #[serde(default, skip_serializing_if = "Vec::is_empty")]
296    pub tools: Vec<ToolDefinition>,
297    #[serde(default, skip_serializing_if = "Option::is_none")]
298    pub prompt_cache_key: Option<String>,
299    #[serde(default, skip_serializing_if = "Option::is_none")]
300    pub reasoning_effort: Option<ReasoningEffort>,
301}
302
303#[derive(Debug, Serialize, Deserialize)]
304#[serde(rename_all = "lowercase")]
305pub enum ToolChoice {
306    Auto,
307    Required,
308    None,
309    #[serde(untagged)]
310    Other(ToolDefinition),
311}
312
313#[derive(Clone, Deserialize, Serialize, Debug)]
314#[serde(tag = "type", rename_all = "snake_case")]
315pub enum ToolDefinition {
316    #[allow(dead_code)]
317    Function { function: FunctionDefinition },
318}
319
320#[derive(Clone, Debug, Serialize, Deserialize)]
321pub struct FunctionDefinition {
322    pub name: String,
323    pub description: Option<String>,
324    pub parameters: Option<Value>,
325}
326
327#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
328#[serde(tag = "role", rename_all = "lowercase")]
329pub enum RequestMessage {
330    Assistant {
331        content: Option<MessageContent>,
332        #[serde(default, skip_serializing_if = "Vec::is_empty")]
333        tool_calls: Vec<ToolCall>,
334    },
335    User {
336        content: MessageContent,
337    },
338    System {
339        content: MessageContent,
340    },
341    Tool {
342        content: MessageContent,
343        tool_call_id: String,
344    },
345}
346
347#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
348#[serde(untagged)]
349pub enum MessageContent {
350    Plain(String),
351    Multipart(Vec<MessagePart>),
352}
353
354impl MessageContent {
355    pub fn empty() -> Self {
356        MessageContent::Multipart(vec![])
357    }
358
359    pub fn push_part(&mut self, part: MessagePart) {
360        match self {
361            MessageContent::Plain(text) => {
362                *self =
363                    MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
364            }
365            MessageContent::Multipart(parts) if parts.is_empty() => match part {
366                MessagePart::Text { text } => *self = MessageContent::Plain(text),
367                MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
368            },
369            MessageContent::Multipart(parts) => parts.push(part),
370        }
371    }
372}
373
374impl From<Vec<MessagePart>> for MessageContent {
375    fn from(mut parts: Vec<MessagePart>) -> Self {
376        if let [MessagePart::Text { text }] = parts.as_mut_slice() {
377            MessageContent::Plain(std::mem::take(text))
378        } else {
379            MessageContent::Multipart(parts)
380        }
381    }
382}
383
384#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
385#[serde(tag = "type")]
386pub enum MessagePart {
387    #[serde(rename = "text")]
388    Text { text: String },
389    #[serde(rename = "image_url")]
390    Image { image_url: ImageUrl },
391}
392
393#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
394pub struct ImageUrl {
395    pub url: String,
396    #[serde(skip_serializing_if = "Option::is_none")]
397    pub detail: Option<String>,
398}
399
400#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
401pub struct ToolCall {
402    pub id: String,
403    #[serde(flatten)]
404    pub content: ToolCallContent,
405}
406
407#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
408#[serde(tag = "type", rename_all = "lowercase")]
409pub enum ToolCallContent {
410    Function { function: FunctionContent },
411}
412
413#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
414pub struct FunctionContent {
415    pub name: String,
416    pub arguments: String,
417}
418
419#[derive(Clone, Serialize, Deserialize, Debug)]
420pub struct Response {
421    pub id: String,
422    pub object: String,
423    pub created: u64,
424    pub model: String,
425    pub choices: Vec<Choice>,
426    pub usage: Usage,
427}
428
429#[derive(Clone, Serialize, Deserialize, Debug)]
430pub struct Choice {
431    pub index: u32,
432    pub message: RequestMessage,
433    pub finish_reason: Option<String>,
434}
435
436#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
437pub struct ResponseMessageDelta {
438    pub role: Option<Role>,
439    pub content: Option<String>,
440    #[serde(default, skip_serializing_if = "is_none_or_empty")]
441    pub tool_calls: Option<Vec<ToolCallChunk>>,
442    #[serde(default, skip_serializing_if = "is_none_or_empty")]
443    pub reasoning_content: Option<String>,
444}
445
446#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
447pub struct ToolCallChunk {
448    pub index: usize,
449    pub id: Option<String>,
450
451    // There is also an optional `type` field that would determine if a
452    // function is there. Sometimes this streams in with the `function` before
453    // it streams in the `type`
454    pub function: Option<FunctionChunk>,
455}
456
457#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
458pub struct FunctionChunk {
459    pub name: Option<String>,
460    pub arguments: Option<String>,
461}
462
463#[derive(Clone, Serialize, Deserialize, Debug)]
464pub struct Usage {
465    pub prompt_tokens: u64,
466    pub completion_tokens: u64,
467    pub total_tokens: u64,
468}
469
470#[derive(Serialize, Deserialize, Debug)]
471pub struct ChoiceDelta {
472    pub index: u32,
473    pub delta: Option<ResponseMessageDelta>,
474    pub finish_reason: Option<String>,
475}
476
477#[derive(Error, Debug)]
478pub enum RequestError {
479    #[error("HTTP response error from {provider}'s API: status {status_code} - {body:?}")]
480    HttpResponseError {
481        provider: String,
482        status_code: StatusCode,
483        body: String,
484        headers: HeaderMap<HeaderValue>,
485    },
486    #[error(transparent)]
487    Other(#[from] anyhow::Error),
488}
489
490#[derive(Serialize, Deserialize, Debug)]
491pub struct ResponseStreamError {
492    message: String,
493}
494
495#[derive(Serialize, Deserialize, Debug)]
496#[serde(untagged)]
497pub enum ResponseStreamResult {
498    Ok(ResponseStreamEvent),
499    Err { error: ResponseStreamError },
500}
501
502#[derive(Serialize, Deserialize, Debug)]
503pub struct ResponseStreamEvent {
504    pub choices: Vec<ChoiceDelta>,
505    pub usage: Option<Usage>,
506}
507
508pub async fn non_streaming_completion(
509    client: &dyn HttpClient,
510    api_url: &str,
511    api_key: &str,
512    request: Request,
513) -> Result<Response, RequestError> {
514    let uri = format!("{api_url}/chat/completions");
515    let request_builder = HttpRequest::builder()
516        .method(Method::POST)
517        .uri(uri)
518        .header("Content-Type", "application/json")
519        .header("Authorization", format!("Bearer {}", api_key.trim()));
520
521    let request = request_builder
522        .body(AsyncBody::from(
523            serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
524        ))
525        .map_err(|e| RequestError::Other(e.into()))?;
526
527    let mut response = client.send(request).await?;
528    if response.status().is_success() {
529        let mut body = String::new();
530        response
531            .body_mut()
532            .read_to_string(&mut body)
533            .await
534            .map_err(|e| RequestError::Other(e.into()))?;
535
536        serde_json::from_str(&body).map_err(|e| RequestError::Other(e.into()))
537    } else {
538        let mut body = String::new();
539        response
540            .body_mut()
541            .read_to_string(&mut body)
542            .await
543            .map_err(|e| RequestError::Other(e.into()))?;
544
545        Err(RequestError::HttpResponseError {
546            provider: "openai".to_owned(),
547            status_code: response.status(),
548            body,
549            headers: response.headers().clone(),
550        })
551    }
552}
553
554pub async fn stream_completion(
555    client: &dyn HttpClient,
556    provider_name: &str,
557    api_url: &str,
558    api_key: &str,
559    request: Request,
560) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>, RequestError> {
561    let uri = format!("{api_url}/chat/completions");
562    let request_builder = HttpRequest::builder()
563        .method(Method::POST)
564        .uri(uri)
565        .header("Content-Type", "application/json")
566        .header("Authorization", format!("Bearer {}", api_key.trim()));
567
568    let request = request_builder
569        .body(AsyncBody::from(
570            serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
571        ))
572        .map_err(|e| RequestError::Other(e.into()))?;
573
574    let mut response = client.send(request).await?;
575    if response.status().is_success() {
576        let reader = BufReader::new(response.into_body());
577        Ok(reader
578            .lines()
579            .filter_map(|line| async move {
580                match line {
581                    Ok(line) => {
582                        let line = line.strip_prefix("data: ").or_else(|| line.strip_prefix("data:"))?;
583                        if line == "[DONE]" {
584                            None
585                        } else {
586                            match serde_json::from_str(line) {
587                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
588                                Ok(ResponseStreamResult::Err { error }) => {
589                                    Some(Err(anyhow!(error.message)))
590                                }
591                                Err(error) => {
592                                    log::error!(
593                                        "Failed to parse OpenAI response into ResponseStreamResult: `{}`\n\
594                                        Response: `{}`",
595                                        error,
596                                        line,
597                                    );
598                                    Some(Err(anyhow!(error)))
599                                }
600                            }
601                        }
602                    }
603                    Err(error) => Some(Err(anyhow!(error))),
604                }
605            })
606            .boxed())
607    } else {
608        let mut body = String::new();
609        response
610            .body_mut()
611            .read_to_string(&mut body)
612            .await
613            .map_err(|e| RequestError::Other(e.into()))?;
614
615        Err(RequestError::HttpResponseError {
616            provider: provider_name.to_owned(),
617            status_code: response.status(),
618            body,
619            headers: response.headers().clone(),
620        })
621    }
622}
623
624#[derive(Copy, Clone, Serialize, Deserialize)]
625pub enum OpenAiEmbeddingModel {
626    #[serde(rename = "text-embedding-3-small")]
627    TextEmbedding3Small,
628    #[serde(rename = "text-embedding-3-large")]
629    TextEmbedding3Large,
630}
631
632#[derive(Serialize)]
633struct OpenAiEmbeddingRequest<'a> {
634    model: OpenAiEmbeddingModel,
635    input: Vec<&'a str>,
636}
637
638#[derive(Deserialize)]
639pub struct OpenAiEmbeddingResponse {
640    pub data: Vec<OpenAiEmbedding>,
641}
642
643#[derive(Deserialize)]
644pub struct OpenAiEmbedding {
645    pub embedding: Vec<f32>,
646}
647
648pub fn embed<'a>(
649    client: &dyn HttpClient,
650    api_url: &str,
651    api_key: &str,
652    model: OpenAiEmbeddingModel,
653    texts: impl IntoIterator<Item = &'a str>,
654) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
655    let uri = format!("{api_url}/embeddings");
656
657    let request = OpenAiEmbeddingRequest {
658        model,
659        input: texts.into_iter().collect(),
660    };
661    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
662    let request = HttpRequest::builder()
663        .method(Method::POST)
664        .uri(uri)
665        .header("Content-Type", "application/json")
666        .header("Authorization", format!("Bearer {}", api_key.trim()))
667        .body(body)
668        .map(|request| client.send(request));
669
670    async move {
671        let mut response = request?.await?;
672        let mut body = String::new();
673        response.body_mut().read_to_string(&mut body).await?;
674
675        anyhow::ensure!(
676            response.status().is_success(),
677            "error during embedding, status: {:?}, body: {:?}",
678            response.status(),
679            body
680        );
681        let response: OpenAiEmbeddingResponse =
682            serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
683        Ok(response)
684    }
685}