responses.rs

  1use anyhow::{Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6
  7use crate::{ReasoningEffort, RequestError, ToolChoice};
  8
  9#[derive(Serialize, Debug)]
 10pub struct Request {
 11    pub model: String,
 12    #[serde(skip_serializing_if = "Vec::is_empty")]
 13    pub input: Vec<Value>,
 14    #[serde(default)]
 15    pub stream: bool,
 16    #[serde(skip_serializing_if = "Option::is_none")]
 17    pub temperature: Option<f32>,
 18    #[serde(skip_serializing_if = "Option::is_none")]
 19    pub top_p: Option<f32>,
 20    #[serde(skip_serializing_if = "Option::is_none")]
 21    pub max_output_tokens: Option<u64>,
 22    #[serde(skip_serializing_if = "Option::is_none")]
 23    pub parallel_tool_calls: Option<bool>,
 24    #[serde(skip_serializing_if = "Option::is_none")]
 25    pub tool_choice: Option<ToolChoice>,
 26    #[serde(skip_serializing_if = "Vec::is_empty")]
 27    pub tools: Vec<ToolDefinition>,
 28    #[serde(skip_serializing_if = "Option::is_none")]
 29    pub prompt_cache_key: Option<String>,
 30    #[serde(skip_serializing_if = "Option::is_none")]
 31    pub reasoning: Option<ReasoningConfig>,
 32}
 33
 34#[derive(Serialize, Debug)]
 35pub struct ReasoningConfig {
 36    pub effort: ReasoningEffort,
 37}
 38
 39#[derive(Serialize, Debug)]
 40#[serde(tag = "type", rename_all = "snake_case")]
 41pub enum ToolDefinition {
 42    Function {
 43        name: String,
 44        #[serde(skip_serializing_if = "Option::is_none")]
 45        description: Option<String>,
 46        #[serde(skip_serializing_if = "Option::is_none")]
 47        parameters: Option<Value>,
 48        #[serde(skip_serializing_if = "Option::is_none")]
 49        strict: Option<bool>,
 50    },
 51}
 52
 53#[derive(Deserialize, Debug)]
 54pub struct Error {
 55    pub message: String,
 56}
 57
 58#[derive(Deserialize, Debug)]
 59#[serde(tag = "type")]
 60pub enum StreamEvent {
 61    #[serde(rename = "response.created")]
 62    Created { response: ResponseSummary },
 63    #[serde(rename = "response.in_progress")]
 64    InProgress { response: ResponseSummary },
 65    #[serde(rename = "response.output_item.added")]
 66    OutputItemAdded {
 67        output_index: usize,
 68        #[serde(default)]
 69        sequence_number: Option<u64>,
 70        item: ResponseOutputItem,
 71    },
 72    #[serde(rename = "response.output_item.done")]
 73    OutputItemDone {
 74        output_index: usize,
 75        #[serde(default)]
 76        sequence_number: Option<u64>,
 77        item: ResponseOutputItem,
 78    },
 79    #[serde(rename = "response.content_part.added")]
 80    ContentPartAdded {
 81        item_id: String,
 82        output_index: usize,
 83        content_index: usize,
 84        part: Value,
 85    },
 86    #[serde(rename = "response.content_part.done")]
 87    ContentPartDone {
 88        item_id: String,
 89        output_index: usize,
 90        content_index: usize,
 91        part: Value,
 92    },
 93    #[serde(rename = "response.output_text.delta")]
 94    OutputTextDelta {
 95        item_id: String,
 96        output_index: usize,
 97        #[serde(default)]
 98        content_index: Option<usize>,
 99        delta: String,
100    },
101    #[serde(rename = "response.output_text.done")]
102    OutputTextDone {
103        item_id: String,
104        output_index: usize,
105        #[serde(default)]
106        content_index: Option<usize>,
107        text: String,
108    },
109    #[serde(rename = "response.function_call_arguments.delta")]
110    FunctionCallArgumentsDelta {
111        item_id: String,
112        output_index: usize,
113        delta: String,
114        #[serde(default)]
115        sequence_number: Option<u64>,
116    },
117    #[serde(rename = "response.function_call_arguments.done")]
118    FunctionCallArgumentsDone {
119        item_id: String,
120        output_index: usize,
121        arguments: String,
122        #[serde(default)]
123        sequence_number: Option<u64>,
124    },
125    #[serde(rename = "response.completed")]
126    Completed { response: ResponseSummary },
127    #[serde(rename = "response.incomplete")]
128    Incomplete { response: ResponseSummary },
129    #[serde(rename = "response.failed")]
130    Failed { response: ResponseSummary },
131    #[serde(rename = "response.error")]
132    Error { error: Error },
133    #[serde(rename = "error")]
134    GenericError { error: Error },
135    #[serde(other)]
136    Unknown,
137}
138
139#[derive(Deserialize, Debug, Default, Clone)]
140pub struct ResponseSummary {
141    #[serde(default)]
142    pub id: Option<String>,
143    #[serde(default)]
144    pub status: Option<String>,
145    #[serde(default)]
146    pub status_details: Option<ResponseStatusDetails>,
147    #[serde(default)]
148    pub usage: Option<ResponseUsage>,
149    #[serde(default)]
150    pub output: Vec<ResponseOutputItem>,
151}
152
153#[derive(Deserialize, Debug, Default, Clone)]
154pub struct ResponseStatusDetails {
155    #[serde(default)]
156    pub reason: Option<String>,
157    #[serde(default)]
158    pub r#type: Option<String>,
159    #[serde(default)]
160    pub error: Option<Value>,
161}
162
163#[derive(Deserialize, Debug, Default, Clone)]
164pub struct ResponseUsage {
165    #[serde(default)]
166    pub input_tokens: Option<u64>,
167    #[serde(default)]
168    pub output_tokens: Option<u64>,
169    #[serde(default)]
170    pub total_tokens: Option<u64>,
171}
172
173#[derive(Deserialize, Debug, Clone)]
174#[serde(tag = "type", rename_all = "snake_case")]
175pub enum ResponseOutputItem {
176    Message(ResponseOutputMessage),
177    FunctionCall(ResponseFunctionToolCall),
178    #[serde(other)]
179    Unknown,
180}
181
182#[derive(Deserialize, Debug, Clone)]
183pub struct ResponseOutputMessage {
184    #[serde(default)]
185    pub id: Option<String>,
186    #[serde(default)]
187    pub content: Vec<Value>,
188    #[serde(default)]
189    pub role: Option<String>,
190    #[serde(default)]
191    pub status: Option<String>,
192}
193
194#[derive(Deserialize, Debug, Clone)]
195pub struct ResponseFunctionToolCall {
196    #[serde(default)]
197    pub id: Option<String>,
198    #[serde(default)]
199    pub arguments: String,
200    #[serde(default)]
201    pub call_id: Option<String>,
202    #[serde(default)]
203    pub name: Option<String>,
204    #[serde(default)]
205    pub status: Option<String>,
206}
207
208pub async fn stream_response(
209    client: &dyn HttpClient,
210    provider_name: &str,
211    api_url: &str,
212    api_key: &str,
213    request: Request,
214) -> Result<BoxStream<'static, Result<StreamEvent>>, RequestError> {
215    let uri = format!("{api_url}/responses");
216    let request_builder = HttpRequest::builder()
217        .method(Method::POST)
218        .uri(uri)
219        .header("Content-Type", "application/json")
220        .header("Authorization", format!("Bearer {}", api_key.trim()));
221
222    let is_streaming = request.stream;
223    let request = request_builder
224        .body(AsyncBody::from(
225            serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
226        ))
227        .map_err(|e| RequestError::Other(e.into()))?;
228
229    let mut response = client.send(request).await?;
230    if response.status().is_success() {
231        if is_streaming {
232            let reader = BufReader::new(response.into_body());
233            Ok(reader
234                .lines()
235                .filter_map(|line| async move {
236                    match line {
237                        Ok(line) => {
238                            let line = line
239                                .strip_prefix("data: ")
240                                .or_else(|| line.strip_prefix("data:"))?;
241                            if line == "[DONE]" || line.is_empty() {
242                                None
243                            } else {
244                                match serde_json::from_str::<StreamEvent>(line) {
245                                    Ok(event) => Some(Ok(event)),
246                                    Err(error) => {
247                                        log::error!(
248                                            "Failed to parse OpenAI responses stream event: `{}`\nResponse: `{}`",
249                                            error,
250                                            line,
251                                        );
252                                        Some(Err(anyhow!(error)))
253                                    }
254                                }
255                            }
256                        }
257                        Err(error) => Some(Err(anyhow!(error))),
258                    }
259                })
260                .boxed())
261        } else {
262            let mut body = String::new();
263            response
264                .body_mut()
265                .read_to_string(&mut body)
266                .await
267                .map_err(|e| RequestError::Other(e.into()))?;
268
269            match serde_json::from_str::<ResponseSummary>(&body) {
270                Ok(response_summary) => {
271                    let events = vec![
272                        StreamEvent::Created {
273                            response: response_summary.clone(),
274                        },
275                        StreamEvent::InProgress {
276                            response: response_summary.clone(),
277                        },
278                    ];
279
280                    let mut all_events = events;
281                    for (output_index, item) in response_summary.output.iter().enumerate() {
282                        all_events.push(StreamEvent::OutputItemAdded {
283                            output_index,
284                            sequence_number: None,
285                            item: item.clone(),
286                        });
287
288                        match item {
289                            ResponseOutputItem::Message(message) => {
290                                for content_item in &message.content {
291                                    if let Some(text) = content_item.get("text") {
292                                        if let Some(text_str) = text.as_str() {
293                                            if let Some(ref item_id) = message.id {
294                                                all_events.push(StreamEvent::OutputTextDelta {
295                                                    item_id: item_id.clone(),
296                                                    output_index,
297                                                    content_index: None,
298                                                    delta: text_str.to_string(),
299                                                });
300                                            }
301                                        }
302                                    }
303                                }
304                            }
305                            ResponseOutputItem::FunctionCall(function_call) => {
306                                if let Some(ref item_id) = function_call.id {
307                                    all_events.push(StreamEvent::FunctionCallArgumentsDone {
308                                        item_id: item_id.clone(),
309                                        output_index,
310                                        arguments: function_call.arguments.clone(),
311                                        sequence_number: None,
312                                    });
313                                }
314                            }
315                            ResponseOutputItem::Unknown => {}
316                        }
317
318                        all_events.push(StreamEvent::OutputItemDone {
319                            output_index,
320                            sequence_number: None,
321                            item: item.clone(),
322                        });
323                    }
324
325                    all_events.push(StreamEvent::Completed {
326                        response: response_summary,
327                    });
328
329                    Ok(futures::stream::iter(all_events.into_iter().map(Ok)).boxed())
330                }
331                Err(error) => {
332                    log::error!(
333                        "Failed to parse OpenAI non-streaming response: `{}`\nResponse: `{}`",
334                        error,
335                        body,
336                    );
337                    Err(RequestError::Other(anyhow!(error)))
338                }
339            }
340        }
341    } else {
342        let mut body = String::new();
343        response
344            .body_mut()
345            .read_to_string(&mut body)
346            .await
347            .map_err(|e| RequestError::Other(e.into()))?;
348
349        Err(RequestError::HttpResponseError {
350            provider: provider_name.to_owned(),
351            status_code: response.status(),
352            body,
353            headers: response.headers().clone(),
354        })
355    }
356}