1use anyhow::{Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6
7use crate::{ReasoningEffort, RequestError, Role, ToolChoice};
8
9#[derive(Serialize, Debug)]
10pub struct Request {
11 pub model: String,
12 #[serde(skip_serializing_if = "Vec::is_empty")]
13 pub input: Vec<ResponseInputItem>,
14 #[serde(default)]
15 pub stream: bool,
16 #[serde(skip_serializing_if = "Option::is_none")]
17 pub temperature: Option<f32>,
18 #[serde(skip_serializing_if = "Option::is_none")]
19 pub top_p: Option<f32>,
20 #[serde(skip_serializing_if = "Option::is_none")]
21 pub max_output_tokens: Option<u64>,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 pub parallel_tool_calls: Option<bool>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub tool_choice: Option<ToolChoice>,
26 #[serde(skip_serializing_if = "Vec::is_empty")]
27 pub tools: Vec<ToolDefinition>,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub prompt_cache_key: Option<String>,
30 #[serde(skip_serializing_if = "Option::is_none")]
31 pub reasoning: Option<ReasoningConfig>,
32}
33
34#[derive(Debug, Serialize, Deserialize)]
35#[serde(tag = "type", rename_all = "snake_case")]
36pub enum ResponseInputItem {
37 Message(ResponseMessageItem),
38 FunctionCall(ResponseFunctionCallItem),
39 FunctionCallOutput(ResponseFunctionCallOutputItem),
40}
41
42#[derive(Debug, Serialize, Deserialize)]
43pub struct ResponseMessageItem {
44 pub role: Role,
45 pub content: Vec<ResponseInputContent>,
46}
47
48#[derive(Debug, Serialize, Deserialize)]
49pub struct ResponseFunctionCallItem {
50 pub call_id: String,
51 pub name: String,
52 pub arguments: String,
53}
54
55#[derive(Debug, Serialize, Deserialize)]
56pub struct ResponseFunctionCallOutputItem {
57 pub call_id: String,
58 pub output: String,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
62#[serde(tag = "type")]
63pub enum ResponseInputContent {
64 #[serde(rename = "input_text")]
65 Text { text: String },
66 #[serde(rename = "input_image")]
67 Image { image_url: String },
68 #[serde(rename = "output_text")]
69 OutputText {
70 text: String,
71 #[serde(default)]
72 annotations: Vec<serde_json::Value>,
73 },
74 #[serde(rename = "refusal")]
75 Refusal { refusal: String },
76}
77
78#[derive(Serialize, Debug)]
79pub struct ReasoningConfig {
80 pub effort: ReasoningEffort,
81}
82
83#[derive(Serialize, Debug)]
84#[serde(tag = "type", rename_all = "snake_case")]
85pub enum ToolDefinition {
86 Function {
87 name: String,
88 #[serde(skip_serializing_if = "Option::is_none")]
89 description: Option<String>,
90 #[serde(skip_serializing_if = "Option::is_none")]
91 parameters: Option<Value>,
92 #[serde(skip_serializing_if = "Option::is_none")]
93 strict: Option<bool>,
94 },
95}
96
97#[derive(Deserialize, Debug)]
98pub struct Error {
99 pub message: String,
100}
101
102#[derive(Deserialize, Debug)]
103#[serde(tag = "type")]
104pub enum StreamEvent {
105 #[serde(rename = "response.created")]
106 Created { response: ResponseSummary },
107 #[serde(rename = "response.in_progress")]
108 InProgress { response: ResponseSummary },
109 #[serde(rename = "response.output_item.added")]
110 OutputItemAdded {
111 output_index: usize,
112 #[serde(default)]
113 sequence_number: Option<u64>,
114 item: ResponseOutputItem,
115 },
116 #[serde(rename = "response.output_item.done")]
117 OutputItemDone {
118 output_index: usize,
119 #[serde(default)]
120 sequence_number: Option<u64>,
121 item: ResponseOutputItem,
122 },
123 #[serde(rename = "response.content_part.added")]
124 ContentPartAdded {
125 item_id: String,
126 output_index: usize,
127 content_index: usize,
128 part: Value,
129 },
130 #[serde(rename = "response.content_part.done")]
131 ContentPartDone {
132 item_id: String,
133 output_index: usize,
134 content_index: usize,
135 part: Value,
136 },
137 #[serde(rename = "response.output_text.delta")]
138 OutputTextDelta {
139 item_id: String,
140 output_index: usize,
141 #[serde(default)]
142 content_index: Option<usize>,
143 delta: String,
144 },
145 #[serde(rename = "response.output_text.done")]
146 OutputTextDone {
147 item_id: String,
148 output_index: usize,
149 #[serde(default)]
150 content_index: Option<usize>,
151 text: String,
152 },
153 #[serde(rename = "response.function_call_arguments.delta")]
154 FunctionCallArgumentsDelta {
155 item_id: String,
156 output_index: usize,
157 delta: String,
158 #[serde(default)]
159 sequence_number: Option<u64>,
160 },
161 #[serde(rename = "response.function_call_arguments.done")]
162 FunctionCallArgumentsDone {
163 item_id: String,
164 output_index: usize,
165 arguments: String,
166 #[serde(default)]
167 sequence_number: Option<u64>,
168 },
169 #[serde(rename = "response.completed")]
170 Completed { response: ResponseSummary },
171 #[serde(rename = "response.incomplete")]
172 Incomplete { response: ResponseSummary },
173 #[serde(rename = "response.failed")]
174 Failed { response: ResponseSummary },
175 #[serde(rename = "response.error")]
176 Error { error: Error },
177 #[serde(rename = "error")]
178 GenericError { error: Error },
179 #[serde(other)]
180 Unknown,
181}
182
183#[derive(Deserialize, Debug, Default, Clone)]
184pub struct ResponseSummary {
185 #[serde(default)]
186 pub id: Option<String>,
187 #[serde(default)]
188 pub status: Option<String>,
189 #[serde(default)]
190 pub status_details: Option<ResponseStatusDetails>,
191 #[serde(default)]
192 pub usage: Option<ResponseUsage>,
193 #[serde(default)]
194 pub output: Vec<ResponseOutputItem>,
195}
196
197#[derive(Deserialize, Debug, Default, Clone)]
198pub struct ResponseStatusDetails {
199 #[serde(default)]
200 pub reason: Option<String>,
201 #[serde(default)]
202 pub r#type: Option<String>,
203 #[serde(default)]
204 pub error: Option<Value>,
205}
206
207#[derive(Deserialize, Debug, Default, Clone)]
208pub struct ResponseUsage {
209 #[serde(default)]
210 pub input_tokens: Option<u64>,
211 #[serde(default)]
212 pub output_tokens: Option<u64>,
213 #[serde(default)]
214 pub total_tokens: Option<u64>,
215}
216
217#[derive(Deserialize, Debug, Clone)]
218#[serde(tag = "type", rename_all = "snake_case")]
219pub enum ResponseOutputItem {
220 Message(ResponseOutputMessage),
221 FunctionCall(ResponseFunctionToolCall),
222 #[serde(other)]
223 Unknown,
224}
225
226#[derive(Deserialize, Debug, Clone)]
227pub struct ResponseOutputMessage {
228 #[serde(default)]
229 pub id: Option<String>,
230 #[serde(default)]
231 pub content: Vec<Value>,
232 #[serde(default)]
233 pub role: Option<String>,
234 #[serde(default)]
235 pub status: Option<String>,
236}
237
238#[derive(Deserialize, Debug, Clone)]
239pub struct ResponseFunctionToolCall {
240 #[serde(default)]
241 pub id: Option<String>,
242 #[serde(default)]
243 pub arguments: String,
244 #[serde(default)]
245 pub call_id: Option<String>,
246 #[serde(default)]
247 pub name: Option<String>,
248 #[serde(default)]
249 pub status: Option<String>,
250}
251
252pub async fn stream_response(
253 client: &dyn HttpClient,
254 provider_name: &str,
255 api_url: &str,
256 api_key: &str,
257 request: Request,
258) -> Result<BoxStream<'static, Result<StreamEvent>>, RequestError> {
259 let uri = format!("{api_url}/responses");
260 let request_builder = HttpRequest::builder()
261 .method(Method::POST)
262 .uri(uri)
263 .header("Content-Type", "application/json")
264 .header("Authorization", format!("Bearer {}", api_key.trim()));
265
266 let is_streaming = request.stream;
267 let request = request_builder
268 .body(AsyncBody::from(
269 serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
270 ))
271 .map_err(|e| RequestError::Other(e.into()))?;
272
273 let mut response = client.send(request).await?;
274 if response.status().is_success() {
275 if is_streaming {
276 let reader = BufReader::new(response.into_body());
277 Ok(reader
278 .lines()
279 .filter_map(|line| async move {
280 match line {
281 Ok(line) => {
282 let line = line
283 .strip_prefix("data: ")
284 .or_else(|| line.strip_prefix("data:"))?;
285 if line == "[DONE]" || line.is_empty() {
286 None
287 } else {
288 match serde_json::from_str::<StreamEvent>(line) {
289 Ok(event) => Some(Ok(event)),
290 Err(error) => {
291 log::error!(
292 "Failed to parse OpenAI responses stream event: `{}`\nResponse: `{}`",
293 error,
294 line,
295 );
296 Some(Err(anyhow!(error)))
297 }
298 }
299 }
300 }
301 Err(error) => Some(Err(anyhow!(error))),
302 }
303 })
304 .boxed())
305 } else {
306 let mut body = String::new();
307 response
308 .body_mut()
309 .read_to_string(&mut body)
310 .await
311 .map_err(|e| RequestError::Other(e.into()))?;
312
313 match serde_json::from_str::<ResponseSummary>(&body) {
314 Ok(response_summary) => {
315 let events = vec![
316 StreamEvent::Created {
317 response: response_summary.clone(),
318 },
319 StreamEvent::InProgress {
320 response: response_summary.clone(),
321 },
322 ];
323
324 let mut all_events = events;
325 for (output_index, item) in response_summary.output.iter().enumerate() {
326 all_events.push(StreamEvent::OutputItemAdded {
327 output_index,
328 sequence_number: None,
329 item: item.clone(),
330 });
331
332 match item {
333 ResponseOutputItem::Message(message) => {
334 for content_item in &message.content {
335 if let Some(text) = content_item.get("text") {
336 if let Some(text_str) = text.as_str() {
337 if let Some(ref item_id) = message.id {
338 all_events.push(StreamEvent::OutputTextDelta {
339 item_id: item_id.clone(),
340 output_index,
341 content_index: None,
342 delta: text_str.to_string(),
343 });
344 }
345 }
346 }
347 }
348 }
349 ResponseOutputItem::FunctionCall(function_call) => {
350 if let Some(ref item_id) = function_call.id {
351 all_events.push(StreamEvent::FunctionCallArgumentsDone {
352 item_id: item_id.clone(),
353 output_index,
354 arguments: function_call.arguments.clone(),
355 sequence_number: None,
356 });
357 }
358 }
359 ResponseOutputItem::Unknown => {}
360 }
361
362 all_events.push(StreamEvent::OutputItemDone {
363 output_index,
364 sequence_number: None,
365 item: item.clone(),
366 });
367 }
368
369 all_events.push(StreamEvent::Completed {
370 response: response_summary,
371 });
372
373 Ok(futures::stream::iter(all_events.into_iter().map(Ok)).boxed())
374 }
375 Err(error) => {
376 log::error!(
377 "Failed to parse OpenAI non-streaming response: `{}`\nResponse: `{}`",
378 error,
379 body,
380 );
381 Err(RequestError::Other(anyhow!(error)))
382 }
383 }
384 }
385 } else {
386 let mut body = String::new();
387 response
388 .body_mut()
389 .read_to_string(&mut body)
390 .await
391 .map_err(|e| RequestError::Other(e.into()))?;
392
393 Err(RequestError::HttpResponseError {
394 provider: provider_name.to_owned(),
395 status_code: response.status(),
396 body,
397 headers: response.headers().clone(),
398 })
399 }
400}