responses.rs

  1use std::sync::Arc;
  2
  3use super::copilot_request_headers;
  4use anyhow::{Result, anyhow};
  5use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  6use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  7use serde::{Deserialize, Serialize};
  8use serde_json::Value;
  9pub use settings::OpenAiReasoningEffort as ReasoningEffort;
 10
 11#[derive(Serialize, Debug)]
 12pub struct Request {
 13    pub model: String,
 14    pub input: Vec<ResponseInputItem>,
 15    #[serde(default)]
 16    pub stream: bool,
 17    #[serde(skip_serializing_if = "Option::is_none")]
 18    pub temperature: Option<f32>,
 19    #[serde(skip_serializing_if = "Vec::is_empty")]
 20    pub tools: Vec<ToolDefinition>,
 21    #[serde(skip_serializing_if = "Option::is_none")]
 22    pub tool_choice: Option<ToolChoice>,
 23    #[serde(skip_serializing_if = "Option::is_none")]
 24    pub reasoning: Option<ReasoningConfig>,
 25    #[serde(skip_serializing_if = "Option::is_none")]
 26    pub include: Option<Vec<ResponseIncludable>>,
 27}
 28
 29#[derive(Serialize, Deserialize, Debug, Clone)]
 30#[serde(rename_all = "snake_case")]
 31pub enum ResponseIncludable {
 32    #[serde(rename = "reasoning.encrypted_content")]
 33    ReasoningEncryptedContent,
 34}
 35
 36#[derive(Serialize, Deserialize, Debug)]
 37#[serde(tag = "type", rename_all = "snake_case")]
 38pub enum ToolDefinition {
 39    Function {
 40        name: String,
 41        #[serde(skip_serializing_if = "Option::is_none")]
 42        description: Option<String>,
 43        #[serde(skip_serializing_if = "Option::is_none")]
 44        parameters: Option<Value>,
 45        #[serde(skip_serializing_if = "Option::is_none")]
 46        strict: Option<bool>,
 47    },
 48}
 49
 50#[derive(Serialize, Deserialize, Debug)]
 51#[serde(rename_all = "lowercase")]
 52pub enum ToolChoice {
 53    Auto,
 54    Any,
 55    None,
 56    #[serde(untagged)]
 57    Other(ToolDefinition),
 58}
 59
 60#[derive(Serialize, Deserialize, Debug)]
 61#[serde(rename_all = "lowercase")]
 62pub enum ReasoningSummary {
 63    Auto,
 64    Concise,
 65    Detailed,
 66}
 67
 68#[derive(Serialize, Debug)]
 69pub struct ReasoningConfig {
 70    pub effort: ReasoningEffort,
 71    #[serde(skip_serializing_if = "Option::is_none")]
 72    pub summary: Option<ReasoningSummary>,
 73}
 74
 75#[derive(Serialize, Deserialize, Debug, Clone, Default)]
 76#[serde(rename_all = "snake_case")]
 77pub enum ResponseImageDetail {
 78    Low,
 79    High,
 80    #[default]
 81    Auto,
 82}
 83
 84#[derive(Serialize, Deserialize, Debug, Clone)]
 85#[serde(tag = "type", rename_all = "snake_case")]
 86pub enum ResponseInputContent {
 87    InputText {
 88        text: String,
 89    },
 90    OutputText {
 91        text: String,
 92    },
 93    InputImage {
 94        #[serde(skip_serializing_if = "Option::is_none")]
 95        image_url: Option<String>,
 96        #[serde(default)]
 97        detail: ResponseImageDetail,
 98    },
 99}
