open_ai.rs

  1pub mod batches;
  2pub mod responses;
  3
  4use anyhow::{Context as _, Result, anyhow};
  5use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  6use http_client::{
  7    AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode,
  8    http::{HeaderMap, HeaderValue},
  9};
 10use serde::{Deserialize, Serialize};
 11use serde_json::Value;
 12pub use settings::OpenAiReasoningEffort as ReasoningEffort;
 13use std::{convert::TryFrom, future::Future};
 14use strum::EnumIter;
 15use thiserror::Error;
 16
 17pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 18
 19fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 20    opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
 21}
 22
 23#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 24#[serde(rename_all = "lowercase")]
 25pub enum Role {
 26    User,
 27    Assistant,
 28    System,
 29    Tool,
 30}
 31
 32impl TryFrom<String> for Role {
 33    type Error = anyhow::Error;
 34
 35    fn try_from(value: String) -> Result<Self> {
 36        match value.as_str() {
 37            "user" => Ok(Self::User),
 38            "assistant" => Ok(Self::Assistant),
 39            "system" => Ok(Self::System),
 40            "tool" => Ok(Self::Tool),
 41            _ => anyhow::bail!("invalid role '{value}'"),
 42        }
 43    }
 44}
 45
 46impl From<Role> for String {
 47    fn from(val: Role) -> Self {
 48        match val {
 49            Role::User => "user".to_owned(),
 50            Role::Assistant => "assistant".to_owned(),
 51            Role::System => "system".to_owned(),
 52            Role::Tool => "tool".to_owned(),
 53        }
 54    }
 55}
 56
 57#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 58#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 59pub enum Model {
 60    #[serde(rename = "gpt-3.5-turbo")]
 61    ThreePointFiveTurbo,
 62    #[serde(rename = "gpt-4")]
 63    Four,
 64    #[serde(rename = "gpt-4-turbo")]
 65    FourTurbo,
 66    #[serde(rename = "gpt-4o-mini")]
 67    FourOmniMini,
 68    #[serde(rename = "gpt-4.1-nano")]
 69    FourPointOneNano,
 70    #[serde(rename = "o1")]
 71    O1,
 72    #[serde(rename = "o3-mini")]
 73    O3Mini,
 74    #[serde(rename = "o3")]
 75    O3,
 76    #[serde(rename = "gpt-5")]
 77    Five,
 78    #[serde(rename = "gpt-5-codex")]
 79    FiveCodex,
 80    #[serde(rename = "gpt-5-mini")]
 81    #[default]
 82    FiveMini,
 83    #[serde(rename = "gpt-5-nano")]
 84    FiveNano,
 85    #[serde(rename = "gpt-5.1")]
 86    FivePointOne,
 87    #[serde(rename = "gpt-5.2")]
 88    FivePointTwo,
 89    #[serde(rename = "gpt-5.2-codex")]
 90    FivePointTwoCodex,
 91    #[serde(rename = "custom")]
 92    Custom {
 93        name: String,
 94        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 95        display_name: Option<String>,
 96        max_tokens: u64,
 97        max_output_tokens: Option<u64>,
 98        max_completion_tokens: Option<u64>,
 99        reasoning_effort: Option<ReasoningEffort>,
100        #[serde(default = "default_supports_chat_completions")]
101        supports_chat_completions: bool,
102    },
103}
104
105const fn default_supports_chat_completions() -> bool {
106    true
107}
108
109impl Model {
110    pub fn default_fast() -> Self {
111        Self::FiveMini
112    }
113
114    pub fn from_id(id: &str) -> Result<Self> {
115        match id {
116            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
117            "gpt-4" => Ok(Self::Four),
118            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
119            "gpt-4o-mini" => Ok(Self::FourOmniMini),
120            "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
121            "o1" => Ok(Self::O1),
122            "o3-mini" => Ok(Self::O3Mini),
123            "o3" => Ok(Self::O3),
124            "gpt-5" => Ok(Self::Five),
125            "gpt-5-codex" => Ok(Self::FiveCodex),
126            "gpt-5-mini" => Ok(Self::FiveMini),
127            "gpt-5-nano" => Ok(Self::FiveNano),
128            "gpt-5.1" => Ok(Self::FivePointOne),
129            "gpt-5.2" => Ok(Self::FivePointTwo),
130            "gpt-5.2-codex" => Ok(Self::FivePointTwoCodex),
131            invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
132        }
133    }
134
135    pub fn id(&self) -> &str {
136        match self {
137            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
138            Self::Four => "gpt-4",
139            Self::FourTurbo => "gpt-4-turbo",
140            Self::FourOmniMini => "gpt-4o-mini",
141            Self::FourPointOneNano => "gpt-4.1-nano",
142            Self::O1 => "o1",
143            Self::O3Mini => "o3-mini",
144            Self::O3 => "o3",
145            Self::Five => "gpt-5",
146            Self::FiveCodex => "gpt-5-codex",
147            Self::FiveMini => "gpt-5-mini",
148            Self::FiveNano => "gpt-5-nano",
149            Self::FivePointOne => "gpt-5.1",
150            Self::FivePointTwo => "gpt-5.2",
151            Self::FivePointTwoCodex => "gpt-5.2-codex",
152            Self::Custom { name, .. } => name,
153        }
154    }
155
156    pub fn display_name(&self) -> &str {
157        match self {
158            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
159            Self::Four => "gpt-4",
160            Self::FourTurbo => "gpt-4-turbo",
161            Self::FourOmniMini => "gpt-4o-mini",
162            Self::FourPointOneNano => "gpt-4.1-nano",
163            Self::O1 => "o1",
164            Self::O3Mini => "o3-mini",
165            Self::O3 => "o3",
166            Self::Five => "gpt-5",
167            Self::FiveCodex => "gpt-5-codex",
168            Self::FiveMini => "gpt-5-mini",
169            Self::FiveNano => "gpt-5-nano",
170            Self::FivePointOne => "gpt-5.1",
171            Self::FivePointTwo => "gpt-5.2",
172            Self::FivePointTwoCodex => "gpt-5.2-codex",
173            Self::Custom { display_name, .. } => display_name.as_deref().unwrap_or(&self.id()),
174        }
175    }
176
177    pub fn max_token_count(&self) -> u64 {
178        match self {
179            Self::ThreePointFiveTurbo => 16_385,
180            Self::Four => 8_192,
181            Self::FourTurbo => 128_000,
182            Self::FourOmniMini => 128_000,
183            Self::FourPointOneNano => 1_047_576,
184            Self::O1 => 200_000,
185            Self::O3Mini => 200_000,
186            Self::O3 => 200_000,
187            Self::Five => 272_000,
188            Self::FiveCodex => 272_000,
189            Self::FiveMini => 272_000,
190            Self::FiveNano => 272_000,
191            Self::FivePointOne => 400_000,
192            Self::FivePointTwo => 400_000,
193            Self::FivePointTwoCodex => 400_000,
194            Self::Custom { max_tokens, .. } => *max_tokens,
195        }
196    }
197
198    pub fn max_output_tokens(&self) -> Option<u64> {
199        match self {
200            Self::Custom {
201                max_output_tokens, ..
202            } => *max_output_tokens,
203            Self::ThreePointFiveTurbo => Some(4_096),
204            Self::Four => Some(8_192),
205            Self::FourTurbo => Some(4_096),
206            Self::FourOmniMini => Some(16_384),
207            Self::FourPointOneNano => Some(32_768),
208            Self::O1 => Some(100_000),
209            Self::O3Mini => Some(100_000),
210            Self::O3 => Some(100_000),
211            Self::Five => Some(128_000),
212            Self::FiveCodex => Some(128_000),
213            Self::FiveMini => Some(128_000),
214            Self::FiveNano => Some(128_000),
215            Self::FivePointOne => Some(128_000),
216            Self::FivePointTwo => Some(128_000),
217            Self::FivePointTwoCodex => Some(128_000),
218        }
219    }
220
221    pub fn reasoning_effort(&self) -> Option<ReasoningEffort> {
222        match self {
223            Self::Custom {
224                reasoning_effort, ..
225            } => reasoning_effort.to_owned(),
226            _ => None,
227        }
228    }
229
230    pub fn supports_chat_completions(&self) -> bool {
231        match self {
232            Self::Custom {
233                supports_chat_completions,
234                ..
235            } => *supports_chat_completions,
236            Self::FiveCodex | Self::FivePointTwoCodex => false,
237            _ => true,
238        }
239    }
240
241    /// Returns whether the given model supports the `parallel_tool_calls` parameter.
242    ///
243    /// If the model does not support the parameter, do not pass it up, or the API will return an error.
244    pub fn supports_parallel_tool_calls(&self) -> bool {
245        match self {
246            Self::ThreePointFiveTurbo
247            | Self::Four
248            | Self::FourTurbo
249            | Self::FourOmniMini
250            | Self::FourPointOneNano
251            | Self::Five
252            | Self::FiveCodex
253            | Self::FiveMini
254            | Self::FivePointOne
255            | Self::FivePointTwo
256            | Self::FivePointTwoCodex
257            | Self::FiveNano => true,
258            Self::O1 | Self::O3 | Self::O3Mini | Model::Custom { .. } => false,
259        }
260    }
261
262    /// Returns whether the given model supports the `prompt_cache_key` parameter.
263    ///
264    /// If the model does not support the parameter, do not pass it up.
265    pub fn supports_prompt_cache_key(&self) -> bool {
266        true
267    }
268}
269
270#[derive(Debug, Serialize, Deserialize)]
271pub struct Request {
272    pub model: String,
273    pub messages: Vec<RequestMessage>,
274    pub stream: bool,
275    #[serde(default, skip_serializing_if = "Option::is_none")]
276    pub max_completion_tokens: Option<u64>,
277    #[serde(default, skip_serializing_if = "Vec::is_empty")]
278    pub stop: Vec<String>,
279    #[serde(default, skip_serializing_if = "Option::is_none")]
280    pub temperature: Option<f32>,
281    #[serde(default, skip_serializing_if = "Option::is_none")]
282    pub tool_choice: Option<ToolChoice>,
283    /// Whether to enable parallel function calling during tool use.
284    #[serde(default, skip_serializing_if = "Option::is_none")]
285    pub parallel_tool_calls: Option<bool>,
286    #[serde(default, skip_serializing_if = "Vec::is_empty")]
287    pub tools: Vec<ToolDefinition>,
288    #[serde(default, skip_serializing_if = "Option::is_none")]
289    pub prompt_cache_key: Option<String>,
290    #[serde(default, skip_serializing_if = "Option::is_none")]
291    pub reasoning_effort: Option<ReasoningEffort>,
292}
293
294#[derive(Debug, Serialize, Deserialize)]
295#[serde(rename_all = "lowercase")]
296pub enum ToolChoice {
297    Auto,
298    Required,
299    None,
300    #[serde(untagged)]
301    Other(ToolDefinition),
302}
303
304#[derive(Clone, Deserialize, Serialize, Debug)]
305#[serde(tag = "type", rename_all = "snake_case")]
306pub enum ToolDefinition {
307    #[allow(dead_code)]
308    Function { function: FunctionDefinition },
309}
310
311#[derive(Clone, Debug, Serialize, Deserialize)]
312pub struct FunctionDefinition {
313    pub name: String,
314    pub description: Option<String>,
315    pub parameters: Option<Value>,
316}
317
318#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
319#[serde(tag = "role", rename_all = "lowercase")]
320pub enum RequestMessage {
321    Assistant {
322        content: Option<MessageContent>,
323        #[serde(default, skip_serializing_if = "Vec::is_empty")]
324        tool_calls: Vec<ToolCall>,
325    },
326    User {
327        content: MessageContent,
328    },
329    System {
330        content: MessageContent,
331    },
332    Tool {
333        content: MessageContent,
334        tool_call_id: String,
335    },
336}
337
338#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
339#[serde(untagged)]
340pub enum MessageContent {
341    Plain(String),
342    Multipart(Vec<MessagePart>),
343}
344
345impl MessageContent {
346    pub fn empty() -> Self {
347        MessageContent::Multipart(vec![])
348    }
349
350    pub fn push_part(&mut self, part: MessagePart) {
351        match self {
352            MessageContent::Plain(text) => {
353                *self =
354                    MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
355            }
356            MessageContent::Multipart(parts) if parts.is_empty() => match part {
357                MessagePart::Text { text } => *self = MessageContent::Plain(text),
358                MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
359            },
360            MessageContent::Multipart(parts) => parts.push(part),
361        }
362    }
363}
364
365impl From<Vec<MessagePart>> for MessageContent {
366    fn from(mut parts: Vec<MessagePart>) -> Self {
367        if let [MessagePart::Text { text }] = parts.as_mut_slice() {
368            MessageContent::Plain(std::mem::take(text))
369        } else {
370            MessageContent::Multipart(parts)
371        }
372    }
373}
374
375#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
376#[serde(tag = "type")]
377pub enum MessagePart {
378    #[serde(rename = "text")]
379    Text { text: String },
380    #[serde(rename = "image_url")]
381    Image { image_url: ImageUrl },
382}
383
384#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
385pub struct ImageUrl {
386    pub url: String,
387    #[serde(skip_serializing_if = "Option::is_none")]
388    pub detail: Option<String>,
389}
390
391#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
392pub struct ToolCall {
393    pub id: String,
394    #[serde(flatten)]
395    pub content: ToolCallContent,
396}
397
398#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
399#[serde(tag = "type", rename_all = "lowercase")]
400pub enum ToolCallContent {
401    Function { function: FunctionContent },
402}
403
404#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
405pub struct FunctionContent {
406    pub name: String,
407    pub arguments: String,
408}
409
410#[derive(Clone, Serialize, Deserialize, Debug)]
411pub struct Response {
412    pub id: String,
413    pub object: String,
414    pub created: u64,
415    pub model: String,
416    pub choices: Vec<Choice>,
417    pub usage: Usage,
418}
419
420#[derive(Clone, Serialize, Deserialize, Debug)]
421pub struct Choice {
422    pub index: u32,
423    pub message: RequestMessage,
424    pub finish_reason: Option<String>,
425}
426
427#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
428pub struct ResponseMessageDelta {
429    pub role: Option<Role>,
430    pub content: Option<String>,
431    #[serde(default, skip_serializing_if = "is_none_or_empty")]
432    pub tool_calls: Option<Vec<ToolCallChunk>>,
433    #[serde(default, skip_serializing_if = "is_none_or_empty")]
434    pub reasoning_content: Option<String>,
435}
436
437#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
438pub struct ToolCallChunk {
439    pub index: usize,
440    pub id: Option<String>,
441
442    // There is also an optional `type` field that would determine if a
443    // function is there. Sometimes this streams in with the `function` before
444    // it streams in the `type`
445    pub function: Option<FunctionChunk>,
446}
447
448#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
449pub struct FunctionChunk {
450    pub name: Option<String>,
451    pub arguments: Option<String>,
452}
453
454#[derive(Clone, Serialize, Deserialize, Debug)]
455pub struct Usage {
456    pub prompt_tokens: u64,
457    pub completion_tokens: u64,
458    pub total_tokens: u64,
459}
460
461#[derive(Serialize, Deserialize, Debug)]
462pub struct ChoiceDelta {
463    pub index: u32,
464    pub delta: Option<ResponseMessageDelta>,
465    pub finish_reason: Option<String>,
466}
467
468#[derive(Error, Debug)]
469pub enum RequestError {
470    #[error("HTTP response error from {provider}'s API: status {status_code} - {body:?}")]
471    HttpResponseError {
472        provider: String,
473        status_code: StatusCode,
474        body: String,
475        headers: HeaderMap<HeaderValue>,
476    },
477    #[error(transparent)]
478    Other(#[from] anyhow::Error),
479}
480
481#[derive(Serialize, Deserialize, Debug)]
482pub struct ResponseStreamError {
483    message: String,
484}
485
486#[derive(Serialize, Deserialize, Debug)]
487#[serde(untagged)]
488pub enum ResponseStreamResult {
489    Ok(ResponseStreamEvent),
490    Err { error: ResponseStreamError },
491}
492
493#[derive(Serialize, Deserialize, Debug)]
494pub struct ResponseStreamEvent {
495    pub choices: Vec<ChoiceDelta>,
496    pub usage: Option<Usage>,
497}
498
499pub async fn non_streaming_completion(
500    client: &dyn HttpClient,
501    api_url: &str,
502    api_key: &str,
503    request: Request,
504) -> Result<Response, RequestError> {
505    let uri = format!("{api_url}/chat/completions");
506    let request_builder = HttpRequest::builder()
507        .method(Method::POST)
508        .uri(uri)
509        .header("Content-Type", "application/json")
510        .header("Authorization", format!("Bearer {}", api_key.trim()));
511
512    let request = request_builder
513        .body(AsyncBody::from(
514            serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
515        ))
516        .map_err(|e| RequestError::Other(e.into()))?;
517
518    let mut response = client.send(request).await?;
519    if response.status().is_success() {
520        let mut body = String::new();
521        response
522            .body_mut()
523            .read_to_string(&mut body)
524            .await
525            .map_err(|e| RequestError::Other(e.into()))?;
526
527        serde_json::from_str(&body).map_err(|e| RequestError::Other(e.into()))
528    } else {
529        let mut body = String::new();
530        response
531            .body_mut()
532            .read_to_string(&mut body)
533            .await
534            .map_err(|e| RequestError::Other(e.into()))?;
535
536        Err(RequestError::HttpResponseError {
537            provider: "openai".to_owned(),
538            status_code: response.status(),
539            body,
540            headers: response.headers().clone(),
541        })
542    }
543}
544
545pub async fn stream_completion(
546    client: &dyn HttpClient,
547    provider_name: &str,
548    api_url: &str,
549    api_key: &str,
550    request: Request,
551) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>, RequestError> {
552    let uri = format!("{api_url}/chat/completions");
553    let request_builder = HttpRequest::builder()
554        .method(Method::POST)
555        .uri(uri)
556        .header("Content-Type", "application/json")
557        .header("Authorization", format!("Bearer {}", api_key.trim()));
558
559    let request = request_builder
560        .body(AsyncBody::from(
561            serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
562        ))
563        .map_err(|e| RequestError::Other(e.into()))?;
564
565    let mut response = client.send(request).await?;
566    if response.status().is_success() {
567        let reader = BufReader::new(response.into_body());
568        Ok(reader
569            .lines()
570            .filter_map(|line| async move {
571                match line {
572                    Ok(line) => {
573                        let line = line.strip_prefix("data: ").or_else(|| line.strip_prefix("data:"))?;
574                        if line == "[DONE]" {
575                            None
576                        } else {
577                            match serde_json::from_str(line) {
578                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
579                                Ok(ResponseStreamResult::Err { error }) => {
580                                    Some(Err(anyhow!(error.message)))
581                                }
582                                Err(error) => {
583                                    log::error!(
584                                        "Failed to parse OpenAI response into ResponseStreamResult: `{}`\n\
585                                        Response: `{}`",
586                                        error,
587                                        line,
588                                    );
589                                    Some(Err(anyhow!(error)))
590                                }
591                            }
592                        }
593                    }
594                    Err(error) => Some(Err(anyhow!(error))),
595                }
596            })
597            .boxed())
598    } else {
599        let mut body = String::new();
600        response
601            .body_mut()
602            .read_to_string(&mut body)
603            .await
604            .map_err(|e| RequestError::Other(e.into()))?;
605
606        Err(RequestError::HttpResponseError {
607            provider: provider_name.to_owned(),
608            status_code: response.status(),
609            body,
610            headers: response.headers().clone(),
611        })
612    }
613}
614
615#[derive(Copy, Clone, Serialize, Deserialize)]
616pub enum OpenAiEmbeddingModel {
617    #[serde(rename = "text-embedding-3-small")]
618    TextEmbedding3Small,
619    #[serde(rename = "text-embedding-3-large")]
620    TextEmbedding3Large,
621}
622
623#[derive(Serialize)]
624struct OpenAiEmbeddingRequest<'a> {
625    model: OpenAiEmbeddingModel,
626    input: Vec<&'a str>,
627}
628
629#[derive(Deserialize)]
630pub struct OpenAiEmbeddingResponse {
631    pub data: Vec<OpenAiEmbedding>,
632}
633
634#[derive(Deserialize)]
635pub struct OpenAiEmbedding {
636    pub embedding: Vec<f32>,
637}
638
639pub fn embed<'a>(
640    client: &dyn HttpClient,
641    api_url: &str,
642    api_key: &str,
643    model: OpenAiEmbeddingModel,
644    texts: impl IntoIterator<Item = &'a str>,
645) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
646    let uri = format!("{api_url}/embeddings");
647
648    let request = OpenAiEmbeddingRequest {
649        model,
650        input: texts.into_iter().collect(),
651    };
652    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
653    let request = HttpRequest::builder()
654        .method(Method::POST)
655        .uri(uri)
656        .header("Content-Type", "application/json")
657        .header("Authorization", format!("Bearer {}", api_key.trim()))
658        .body(body)
659        .map(|request| client.send(request));
660
661    async move {
662        let mut response = request?.await?;
663        let mut body = String::new();
664        response.body_mut().read_to_string(&mut body).await?;
665
666        anyhow::ensure!(
667            response.status().is_success(),
668            "error during embedding, status: {:?}, body: {:?}",
669            response.status(),
670            body
671        );
672        let response: OpenAiEmbeddingResponse =
673            serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
674        Ok(response)
675    }
676}