responses.rs

  1use anyhow::{Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6
  7use crate::{ReasoningEffort, RequestError, Role, ToolChoice};
  8
  9#[derive(Serialize, Debug)]
 10pub struct Request {
 11    pub model: String,
 12    #[serde(skip_serializing_if = "Vec::is_empty")]
 13    pub input: Vec<ResponseInputItem>,
 14    #[serde(default)]
 15    pub stream: bool,
 16    #[serde(skip_serializing_if = "Option::is_none")]
 17    pub temperature: Option<f32>,
 18    #[serde(skip_serializing_if = "Option::is_none")]
 19    pub top_p: Option<f32>,
 20    #[serde(skip_serializing_if = "Option::is_none")]
 21    pub max_output_tokens: Option<u64>,
 22    #[serde(skip_serializing_if = "Option::is_none")]
 23    pub parallel_tool_calls: Option<bool>,
 24    #[serde(skip_serializing_if = "Option::is_none")]
 25    pub tool_choice: Option<ToolChoice>,
 26    #[serde(skip_serializing_if = "Vec::is_empty")]
 27    pub tools: Vec<ToolDefinition>,
 28    #[serde(skip_serializing_if = "Option::is_none")]
 29    pub prompt_cache_key: Option<String>,
 30    #[serde(skip_serializing_if = "Option::is_none")]
 31    pub reasoning: Option<ReasoningConfig>,
 32}
 33
 34#[derive(Debug, Serialize, Deserialize)]
 35#[serde(tag = "type", rename_all = "snake_case")]
 36pub enum ResponseInputItem {
 37    Message(ResponseMessageItem),
 38    FunctionCall(ResponseFunctionCallItem),
 39    FunctionCallOutput(ResponseFunctionCallOutputItem),
 40}
 41
 42#[derive(Debug, Serialize, Deserialize)]
 43pub struct ResponseMessageItem {
 44    pub role: Role,
 45    pub content: Vec<ResponseInputContent>,
 46}
 47
 48#[derive(Debug, Serialize, Deserialize)]
 49pub struct ResponseFunctionCallItem {
 50    pub call_id: String,
 51    pub name: String,
 52    pub arguments: String,
 53}
 54
 55#[derive(Debug, Serialize, Deserialize)]
 56pub struct ResponseFunctionCallOutputItem {
 57    pub call_id: String,
 58    pub output: String,
 59}
 60
 61#[derive(Debug, Clone, Serialize, Deserialize)]
 62#[serde(tag = "type")]
 63pub enum ResponseInputContent {
 64    #[serde(rename = "input_text")]
 65    Text { text: String },
 66    #[serde(rename = "input_image")]
 67    Image { image_url: String },
 68    #[serde(rename = "output_text")]
 69    OutputText {
 70        text: String,
 71        #[serde(default)]
 72        annotations: Vec<serde_json::Value>,
 73    },
 74    #[serde(rename = "refusal")]
 75    Refusal { refusal: String },
 76}
 77
 78#[derive(Serialize, Debug)]
 79pub struct ReasoningConfig {
 80    pub effort: ReasoningEffort,
 81}
 82
 83#[derive(Serialize, Debug)]
 84#[serde(tag = "type", rename_all = "snake_case")]
 85pub enum ToolDefinition {
 86    Function {
 87        name: String,
 88        #[serde(skip_serializing_if = "Option::is_none")]
 89        description: Option<String>,
 90        #[serde(skip_serializing_if = "Option::is_none")]
 91        parameters: Option<Value>,
 92        #[serde(skip_serializing_if = "Option::is_none")]
 93        strict: Option<bool>,
 94    },
 95}
 96
 97#[derive(Deserialize, Debug)]
 98pub struct Error {
 99    pub message: String,
100}
101
102#[derive(Deserialize, Debug)]
103#[serde(tag = "type")]
104pub enum StreamEvent {
105    #[serde(rename = "response.created")]
106    Created { response: ResponseSummary },
107    #[serde(rename = "response.in_progress")]
108    InProgress { response: ResponseSummary },
109    #[serde(rename = "response.output_item.added")]
110    OutputItemAdded {
111        output_index: usize,
112        #[serde(default)]
113        sequence_number: Option<u64>,
114        item: ResponseOutputItem,
115    },
116    #[serde(rename = "response.output_item.done")]
117    OutputItemDone {
118        output_index: usize,
119        #[serde(default)]
120        sequence_number: Option<u64>,
121        item: ResponseOutputItem,
122    },
123    #[serde(rename = "response.content_part.added")]
124    ContentPartAdded {
125        item_id: String,
126        output_index: usize,
127        content_index: usize,
128        part: Value,
129    },
130    #[serde(rename = "response.content_part.done")]
131    ContentPartDone {
132        item_id: String,
133        output_index: usize,
134        content_index: usize,
135        part: Value,
136    },
137    #[serde(rename = "response.output_text.delta")]
138    OutputTextDelta {
139        item_id: String,
140        output_index: usize,
141        #[serde(default)]
142        content_index: Option<usize>,
143        delta: String,
144    },
145    #[serde(rename = "response.output_text.done")]
146    OutputTextDone {
147        item_id: String,
148        output_index: usize,
149        #[serde(default)]
150        content_index: Option<usize>,
151        text: String,
152    },
153    #[serde(rename = "response.function_call_arguments.delta")]
154    FunctionCallArgumentsDelta {
155        item_id: String,
156        output_index: usize,
157        delta: String,
158        #[serde(default)]
159        sequence_number: Option<u64>,
160    },
161    #[serde(rename = "response.function_call_arguments.done")]
162    FunctionCallArgumentsDone {
163        item_id: String,
164        output_index: usize,
165        arguments: String,
166        #[serde(default)]
167        sequence_number: Option<u64>,
168    },
169    #[serde(rename = "response.completed")]
170    Completed { response: ResponseSummary },
171    #[serde(rename = "response.incomplete")]
172    Incomplete { response: ResponseSummary },
173    #[serde(rename = "response.failed")]
174    Failed { response: ResponseSummary },
175    #[serde(rename = "response.error")]
176    Error { error: Error },
177    #[serde(rename = "error")]
178    GenericError { error: Error },
179    #[serde(other)]
180    Unknown,
181}
182
183#[derive(Deserialize, Debug, Default, Clone)]
184pub struct ResponseSummary {
185    #[serde(default)]
186    pub id: Option<String>,
187    #[serde(default)]
188    pub status: Option<String>,
189    #[serde(default)]
190    pub status_details: Option<ResponseStatusDetails>,
191    #[serde(default)]
192    pub usage: Option<ResponseUsage>,
193    #[serde(default)]
194    pub output: Vec<ResponseOutputItem>,
195}
196
197#[derive(Deserialize, Debug, Default, Clone)]
198pub struct ResponseStatusDetails {
199    #[serde(default)]
200    pub reason: Option<String>,
201    #[serde(default)]
202    pub r#type: Option<String>,
203    #[serde(default)]
204    pub error: Option<Value>,
205}
206
207#[derive(Deserialize, Debug, Default, Clone)]
208pub struct ResponseUsage {
209    #[serde(default)]
210    pub input_tokens: Option<u64>,
211    #[serde(default)]
212    pub output_tokens: Option<u64>,
213    #[serde(default)]
214    pub total_tokens: Option<u64>,
215}
216
217#[derive(Deserialize, Debug, Clone)]
218#[serde(tag = "type", rename_all = "snake_case")]
219pub enum ResponseOutputItem {
220    Message(ResponseOutputMessage),
221    FunctionCall(ResponseFunctionToolCall),
222    #[serde(other)]
223    Unknown,
224}
225
226#[derive(Deserialize, Debug, Clone)]
227pub struct ResponseOutputMessage {
228    #[serde(default)]
229    pub id: Option<String>,
230    #[serde(default)]
231    pub content: Vec<Value>,
232    #[serde(default)]
233    pub role: Option<String>,
234    #[serde(default)]
235    pub status: Option<String>,
236}
237
238#[derive(Deserialize, Debug, Clone)]
239pub struct ResponseFunctionToolCall {
240    #[serde(default)]
241    pub id: Option<String>,
242    #[serde(default)]
243    pub arguments: String,
244    #[serde(default)]
245    pub call_id: Option<String>,
246    #[serde(default)]
247    pub name: Option<String>,
248    #[serde(default)]
249    pub status: Option<String>,
250}
251
252pub async fn stream_response(
253    client: &dyn HttpClient,
254    provider_name: &str,
255    api_url: &str,
256    api_key: &str,
257    request: Request,
258) -> Result<BoxStream<'static, Result<StreamEvent>>, RequestError> {
259    let uri = format!("{api_url}/responses");
260    let request_builder = HttpRequest::builder()
261        .method(Method::POST)
262        .uri(uri)
263        .header("Content-Type", "application/json")
264        .header("Authorization", format!("Bearer {}", api_key.trim()));
265
266    let is_streaming = request.stream;
267    let request = request_builder
268        .body(AsyncBody::from(
269            serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
270        ))
271        .map_err(|e| RequestError::Other(e.into()))?;
272
273    let mut response = client.send(request).await?;
274    if response.status().is_success() {
275        if is_streaming {
276            let reader = BufReader::new(response.into_body());
277            Ok(reader
278                .lines()
279                .filter_map(|line| async move {
280                    match line {
281                        Ok(line) => {
282                            let line = line
283                                .strip_prefix("data: ")
284                                .or_else(|| line.strip_prefix("data:"))?;
285                            if line == "[DONE]" || line.is_empty() {
286                                None
287                            } else {
288                                match serde_json::from_str::<StreamEvent>(line) {
289                                    Ok(event) => Some(Ok(event)),
290                                    Err(error) => {
291                                        log::error!(
292                                            "Failed to parse OpenAI responses stream event: `{}`\nResponse: `{}`",
293                                            error,
294                                            line,
295                                        );
296                                        Some(Err(anyhow!(error)))
297                                    }
298                                }
299                            }
300                        }
301                        Err(error) => Some(Err(anyhow!(error))),
302                    }
303                })
304                .boxed())
305        } else {
306            let mut body = String::new();
307            response
308                .body_mut()
309                .read_to_string(&mut body)
310                .await
311                .map_err(|e| RequestError::Other(e.into()))?;
312
313            match serde_json::from_str::<ResponseSummary>(&body) {
314                Ok(response_summary) => {
315                    let events = vec![
316                        StreamEvent::Created {
317                            response: response_summary.clone(),
318                        },
319                        StreamEvent::InProgress {
320                            response: response_summary.clone(),
321                        },
322                    ];
323
324                    let mut all_events = events;
325                    for (output_index, item) in response_summary.output.iter().enumerate() {
326                        all_events.push(StreamEvent::OutputItemAdded {
327                            output_index,
328                            sequence_number: None,
329                            item: item.clone(),
330                        });
331
332                        match item {
333                            ResponseOutputItem::Message(message) => {
334                                for content_item in &message.content {
335                                    if let Some(text) = content_item.get("text") {
336                                        if let Some(text_str) = text.as_str() {
337                                            if let Some(ref item_id) = message.id {
338                                                all_events.push(StreamEvent::OutputTextDelta {
339                                                    item_id: item_id.clone(),
340                                                    output_index,
341                                                    content_index: None,
342                                                    delta: text_str.to_string(),
343                                                });
344                                            }
345                                        }
346                                    }
347                                }
348                            }
349                            ResponseOutputItem::FunctionCall(function_call) => {
350                                if let Some(ref item_id) = function_call.id {
351                                    all_events.push(StreamEvent::FunctionCallArgumentsDone {
352                                        item_id: item_id.clone(),
353                                        output_index,
354                                        arguments: function_call.arguments.clone(),
355                                        sequence_number: None,
356                                    });
357                                }
358                            }
359                            ResponseOutputItem::Unknown => {}
360                        }
361
362                        all_events.push(StreamEvent::OutputItemDone {
363                            output_index,
364                            sequence_number: None,
365                            item: item.clone(),
366                        });
367                    }
368
369                    all_events.push(StreamEvent::Completed {
370                        response: response_summary,
371                    });
372
373                    Ok(futures::stream::iter(all_events.into_iter().map(Ok)).boxed())
374                }
375                Err(error) => {
376                    log::error!(
377                        "Failed to parse OpenAI non-streaming response: `{}`\nResponse: `{}`",
378                        error,
379                        body,
380                    );
381                    Err(RequestError::Other(anyhow!(error)))
382                }
383            }
384        }
385    } else {
386        let mut body = String::new();
387        response
388            .body_mut()
389            .read_to_string(&mut body)
390            .await
391            .map_err(|e| RequestError::Other(e.into()))?;
392
393        Err(RequestError::HttpResponseError {
394            provider: provider_name.to_owned(),
395            status_code: response.status(),
396            body,
397            headers: response.headers().clone(),
398        })
399    }
400}