From ac2eb24d625934e2d15ceb5826125093e30c65ee Mon Sep 17 00:00:00 2001 From: "Joseph T. Lyons" Date: Thu, 19 Jun 2025 14:03:52 -0400 Subject: [PATCH] Revert "OpenAI cleanups (#32597)" This reverts commit 15f044f0a10a1804cc969fb7abd745c8629cf666. --- .../language_models/src/provider/open_ai.rs | 4 +- crates/open_ai/src/open_ai.rs | 245 +++++++++++++++++- 2 files changed, 244 insertions(+), 5 deletions(-) diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index e0998fdacc5b2afc43af1513fd770fa96d46449f..ac6d0b75be5460bbea9a366f37429eb8c925e9e6 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -440,7 +440,7 @@ pub fn into_open_ai( stream, stop: request.stop, temperature: request.temperature.unwrap_or(1.0), - max_completion_tokens: max_output_tokens, + max_tokens: max_output_tokens, parallel_tool_calls: if model.supports_parallel_tool_calls() && !request.tools.is_empty() { // Disable parallel tool calls, as the Agent currently expects a maximum of one per turn. Some(false) @@ -648,6 +648,8 @@ pub fn count_open_ai_tokens( | Model::FourPointOneMini | Model::FourPointOneNano | Model::O1 + | Model::O1Preview + | Model::O1Mini | Model::O3 | Model::O3Mini | Model::O4Mini => tiktoken_rs::num_tokens_from_messages(model.id(), &messages), diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 3b148d20c4910a49ee5a06dd9b4ce11da850170f..3390dfda75a92c31f4cd52e8b95e4a9529efd40f 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -1,9 +1,16 @@ use anyhow::{Context as _, Result, anyhow}; -use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; +use futures::{ + AsyncBufReadExt, AsyncReadExt, StreamExt, + io::BufReader, + stream::{self, BoxStream}, +}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use serde::{Deserialize, Serialize}; use serde_json::Value; -use std::{convert::TryFrom, future::Future}; +use std::{ + convert::TryFrom, + future::{self, Future}, +}; use strum::EnumIter; pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1"; @@ -68,6 +75,10 @@ pub enum Model { FourPointOneNano, #[serde(rename = "o1")] O1, + #[serde(rename = "o1-preview")] + O1Preview, + #[serde(rename = "o1-mini")] + O1Mini, #[serde(rename = "o3-mini")] O3Mini, #[serde(rename = "o3")] @@ -102,6 +113,8 @@ impl Model { "gpt-4.1-mini" => Ok(Self::FourPointOneMini), "gpt-4.1-nano" => Ok(Self::FourPointOneNano), "o1" => Ok(Self::O1), + "o1-preview" => Ok(Self::O1Preview), + "o1-mini" => Ok(Self::O1Mini), "o3-mini" => Ok(Self::O3Mini), "o3" => Ok(Self::O3), "o4-mini" => Ok(Self::O4Mini), @@ -120,6 +133,8 @@ impl Model { Self::FourPointOneMini => "gpt-4.1-mini", Self::FourPointOneNano => "gpt-4.1-nano", Self::O1 => "o1", + Self::O1Preview => "o1-preview", + Self::O1Mini => "o1-mini", Self::O3Mini => "o3-mini", Self::O3 => "o3", Self::O4Mini => "o4-mini", @@ -138,6 +153,8 @@ impl Model { Self::FourPointOneMini => "gpt-4.1-mini", Self::FourPointOneNano => "gpt-4.1-nano", Self::O1 => "o1", + Self::O1Preview => "o1-preview", + Self::O1Mini => "o1-mini", Self::O3Mini => "o3-mini", Self::O3 => "o3", Self::O4Mini => "o4-mini", @@ -158,6 +175,8 @@ impl Model { Self::FourPointOneMini => 1_047_576, Self::FourPointOneNano => 1_047_576, Self::O1 => 200_000, + Self::O1Preview => 128_000, + Self::O1Mini => 128_000, Self::O3Mini => 200_000, Self::O3 => 200_000, Self::O4Mini => 200_000, @@ -179,6 +198,8 @@ impl Model { Self::FourPointOneMini => Some(32_768), Self::FourPointOneNano => Some(32_768), Self::O1 => Some(100_000), + Self::O1Preview => Some(32_768), + Self::O1Mini => Some(65_536), Self::O3Mini => Some(100_000), Self::O3 => Some(100_000), Self::O4Mini => Some(100_000), @@ -198,7 +219,13 @@ impl Model { | Self::FourPointOne | Self::FourPointOneMini | Self::FourPointOneNano => true, - Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false, + Self::O1 + | Self::O1Preview + | Self::O1Mini + | Self::O3 + | Self::O3Mini + | Self::O4Mini + | Model::Custom { .. } => false, } } } @@ -209,7 +236,7 @@ pub struct Request { pub messages: Vec, pub stream: bool, #[serde(default, skip_serializing_if = "Option::is_none")] - pub max_completion_tokens: Option, + pub max_tokens: Option, #[serde(default, skip_serializing_if = "Vec::is_empty")] pub stop: Vec, pub temperature: f32, @@ -222,6 +249,24 @@ pub struct Request { pub tools: Vec, } +#[derive(Debug, Serialize, Deserialize)] +pub struct CompletionRequest { + pub model: String, + pub prompt: String, + pub max_tokens: u32, + pub temperature: f32, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub prediction: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub rewrite_speculation: Option, +} + +#[derive(Clone, Deserialize, Serialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Prediction { + Content { content: String }, +} + #[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum ToolChoice { @@ -390,12 +435,204 @@ pub struct ResponseStreamEvent { pub usage: Option, } +#[derive(Serialize, Deserialize, Debug)] +pub struct CompletionResponse { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, + pub usage: Usage, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct CompletionChoice { + pub text: String, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Response { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, + pub usage: Usage, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Choice { + pub index: u32, + pub message: RequestMessage, + pub finish_reason: Option, +} + +pub async fn complete( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + request: Request, +) -> Result { + let uri = format!("{api_url}/chat/completions"); + let request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)); + + let mut request_body = request; + request_body.stream = false; + + let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?; + let mut response = client.send(request).await?; + + if response.status().is_success() { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + let response: Response = serde_json::from_str(&body)?; + Ok(response) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + #[derive(Deserialize)] + struct OpenAiResponse { + error: OpenAiError, + } + + #[derive(Deserialize)] + struct OpenAiError { + message: String, + } + + match serde_json::from_str::(&body) { + Ok(response) if !response.error.message.is_empty() => anyhow::bail!( + "Failed to connect to OpenAI API: {}", + response.error.message, + ), + _ => anyhow::bail!( + "Failed to connect to OpenAI API: {} {}", + response.status(), + body, + ), + } + } +} + +pub async fn complete_text( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + request: CompletionRequest, +) -> Result { + let uri = format!("{api_url}/completions"); + let request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)); + + let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; + let mut response = client.send(request).await?; + + if response.status().is_success() { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + let response = serde_json::from_str(&body)?; + Ok(response) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + #[derive(Deserialize)] + struct OpenAiResponse { + error: OpenAiError, + } + + #[derive(Deserialize)] + struct OpenAiError { + message: String, + } + + match serde_json::from_str::(&body) { + Ok(response) if !response.error.message.is_empty() => anyhow::bail!( + "Failed to connect to OpenAI API: {}", + response.error.message, + ), + _ => anyhow::bail!( + "Failed to connect to OpenAI API: {} {}", + response.status(), + body, + ), + } + } +} + +fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent { + ResponseStreamEvent { + created: response.created as u32, + model: response.model, + choices: response + .choices + .into_iter() + .map(|choice| { + let content = match &choice.message { + RequestMessage::Assistant { content, .. } => content.as_ref(), + RequestMessage::User { content } => Some(content), + RequestMessage::System { content } => Some(content), + RequestMessage::Tool { content, .. } => Some(content), + }; + + let mut text_content = String::new(); + match content { + Some(MessageContent::Plain(text)) => text_content.push_str(&text), + Some(MessageContent::Multipart(parts)) => { + for part in parts { + match part { + MessagePart::Text { text } => text_content.push_str(&text), + MessagePart::Image { .. } => {} + } + } + } + None => {} + }; + + ChoiceDelta { + index: choice.index, + delta: ResponseMessageDelta { + role: Some(match choice.message { + RequestMessage::Assistant { .. } => Role::Assistant, + RequestMessage::User { .. } => Role::User, + RequestMessage::System { .. } => Role::System, + RequestMessage::Tool { .. } => Role::Tool, + }), + content: if text_content.is_empty() { + None + } else { + Some(text_content) + }, + tool_calls: None, + }, + finish_reason: choice.finish_reason, + } + }) + .collect(), + usage: Some(response.usage), + } +} + pub async fn stream_completion( client: &dyn HttpClient, api_url: &str, api_key: &str, request: Request, ) -> Result>> { + if request.model.starts_with("o1") { + let response = complete(client, api_url, api_key, request).await; + let response_stream_event = response.map(adapt_response_to_stream); + return Ok(stream::once(future::ready(response_stream_event)).boxed()); + } + let uri = format!("{api_url}/chat/completions"); let request_builder = HttpRequest::builder() .method(Method::POST)