responses.rs

  1use std::sync::Arc;
  2
  3use super::{ChatLocation, copilot_request_headers};
  4use anyhow::{Result, anyhow};
  5use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  6use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
  7use serde::{Deserialize, Serialize};
  8use serde_json::Value;
  9pub use settings::OpenAiReasoningEffort as ReasoningEffort;
 10
 11#[derive(Serialize, Debug)]
 12pub struct Request {
 13    pub model: String,
 14    pub input: Vec<ResponseInputItem>,
 15    #[serde(default)]
 16    pub stream: bool,
 17    #[serde(skip_serializing_if = "Option::is_none")]
 18    pub temperature: Option<f32>,
 19    #[serde(skip_serializing_if = "Vec::is_empty")]
 20    pub tools: Vec<ToolDefinition>,
 21    #[serde(skip_serializing_if = "Option::is_none")]
 22    pub tool_choice: Option<ToolChoice>,
 23    #[serde(skip_serializing_if = "Option::is_none")]
 24    pub reasoning: Option<ReasoningConfig>,
 25    #[serde(skip_serializing_if = "Option::is_none")]
 26    pub include: Option<Vec<ResponseIncludable>>,
 27    pub store: bool,
 28}
 29
 30#[derive(Serialize, Deserialize, Debug, Clone)]
 31#[serde(rename_all = "snake_case")]
 32pub enum ResponseIncludable {
 33    #[serde(rename = "reasoning.encrypted_content")]
 34    ReasoningEncryptedContent,
 35}
 36
 37#[derive(Serialize, Deserialize, Debug)]
 38#[serde(tag = "type", rename_all = "snake_case")]
 39pub enum ToolDefinition {
 40    Function {
 41        name: String,
 42        #[serde(skip_serializing_if = "Option::is_none")]
 43        description: Option<String>,
 44        #[serde(skip_serializing_if = "Option::is_none")]
 45        parameters: Option<Value>,
 46        #[serde(skip_serializing_if = "Option::is_none")]
 47        strict: Option<bool>,
 48    },
 49}
 50
 51#[derive(Serialize, Deserialize, Debug)]
 52#[serde(rename_all = "lowercase")]
 53pub enum ToolChoice {
 54    Auto,
 55    Any,
 56    None,
 57    #[serde(untagged)]
 58    Other(ToolDefinition),
 59}
 60
 61#[derive(Serialize, Deserialize, Debug)]
 62#[serde(rename_all = "lowercase")]
 63pub enum ReasoningSummary {
 64    Auto,
 65    Concise,
 66    Detailed,
 67}
 68
 69#[derive(Serialize, Debug)]
 70pub struct ReasoningConfig {
 71    pub effort: ReasoningEffort,
 72    #[serde(skip_serializing_if = "Option::is_none")]
 73    pub summary: Option<ReasoningSummary>,
 74}
 75
 76#[derive(Serialize, Deserialize, Debug, Clone, Default)]
 77#[serde(rename_all = "snake_case")]
 78pub enum ResponseImageDetail {
 79    Low,
 80    High,
 81    #[default]
 82    Auto,
 83}
 84
 85#[derive(Serialize, Deserialize, Debug, Clone)]
 86#[serde(tag = "type", rename_all = "snake_case")]
 87pub enum ResponseInputContent {
 88    InputText {
 89        text: String,
 90    },
 91    OutputText {
 92        text: String,
 93    },
 94    InputImage {
 95        #[serde(skip_serializing_if = "Option::is_none")]
 96        image_url: Option<String>,
 97        #[serde(default)]
 98        detail: ResponseImageDetail,
 99    },
100}
101
102#[derive(Serialize, Deserialize, Debug, Clone)]
103#[serde(rename_all = "snake_case")]
104pub enum ItemStatus {
105    InProgress,
106    Completed,
107    Incomplete,
108}
109
110#[derive(Serialize, Deserialize, Debug, Clone)]
111#[serde(untagged)]
112pub enum ResponseFunctionOutput {
113    Text(String),
114    Content(Vec<ResponseInputContent>),
115}
116
117#[derive(Serialize, Deserialize, Debug, Clone)]
118#[serde(tag = "type", rename_all = "snake_case")]
119pub enum ResponseInputItem {
120    Message {
121        role: String,
122        #[serde(skip_serializing_if = "Option::is_none")]
123        content: Option<Vec<ResponseInputContent>>,
124        #[serde(skip_serializing_if = "Option::is_none")]
125        status: Option<String>,
126    },
127    FunctionCall {
128        call_id: String,
129        name: String,
130        arguments: String,
131        #[serde(skip_serializing_if = "Option::is_none")]
132        status: Option<ItemStatus>,
133        #[serde(default, skip_serializing_if = "Option::is_none")]
134        thought_signature: Option<String>,
135    },
136    FunctionCallOutput {
137        call_id: String,
138        output: ResponseFunctionOutput,
139        #[serde(skip_serializing_if = "Option::is_none")]
140        status: Option<ItemStatus>,
141    },
142    Reasoning {
143        #[serde(skip_serializing_if = "Option::is_none")]
144        id: Option<String>,
145        summary: Vec<ResponseReasoningItem>,
146        encrypted_content: String,
147    },
148}
149
150#[derive(Deserialize, Debug, Clone)]
151#[serde(rename_all = "snake_case")]
152pub enum IncompleteReason {
153    #[serde(rename = "max_output_tokens")]
154    MaxOutputTokens,
155    #[serde(rename = "content_filter")]
156    ContentFilter,
157}
158
159#[derive(Deserialize, Debug, Clone)]
160pub struct IncompleteDetails {
161    #[serde(skip_serializing_if = "Option::is_none")]
162    pub reason: Option<IncompleteReason>,
163}
164
165#[derive(Serialize, Deserialize, Debug, Clone)]
166pub struct ResponseReasoningItem {
167    #[serde(rename = "type")]
168    pub kind: String,
169    pub text: String,
170}
171
172#[derive(Deserialize, Debug)]
173#[serde(tag = "type")]
174pub enum StreamEvent {
175    #[serde(rename = "error")]
176    GenericError { error: ResponseError },
177
178    #[serde(rename = "response.created")]
179    Created { response: Response },
180
181    #[serde(rename = "response.output_item.added")]
182    OutputItemAdded {
183        output_index: usize,
184        #[serde(default)]
185        sequence_number: Option<u64>,
186        item: ResponseOutputItem,
187    },
188
189    #[serde(rename = "response.output_text.delta")]
190    OutputTextDelta {
191        item_id: String,
192        output_index: usize,
193        delta: String,
194    },
195
196    #[serde(rename = "response.output_item.done")]
197    OutputItemDone {
198        output_index: usize,
199        #[serde(default)]
200        sequence_number: Option<u64>,
201        item: ResponseOutputItem,
202    },
203
204    #[serde(rename = "response.incomplete")]
205    Incomplete { response: Response },
206
207    #[serde(rename = "response.completed")]
208    Completed { response: Response },
209
210    #[serde(rename = "response.failed")]
211    Failed { response: Response },
212
213    #[serde(other)]
214    Unknown,
215}
216
217#[derive(Deserialize, Debug, Clone)]
218pub struct ResponseError {
219    pub code: String,
220    pub message: String,
221}
222
223#[derive(Deserialize, Debug, Default, Clone)]
224pub struct Response {
225    pub id: Option<String>,
226    pub status: Option<String>,
227    pub usage: Option<ResponseUsage>,
228    pub output: Vec<ResponseOutputItem>,
229    #[serde(skip_serializing_if = "Option::is_none")]
230    pub incomplete_details: Option<IncompleteDetails>,
231    #[serde(skip_serializing_if = "Option::is_none")]
232    pub error: Option<ResponseError>,
233}
234
235#[derive(Deserialize, Debug, Default, Clone)]
236pub struct ResponseUsage {
237    pub input_tokens: Option<u64>,
238    pub output_tokens: Option<u64>,
239    pub total_tokens: Option<u64>,
240}
241
242#[derive(Deserialize, Debug, Clone)]
243#[serde(tag = "type", rename_all = "snake_case")]
244pub enum ResponseOutputItem {
245    Message {
246        id: String,
247        role: String,
248        #[serde(skip_serializing_if = "Option::is_none")]
249        content: Option<Vec<ResponseOutputContent>>,
250    },
251    FunctionCall {
252        #[serde(skip_serializing_if = "Option::is_none")]
253        id: Option<String>,
254        call_id: String,
255        name: String,
256        arguments: String,
257        #[serde(skip_serializing_if = "Option::is_none")]
258        status: Option<ItemStatus>,
259        #[serde(default, skip_serializing_if = "Option::is_none")]
260        thought_signature: Option<String>,
261    },
262    Reasoning {
263        id: String,
264        #[serde(skip_serializing_if = "Option::is_none")]
265        summary: Option<Vec<ResponseReasoningItem>>,
266        #[serde(skip_serializing_if = "Option::is_none")]
267        encrypted_content: Option<String>,
268    },
269}
270
271#[derive(Deserialize, Debug, Clone)]
272#[serde(tag = "type", rename_all = "snake_case")]
273pub enum ResponseOutputContent {
274    OutputText { text: String },
275    Refusal { refusal: String },
276}
277
278pub async fn stream_response(
279    client: Arc<dyn HttpClient>,
280    oauth_token: String,
281    api_url: String,
282    request: Request,
283    is_user_initiated: bool,
284    location: ChatLocation,
285) -> Result<BoxStream<'static, Result<StreamEvent>>> {
286    let is_vision_request = request.input.iter().any(|item| match item {
287        ResponseInputItem::Message {
288            content: Some(parts),
289            ..
290        } => parts
291            .iter()
292            .any(|p| matches!(p, ResponseInputContent::InputImage { .. })),
293        _ => false,
294    });
295
296    let request_builder = copilot_request_headers(
297        HttpRequest::builder().method(Method::POST).uri(&api_url),
298        &oauth_token,
299        Some(is_user_initiated),
300        Some(location),
301    )
302    .when(is_vision_request, |builder| {
303        builder.header("Copilot-Vision-Request", "true")
304    });
305
306    let is_streaming = request.stream;
307    let json = serde_json::to_string(&request)?;
308    let request = request_builder.body(AsyncBody::from(json))?;
309    let mut response = client.send(request).await?;
310
311    if !response.status().is_success() {
312        let mut body = String::new();
313        response.body_mut().read_to_string(&mut body).await?;
314        anyhow::bail!("Failed to connect to API: {} {}", response.status(), body);
315    }
316
317    if is_streaming {
318        let reader = BufReader::new(response.into_body());
319        Ok(reader
320            .lines()
321            .filter_map(|line| async move {
322                match line {
323                    Ok(line) => {
324                        let line = line.strip_prefix("data: ")?;
325                        if line.starts_with("[DONE]") || line.is_empty() {
326                            return None;
327                        }
328
329                        match serde_json::from_str::<StreamEvent>(line) {
330                            Ok(event) => Some(Ok(event)),
331                            Err(error) => {
332                                log::error!(
333                                    "Failed to parse Copilot responses stream event: `{}`\nResponse: `{}`",
334                                    error,
335                                    line,
336                                );
337                                Some(Err(anyhow!(error)))
338                            }
339                        }
340                    }
341                    Err(error) => Some(Err(anyhow!(error))),
342                }
343            })
344            .boxed())
345    } else {
346        // Simulate streaming this makes the mapping of this function return more straight-forward to handle if all callers assume it streams.
347        // Removes the need of having a method to map StreamEvent and another to map Response to a LanguageCompletionEvent
348        let mut body = String::new();
349        response.body_mut().read_to_string(&mut body).await?;
350
351        match serde_json::from_str::<Response>(&body) {
352            Ok(response) => {
353                let events = vec![StreamEvent::Created {
354                    response: response.clone(),
355                }];
356
357                let mut all_events = events;
358                for (output_index, item) in response.output.iter().enumerate() {
359                    all_events.push(StreamEvent::OutputItemAdded {
360                        output_index,
361                        sequence_number: None,
362                        item: item.clone(),
363                    });
364
365                    if let ResponseOutputItem::Message {
366                        id,
367                        content: Some(content),
368                        ..
369                    } = item
370                    {
371                        for part in content {
372                            if let ResponseOutputContent::OutputText { text } = part {
373                                all_events.push(StreamEvent::OutputTextDelta {
374                                    item_id: id.clone(),
375                                    output_index,
376                                    delta: text.clone(),
377                                });
378                            }
379                        }
380                    }
381
382                    all_events.push(StreamEvent::OutputItemDone {
383                        output_index,
384                        sequence_number: None,
385                        item: item.clone(),
386                    });
387                }
388
389                let final_event = if response.error.is_some() {
390                    StreamEvent::Failed { response }
391                } else if response.incomplete_details.is_some() {
392                    StreamEvent::Incomplete { response }
393                } else {
394                    StreamEvent::Completed { response }
395                };
396                all_events.push(final_event);
397
398                Ok(futures::stream::iter(all_events.into_iter().map(Ok)).boxed())
399            }
400            Err(error) => {
401                log::error!(
402                    "Failed to parse Copilot non-streaming response: `{}`\nResponse: `{}`",
403                    error,
404                    body,
405                );
406                Err(anyhow!(error))
407            }
408        }
409    }
410}