1use anyhow::{Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6
7use crate::{ReasoningEffort, RequestError, ToolChoice};
8
9#[derive(Serialize, Debug)]
10pub struct Request {
11 pub model: String,
12 #[serde(skip_serializing_if = "Vec::is_empty")]
13 pub input: Vec<Value>,
14 #[serde(default)]
15 pub stream: bool,
16 #[serde(skip_serializing_if = "Option::is_none")]
17 pub temperature: Option<f32>,
18 #[serde(skip_serializing_if = "Option::is_none")]
19 pub top_p: Option<f32>,
20 #[serde(skip_serializing_if = "Option::is_none")]
21 pub max_output_tokens: Option<u64>,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 pub parallel_tool_calls: Option<bool>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub tool_choice: Option<ToolChoice>,
26 #[serde(skip_serializing_if = "Vec::is_empty")]
27 pub tools: Vec<ToolDefinition>,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub prompt_cache_key: Option<String>,
30 #[serde(skip_serializing_if = "Option::is_none")]
31 pub reasoning: Option<ReasoningConfig>,
32}
33
34#[derive(Serialize, Debug)]
35pub struct ReasoningConfig {
36 pub effort: ReasoningEffort,
37}
38
39#[derive(Serialize, Debug)]
40#[serde(tag = "type", rename_all = "snake_case")]
41pub enum ToolDefinition {
42 Function {
43 name: String,
44 #[serde(skip_serializing_if = "Option::is_none")]
45 description: Option<String>,
46 #[serde(skip_serializing_if = "Option::is_none")]
47 parameters: Option<Value>,
48 #[serde(skip_serializing_if = "Option::is_none")]
49 strict: Option<bool>,
50 },
51}
52
53#[derive(Deserialize, Debug)]
54pub struct Error {
55 pub message: String,
56}
57
58#[derive(Deserialize, Debug)]
59#[serde(tag = "type")]
60pub enum StreamEvent {
61 #[serde(rename = "response.created")]
62 Created { response: ResponseSummary },
63 #[serde(rename = "response.in_progress")]
64 InProgress { response: ResponseSummary },
65 #[serde(rename = "response.output_item.added")]
66 OutputItemAdded {
67 output_index: usize,
68 #[serde(default)]
69 sequence_number: Option<u64>,
70 item: ResponseOutputItem,
71 },
72 #[serde(rename = "response.output_item.done")]
73 OutputItemDone {
74 output_index: usize,
75 #[serde(default)]
76 sequence_number: Option<u64>,
77 item: ResponseOutputItem,
78 },
79 #[serde(rename = "response.content_part.added")]
80 ContentPartAdded {
81 item_id: String,
82 output_index: usize,
83 content_index: usize,
84 part: Value,
85 },
86 #[serde(rename = "response.content_part.done")]
87 ContentPartDone {
88 item_id: String,
89 output_index: usize,
90 content_index: usize,
91 part: Value,
92 },
93 #[serde(rename = "response.output_text.delta")]
94 OutputTextDelta {
95 item_id: String,
96 output_index: usize,
97 #[serde(default)]
98 content_index: Option<usize>,
99 delta: String,
100 },
101 #[serde(rename = "response.output_text.done")]
102 OutputTextDone {
103 item_id: String,
104 output_index: usize,
105 #[serde(default)]
106 content_index: Option<usize>,
107 text: String,
108 },
109 #[serde(rename = "response.function_call_arguments.delta")]
110 FunctionCallArgumentsDelta {
111 item_id: String,
112 output_index: usize,
113 delta: String,
114 #[serde(default)]
115 sequence_number: Option<u64>,
116 },
117 #[serde(rename = "response.function_call_arguments.done")]
118 FunctionCallArgumentsDone {
119 item_id: String,
120 output_index: usize,
121 arguments: String,
122 #[serde(default)]
123 sequence_number: Option<u64>,
124 },
125 #[serde(rename = "response.completed")]
126 Completed { response: ResponseSummary },
127 #[serde(rename = "response.incomplete")]
128 Incomplete { response: ResponseSummary },
129 #[serde(rename = "response.failed")]
130 Failed { response: ResponseSummary },
131 #[serde(rename = "response.error")]
132 Error { error: Error },
133 #[serde(rename = "error")]
134 GenericError { error: Error },
135 #[serde(other)]
136 Unknown,
137}
138
139#[derive(Deserialize, Debug, Default, Clone)]
140pub struct ResponseSummary {
141 #[serde(default)]
142 pub id: Option<String>,
143 #[serde(default)]
144 pub status: Option<String>,
145 #[serde(default)]
146 pub status_details: Option<ResponseStatusDetails>,
147 #[serde(default)]
148 pub usage: Option<ResponseUsage>,
149 #[serde(default)]
150 pub output: Vec<ResponseOutputItem>,
151}
152
153#[derive(Deserialize, Debug, Default, Clone)]
154pub struct ResponseStatusDetails {
155 #[serde(default)]
156 pub reason: Option<String>,
157 #[serde(default)]
158 pub r#type: Option<String>,
159 #[serde(default)]
160 pub error: Option<Value>,
161}
162
163#[derive(Deserialize, Debug, Default, Clone)]
164pub struct ResponseUsage {
165 #[serde(default)]
166 pub input_tokens: Option<u64>,
167 #[serde(default)]
168 pub output_tokens: Option<u64>,
169 #[serde(default)]
170 pub total_tokens: Option<u64>,
171}
172
173#[derive(Deserialize, Debug, Clone)]
174#[serde(tag = "type", rename_all = "snake_case")]
175pub enum ResponseOutputItem {
176 Message(ResponseOutputMessage),
177 FunctionCall(ResponseFunctionToolCall),
178 #[serde(other)]
179 Unknown,
180}
181
182#[derive(Deserialize, Debug, Clone)]
183pub struct ResponseOutputMessage {
184 #[serde(default)]
185 pub id: Option<String>,
186 #[serde(default)]
187 pub content: Vec<Value>,
188 #[serde(default)]
189 pub role: Option<String>,
190 #[serde(default)]
191 pub status: Option<String>,
192}
193
194#[derive(Deserialize, Debug, Clone)]
195pub struct ResponseFunctionToolCall {
196 #[serde(default)]
197 pub id: Option<String>,
198 #[serde(default)]
199 pub arguments: String,
200 #[serde(default)]
201 pub call_id: Option<String>,
202 #[serde(default)]
203 pub name: Option<String>,
204 #[serde(default)]
205 pub status: Option<String>,
206}
207
208pub async fn stream_response(
209 client: &dyn HttpClient,
210 provider_name: &str,
211 api_url: &str,
212 api_key: &str,
213 request: Request,
214) -> Result<BoxStream<'static, Result<StreamEvent>>, RequestError> {
215 let uri = format!("{api_url}/responses");
216 let request_builder = HttpRequest::builder()
217 .method(Method::POST)
218 .uri(uri)
219 .header("Content-Type", "application/json")
220 .header("Authorization", format!("Bearer {}", api_key.trim()));
221
222 let is_streaming = request.stream;
223 let request = request_builder
224 .body(AsyncBody::from(
225 serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
226 ))
227 .map_err(|e| RequestError::Other(e.into()))?;
228
229 let mut response = client.send(request).await?;
230 if response.status().is_success() {
231 if is_streaming {
232 let reader = BufReader::new(response.into_body());
233 Ok(reader
234 .lines()
235 .filter_map(|line| async move {
236 match line {
237 Ok(line) => {
238 let line = line
239 .strip_prefix("data: ")
240 .or_else(|| line.strip_prefix("data:"))?;
241 if line == "[DONE]" || line.is_empty() {
242 None
243 } else {
244 match serde_json::from_str::<StreamEvent>(line) {
245 Ok(event) => Some(Ok(event)),
246 Err(error) => {
247 log::error!(
248 "Failed to parse OpenAI responses stream event: `{}`\nResponse: `{}`",
249 error,
250 line,
251 );
252 Some(Err(anyhow!(error)))
253 }
254 }
255 }
256 }
257 Err(error) => Some(Err(anyhow!(error))),
258 }
259 })
260 .boxed())
261 } else {
262 let mut body = String::new();
263 response
264 .body_mut()
265 .read_to_string(&mut body)
266 .await
267 .map_err(|e| RequestError::Other(e.into()))?;
268
269 match serde_json::from_str::<ResponseSummary>(&body) {
270 Ok(response_summary) => {
271 let events = vec![
272 StreamEvent::Created {
273 response: response_summary.clone(),
274 },
275 StreamEvent::InProgress {
276 response: response_summary.clone(),
277 },
278 ];
279
280 let mut all_events = events;
281 for (output_index, item) in response_summary.output.iter().enumerate() {
282 all_events.push(StreamEvent::OutputItemAdded {
283 output_index,
284 sequence_number: None,
285 item: item.clone(),
286 });
287
288 match item {
289 ResponseOutputItem::Message(message) => {
290 for content_item in &message.content {
291 if let Some(text) = content_item.get("text") {
292 if let Some(text_str) = text.as_str() {
293 if let Some(ref item_id) = message.id {
294 all_events.push(StreamEvent::OutputTextDelta {
295 item_id: item_id.clone(),
296 output_index,
297 content_index: None,
298 delta: text_str.to_string(),
299 });
300 }
301 }
302 }
303 }
304 }
305 ResponseOutputItem::FunctionCall(function_call) => {
306 if let Some(ref item_id) = function_call.id {
307 all_events.push(StreamEvent::FunctionCallArgumentsDone {
308 item_id: item_id.clone(),
309 output_index,
310 arguments: function_call.arguments.clone(),
311 sequence_number: None,
312 });
313 }
314 }
315 ResponseOutputItem::Unknown => {}
316 }
317
318 all_events.push(StreamEvent::OutputItemDone {
319 output_index,
320 sequence_number: None,
321 item: item.clone(),
322 });
323 }
324
325 all_events.push(StreamEvent::Completed {
326 response: response_summary,
327 });
328
329 Ok(futures::stream::iter(all_events.into_iter().map(Ok)).boxed())
330 }
331 Err(error) => {
332 log::error!(
333 "Failed to parse OpenAI non-streaming response: `{}`\nResponse: `{}`",
334 error,
335 body,
336 );
337 Err(RequestError::Other(anyhow!(error)))
338 }
339 }
340 }
341 } else {
342 let mut body = String::new();
343 response
344 .body_mut()
345 .read_to_string(&mut body)
346 .await
347 .map_err(|e| RequestError::Other(e.into()))?;
348
349 Err(RequestError::HttpResponseError {
350 provider: provider_name.to_owned(),
351 status_code: response.status(),
352 body,
353 headers: response.headers().clone(),
354 })
355 }
356}