100
101#[derive(Serialize, Deserialize, Debug, Clone)]
102#[serde(rename_all = "snake_case")]
103pub enum ItemStatus {
104    InProgress,
105    Completed,
106    Incomplete,
107}
108
109#[derive(Serialize, Deserialize, Debug, Clone)]
110#[serde(untagged)]
111pub enum ResponseFunctionOutput {
112    Text(String),
113    Content(Vec<ResponseInputContent>),
114}
115
116#[derive(Serialize, Deserialize, Debug, Clone)]
117#[serde(tag = "type", rename_all = "snake_case")]
118pub enum ResponseInputItem {
119    Message {
120        role: String,
121        #[serde(skip_serializing_if = "Option::is_none")]
122        content: Option<Vec<ResponseInputContent>>,
123        #[serde(skip_serializing_if = "Option::is_none")]
124        status: Option<String>,
125    },
126    FunctionCall {
127        call_id: String,
128        name: String,
129        arguments: String,
130        #[serde(skip_serializing_if = "Option::is_none")]
131        status: Option<ItemStatus>,
132        #[serde(default, skip_serializing_if = "Option::is_none")]
133        thought_signature: Option<String>,
134    },
135    FunctionCallOutput {
136        call_id: String,
137        output: ResponseFunctionOutput,
138        #[serde(skip_serializing_if = "Option::is_none")]
139        status: Option<ItemStatus>,
140    },
141    Reasoning {
142        #[serde(skip_serializing_if = "Option::is_none")]
143        id: Option<String>,
144        summary: Vec<ResponseReasoningItem>,
145        encrypted_content: String,
146    },
147}
148
149#[derive(Deserialize, Debug, Clone)]
150#[serde(rename_all = "snake_case")]
151pub enum IncompleteReason {
152    #[serde(rename = "max_output_tokens")]
153    MaxOutputTokens,
154    #[serde(rename = "content_filter")]
155    ContentFilter,
156}
157
158#[derive(Deserialize, Debug, Clone)]
159pub struct IncompleteDetails {
160    #[serde(skip_serializing_if = "Option::is_none")]
161    pub reason: Option<IncompleteReason>,
162}
163
164#[derive(Serialize, Deserialize, Debug, Clone)]
165pub struct ResponseReasoningItem {
166    #[serde(rename = "type")]
167    pub kind: String,
168    pub text: String,
169}
170
171#[derive(Deserialize, Debug)]
172#[serde(tag = "type")]
173pub enum StreamEvent {
174    #[serde(rename = "error")]
175    GenericError { error: ResponseError },
176
177    #[serde(rename = "response.created")]
178    Created { response: Response },
179
180    #[serde(rename = "response.output_item.added")]
181    OutputItemAdded {
182        output_index: usize,
183        #[serde(default)]
184        sequence_number: Option<u64>,
185        item: ResponseOutputItem,
186    },
187
188    #[serde(rename = "response.output_text.delta")]
189    OutputTextDelta {
190        item_id: String,
191        output_index: usize,
192        delta: String,
193    },
194
195    #[serde(rename = "response.output_item.done")]
196    OutputItemDone {
197        output_index: usize,
198        #[serde(default)]
199        sequence_number: Option<u64>,
200        item: ResponseOutputItem,
201    },
202
203    #[serde(rename = "response.incomplete")]
204    Incomplete { response: Response },
205
206    #[serde(rename = "response.completed")]
207    Completed { response: Response },
208
209    #[serde(rename = "response.failed")]
210    Failed { response: Response },
211
212    #[serde(other)]
213    Unknown,
214}
215
216#[derive(Deserialize, Debug, Clone)]
217pub struct ResponseError {
218    pub code: String,
219    pub message: String,
220}
221
222#[derive(Deserialize, Debug, Default, Clone)]
223pub struct Response {
224    pub id: Option<String>,
225    pub status: Option<String>,
226    pub usage: Option<ResponseUsage>,
227    pub output: Vec<ResponseOutputItem>,
228    #[serde(skip_serializing_if = "Option::is_none")]
229    pub incomplete_details: Option<IncompleteDetails>,
230    #[serde(skip_serializing_if = "Option::is_none")]
231    pub error: Option<ResponseError>,
232}
233
234#[derive(Deserialize, Debug, Default, Clone)]
235pub struct ResponseUsage {
236    pub input_tokens: Option<u64>,
237    pub output_tokens: Option<u64>,
238    pub total_tokens: Option<u64>,
239}
240
241#[derive(Deserialize, Debug, Clone)]
242#[serde(tag = "type", rename_all = "snake_case")]
243pub enum ResponseOutputItem {
244    Message {
245        id: String,
246        role: String,
247        #[serde(skip_serializing_if = "Option::is_none")]
248        content: Option<Vec<ResponseOutputContent>>,
249    },
250    FunctionCall {
251        #[serde(skip_serializing_if = "Option::is_none")]
252        id: Option<String>,
253        call_id: String,
254        name: String,
255        arguments: String,
256        #[serde(skip_serializing_if = "Option::is_none")]
257        status: Option<ItemStatus>,
258        #[serde(default, skip_serializing_if = "Option::is_none")]
259        thought_signature: Option<String>,
260    },
261    Reasoning {
262        id: String,
263        #[serde(skip_serializing_if = "Option::is_none")]
264        summary: Option<Vec<ResponseReasoningItem>>,
265        #[serde(skip_serializing_if = "Option::is_none")]
266        encrypted_content: Option<String>,
267    },
268}
269
270#[derive(Deserialize, Debug, Clone)]
271#[serde(tag = "type", rename_all = "snake_case")]
272pub enum ResponseOutputContent {
273    OutputText { text: String },
274    Refusal { refusal: String },
275}
276
277pub async fn stream_response(
278    client: Arc<dyn HttpClient>,
279    oauth_token: String,
280    api_url: String,
281    request: Request,
282    is_user_initiated: bool,
283) -> Result<BoxStream<'static, Result<StreamEvent>>> {
284    let is_vision_request = request.input.iter().any(|item| match item {
285        ResponseInputItem::Message {
286            content: Some(parts),
287            ..
288        } => parts
289            .iter()
290            .any(|p| matches!(p, ResponseInputContent::InputImage { .. })),
291        _ => false,
292    });
293
294    let request_builder = copilot_request_headers(
295        HttpRequest::builder().method(Method::POST).uri(&api_url),
296        &oauth_token,
297        Some(is_user_initiated),
298    );
299
300    let request_builder = if is_vision_request {
301        request_builder.header("Copilot-Vision-Request", "true")
302    } else {
303        request_builder
304    };
305
306    let is_streaming = request.stream;
307    let json = serde_json::to_string(&request)?;
308    let request = request_builder.body(AsyncBody::from(json))?;
309    let mut response = client.send(request).await?;
310
311    if !response.status().is_success() {
312        let mut body = String::new();
313        response.body_mut().read_to_string(&mut body).await?;
314        anyhow::bail!("Failed to connect to API: {} {}", response.status(), body);
315    }
316
317    if is_streaming {
318        let reader = BufReader::new(response.into_body());
319        Ok(reader
320            .lines()
321            .filter_map(|line| async move {
322                match line {
323                    Ok(line) => {
324                        let line = line.strip_prefix("data: ")?;
325                        if line.starts_with("[DONE]") || line.is_empty() {
326                            return None;
327                        }
328
329                        match serde_json::from_str::<StreamEvent>(line) {
330                            Ok(event) => Some(Ok(event)),
331                            Err(error) => {
332                                log::error!(
333                                    "Failed to parse Copilot responses stream event: `{}`\nResponse: `{}`",
334                                    error,
335                                    line,
336                                );
337                                Some(Err(anyhow!(error)))
338                            }
339                        }
340                    }
341                    Err(error) => Some(Err(anyhow!(error))),
342                }
343            })
344            .boxed())
345    } else {
346        // Simulate streaming this makes the mapping of this function return more straight-forward to handle if all callers assume it streams.
347        // Removes the need of having a method to map StreamEvent and another to map Response to a LanguageCompletionEvent
348        let mut body = String::new();
349        response.body_mut().read_to_string(&mut body).await?;
350
351        match serde_json::from_str::<Response>(&body) {
352            Ok(response) => {
353                let events = vec![StreamEvent::Created {
354                    response: response.clone(),
355                }];
356
357                let mut all_events = events;
358                for (output_index, item) in response.output.iter().enumerate() {
359                    all_events.push(StreamEvent::OutputItemAdded {
360                        output_index,
361                        sequence_number: None,
362                        item: item.clone(),
363                    });
364
365                    if let ResponseOutputItem::Message {
366                        id,
367                        content: Some(content),
368                        ..
369                    } = item
370                    {
371                        for part in content {
372                            if let ResponseOutputContent::OutputText { text } = part {
373                                all_events.push(StreamEvent::OutputTextDelta {
374                                    item_id: id.clone(),
375                                    output_index,
376                                    delta: text.clone(),
377                                });
378                            }
379                        }
380                    }
381
382                    all_events.push(StreamEvent::OutputItemDone {
383                        output_index,
384                        sequence_number: None,
385                        item: item.clone(),
386                    });
387                }
388
389                let final_event = if response.error.is_some() {
390                    StreamEvent::Failed { response }
391                } else if response.incomplete_details.is_some() {
392                    StreamEvent::Incomplete { response }
393                } else {
394                    StreamEvent::Completed { response }
395                };
396                all_events.push(final_event);
397
398                Ok(futures::stream::iter(all_events.into_iter().map(Ok)).boxed())
399            }
400            Err(error) => {
401                log::error!(
402                    "Failed to parse Copilot non-streaming response: `{}`\nResponse: `{}`",
403                    error,
404                    body,
405                );
406                Err(anyhow!(error))
407            }
408        }
409    }
410}