open_ai.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
   3use http_client::{
   4    AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode,
   5    http::{HeaderMap, HeaderValue},
   6};
   7use serde::{Deserialize, Serialize};
   8use serde_json::Value;
   9pub use settings::OpenAiReasoningEffort as ReasoningEffort;
  10use std::{convert::TryFrom, future::Future};
  11use strum::EnumIter;
  12use thiserror::Error;
  13
  14pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
  15
  16fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
  17    opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
  18}
  19
  20#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
  21#[serde(rename_all = "lowercase")]
  22pub enum Role {
  23    User,
  24    Assistant,
  25    System,
  26    Tool,
  27}
  28
  29impl TryFrom<String> for Role {
  30    type Error = anyhow::Error;
  31
  32    fn try_from(value: String) -> Result<Self> {
  33        match value.as_str() {
  34            "user" => Ok(Self::User),
  35            "assistant" => Ok(Self::Assistant),
  36            "system" => Ok(Self::System),
  37            "tool" => Ok(Self::Tool),
  38            _ => anyhow::bail!("invalid role '{value}'"),
  39        }
  40    }
  41}
  42
  43impl From<Role> for String {
  44    fn from(val: Role) -> Self {
  45        match val {
  46            Role::User => "user".to_owned(),
  47            Role::Assistant => "assistant".to_owned(),
  48            Role::System => "system".to_owned(),
  49            Role::Tool => "tool".to_owned(),
  50        }
  51    }
  52}
  53
  54#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
  55#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
  56pub enum Model {
  57    #[serde(rename = "gpt-3.5-turbo")]
  58    ThreePointFiveTurbo,
  59    #[serde(rename = "gpt-4")]
  60    Four,
  61    #[serde(rename = "gpt-4-turbo")]
  62    FourTurbo,
  63    #[serde(rename = "gpt-4o")]
  64    #[default]
  65    FourOmni,
  66    #[serde(rename = "gpt-4o-mini")]
  67    FourOmniMini,
  68    #[serde(rename = "gpt-4.1")]
  69    FourPointOne,
  70    #[serde(rename = "gpt-4.1-mini")]
  71    FourPointOneMini,
  72    #[serde(rename = "gpt-4.1-nano")]
  73    FourPointOneNano,
  74    #[serde(rename = "o1")]
  75    O1,
  76    #[serde(rename = "o3-mini")]
  77    O3Mini,
  78    #[serde(rename = "o3")]
  79    O3,
  80    #[serde(rename = "o4-mini")]
  81    O4Mini,
  82    #[serde(rename = "gpt-5")]
  83    Five,
  84    #[serde(rename = "gpt-5-codex")]
  85    FiveCodex,
  86    #[serde(rename = "gpt-5-mini")]
  87    FiveMini,
  88    #[serde(rename = "gpt-5-nano")]
  89    FiveNano,
  90    #[serde(rename = "gpt-5.1")]
  91    FivePointOne,
  92    #[serde(rename = "gpt-5.2")]
  93    FivePointTwo,
  94    #[serde(rename = "custom")]
  95    Custom {
  96        name: String,
  97        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
  98        display_name: Option<String>,
  99        max_tokens: u64,
 100        max_output_tokens: Option<u64>,
 101        max_completion_tokens: Option<u64>,
 102        reasoning_effort: Option<ReasoningEffort>,
 103        #[serde(default = "default_supports_chat_completions")]
 104        supports_chat_completions: bool,
 105    },
 106}
 107
 108const fn default_supports_chat_completions() -> bool {
 109    true
 110}
 111
 112impl Model {
 113    pub fn default_fast() -> Self {
 114        // TODO: Replace with FiveMini since all other models are deprecated
 115        Self::FourPointOneMini
 116    }
 117
 118    pub fn from_id(id: &str) -> Result<Self> {
 119        match id {
 120            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
 121            "gpt-4" => Ok(Self::Four),
 122            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
 123            "gpt-4o" => Ok(Self::FourOmni),
 124            "gpt-4o-mini" => Ok(Self::FourOmniMini),
 125            "gpt-4.1" => Ok(Self::FourPointOne),
 126            "gpt-4.1-mini" => Ok(Self::FourPointOneMini),
 127            "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
 128            "o1" => Ok(Self::O1),
 129            "o3-mini" => Ok(Self::O3Mini),
 130            "o3" => Ok(Self::O3),
 131            "o4-mini" => Ok(Self::O4Mini),
 132            "gpt-5" => Ok(Self::Five),
 133            "gpt-5-codex" => Ok(Self::FiveCodex),
 134            "gpt-5-mini" => Ok(Self::FiveMini),
 135            "gpt-5-nano" => Ok(Self::FiveNano),
 136            "gpt-5.1" => Ok(Self::FivePointOne),
 137            "gpt-5.2" => Ok(Self::FivePointTwo),
 138            invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
 139        }
 140    }
 141
 142    pub fn id(&self) -> &str {
 143        match self {
 144            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
 145            Self::Four => "gpt-4",
 146            Self::FourTurbo => "gpt-4-turbo",
 147            Self::FourOmni => "gpt-4o",
 148            Self::FourOmniMini => "gpt-4o-mini",
 149            Self::FourPointOne => "gpt-4.1",
 150            Self::FourPointOneMini => "gpt-4.1-mini",
 151            Self::FourPointOneNano => "gpt-4.1-nano",
 152            Self::O1 => "o1",
 153            Self::O3Mini => "o3-mini",
 154            Self::O3 => "o3",
 155            Self::O4Mini => "o4-mini",
 156            Self::Five => "gpt-5",
 157            Self::FiveCodex => "gpt-5-codex",
 158            Self::FiveMini => "gpt-5-mini",
 159            Self::FiveNano => "gpt-5-nano",
 160            Self::FivePointOne => "gpt-5.1",
 161            Self::FivePointTwo => "gpt-5.2",
 162            Self::Custom { name, .. } => name,
 163        }
 164    }
 165
 166    pub fn display_name(&self) -> &str {
 167        match self {
 168            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
 169            Self::Four => "gpt-4",
 170            Self::FourTurbo => "gpt-4-turbo",
 171            Self::FourOmni => "gpt-4o",
 172            Self::FourOmniMini => "gpt-4o-mini",
 173            Self::FourPointOne => "gpt-4.1",
 174            Self::FourPointOneMini => "gpt-4.1-mini",
 175            Self::FourPointOneNano => "gpt-4.1-nano",
 176            Self::O1 => "o1",
 177            Self::O3Mini => "o3-mini",
 178            Self::O3 => "o3",
 179            Self::O4Mini => "o4-mini",
 180            Self::Five => "gpt-5",
 181            Self::FiveCodex => "gpt-5-codex",
 182            Self::FiveMini => "gpt-5-mini",
 183            Self::FiveNano => "gpt-5-nano",
 184            Self::FivePointOne => "gpt-5.1",
 185            Self::FivePointTwo => "gpt-5.2",
 186            Self::Custom {
 187                name, display_name, ..
 188            } => display_name.as_ref().unwrap_or(name),
 189        }
 190    }
 191
 192    pub fn max_token_count(&self) -> u64 {
 193        match self {
 194            Self::ThreePointFiveTurbo => 16_385,
 195            Self::Four => 8_192,
 196            Self::FourTurbo => 128_000,
 197            Self::FourOmni => 128_000,
 198            Self::FourOmniMini => 128_000,
 199            Self::FourPointOne => 1_047_576,
 200            Self::FourPointOneMini => 1_047_576,
 201            Self::FourPointOneNano => 1_047_576,
 202            Self::O1 => 200_000,
 203            Self::O3Mini => 200_000,
 204            Self::O3 => 200_000,
 205            Self::O4Mini => 200_000,
 206            Self::Five => 272_000,
 207            Self::FiveCodex => 272_000,
 208            Self::FiveMini => 272_000,
 209            Self::FiveNano => 272_000,
 210            Self::FivePointOne => 400_000,
 211            Self::FivePointTwo => 400_000,
 212            Self::Custom { max_tokens, .. } => *max_tokens,
 213        }
 214    }
 215
 216    pub fn max_output_tokens(&self) -> Option<u64> {
 217        match self {
 218            Self::Custom {
 219                max_output_tokens, ..
 220            } => *max_output_tokens,
 221            Self::ThreePointFiveTurbo => Some(4_096),
 222            Self::Four => Some(8_192),
 223            Self::FourTurbo => Some(4_096),
 224            Self::FourOmni => Some(16_384),
 225            Self::FourOmniMini => Some(16_384),
 226            Self::FourPointOne => Some(32_768),
 227            Self::FourPointOneMini => Some(32_768),
 228            Self::FourPointOneNano => Some(32_768),
 229            Self::O1 => Some(100_000),
 230            Self::O3Mini => Some(100_000),
 231            Self::O3 => Some(100_000),
 232            Self::O4Mini => Some(100_000),
 233            Self::Five => Some(128_000),
 234            Self::FiveCodex => Some(128_000),
 235            Self::FiveMini => Some(128_000),
 236            Self::FiveNano => Some(128_000),
 237            Self::FivePointOne => Some(128_000),
 238            Self::FivePointTwo => Some(128_000),
 239        }
 240    }
 241
 242    pub fn reasoning_effort(&self) -> Option<ReasoningEffort> {
 243        match self {
 244            Self::Custom {
 245                reasoning_effort, ..
 246            } => reasoning_effort.to_owned(),
 247            _ => None,
 248        }
 249    }
 250
 251    pub fn supports_chat_completions(&self) -> bool {
 252        match self {
 253            Self::Custom {
 254                supports_chat_completions,
 255                ..
 256            } => *supports_chat_completions,
 257            Self::FiveCodex => false,
 258            _ => true,
 259        }
 260    }
 261
 262    /// Returns whether the given model supports the `parallel_tool_calls` parameter.
 263    ///
 264    /// If the model does not support the parameter, do not pass it up, or the API will return an error.
 265    pub fn supports_parallel_tool_calls(&self) -> bool {
 266        match self {
 267            Self::ThreePointFiveTurbo
 268            | Self::Four
 269            | Self::FourTurbo
 270            | Self::FourOmni
 271            | Self::FourOmniMini
 272            | Self::FourPointOne
 273            | Self::FourPointOneMini
 274            | Self::FourPointOneNano
 275            | Self::Five
 276            | Self::FiveCodex
 277            | Self::FiveMini
 278            | Self::FivePointOne
 279            | Self::FivePointTwo
 280            | Self::FiveNano => true,
 281            Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false,
 282        }
 283    }
 284
 285    /// Returns whether the given model supports the `prompt_cache_key` parameter.
 286    ///
 287    /// If the model does not support the parameter, do not pass it up.
 288    pub fn supports_prompt_cache_key(&self) -> bool {
 289        true
 290    }
 291}
 292
 293#[derive(Debug, Serialize, Deserialize)]
 294pub struct Request {
 295    pub model: String,
 296    pub messages: Vec<RequestMessage>,
 297    pub stream: bool,
 298    #[serde(default, skip_serializing_if = "Option::is_none")]
 299    pub max_completion_tokens: Option<u64>,
 300    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 301    pub stop: Vec<String>,
 302    #[serde(default, skip_serializing_if = "Option::is_none")]
 303    pub temperature: Option<f32>,
 304    #[serde(default, skip_serializing_if = "Option::is_none")]
 305    pub tool_choice: Option<ToolChoice>,
 306    /// Whether to enable parallel function calling during tool use.
 307    #[serde(default, skip_serializing_if = "Option::is_none")]
 308    pub parallel_tool_calls: Option<bool>,
 309    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 310    pub tools: Vec<ToolDefinition>,
 311    #[serde(default, skip_serializing_if = "Option::is_none")]
 312    pub prompt_cache_key: Option<String>,
 313    #[serde(default, skip_serializing_if = "Option::is_none")]
 314    pub reasoning_effort: Option<ReasoningEffort>,
 315}
 316
 317#[derive(Debug, Serialize, Deserialize)]
 318#[serde(rename_all = "lowercase")]
 319pub enum ToolChoice {
 320    Auto,
 321    Required,
 322    None,
 323    #[serde(untagged)]
 324    Other(ToolDefinition),
 325}
 326
 327#[derive(Clone, Deserialize, Serialize, Debug)]
 328#[serde(tag = "type", rename_all = "snake_case")]
 329pub enum ToolDefinition {
 330    #[allow(dead_code)]
 331    Function { function: FunctionDefinition },
 332}
 333
 334#[derive(Clone, Debug, Serialize, Deserialize)]
 335pub struct FunctionDefinition {
 336    pub name: String,
 337    pub description: Option<String>,
 338    pub parameters: Option<Value>,
 339}
 340
 341#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 342#[serde(tag = "role", rename_all = "lowercase")]
 343pub enum RequestMessage {
 344    Assistant {
 345        content: Option<MessageContent>,
 346        #[serde(default, skip_serializing_if = "Vec::is_empty")]
 347        tool_calls: Vec<ToolCall>,
 348    },
 349    User {
 350        content: MessageContent,
 351    },
 352    System {
 353        content: MessageContent,
 354    },
 355    Tool {
 356        content: MessageContent,
 357        tool_call_id: String,
 358    },
 359}
 360
 361#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
 362#[serde(untagged)]
 363pub enum MessageContent {
 364    Plain(String),
 365    Multipart(Vec<MessagePart>),
 366}
 367
 368impl MessageContent {
 369    pub fn empty() -> Self {
 370        MessageContent::Multipart(vec![])
 371    }
 372
 373    pub fn push_part(&mut self, part: MessagePart) {
 374        match self {
 375            MessageContent::Plain(text) => {
 376                *self =
 377                    MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
 378            }
 379            MessageContent::Multipart(parts) if parts.is_empty() => match part {
 380                MessagePart::Text { text } => *self = MessageContent::Plain(text),
 381                MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
 382            },
 383            MessageContent::Multipart(parts) => parts.push(part),
 384        }
 385    }
 386}
 387
 388impl From<Vec<MessagePart>> for MessageContent {
 389    fn from(mut parts: Vec<MessagePart>) -> Self {
 390        if let [MessagePart::Text { text }] = parts.as_mut_slice() {
 391            MessageContent::Plain(std::mem::take(text))
 392        } else {
 393            MessageContent::Multipart(parts)
 394        }
 395    }
 396}
 397
 398#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
 399#[serde(tag = "type")]
 400pub enum MessagePart {
 401    #[serde(rename = "text")]
 402    Text { text: String },
 403    #[serde(rename = "image_url")]
 404    Image { image_url: ImageUrl },
 405}
 406
 407#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
 408pub struct ImageUrl {
 409    pub url: String,
 410    #[serde(skip_serializing_if = "Option::is_none")]
 411    pub detail: Option<String>,
 412}
 413
 414#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 415pub struct ToolCall {
 416    pub id: String,
 417    #[serde(flatten)]
 418    pub content: ToolCallContent,
 419}
 420
 421#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 422#[serde(tag = "type", rename_all = "lowercase")]
 423pub enum ToolCallContent {
 424    Function { function: FunctionContent },
 425}
 426
 427#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 428pub struct FunctionContent {
 429    pub name: String,
 430    pub arguments: String,
 431}
 432
 433#[derive(Clone, Serialize, Deserialize, Debug)]
 434pub struct Response {
 435    pub id: String,
 436    pub object: String,
 437    pub created: u64,
 438    pub model: String,
 439    pub choices: Vec<Choice>,
 440    pub usage: Usage,
 441}
 442
 443#[derive(Clone, Serialize, Deserialize, Debug)]
 444pub struct Choice {
 445    pub index: u32,
 446    pub message: RequestMessage,
 447    pub finish_reason: Option<String>,
 448}
 449
 450#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 451pub struct ResponseMessageDelta {
 452    pub role: Option<Role>,
 453    pub content: Option<String>,
 454    #[serde(default, skip_serializing_if = "is_none_or_empty")]
 455    pub tool_calls: Option<Vec<ToolCallChunk>>,
 456}
 457
 458#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 459pub struct ToolCallChunk {
 460    pub index: usize,
 461    pub id: Option<String>,
 462
 463    // There is also an optional `type` field that would determine if a
 464    // function is there. Sometimes this streams in with the `function` before
 465    // it streams in the `type`
 466    pub function: Option<FunctionChunk>,
 467}
 468
 469#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 470pub struct FunctionChunk {
 471    pub name: Option<String>,
 472    pub arguments: Option<String>,
 473}
 474
 475#[derive(Clone, Serialize, Deserialize, Debug)]
 476pub struct Usage {
 477    pub prompt_tokens: u64,
 478    pub completion_tokens: u64,
 479    pub total_tokens: u64,
 480}
 481
 482#[derive(Serialize, Deserialize, Debug)]
 483pub struct ChoiceDelta {
 484    pub index: u32,
 485    pub delta: Option<ResponseMessageDelta>,
 486    pub finish_reason: Option<String>,
 487}
 488
 489#[derive(Error, Debug)]
 490pub enum RequestError {
 491    #[error("HTTP response error from {provider}'s API: status {status_code} - {body:?}")]
 492    HttpResponseError {
 493        provider: String,
 494        status_code: StatusCode,
 495        body: String,
 496        headers: HeaderMap<HeaderValue>,
 497    },
 498    #[error(transparent)]
 499    Other(#[from] anyhow::Error),
 500}
 501
 502#[derive(Serialize, Deserialize, Debug)]
 503pub struct ResponseStreamError {
 504    message: String,
 505}
 506
 507#[derive(Serialize, Deserialize, Debug)]
 508#[serde(untagged)]
 509pub enum ResponseStreamResult {
 510    Ok(ResponseStreamEvent),
 511    Err { error: ResponseStreamError },
 512}
 513
 514#[derive(Serialize, Deserialize, Debug)]
 515pub struct ResponseStreamEvent {
 516    pub choices: Vec<ChoiceDelta>,
 517    pub usage: Option<Usage>,
 518}
 519
 520pub async fn stream_completion(
 521    client: &dyn HttpClient,
 522    provider_name: &str,
 523    api_url: &str,
 524    api_key: &str,
 525    request: Request,
 526) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>, RequestError> {
 527    let uri = format!("{api_url}/chat/completions");
 528    let request_builder = HttpRequest::builder()
 529        .method(Method::POST)
 530        .uri(uri)
 531        .header("Content-Type", "application/json")
 532        .header("Authorization", format!("Bearer {}", api_key.trim()));
 533
 534    let request = request_builder
 535        .body(AsyncBody::from(
 536            serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
 537        ))
 538        .map_err(|e| RequestError::Other(e.into()))?;
 539
 540    let mut response = client.send(request).await?;
 541    if response.status().is_success() {
 542        let reader = BufReader::new(response.into_body());
 543        Ok(reader
 544            .lines()
 545            .filter_map(|line| async move {
 546                match line {
 547                    Ok(line) => {
 548                        let line = line.strip_prefix("data: ").or_else(|| line.strip_prefix("data:"))?;
 549                        if line == "[DONE]" {
 550                            None
 551                        } else {
 552                            match serde_json::from_str(line) {
 553                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
 554                                Ok(ResponseStreamResult::Err { error }) => {
 555                                    Some(Err(anyhow!(error.message)))
 556                                }
 557                                Err(error) => {
 558                                    log::error!(
 559                                        "Failed to parse OpenAI response into ResponseStreamResult: `{}`\n\
 560                                        Response: `{}`",
 561                                        error,
 562                                        line,
 563                                    );
 564                                    Some(Err(anyhow!(error)))
 565                                }
 566                            }
 567                        }
 568                    }
 569                    Err(error) => Some(Err(anyhow!(error))),
 570                }
 571            })
 572            .boxed())
 573    } else {
 574        let mut body = String::new();
 575        response
 576            .body_mut()
 577            .read_to_string(&mut body)
 578            .await
 579            .map_err(|e| RequestError::Other(e.into()))?;
 580
 581        Err(RequestError::HttpResponseError {
 582            provider: provider_name.to_owned(),
 583            status_code: response.status(),
 584            body,
 585            headers: response.headers().clone(),
 586        })
 587    }
 588}
 589
 590#[derive(Copy, Clone, Serialize, Deserialize)]
 591pub enum OpenAiEmbeddingModel {
 592    #[serde(rename = "text-embedding-3-small")]
 593    TextEmbedding3Small,
 594    #[serde(rename = "text-embedding-3-large")]
 595    TextEmbedding3Large,
 596}
 597
 598#[derive(Serialize)]
 599struct OpenAiEmbeddingRequest<'a> {
 600    model: OpenAiEmbeddingModel,
 601    input: Vec<&'a str>,
 602}
 603
 604#[derive(Deserialize)]
 605pub struct OpenAiEmbeddingResponse {
 606    pub data: Vec<OpenAiEmbedding>,
 607}
 608
 609#[derive(Deserialize)]
 610pub struct OpenAiEmbedding {
 611    pub embedding: Vec<f32>,
 612}
 613
 614pub fn embed<'a>(
 615    client: &dyn HttpClient,
 616    api_url: &str,
 617    api_key: &str,
 618    model: OpenAiEmbeddingModel,
 619    texts: impl IntoIterator<Item = &'a str>,
 620) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
 621    let uri = format!("{api_url}/embeddings");
 622
 623    let request = OpenAiEmbeddingRequest {
 624        model,
 625        input: texts.into_iter().collect(),
 626    };
 627    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
 628    let request = HttpRequest::builder()
 629        .method(Method::POST)
 630        .uri(uri)
 631        .header("Content-Type", "application/json")
 632        .header("Authorization", format!("Bearer {}", api_key.trim()))
 633        .body(body)
 634        .map(|request| client.send(request));
 635
 636    async move {
 637        let mut response = request?.await?;
 638        let mut body = String::new();
 639        response.body_mut().read_to_string(&mut body).await?;
 640
 641        anyhow::ensure!(
 642            response.status().is_success(),
 643            "error during embedding, status: {:?}, body: {:?}",
 644            response.status(),
 645            body
 646        );
 647        let response: OpenAiEmbeddingResponse =
 648            serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
 649        Ok(response)
 650    }
 651}
 652
 653pub mod responses {
 654    use anyhow::{Result, anyhow};
 655    use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
 656    use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 657    use serde::{Deserialize, Serialize};
 658    use serde_json::Value;
 659
 660    use crate::RequestError;
 661
 662    #[derive(Serialize, Debug)]
 663    pub struct Request {
 664        pub model: String,
 665        #[serde(skip_serializing_if = "Vec::is_empty")]
 666        pub input: Vec<Value>,
 667        #[serde(default)]
 668        pub stream: bool,
 669        #[serde(skip_serializing_if = "Option::is_none")]
 670        pub temperature: Option<f32>,
 671        #[serde(skip_serializing_if = "Option::is_none")]
 672        pub top_p: Option<f32>,
 673        #[serde(skip_serializing_if = "Option::is_none")]
 674        pub max_output_tokens: Option<u64>,
 675        #[serde(skip_serializing_if = "Option::is_none")]
 676        pub parallel_tool_calls: Option<bool>,
 677        #[serde(skip_serializing_if = "Option::is_none")]
 678        pub tool_choice: Option<super::ToolChoice>,
 679        #[serde(skip_serializing_if = "Vec::is_empty")]
 680        pub tools: Vec<ToolDefinition>,
 681        #[serde(skip_serializing_if = "Option::is_none")]
 682        pub prompt_cache_key: Option<String>,
 683        #[serde(skip_serializing_if = "Option::is_none")]
 684        pub reasoning: Option<ReasoningConfig>,
 685    }
 686
 687    #[derive(Serialize, Debug)]
 688    pub struct ReasoningConfig {
 689        pub effort: super::ReasoningEffort,
 690    }
 691
 692    #[derive(Serialize, Debug)]
 693    #[serde(tag = "type", rename_all = "snake_case")]
 694    pub enum ToolDefinition {
 695        Function {
 696            name: String,
 697            #[serde(skip_serializing_if = "Option::is_none")]
 698            description: Option<String>,
 699            #[serde(skip_serializing_if = "Option::is_none")]
 700            parameters: Option<Value>,
 701            #[serde(skip_serializing_if = "Option::is_none")]
 702            strict: Option<bool>,
 703        },
 704    }
 705
 706    #[derive(Deserialize, Debug)]
 707    pub struct Error {
 708        pub message: String,
 709    }
 710
 711    #[derive(Deserialize, Debug)]
 712    #[serde(tag = "type")]
 713    pub enum StreamEvent {
 714        #[serde(rename = "response.created")]
 715        Created { response: ResponseSummary },
 716        #[serde(rename = "response.in_progress")]
 717        InProgress { response: ResponseSummary },
 718        #[serde(rename = "response.output_item.added")]
 719        OutputItemAdded {
 720            output_index: usize,
 721            #[serde(default)]
 722            sequence_number: Option<u64>,
 723            item: ResponseOutputItem,
 724        },
 725        #[serde(rename = "response.output_item.done")]
 726        OutputItemDone {
 727            output_index: usize,
 728            #[serde(default)]
 729            sequence_number: Option<u64>,
 730            item: ResponseOutputItem,
 731        },
 732        #[serde(rename = "response.content_part.added")]
 733        ContentPartAdded {
 734            item_id: String,
 735            output_index: usize,
 736            content_index: usize,
 737            part: Value,
 738        },
 739        #[serde(rename = "response.content_part.done")]
 740        ContentPartDone {
 741            item_id: String,
 742            output_index: usize,
 743            content_index: usize,
 744            part: Value,
 745        },
 746        #[serde(rename = "response.output_text.delta")]
 747        OutputTextDelta {
 748            item_id: String,
 749            output_index: usize,
 750            #[serde(default)]
 751            content_index: Option<usize>,
 752            delta: String,
 753        },
 754        #[serde(rename = "response.output_text.done")]
 755        OutputTextDone {
 756            item_id: String,
 757            output_index: usize,
 758            #[serde(default)]
 759            content_index: Option<usize>,
 760            text: String,
 761        },
 762        #[serde(rename = "response.function_call_arguments.delta")]
 763        FunctionCallArgumentsDelta {
 764            item_id: String,
 765            output_index: usize,
 766            delta: String,
 767            #[serde(default)]
 768            sequence_number: Option<u64>,
 769        },
 770        #[serde(rename = "response.function_call_arguments.done")]
 771        FunctionCallArgumentsDone {
 772            item_id: String,
 773            output_index: usize,
 774            arguments: String,
 775            #[serde(default)]
 776            sequence_number: Option<u64>,
 777        },
 778        #[serde(rename = "response.completed")]
 779        Completed { response: ResponseSummary },
 780        #[serde(rename = "response.incomplete")]
 781        Incomplete { response: ResponseSummary },
 782        #[serde(rename = "response.failed")]
 783        Failed { response: ResponseSummary },
 784        #[serde(rename = "response.error")]
 785        Error { error: Error },
 786        #[serde(rename = "error")]
 787        GenericError { error: Error },
 788        #[serde(other)]
 789        Unknown,
 790    }
 791
 792    #[derive(Deserialize, Debug, Default, Clone)]
 793    pub struct ResponseSummary {
 794        #[serde(default)]
 795        pub id: Option<String>,
 796        #[serde(default)]
 797        pub status: Option<String>,
 798        #[serde(default)]
 799        pub status_details: Option<ResponseStatusDetails>,
 800        #[serde(default)]
 801        pub usage: Option<ResponseUsage>,
 802        #[serde(default)]
 803        pub output: Vec<ResponseOutputItem>,
 804    }
 805
 806    #[derive(Deserialize, Debug, Default, Clone)]
 807    pub struct ResponseStatusDetails {
 808        #[serde(default)]
 809        pub reason: Option<String>,
 810        #[serde(default)]
 811        pub r#type: Option<String>,
 812        #[serde(default)]
 813        pub error: Option<Value>,
 814    }
 815
 816    #[derive(Deserialize, Debug, Default, Clone)]
 817    pub struct ResponseUsage {
 818        #[serde(default)]
 819        pub input_tokens: Option<u64>,
 820        #[serde(default)]
 821        pub output_tokens: Option<u64>,
 822        #[serde(default)]
 823        pub total_tokens: Option<u64>,
 824    }
 825
 826    #[derive(Deserialize, Debug, Clone)]
 827    #[serde(tag = "type", rename_all = "snake_case")]
 828    pub enum ResponseOutputItem {
 829        Message(ResponseOutputMessage),
 830        FunctionCall(ResponseFunctionToolCall),
 831        #[serde(other)]
 832        Unknown,
 833    }
 834
 835    #[derive(Deserialize, Debug, Clone)]
 836    pub struct ResponseOutputMessage {
 837        #[serde(default)]
 838        pub id: Option<String>,
 839        #[serde(default)]
 840        pub content: Vec<Value>,
 841        #[serde(default)]
 842        pub role: Option<String>,
 843        #[serde(default)]
 844        pub status: Option<String>,
 845    }
 846
 847    #[derive(Deserialize, Debug, Clone)]
 848    pub struct ResponseFunctionToolCall {
 849        #[serde(default)]
 850        pub id: Option<String>,
 851        #[serde(default)]
 852        pub arguments: String,
 853        #[serde(default)]
 854        pub call_id: Option<String>,
 855        #[serde(default)]
 856        pub name: Option<String>,
 857        #[serde(default)]
 858        pub status: Option<String>,
 859    }
 860
 861    pub async fn stream_response(
 862        client: &dyn HttpClient,
 863        provider_name: &str,
 864        api_url: &str,
 865        api_key: &str,
 866        request: Request,
 867    ) -> Result<BoxStream<'static, Result<StreamEvent>>, RequestError> {
 868        let uri = format!("{api_url}/responses");
 869        let request_builder = HttpRequest::builder()
 870            .method(Method::POST)
 871            .uri(uri)
 872            .header("Content-Type", "application/json")
 873            .header("Authorization", format!("Bearer {}", api_key.trim()));
 874
 875        let is_streaming = request.stream;
 876        let request = request_builder
 877            .body(AsyncBody::from(
 878                serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
 879            ))
 880            .map_err(|e| RequestError::Other(e.into()))?;
 881
 882        let mut response = client.send(request).await?;
 883        if response.status().is_success() {
 884            if is_streaming {
 885                let reader = BufReader::new(response.into_body());
 886                Ok(reader
 887                    .lines()
 888                    .filter_map(|line| async move {
 889                        match line {
 890                            Ok(line) => {
 891                                let line = line
 892                                    .strip_prefix("data: ")
 893                                    .or_else(|| line.strip_prefix("data:"))?;
 894                                if line == "[DONE]" || line.is_empty() {
 895                                    None
 896                                } else {
 897                                    match serde_json::from_str::<StreamEvent>(line) {
 898                                        Ok(event) => Some(Ok(event)),
 899                                        Err(error) => {
 900                                            log::error!(
 901                                                "Failed to parse OpenAI responses stream event: `{}`\nResponse: `{}`",
 902                                                error,
 903                                                line,
 904                                            );
 905                                            Some(Err(anyhow!(error)))
 906                                        }
 907                                    }
 908                                }
 909                            }
 910                            Err(error) => Some(Err(anyhow!(error))),
 911                        }
 912                    })
 913                    .boxed())
 914            } else {
 915                let mut body = String::new();
 916                response
 917                    .body_mut()
 918                    .read_to_string(&mut body)
 919                    .await
 920                    .map_err(|e| RequestError::Other(e.into()))?;
 921
 922                match serde_json::from_str::<ResponseSummary>(&body) {
 923                    Ok(response_summary) => {
 924                        let events = vec![
 925                            StreamEvent::Created {
 926                                response: response_summary.clone(),
 927                            },
 928                            StreamEvent::InProgress {
 929                                response: response_summary.clone(),
 930                            },
 931                        ];
 932
 933                        let mut all_events = events;
 934                        for (output_index, item) in response_summary.output.iter().enumerate() {
 935                            all_events.push(StreamEvent::OutputItemAdded {
 936                                output_index,
 937                                sequence_number: None,
 938                                item: item.clone(),
 939                            });
 940
 941                            match item {
 942                                ResponseOutputItem::Message(message) => {
 943                                    for content_item in &message.content {
 944                                        if let Some(text) = content_item.get("text") {
 945                                            if let Some(text_str) = text.as_str() {
 946                                                if let Some(ref item_id) = message.id {
 947                                                    all_events.push(StreamEvent::OutputTextDelta {
 948                                                        item_id: item_id.clone(),
 949                                                        output_index,
 950                                                        content_index: None,
 951                                                        delta: text_str.to_string(),
 952                                                    });
 953                                                }
 954                                            }
 955                                        }
 956                                    }
 957                                }
 958                                ResponseOutputItem::FunctionCall(function_call) => {
 959                                    if let Some(ref item_id) = function_call.id {
 960                                        all_events.push(StreamEvent::FunctionCallArgumentsDone {
 961                                            item_id: item_id.clone(),
 962                                            output_index,
 963                                            arguments: function_call.arguments.clone(),
 964                                            sequence_number: None,
 965                                        });
 966                                    }
 967                                }
 968                                ResponseOutputItem::Unknown => {}
 969                            }
 970
 971                            all_events.push(StreamEvent::OutputItemDone {
 972                                output_index,
 973                                sequence_number: None,
 974                                item: item.clone(),
 975                            });
 976                        }
 977
 978                        all_events.push(StreamEvent::Completed {
 979                            response: response_summary,
 980                        });
 981
 982                        Ok(futures::stream::iter(all_events.into_iter().map(Ok)).boxed())
 983                    }
 984                    Err(error) => {
 985                        log::error!(
 986                            "Failed to parse OpenAI non-streaming response: `{}`\nResponse: `{}`",
 987                            error,
 988                            body,
 989                        );
 990                        Err(RequestError::Other(anyhow!(error)))
 991                    }
 992                }
 993            }
 994        } else {
 995            let mut body = String::new();
 996            response
 997                .body_mut()
 998                .read_to_string(&mut body)
 999                .await
1000                .map_err(|e| RequestError::Other(e.into()))?;
1001
1002            Err(RequestError::HttpResponseError {
1003                provider: provider_name.to_owned(),
1004                status_code: response.status(),
1005                body,
1006                headers: response.headers().clone(),
1007            })
1008        }
1009    }
1010}