responses.rs

  1use std::sync::Arc;
  2
  3use anyhow::{Result, anyhow};
  4use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  5use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  6use serde::{Deserialize, Serialize};
  7use serde_json::Value;
  8pub use settings::OpenAiReasoningEffort as ReasoningEffort;
  9
 10#[derive(Serialize, Debug)]
 11pub struct Request {
 12    pub model: String,
 13    pub input: Vec<ResponseInputItem>,
 14    #[serde(default)]
 15    pub stream: bool,
 16    #[serde(skip_serializing_if = "Option::is_none")]
 17    pub temperature: Option<f32>,
 18    #[serde(skip_serializing_if = "Vec::is_empty")]
 19    pub tools: Vec<ToolDefinition>,
 20    #[serde(skip_serializing_if = "Option::is_none")]
 21    pub tool_choice: Option<ToolChoice>,
 22    #[serde(skip_serializing_if = "Option::is_none")]
 23    pub reasoning: Option<ReasoningConfig>,
 24    #[serde(skip_serializing_if = "Option::is_none")]
 25    pub include: Option<Vec<ResponseIncludable>>,
 26}
 27
 28#[derive(Serialize, Deserialize, Debug, Clone)]
 29#[serde(rename_all = "snake_case")]
 30pub enum ResponseIncludable {
 31    #[serde(rename = "reasoning.encrypted_content")]
 32    ReasoningEncryptedContent,
 33}
 34
 35#[derive(Serialize, Deserialize, Debug)]
 36#[serde(tag = "type", rename_all = "snake_case")]
 37pub enum ToolDefinition {
 38    Function {
 39        name: String,
 40        #[serde(skip_serializing_if = "Option::is_none")]
 41        description: Option<String>,
 42        #[serde(skip_serializing_if = "Option::is_none")]
 43        parameters: Option<Value>,
 44        #[serde(skip_serializing_if = "Option::is_none")]
 45        strict: Option<bool>,
 46    },
 47}
 48
 49#[derive(Serialize, Deserialize, Debug)]
 50#[serde(rename_all = "lowercase")]
 51pub enum ToolChoice {
 52    Auto,
 53    Any,
 54    None,
 55    #[serde(untagged)]
 56    Other(ToolDefinition),
 57}
 58
 59#[derive(Serialize, Deserialize, Debug)]
 60#[serde(rename_all = "lowercase")]
 61pub enum ReasoningSummary {
 62    Auto,
 63    Concise,
 64    Detailed,
 65}
 66
 67#[derive(Serialize, Debug)]
 68pub struct ReasoningConfig {
 69    pub effort: ReasoningEffort,
 70    #[serde(skip_serializing_if = "Option::is_none")]
 71    pub summary: Option<ReasoningSummary>,
 72}
 73
 74#[derive(Serialize, Deserialize, Debug, Clone, Default)]
 75#[serde(rename_all = "snake_case")]
 76pub enum ResponseImageDetail {
 77    Low,
 78    High,
 79    #[default]
 80    Auto,
 81}
 82
 83#[derive(Serialize, Deserialize, Debug, Clone)]
 84#[serde(tag = "type", rename_all = "snake_case")]
 85pub enum ResponseInputContent {
 86    InputText {
 87        text: String,
 88    },
 89    OutputText {
 90        text: String,
 91    },
 92    InputImage {
 93        #[serde(skip_serializing_if = "Option::is_none")]
 94        image_url: Option<String>,
 95        #[serde(default)]
 96        detail: ResponseImageDetail,
 97    },
 98}
 99
