open_ai.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
   3use http_client::{
   4    AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode,
   5    http::{HeaderMap, HeaderValue},
   6};
   7use serde::{Deserialize, Serialize};
   8use serde_json::Value;
   9pub use settings::OpenAiReasoningEffort as ReasoningEffort;
  10use std::{convert::TryFrom, future::Future};
  11use strum::EnumIter;
  12use thiserror::Error;
  13
  14pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
  15
  16fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
  17    opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
  18}
  19
  20#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
  21#[serde(rename_all = "lowercase")]
  22pub enum Role {
  23    User,
  24    Assistant,
  25    System,
  26    Tool,
  27}
  28
  29impl TryFrom<String> for Role {
  30    type Error = anyhow::Error;
  31
  32    fn try_from(value: String) -> Result<Self> {
  33        match value.as_str() {
  34            "user" => Ok(Self::User),
  35            "assistant" => Ok(Self::Assistant),
  36            "system" => Ok(Self::System),
  37            "tool" => Ok(Self::Tool),
  38            _ => anyhow::bail!("invalid role '{value}'"),
  39        }
  40    }
  41}
  42
  43impl From<Role> for String {
  44    fn from(val: Role) -> Self {
  45        match val {
  46            Role::User => "user".to_owned(),
  47            Role::Assistant => "assistant".to_owned(),
  48            Role::System => "system".to_owned(),
  49            Role::Tool => "tool".to_owned(),
  50        }
  51    }
  52}
  53
  54#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
  55#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
  56pub enum Model {
  57    #[serde(rename = "gpt-3.5-turbo")]
  58    ThreePointFiveTurbo,
  59    #[serde(rename = "gpt-4")]
  60    Four,
  61    #[serde(rename = "gpt-4-turbo")]
  62    FourTurbo,
  63    #[serde(rename = "gpt-4o")]
  64    #[default]
  65    FourOmni,
  66    #[serde(rename = "gpt-4o-mini")]
  67    FourOmniMini,
  68    #[serde(rename = "gpt-4.1")]
  69    FourPointOne,
  70    #[serde(rename = "gpt-4.1-mini")]
  71    FourPointOneMini,
  72    #[serde(rename = "gpt-4.1-nano")]
  73    FourPointOneNano,
  74    #[serde(rename = "o1")]
  75    O1,
  76    #[serde(rename = "o3-mini")]
  77    O3Mini,
  78    #[serde(rename = "o3")]
  79    O3,
  80    #[serde(rename = "o4-mini")]
  81    O4Mini,
  82    #[serde(rename = "gpt-5")]
  83    Five,
  84    #[serde(rename = "gpt-5-codex")]
  85    FiveCodex,
  86    #[serde(rename = "gpt-5-mini")]
  87    FiveMini,
  88    #[serde(rename = "gpt-5-nano")]
  89    FiveNano,
  90    #[serde(rename = "gpt-5.1")]
  91    FivePointOne,
  92    #[serde(rename = "gpt-5.2")]
  93    FivePointTwo,
  94    #[serde(rename = "gpt-5.2-codex")]
  95    FivePointTwoCodex,
  96    #[serde(rename = "custom")]
  97    Custom {
  98        name: String,
  99        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 100        display_name: Option<String>,
 101        max_tokens: u64,
 102        max_output_tokens: Option<u64>,
 103        max_completion_tokens: Option<u64>,
 104        reasoning_effort: Option<ReasoningEffort>,
 105        #[serde(default = "default_supports_chat_completions")]
 106        supports_chat_completions: bool,
 107    },
 108}
 109
 110const fn default_supports_chat_completions() -> bool {
 111    true
 112}
 113
 114impl Model {
 115    pub fn default_fast() -> Self {
 116        // TODO: Replace with FiveMini since all other models are deprecated
 117        Self::FourPointOneMini
 118    }
 119
 120    pub fn from_id(id: &str) -> Result<Self> {
 121        match id {
 122            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
 123            "gpt-4" => Ok(Self::Four),
 124            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
 125            "gpt-4o" => Ok(Self::FourOmni),
 126            "gpt-4o-mini" => Ok(Self::FourOmniMini),
 127            "gpt-4.1" => Ok(Self::FourPointOne),
 128            "gpt-4.1-mini" => Ok(Self::FourPointOneMini),
 129            "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
 130            "o1" => Ok(Self::O1),
 131            "o3-mini" => Ok(Self::O3Mini),
 132            "o3" => Ok(Self::O3),
 133            "o4-mini" => Ok(Self::O4Mini),
 134            "gpt-5" => Ok(Self::Five),
 135            "gpt-5-codex" => Ok(Self::FiveCodex),
 136            "gpt-5-mini" => Ok(Self::FiveMini),
 137            "gpt-5-nano" => Ok(Self::FiveNano),
 138            "gpt-5.1" => Ok(Self::FivePointOne),
 139            "gpt-5.2" => Ok(Self::FivePointTwo),
 140            "gpt-5.2-codex" => Ok(Self::FivePointTwoCodex),
 141            invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
 142        }
 143    }
 144
 145    pub fn id(&self) -> &str {
 146        match self {
 147            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
 148            Self::Four => "gpt-4",
 149            Self::FourTurbo => "gpt-4-turbo",
 150            Self::FourOmni => "gpt-4o",
 151            Self::FourOmniMini => "gpt-4o-mini",
 152            Self::FourPointOne => "gpt-4.1",
 153            Self::FourPointOneMini => "gpt-4.1-mini",
 154            Self::FourPointOneNano => "gpt-4.1-nano",
 155            Self::O1 => "o1",
 156            Self::O3Mini => "o3-mini",
 157            Self::O3 => "o3",
 158            Self::O4Mini => "o4-mini",
 159            Self::Five => "gpt-5",
 160            Self::FiveCodex => "gpt-5-codex",
 161            Self::FiveMini => "gpt-5-mini",
 162            Self::FiveNano => "gpt-5-nano",
 163            Self::FivePointOne => "gpt-5.1",
 164            Self::FivePointTwo => "gpt-5.2",
 165            Self::FivePointTwoCodex => "gpt-5.2-codex",
 166            Self::Custom { name, .. } => name,
 167        }
 168    }
 169
 170    pub fn display_name(&self) -> &str {
 171        match self {
 172            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
 173            Self::Four => "gpt-4",
 174            Self::FourTurbo => "gpt-4-turbo",
 175            Self::FourOmni => "gpt-4o",
 176            Self::FourOmniMini => "gpt-4o-mini",
 177            Self::FourPointOne => "gpt-4.1",
 178            Self::FourPointOneMini => "gpt-4.1-mini",
 179            Self::FourPointOneNano => "gpt-4.1-nano",
 180            Self::O1 => "o1",
 181            Self::O3Mini => "o3-mini",
 182            Self::O3 => "o3",
 183            Self::O4Mini => "o4-mini",
 184            Self::Five => "gpt-5",
 185            Self::FiveCodex => "gpt-5-codex",
 186            Self::FiveMini => "gpt-5-mini",
 187            Self::FiveNano => "gpt-5-nano",
 188            Self::FivePointOne => "gpt-5.1",
 189            Self::FivePointTwo => "gpt-5.2",
 190            Self::FivePointTwoCodex => "gpt-5.2-codex",
 191            Self::Custom {
 192                name, display_name, ..
 193            } => display_name.as_ref().unwrap_or(name),
 194        }
 195    }
 196
 197    pub fn max_token_count(&self) -> u64 {
 198        match self {
 199            Self::ThreePointFiveTurbo => 16_385,
 200            Self::Four => 8_192,
 201            Self::FourTurbo => 128_000,
 202            Self::FourOmni => 128_000,
 203            Self::FourOmniMini => 128_000,
 204            Self::FourPointOne => 1_047_576,
 205            Self::FourPointOneMini => 1_047_576,
 206            Self::FourPointOneNano => 1_047_576,
 207            Self::O1 => 200_000,
 208            Self::O3Mini => 200_000,
 209            Self::O3 => 200_000,
 210            Self::O4Mini => 200_000,
 211            Self::Five => 272_000,
 212            Self::FiveCodex => 272_000,
 213            Self::FiveMini => 272_000,
 214            Self::FiveNano => 272_000,
 215            Self::FivePointOne => 400_000,
 216            Self::FivePointTwo => 400_000,
 217            Self::FivePointTwoCodex => 400_000,
 218            Self::Custom { max_tokens, .. } => *max_tokens,
 219        }
 220    }
 221
 222    pub fn max_output_tokens(&self) -> Option<u64> {
 223        match self {
 224            Self::Custom {
 225                max_output_tokens, ..
 226            } => *max_output_tokens,
 227            Self::ThreePointFiveTurbo => Some(4_096),
 228            Self::Four => Some(8_192),
 229            Self::FourTurbo => Some(4_096),
 230            Self::FourOmni => Some(16_384),
 231            Self::FourOmniMini => Some(16_384),
 232            Self::FourPointOne => Some(32_768),
 233            Self::FourPointOneMini => Some(32_768),
 234            Self::FourPointOneNano => Some(32_768),
 235            Self::O1 => Some(100_000),
 236            Self::O3Mini => Some(100_000),
 237            Self::O3 => Some(100_000),
 238            Self::O4Mini => Some(100_000),
 239            Self::Five => Some(128_000),
 240            Self::FiveCodex => Some(128_000),
 241            Self::FiveMini => Some(128_000),
 242            Self::FiveNano => Some(128_000),
 243            Self::FivePointOne => Some(128_000),
 244            Self::FivePointTwo => Some(128_000),
 245            Self::FivePointTwoCodex => Some(128_000),
 246        }
 247    }
 248
 249    pub fn reasoning_effort(&self) -> Option<ReasoningEffort> {
 250        match self {
 251            Self::Custom {
 252                reasoning_effort, ..
 253            } => reasoning_effort.to_owned(),
 254            _ => None,
 255        }
 256    }
 257
 258    pub fn supports_chat_completions(&self) -> bool {
 259        match self {
 260            Self::Custom {
 261                supports_chat_completions,
 262                ..
 263            } => *supports_chat_completions,
 264            Self::FiveCodex | Self::FivePointTwoCodex => false,
 265            _ => true,
 266        }
 267    }
 268
 269    /// Returns whether the given model supports the `parallel_tool_calls` parameter.
 270    ///
 271    /// If the model does not support the parameter, do not pass it up, or the API will return an error.
 272    pub fn supports_parallel_tool_calls(&self) -> bool {
 273        match self {
 274            Self::ThreePointFiveTurbo
 275            | Self::Four
 276            | Self::FourTurbo
 277            | Self::FourOmni
 278            | Self::FourOmniMini
 279            | Self::FourPointOne
 280            | Self::FourPointOneMini
 281            | Self::FourPointOneNano
 282            | Self::Five
 283            | Self::FiveCodex
 284            | Self::FiveMini
 285            | Self::FivePointOne
 286            | Self::FivePointTwo
 287            | Self::FivePointTwoCodex
 288            | Self::FiveNano => true,
 289            Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false,
 290        }
 291    }
 292
 293    /// Returns whether the given model supports the `prompt_cache_key` parameter.
 294    ///
 295    /// If the model does not support the parameter, do not pass it up.
 296    pub fn supports_prompt_cache_key(&self) -> bool {
 297        true
 298    }
 299}
 300
 301#[derive(Debug, Serialize, Deserialize)]
 302pub struct Request {
 303    pub model: String,
 304    pub messages: Vec<RequestMessage>,
 305    pub stream: bool,
 306    #[serde(default, skip_serializing_if = "Option::is_none")]
 307    pub max_completion_tokens: Option<u64>,
 308    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 309    pub stop: Vec<String>,
 310    #[serde(default, skip_serializing_if = "Option::is_none")]
 311    pub temperature: Option<f32>,
 312    #[serde(default, skip_serializing_if = "Option::is_none")]
 313    pub tool_choice: Option<ToolChoice>,
 314    /// Whether to enable parallel function calling during tool use.
 315    #[serde(default, skip_serializing_if = "Option::is_none")]
 316    pub parallel_tool_calls: Option<bool>,
 317    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 318    pub tools: Vec<ToolDefinition>,
 319    #[serde(default, skip_serializing_if = "Option::is_none")]
 320    pub prompt_cache_key: Option<String>,
 321    #[serde(default, skip_serializing_if = "Option::is_none")]
 322    pub reasoning_effort: Option<ReasoningEffort>,
 323}
 324
 325#[derive(Debug, Serialize, Deserialize)]
 326#[serde(rename_all = "lowercase")]
 327pub enum ToolChoice {
 328    Auto,
 329    Required,
 330    None,
 331    #[serde(untagged)]
 332    Other(ToolDefinition),
 333}
 334
 335#[derive(Clone, Deserialize, Serialize, Debug)]
 336#[serde(tag = "type", rename_all = "snake_case")]
 337pub enum ToolDefinition {
 338    #[allow(dead_code)]
 339    Function { function: FunctionDefinition },
 340}
 341
 342#[derive(Clone, Debug, Serialize, Deserialize)]
 343pub struct FunctionDefinition {
 344    pub name: String,
 345    pub description: Option<String>,
 346    pub parameters: Option<Value>,
 347}
 348
 349#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 350#[serde(tag = "role", rename_all = "lowercase")]
 351pub enum RequestMessage {
 352    Assistant {
 353        content: Option<MessageContent>,
 354        #[serde(default, skip_serializing_if = "Vec::is_empty")]
 355        tool_calls: Vec<ToolCall>,
 356    },
 357    User {
 358        content: MessageContent,
 359    },
 360    System {
 361        content: MessageContent,
 362    },
 363    Tool {
 364        content: MessageContent,
 365        tool_call_id: String,
 366    },
 367}
 368
 369#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
 370#[serde(untagged)]
 371pub enum MessageContent {
 372    Plain(String),
 373    Multipart(Vec<MessagePart>),
 374}
 375
 376impl MessageContent {
 377    pub fn empty() -> Self {
 378        MessageContent::Multipart(vec![])
 379    }
 380
 381    pub fn push_part(&mut self, part: MessagePart) {
 382        match self {
 383            MessageContent::Plain(text) => {
 384                *self =
 385                    MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
 386            }
 387            MessageContent::Multipart(parts) if parts.is_empty() => match part {
 388                MessagePart::Text { text } => *self = MessageContent::Plain(text),
 389                MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
 390            },
 391            MessageContent::Multipart(parts) => parts.push(part),
 392        }
 393    }
 394}
 395
 396impl From<Vec<MessagePart>> for MessageContent {
 397    fn from(mut parts: Vec<MessagePart>) -> Self {
 398        if let [MessagePart::Text { text }] = parts.as_mut_slice() {
 399            MessageContent::Plain(std::mem::take(text))
 400        } else {
 401            MessageContent::Multipart(parts)
 402        }
 403    }
 404}
 405
 406#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
 407#[serde(tag = "type")]
 408pub enum MessagePart {
 409    #[serde(rename = "text")]
 410    Text { text: String },
 411    #[serde(rename = "image_url")]
 412    Image { image_url: ImageUrl },
 413}
 414
 415#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
 416pub struct ImageUrl {
 417    pub url: String,
 418    #[serde(skip_serializing_if = "Option::is_none")]
 419    pub detail: Option<String>,
 420}
 421
 422#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 423pub struct ToolCall {
 424    pub id: String,
 425    #[serde(flatten)]
 426    pub content: ToolCallContent,
 427}
 428
 429#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 430#[serde(tag = "type", rename_all = "lowercase")]
 431pub enum ToolCallContent {
 432    Function { function: FunctionContent },
 433}
 434
 435#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 436pub struct FunctionContent {
 437    pub name: String,
 438    pub arguments: String,
 439}
 440
 441#[derive(Clone, Serialize, Deserialize, Debug)]
 442pub struct Response {
 443    pub id: String,
 444    pub object: String,
 445    pub created: u64,
 446    pub model: String,
 447    pub choices: Vec<Choice>,
 448    pub usage: Usage,
 449}
 450
 451#[derive(Clone, Serialize, Deserialize, Debug)]
 452pub struct Choice {
 453    pub index: u32,
 454    pub message: RequestMessage,
 455    pub finish_reason: Option<String>,
 456}
 457
 458#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 459pub struct ResponseMessageDelta {
 460    pub role: Option<Role>,
 461    pub content: Option<String>,
 462    #[serde(default, skip_serializing_if = "is_none_or_empty")]
 463    pub tool_calls: Option<Vec<ToolCallChunk>>,
 464}
 465
 466#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 467pub struct ToolCallChunk {
 468    pub index: usize,
 469    pub id: Option<String>,
 470
 471    // There is also an optional `type` field that would determine if a
 472    // function is there. Sometimes this streams in with the `function` before
 473    // it streams in the `type`
 474    pub function: Option<FunctionChunk>,
 475}
 476
 477#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 478pub struct FunctionChunk {
 479    pub name: Option<String>,
 480    pub arguments: Option<String>,
 481}
 482
 483#[derive(Clone, Serialize, Deserialize, Debug)]
 484pub struct Usage {
 485    pub prompt_tokens: u64,
 486    pub completion_tokens: u64,
 487    pub total_tokens: u64,
 488}
 489
 490#[derive(Serialize, Deserialize, Debug)]
 491pub struct ChoiceDelta {
 492    pub index: u32,
 493    pub delta: Option<ResponseMessageDelta>,
 494    pub finish_reason: Option<String>,
 495}
 496
 497#[derive(Error, Debug)]
 498pub enum RequestError {
 499    #[error("HTTP response error from {provider}'s API: status {status_code} - {body:?}")]
 500    HttpResponseError {
 501        provider: String,
 502        status_code: StatusCode,
 503        body: String,
 504        headers: HeaderMap<HeaderValue>,
 505    },
 506    #[error(transparent)]
 507    Other(#[from] anyhow::Error),
 508}
 509
 510#[derive(Serialize, Deserialize, Debug)]
 511pub struct ResponseStreamError {
 512    message: String,
 513}
 514
 515#[derive(Serialize, Deserialize, Debug)]
 516#[serde(untagged)]
 517pub enum ResponseStreamResult {
 518    Ok(ResponseStreamEvent),
 519    Err { error: ResponseStreamError },
 520}
 521
 522#[derive(Serialize, Deserialize, Debug)]
 523pub struct ResponseStreamEvent {
 524    pub choices: Vec<ChoiceDelta>,
 525    pub usage: Option<Usage>,
 526}
 527
 528pub async fn stream_completion(
 529    client: &dyn HttpClient,
 530    provider_name: &str,
 531    api_url: &str,
 532    api_key: &str,
 533    request: Request,
 534) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>, RequestError> {
 535    let uri = format!("{api_url}/chat/completions");
 536    let request_builder = HttpRequest::builder()
 537        .method(Method::POST)
 538        .uri(uri)
 539        .header("Content-Type", "application/json")
 540        .header("Authorization", format!("Bearer {}", api_key.trim()));
 541
 542    let request = request_builder
 543        .body(AsyncBody::from(
 544            serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
 545        ))
 546        .map_err(|e| RequestError::Other(e.into()))?;
 547
 548    let mut response = client.send(request).await?;
 549    if response.status().is_success() {
 550        let reader = BufReader::new(response.into_body());
 551        Ok(reader
 552            .lines()
 553            .filter_map(|line| async move {
 554                match line {
 555                    Ok(line) => {
 556                        let line = line.strip_prefix("data: ").or_else(|| line.strip_prefix("data:"))?;
 557                        if line == "[DONE]" {
 558                            None
 559                        } else {
 560                            match serde_json::from_str(line) {
 561                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
 562                                Ok(ResponseStreamResult::Err { error }) => {
 563                                    Some(Err(anyhow!(error.message)))
 564                                }
 565                                Err(error) => {
 566                                    log::error!(
 567                                        "Failed to parse OpenAI response into ResponseStreamResult: `{}`\n\
 568                                        Response: `{}`",
 569                                        error,
 570                                        line,
 571                                    );
 572                                    Some(Err(anyhow!(error)))
 573                                }
 574                            }
 575                        }
 576                    }
 577                    Err(error) => Some(Err(anyhow!(error))),
 578                }
 579            })
 580            .boxed())
 581    } else {
 582        let mut body = String::new();
 583        response
 584            .body_mut()
 585            .read_to_string(&mut body)
 586            .await
 587            .map_err(|e| RequestError::Other(e.into()))?;
 588
 589        Err(RequestError::HttpResponseError {
 590            provider: provider_name.to_owned(),
 591            status_code: response.status(),
 592            body,
 593            headers: response.headers().clone(),
 594        })
 595    }
 596}
 597
 598#[derive(Copy, Clone, Serialize, Deserialize)]
 599pub enum OpenAiEmbeddingModel {
 600    #[serde(rename = "text-embedding-3-small")]
 601    TextEmbedding3Small,
 602    #[serde(rename = "text-embedding-3-large")]
 603    TextEmbedding3Large,
 604}
 605
 606#[derive(Serialize)]
 607struct OpenAiEmbeddingRequest<'a> {
 608    model: OpenAiEmbeddingModel,
 609    input: Vec<&'a str>,
 610}
 611
 612#[derive(Deserialize)]
 613pub struct OpenAiEmbeddingResponse {
 614    pub data: Vec<OpenAiEmbedding>,
 615}
 616
 617#[derive(Deserialize)]
 618pub struct OpenAiEmbedding {
 619    pub embedding: Vec<f32>,
 620}
 621
 622pub fn embed<'a>(
 623    client: &dyn HttpClient,
 624    api_url: &str,
 625    api_key: &str,
 626    model: OpenAiEmbeddingModel,
 627    texts: impl IntoIterator<Item = &'a str>,
 628) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
 629    let uri = format!("{api_url}/embeddings");
 630
 631    let request = OpenAiEmbeddingRequest {
 632        model,
 633        input: texts.into_iter().collect(),
 634    };
 635    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
 636    let request = HttpRequest::builder()
 637        .method(Method::POST)
 638        .uri(uri)
 639        .header("Content-Type", "application/json")
 640        .header("Authorization", format!("Bearer {}", api_key.trim()))
 641        .body(body)
 642        .map(|request| client.send(request));
 643
 644    async move {
 645        let mut response = request?.await?;
 646        let mut body = String::new();
 647        response.body_mut().read_to_string(&mut body).await?;
 648
 649        anyhow::ensure!(
 650            response.status().is_success(),
 651            "error during embedding, status: {:?}, body: {:?}",
 652            response.status(),
 653            body
 654        );
 655        let response: OpenAiEmbeddingResponse =
 656            serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
 657        Ok(response)
 658    }
 659}
 660
 661pub mod responses {
 662    use anyhow::{Result, anyhow};
 663    use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
 664    use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 665    use serde::{Deserialize, Serialize};
 666    use serde_json::Value;
 667
 668    use crate::RequestError;
 669
 670    #[derive(Serialize, Debug)]
 671    pub struct Request {
 672        pub model: String,
 673        #[serde(skip_serializing_if = "Vec::is_empty")]
 674        pub input: Vec<Value>,
 675        #[serde(default)]
 676        pub stream: bool,
 677        #[serde(skip_serializing_if = "Option::is_none")]
 678        pub temperature: Option<f32>,
 679        #[serde(skip_serializing_if = "Option::is_none")]
 680        pub top_p: Option<f32>,
 681        #[serde(skip_serializing_if = "Option::is_none")]
 682        pub max_output_tokens: Option<u64>,
 683        #[serde(skip_serializing_if = "Option::is_none")]
 684        pub parallel_tool_calls: Option<bool>,
 685        #[serde(skip_serializing_if = "Option::is_none")]
 686        pub tool_choice: Option<super::ToolChoice>,
 687        #[serde(skip_serializing_if = "Vec::is_empty")]
 688        pub tools: Vec<ToolDefinition>,
 689        #[serde(skip_serializing_if = "Option::is_none")]
 690        pub prompt_cache_key: Option<String>,
 691        #[serde(skip_serializing_if = "Option::is_none")]
 692        pub reasoning: Option<ReasoningConfig>,
 693    }
 694
 695    #[derive(Serialize, Debug)]
 696    pub struct ReasoningConfig {
 697        pub effort: super::ReasoningEffort,
 698    }
 699
 700    #[derive(Serialize, Debug)]
 701    #[serde(tag = "type", rename_all = "snake_case")]
 702    pub enum ToolDefinition {
 703        Function {
 704            name: String,
 705            #[serde(skip_serializing_if = "Option::is_none")]
 706            description: Option<String>,
 707            #[serde(skip_serializing_if = "Option::is_none")]
 708            parameters: Option<Value>,
 709            #[serde(skip_serializing_if = "Option::is_none")]
 710            strict: Option<bool>,
 711        },
 712    }
 713
 714    #[derive(Deserialize, Debug)]
 715    pub struct Error {
 716        pub message: String,
 717    }
 718
 719    #[derive(Deserialize, Debug)]
 720    #[serde(tag = "type")]
 721    pub enum StreamEvent {
 722        #[serde(rename = "response.created")]
 723        Created { response: ResponseSummary },
 724        #[serde(rename = "response.in_progress")]
 725        InProgress { response: ResponseSummary },
 726        #[serde(rename = "response.output_item.added")]
 727        OutputItemAdded {
 728            output_index: usize,
 729            #[serde(default)]
 730            sequence_number: Option<u64>,
 731            item: ResponseOutputItem,
 732        },
 733        #[serde(rename = "response.output_item.done")]
 734        OutputItemDone {
 735            output_index: usize,
 736            #[serde(default)]
 737            sequence_number: Option<u64>,
 738            item: ResponseOutputItem,
 739        },
 740        #[serde(rename = "response.content_part.added")]
 741        ContentPartAdded {
 742            item_id: String,
 743            output_index: usize,
 744            content_index: usize,
 745            part: Value,
 746        },
 747        #[serde(rename = "response.content_part.done")]
 748        ContentPartDone {
 749            item_id: String,
 750            output_index: usize,
 751            content_index: usize,
 752            part: Value,
 753        },
 754        #[serde(rename = "response.output_text.delta")]
 755        OutputTextDelta {
 756            item_id: String,
 757            output_index: usize,
 758            #[serde(default)]
 759            content_index: Option<usize>,
 760            delta: String,
 761        },
 762        #[serde(rename = "response.output_text.done")]
 763        OutputTextDone {
 764            item_id: String,
 765            output_index: usize,
 766            #[serde(default)]
 767            content_index: Option<usize>,
 768            text: String,
 769        },
 770        #[serde(rename = "response.function_call_arguments.delta")]
 771        FunctionCallArgumentsDelta {
 772            item_id: String,
 773            output_index: usize,
 774            delta: String,
 775            #[serde(default)]
 776            sequence_number: Option<u64>,
 777        },
 778        #[serde(rename = "response.function_call_arguments.done")]
 779        FunctionCallArgumentsDone {
 780            item_id: String,
 781            output_index: usize,
 782            arguments: String,
 783            #[serde(default)]
 784            sequence_number: Option<u64>,
 785        },
 786        #[serde(rename = "response.completed")]
 787        Completed { response: ResponseSummary },
 788        #[serde(rename = "response.incomplete")]
 789        Incomplete { response: ResponseSummary },
 790        #[serde(rename = "response.failed")]
 791        Failed { response: ResponseSummary },
 792        #[serde(rename = "response.error")]
 793        Error { error: Error },
 794        #[serde(rename = "error")]
 795        GenericError { error: Error },
 796        #[serde(other)]
 797        Unknown,
 798    }
 799
 800    #[derive(Deserialize, Debug, Default, Clone)]
 801    pub struct ResponseSummary {
 802        #[serde(default)]
 803        pub id: Option<String>,
 804        #[serde(default)]
 805        pub status: Option<String>,
 806        #[serde(default)]
 807        pub status_details: Option<ResponseStatusDetails>,
 808        #[serde(default)]
 809        pub usage: Option<ResponseUsage>,
 810        #[serde(default)]
 811        pub output: Vec<ResponseOutputItem>,
 812    }
 813
 814    #[derive(Deserialize, Debug, Default, Clone)]
 815    pub struct ResponseStatusDetails {
 816        #[serde(default)]
 817        pub reason: Option<String>,
 818        #[serde(default)]
 819        pub r#type: Option<String>,
 820        #[serde(default)]
 821        pub error: Option<Value>,
 822    }
 823
 824    #[derive(Deserialize, Debug, Default, Clone)]
 825    pub struct ResponseUsage {
 826        #[serde(default)]
 827        pub input_tokens: Option<u64>,
 828        #[serde(default)]
 829        pub output_tokens: Option<u64>,
 830        #[serde(default)]
 831        pub total_tokens: Option<u64>,
 832    }
 833
 834    #[derive(Deserialize, Debug, Clone)]
 835    #[serde(tag = "type", rename_all = "snake_case")]
 836    pub enum ResponseOutputItem {
 837        Message(ResponseOutputMessage),
 838        FunctionCall(ResponseFunctionToolCall),
 839        #[serde(other)]
 840        Unknown,
 841    }
 842
 843    #[derive(Deserialize, Debug, Clone)]
 844    pub struct ResponseOutputMessage {
 845        #[serde(default)]
 846        pub id: Option<String>,
 847        #[serde(default)]
 848        pub content: Vec<Value>,
 849        #[serde(default)]
 850        pub role: Option<String>,
 851        #[serde(default)]
 852        pub status: Option<String>,
 853    }
 854
 855    #[derive(Deserialize, Debug, Clone)]
 856    pub struct ResponseFunctionToolCall {
 857        #[serde(default)]
 858        pub id: Option<String>,
 859        #[serde(default)]
 860        pub arguments: String,
 861        #[serde(default)]
 862        pub call_id: Option<String>,
 863        #[serde(default)]
 864        pub name: Option<String>,
 865        #[serde(default)]
 866        pub status: Option<String>,
 867    }
 868
 869    pub async fn stream_response(
 870        client: &dyn HttpClient,
 871        provider_name: &str,
 872        api_url: &str,
 873        api_key: &str,
 874        request: Request,
 875    ) -> Result<BoxStream<'static, Result<StreamEvent>>, RequestError> {
 876        let uri = format!("{api_url}/responses");
 877        let request_builder = HttpRequest::builder()
 878            .method(Method::POST)
 879            .uri(uri)
 880            .header("Content-Type", "application/json")
 881            .header("Authorization", format!("Bearer {}", api_key.trim()));
 882
 883        let is_streaming = request.stream;
 884        let request = request_builder
 885            .body(AsyncBody::from(
 886                serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
 887            ))
 888            .map_err(|e| RequestError::Other(e.into()))?;
 889
 890        let mut response = client.send(request).await?;
 891        if response.status().is_success() {
 892            if is_streaming {
 893                let reader = BufReader::new(response.into_body());
 894                Ok(reader
 895                    .lines()
 896                    .filter_map(|line| async move {
 897                        match line {
 898                            Ok(line) => {
 899                                let line = line
 900                                    .strip_prefix("data: ")
 901                                    .or_else(|| line.strip_prefix("data:"))?;
 902                                if line == "[DONE]" || line.is_empty() {
 903                                    None
 904                                } else {
 905                                    match serde_json::from_str::<StreamEvent>(line) {
 906                                        Ok(event) => Some(Ok(event)),
 907                                        Err(error) => {
 908                                            log::error!(
 909                                                "Failed to parse OpenAI responses stream event: `{}`\nResponse: `{}`",
 910                                                error,
 911                                                line,
 912                                            );
 913                                            Some(Err(anyhow!(error)))
 914                                        }
 915                                    }
 916                                }
 917                            }
 918                            Err(error) => Some(Err(anyhow!(error))),
 919                        }
 920                    })
 921                    .boxed())
 922            } else {
 923                let mut body = String::new();
 924                response
 925                    .body_mut()
 926                    .read_to_string(&mut body)
 927                    .await
 928                    .map_err(|e| RequestError::Other(e.into()))?;
 929
 930                match serde_json::from_str::<ResponseSummary>(&body) {
 931                    Ok(response_summary) => {
 932                        let events = vec![
 933                            StreamEvent::Created {
 934                                response: response_summary.clone(),
 935                            },
 936                            StreamEvent::InProgress {
 937                                response: response_summary.clone(),
 938                            },
 939                        ];
 940
 941                        let mut all_events = events;
 942                        for (output_index, item) in response_summary.output.iter().enumerate() {
 943                            all_events.push(StreamEvent::OutputItemAdded {
 944                                output_index,
 945                                sequence_number: None,
 946                                item: item.clone(),
 947                            });
 948
 949                            match item {
 950                                ResponseOutputItem::Message(message) => {
 951                                    for content_item in &message.content {
 952                                        if let Some(text) = content_item.get("text") {
 953                                            if let Some(text_str) = text.as_str() {
 954                                                if let Some(ref item_id) = message.id {
 955                                                    all_events.push(StreamEvent::OutputTextDelta {
 956                                                        item_id: item_id.clone(),
 957                                                        output_index,
 958                                                        content_index: None,
 959                                                        delta: text_str.to_string(),
 960                                                    });
 961                                                }
 962                                            }
 963                                        }
 964                                    }
 965                                }
 966                                ResponseOutputItem::FunctionCall(function_call) => {
 967                                    if let Some(ref item_id) = function_call.id {
 968                                        all_events.push(StreamEvent::FunctionCallArgumentsDone {
 969                                            item_id: item_id.clone(),
 970                                            output_index,
 971                                            arguments: function_call.arguments.clone(),
 972                                            sequence_number: None,
 973                                        });
 974                                    }
 975                                }
 976                                ResponseOutputItem::Unknown => {}
 977                            }
 978
 979                            all_events.push(StreamEvent::OutputItemDone {
 980                                output_index,
 981                                sequence_number: None,
 982                                item: item.clone(),
 983                            });
 984                        }
 985
 986                        all_events.push(StreamEvent::Completed {
 987                            response: response_summary,
 988                        });
 989
 990                        Ok(futures::stream::iter(all_events.into_iter().map(Ok)).boxed())
 991                    }
 992                    Err(error) => {
 993                        log::error!(
 994                            "Failed to parse OpenAI non-streaming response: `{}`\nResponse: `{}`",
 995                            error,
 996                            body,
 997                        );
 998                        Err(RequestError::Other(anyhow!(error)))
 999                    }
1000                }
1001            }
1002        } else {
1003            let mut body = String::new();
1004            response
1005                .body_mut()
1006                .read_to_string(&mut body)
1007                .await
1008                .map_err(|e| RequestError::Other(e.into()))?;
1009
1010            Err(RequestError::HttpResponseError {
1011                provider: provider_name.to_owned(),
1012                status_code: response.status(),
1013                body,
1014                headers: response.headers().clone(),
1015            })
1016        }
1017    }
1018}