100#[derive(Serialize, Deserialize, Debug, Clone)]
101#[serde(rename_all = "snake_case")]
102pub enum ItemStatus {
103    InProgress,
104    Completed,
105    Incomplete,
106}
107
108#[derive(Serialize, Deserialize, Debug, Clone)]
109#[serde(untagged)]
110pub enum ResponseFunctionOutput {
111    Text(String),
112    Content(Vec<ResponseInputContent>),
113}
114
115#[derive(Serialize, Deserialize, Debug, Clone)]
116#[serde(tag = "type", rename_all = "snake_case")]
117pub enum ResponseInputItem {
118    Message {
119        role: String,
120        #[serde(skip_serializing_if = "Option::is_none")]
121        content: Option<Vec<ResponseInputContent>>,
122        #[serde(skip_serializing_if = "Option::is_none")]
123        status: Option<String>,
124    },
125    FunctionCall {
126        call_id: String,
127        name: String,
128        arguments: String,
129        #[serde(skip_serializing_if = "Option::is_none")]
130        status: Option<ItemStatus>,
131        #[serde(default, skip_serializing_if = "Option::is_none")]
132        thought_signature: Option<String>,
133    },
134    FunctionCallOutput {
135        call_id: String,
136        output: ResponseFunctionOutput,
137        #[serde(skip_serializing_if = "Option::is_none")]
138        status: Option<ItemStatus>,
139    },
140    Reasoning {
141        #[serde(skip_serializing_if = "Option::is_none")]
142        id: Option<String>,
143        summary: Vec<ResponseReasoningItem>,
144        encrypted_content: String,
145    },
146}
147
148#[derive(Deserialize, Debug, Clone)]
149#[serde(rename_all = "snake_case")]
150pub enum IncompleteReason {
151    #[serde(rename = "max_output_tokens")]
152    MaxOutputTokens,
153    #[serde(rename = "content_filter")]
154    ContentFilter,
155}
156
157#[derive(Deserialize, Debug, Clone)]
158pub struct IncompleteDetails {
159    #[serde(skip_serializing_if = "Option::is_none")]
160    pub reason: Option<IncompleteReason>,
161}
162
163#[derive(Serialize, Deserialize, Debug, Clone)]
164pub struct ResponseReasoningItem {
165    #[serde(rename = "type")]
166    pub kind: String,
167    pub text: String,
168}
169
170#[derive(Deserialize, Debug)]
171#[serde(tag = "type")]
172pub enum StreamEvent {
173    #[serde(rename = "error")]
174    GenericError { error: ResponseError },
175
176    #[serde(rename = "response.created")]
177    Created { response: Response },
178
179    #[serde(rename = "response.output_item.added")]
180    OutputItemAdded {
181        output_index: usize,
182        #[serde(default)]
183        sequence_number: Option<u64>,
184        item: ResponseOutputItem,
185    },
186
187    #[serde(rename = "response.output_text.delta")]
188    OutputTextDelta {
189        item_id: String,
190        output_index: usize,
191        delta: String,
192    },
193
194    #[serde(rename = "response.output_item.done")]
195    OutputItemDone {
196        output_index: usize,
197        #[serde(default)]
198        sequence_number: Option<u64>,
199        item: ResponseOutputItem,
200    },
201
202    #[serde(rename = "response.incomplete")]
203    Incomplete { response: Response },
204
205    #[serde(rename = "response.completed")]
206    Completed { response: Response },
207
208    #[serde(rename = "response.failed")]
209    Failed { response: Response },
210
211    #[serde(other)]
212    Unknown,
213}
214
215#[derive(Deserialize, Debug, Clone)]
216pub struct ResponseError {
217    pub code: String,
218    pub message: String,
219}
220
221#[derive(Deserialize, Debug, Default, Clone)]
222pub struct Response {
223    pub id: Option<String>,
224    pub status: Option<String>,
225    pub usage: Option<ResponseUsage>,
226    pub output: Vec<ResponseOutputItem>,
227    #[serde(skip_serializing_if = "Option::is_none")]
228    pub incomplete_details: Option<IncompleteDetails>,
229    #[serde(skip_serializing_if = "Option::is_none")]
230    pub error: Option<ResponseError>,
231}
232
233#[derive(Deserialize, Debug, Default, Clone)]
234pub struct ResponseUsage {
235    pub input_tokens: Option<u64>,
236    pub output_tokens: Option<u64>,
237    pub total_tokens: Option<u64>,
238}
239
240#[derive(Deserialize, Debug, Clone)]
241#[serde(tag = "type", rename_all = "snake_case")]
242pub enum ResponseOutputItem {
243    Message {
244        id: String,
245        role: String,
246        #[serde(skip_serializing_if = "Option::is_none")]
247        content: Option<Vec<ResponseOutputContent>>,
248    },
249    FunctionCall {
250        #[serde(skip_serializing_if = "Option::is_none")]
251        id: Option<String>,
252        call_id: String,
253        name: String,
254        arguments: String,
255        #[serde(skip_serializing_if = "Option::is_none")]
256        status: Option<ItemStatus>,
257        #[serde(default, skip_serializing_if = "Option::is_none")]
258        thought_signature: Option<String>,
259    },
260    Reasoning {
261        id: String,
262        #[serde(skip_serializing_if = "Option::is_none")]
263        summary: Option<Vec<ResponseReasoningItem>>,
264        #[serde(skip_serializing_if = "Option::is_none")]
265        encrypted_content: Option<String>,
266    },
267}
268
269#[derive(Deserialize, Debug, Clone)]
270#[serde(tag = "type", rename_all = "snake_case")]
271pub enum ResponseOutputContent {
272    OutputText { text: String },
273    Refusal { refusal: String },
274}
275
276pub async fn stream_response(
277    client: Arc<dyn HttpClient>,
278    api_key: String,
279    api_url: String,
280    request: Request,
281    is_user_initiated: bool,
282) -> Result<BoxStream<'static, Result<StreamEvent>>> {
283    let is_vision_request = request.input.iter().any(|item| match item {
284        ResponseInputItem::Message {
285            content: Some(parts),
286            ..
287        } => parts
288            .iter()
289            .any(|p| matches!(p, ResponseInputContent::InputImage { .. })),
290        _ => false,
291    });
292
293    let request_initiator = if is_user_initiated { "user" } else { "agent" };
294
295    let request_builder = HttpRequest::builder()
296        .method(Method::POST)
297        .uri(&api_url)
298        .header(
299            "Editor-Version",
300            format!(
301                "Zed/{}",
302                option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
303            ),
304        )
305        .header("Authorization", format!("Bearer {}", api_key))
306        .header("Content-Type", "application/json")
307        .header("Copilot-Integration-Id", "vscode-chat")
308        .header("X-Initiator", request_initiator);
309
310    let request_builder = if is_vision_request {
311        request_builder.header("Copilot-Vision-Request", "true")
312    } else {
313        request_builder
314    };
315
316    let is_streaming = request.stream;
317    let json = serde_json::to_string(&request)?;
318    let request = request_builder.body(AsyncBody::from(json))?;
319    let mut response = client.send(request).await?;
320
321    if !response.status().is_success() {
322        let mut body = String::new();
323        response.body_mut().read_to_string(&mut body).await?;
324        anyhow::bail!("Failed to connect to API: {} {}", response.status(), body);
325    }
326
327    if is_streaming {
328        let reader = BufReader::new(response.into_body());
329        Ok(reader
330            .lines()
331            .filter_map(|line| async move {
332                match line {
333                    Ok(line) => {
334                        let line = line.strip_prefix("data: ")?;
335                        if line.starts_with("[DONE]") || line.is_empty() {
336                            return None;
337                        }
338
339                        match serde_json::from_str::<StreamEvent>(line) {
340                            Ok(event) => Some(Ok(event)),
341                            Err(error) => {
342                                log::error!(
343                                    "Failed to parse Copilot responses stream event: `{}`\nResponse: `{}`",
344                                    error,
345                                    line,
346                                );
347                                Some(Err(anyhow!(error)))
348                            }
349                        }
350                    }
351                    Err(error) => Some(Err(anyhow!(error))),
352                }
353            })
354            .boxed())
355    } else {
356        // Simulate streaming this makes the mapping of this function return more straight-forward to handle if all callers assume it streams.
357        // Removes the need of having a method to map StreamEvent and another to map Response to a LanguageCompletionEvent
358        let mut body = String::new();
359        response.body_mut().read_to_string(&mut body).await?;
360
361        match serde_json::from_str::<Response>(&body) {
362            Ok(response) => {
363                let events = vec![StreamEvent::Created {
364                    response: response.clone(),
365                }];
366
367                let mut all_events = events;
368                for (output_index, item) in response.output.iter().enumerate() {
369                    all_events.push(StreamEvent::OutputItemAdded {
370                        output_index,
371                        sequence_number: None,
372                        item: item.clone(),
373                    });
374
375                    if let ResponseOutputItem::Message {
376                        id,
377                        content: Some(content),
378                        ..
379                    } = item
380                    {
381                        for part in content {
382                            if let ResponseOutputContent::OutputText { text } = part {
383                                all_events.push(StreamEvent::OutputTextDelta {
384                                    item_id: id.clone(),
385                                    output_index,
386                                    delta: text.clone(),
387                                });
388                            }
389                        }
390                    }
391
392                    all_events.push(StreamEvent::OutputItemDone {
393                        output_index,
394                        sequence_number: None,
395                        item: item.clone(),
396                    });
397                }
398
399                let final_event = if response.error.is_some() {
400                    StreamEvent::Failed { response }
401                } else if response.incomplete_details.is_some() {
402                    StreamEvent::Incomplete { response }
403                } else {
404                    StreamEvent::Completed { response }
405                };
406                all_events.push(final_event);
407
408                Ok(futures::stream::iter(all_events.into_iter().map(Ok)).boxed())
409            }
410            Err(error) => {
411                log::error!(
412                    "Failed to parse Copilot non-streaming response: `{}`\nResponse: `{}`",
413                    error,
414                    body,
415                );
416                Err(anyhow!(error))
417            }
418        }
419    }
420}