copilot_responses.rs

  1use super::*;
  2use anyhow::{Result, anyhow};
  3use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  4use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  5use serde::{Deserialize, Serialize};
  6use serde_json::Value;
  7pub use settings::OpenAiReasoningEffort as ReasoningEffort;
  8
  9#[derive(Serialize, Debug)]
 10pub struct Request {
 11    pub model: String,
 12    pub input: Vec<ResponseInputItem>,
 13    #[serde(default)]
 14    pub stream: bool,
 15    #[serde(skip_serializing_if = "Option::is_none")]
 16    pub temperature: Option<f32>,
 17    #[serde(skip_serializing_if = "Vec::is_empty")]
 18    pub tools: Vec<ToolDefinition>,
 19    #[serde(skip_serializing_if = "Option::is_none")]
 20    pub tool_choice: Option<ToolChoice>,
 21    #[serde(skip_serializing_if = "Option::is_none")]
 22    pub reasoning: Option<ReasoningConfig>,
 23    #[serde(skip_serializing_if = "Option::is_none")]
 24    pub include: Option<Vec<ResponseIncludable>>,
 25}
 26
 27#[derive(Serialize, Deserialize, Debug, Clone)]
 28#[serde(rename_all = "snake_case")]
 29pub enum ResponseIncludable {
 30    #[serde(rename = "reasoning.encrypted_content")]
 31    ReasoningEncryptedContent,
 32}
 33
 34#[derive(Serialize, Deserialize, Debug)]
 35#[serde(tag = "type", rename_all = "snake_case")]
 36pub enum ToolDefinition {
 37    Function {
 38        name: String,
 39        #[serde(skip_serializing_if = "Option::is_none")]
 40        description: Option<String>,
 41        #[serde(skip_serializing_if = "Option::is_none")]
 42        parameters: Option<Value>,
 43        #[serde(skip_serializing_if = "Option::is_none")]
 44        strict: Option<bool>,
 45    },
 46}
 47
 48#[derive(Serialize, Deserialize, Debug)]
 49#[serde(rename_all = "lowercase")]
 50pub enum ToolChoice {
 51    Auto,
 52    Any,
 53    None,
 54    #[serde(untagged)]
 55    Other(ToolDefinition),
 56}
 57
 58#[derive(Serialize, Deserialize, Debug)]
 59#[serde(rename_all = "lowercase")]
 60pub enum ReasoningSummary {
 61    Auto,
 62    Concise,
 63    Detailed,
 64}
 65
 66#[derive(Serialize, Debug)]
 67pub struct ReasoningConfig {
 68    pub effort: ReasoningEffort,
 69    #[serde(skip_serializing_if = "Option::is_none")]
 70    pub summary: Option<ReasoningSummary>,
 71}
 72
 73#[derive(Serialize, Deserialize, Debug, Clone, Default)]
 74#[serde(rename_all = "snake_case")]
 75pub enum ResponseImageDetail {
 76    Low,
 77    High,
 78    #[default]
 79    Auto,
 80}
 81
 82#[derive(Serialize, Deserialize, Debug, Clone)]
 83#[serde(tag = "type", rename_all = "snake_case")]
 84pub enum ResponseInputContent {
 85    InputText {
 86        text: String,
 87    },
 88    OutputText {
 89        text: String,
 90    },
 91    InputImage {
 92        #[serde(skip_serializing_if = "Option::is_none")]
 93        image_url: Option<String>,
 94        #[serde(default)]
 95        detail: ResponseImageDetail,
 96    },
 97}
 98
 99#[derive(Serialize, Deserialize, Debug, Clone)]
100#[serde(rename_all = "snake_case")]
101pub enum ItemStatus {
102    InProgress,
103    Completed,
104    Incomplete,
105}
106
107#[derive(Serialize, Deserialize, Debug, Clone)]
108#[serde(untagged)]
109pub enum ResponseFunctionOutput {
110    Text(String),
111    Content(Vec<ResponseInputContent>),
112}
113
114#[derive(Serialize, Deserialize, Debug, Clone)]
115#[serde(tag = "type", rename_all = "snake_case")]
116pub enum ResponseInputItem {
117    Message {
118        role: String,
119        #[serde(skip_serializing_if = "Option::is_none")]
120        content: Option<Vec<ResponseInputContent>>,
121        #[serde(skip_serializing_if = "Option::is_none")]
122        status: Option<String>,
123    },
124    FunctionCall {
125        call_id: String,
126        name: String,
127        arguments: String,
128        #[serde(skip_serializing_if = "Option::is_none")]
129        status: Option<ItemStatus>,
130    },
131    FunctionCallOutput {
132        call_id: String,
133        output: ResponseFunctionOutput,
134        #[serde(skip_serializing_if = "Option::is_none")]
135        status: Option<ItemStatus>,
136    },
137    Reasoning {
138        #[serde(skip_serializing_if = "Option::is_none")]
139        id: Option<String>,
140        summary: Vec<ResponseReasoningItem>,
141        encrypted_content: String,
142    },
143}
144
145#[derive(Deserialize, Debug, Clone)]
146#[serde(rename_all = "snake_case")]
147pub enum IncompleteReason {
148    #[serde(rename = "max_output_tokens")]
149    MaxOutputTokens,
150    #[serde(rename = "content_filter")]
151    ContentFilter,
152}
153
154#[derive(Deserialize, Debug, Clone)]
155pub struct IncompleteDetails {
156    #[serde(skip_serializing_if = "Option::is_none")]
157    pub reason: Option<IncompleteReason>,
158}
159
160#[derive(Serialize, Deserialize, Debug, Clone)]
161pub struct ResponseReasoningItem {
162    #[serde(rename = "type")]
163    pub kind: String,
164    pub text: String,
165}
166
167#[derive(Deserialize, Debug)]
168#[serde(tag = "type")]
169pub enum StreamEvent {
170    #[serde(rename = "error")]
171    GenericError { error: ResponseError },
172
173    #[serde(rename = "response.created")]
174    Created { response: Response },
175
176    #[serde(rename = "response.output_item.added")]
177    OutputItemAdded {
178        output_index: usize,
179        #[serde(default)]
180        sequence_number: Option<u64>,
181        item: ResponseOutputItem,
182    },
183
184    #[serde(rename = "response.output_text.delta")]
185    OutputTextDelta {
186        item_id: String,
187        output_index: usize,
188        delta: String,
189    },
190
191    #[serde(rename = "response.output_item.done")]
192    OutputItemDone {
193        output_index: usize,
194        #[serde(default)]
195        sequence_number: Option<u64>,
196        item: ResponseOutputItem,
197    },
198
199    #[serde(rename = "response.incomplete")]
200    Incomplete { response: Response },
201
202    #[serde(rename = "response.completed")]
203    Completed { response: Response },
204
205    #[serde(rename = "response.failed")]
206    Failed { response: Response },
207
208    #[serde(other)]
209    Unknown,
210}
211
212#[derive(Deserialize, Debug, Clone)]
213pub struct ResponseError {
214    pub code: String,
215    pub message: String,
216}
217
218#[derive(Deserialize, Debug, Default, Clone)]
219pub struct Response {
220    pub id: Option<String>,
221    pub status: Option<String>,
222    pub usage: Option<ResponseUsage>,
223    pub output: Vec<ResponseOutputItem>,
224    #[serde(skip_serializing_if = "Option::is_none")]
225    pub incomplete_details: Option<IncompleteDetails>,
226    #[serde(skip_serializing_if = "Option::is_none")]
227    pub error: Option<ResponseError>,
228}
229
230#[derive(Deserialize, Debug, Default, Clone)]
231pub struct ResponseUsage {
232    pub input_tokens: Option<u64>,
233    pub output_tokens: Option<u64>,
234    pub total_tokens: Option<u64>,
235}
236
237#[derive(Deserialize, Debug, Clone)]
238#[serde(tag = "type", rename_all = "snake_case")]
239pub enum ResponseOutputItem {
240    Message {
241        id: String,
242        role: String,
243        #[serde(skip_serializing_if = "Option::is_none")]
244        content: Option<Vec<ResponseOutputContent>>,
245    },
246    FunctionCall {
247        #[serde(skip_serializing_if = "Option::is_none")]
248        id: Option<String>,
249        call_id: String,
250        name: String,
251        arguments: String,
252        #[serde(skip_serializing_if = "Option::is_none")]
253        status: Option<ItemStatus>,
254    },
255    Reasoning {
256        id: String,
257        #[serde(skip_serializing_if = "Option::is_none")]
258        summary: Option<Vec<ResponseReasoningItem>>,
259        #[serde(skip_serializing_if = "Option::is_none")]
260        encrypted_content: Option<String>,
261    },
262}
263
264#[derive(Deserialize, Debug, Clone)]
265#[serde(tag = "type", rename_all = "snake_case")]
266pub enum ResponseOutputContent {
267    OutputText { text: String },
268    Refusal { refusal: String },
269}
270
271pub async fn stream_response(
272    client: Arc<dyn HttpClient>,
273    api_key: String,
274    api_url: String,
275    request: Request,
276    is_user_initiated: bool,
277) -> Result<BoxStream<'static, Result<StreamEvent>>> {
278    let is_vision_request = request.input.iter().any(|item| match item {
279        ResponseInputItem::Message {
280            content: Some(parts),
281            ..
282        } => parts
283            .iter()
284            .any(|p| matches!(p, ResponseInputContent::InputImage { .. })),
285        _ => false,
286    });
287
288    let request_initiator = if is_user_initiated { "user" } else { "agent" };
289
290    let request_builder = HttpRequest::builder()
291        .method(Method::POST)
292        .uri(&api_url)
293        .header(
294            "Editor-Version",
295            format!(
296                "Zed/{}",
297                option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
298            ),
299        )
300        .header("Authorization", format!("Bearer {}", api_key))
301        .header("Content-Type", "application/json")
302        .header("Copilot-Integration-Id", "vscode-chat")
303        .header("X-Initiator", request_initiator);
304
305    let request_builder = if is_vision_request {
306        request_builder.header("Copilot-Vision-Request", "true")
307    } else {
308        request_builder
309    };
310
311    let is_streaming = request.stream;
312    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
313    let mut response = client.send(request).await?;
314
315    if !response.status().is_success() {
316        let mut body = String::new();
317        response.body_mut().read_to_string(&mut body).await?;
318        anyhow::bail!("Failed to connect to API: {} {}", response.status(), body);
319    }
320
321    if is_streaming {
322        let reader = BufReader::new(response.into_body());
323        Ok(reader
324            .lines()
325            .filter_map(|line| async move {
326                match line {
327                    Ok(line) => {
328                        let line = line.strip_prefix("data: ")?;
329                        if line.starts_with("[DONE]") || line.is_empty() {
330                            return None;
331                        }
332
333                        match serde_json::from_str::<StreamEvent>(line) {
334                            Ok(event) => Some(Ok(event)),
335                            Err(error) => {
336                                log::error!(
337                                    "Failed to parse Copilot responses stream event: `{}`\nResponse: `{}`",
338                                    error,
339                                    line,
340                                );
341                                Some(Err(anyhow!(error)))
342                            }
343                        }
344                    }
345                    Err(error) => Some(Err(anyhow!(error))),
346                }
347            })
348            .boxed())
349    } else {
350        // Simulate streaming this makes the mapping of this function return more straight-forward to handle if all callers assume it streams.
351        // Removes the need of having a method to map StreamEvent and another to map Response to a LanguageCompletionEvent
352        let mut body = String::new();
353        response.body_mut().read_to_string(&mut body).await?;
354
355        match serde_json::from_str::<Response>(&body) {
356            Ok(response) => {
357                let events = vec![StreamEvent::Created {
358                    response: response.clone(),
359                }];
360
361                let mut all_events = events;
362                for (output_index, item) in response.output.iter().enumerate() {
363                    all_events.push(StreamEvent::OutputItemAdded {
364                        output_index,
365                        sequence_number: None,
366                        item: item.clone(),
367                    });
368
369                    if let ResponseOutputItem::Message {
370                        id,
371                        content: Some(content),
372                        ..
373                    } = item
374                    {
375                        for part in content {
376                            if let ResponseOutputContent::OutputText { text } = part {
377                                all_events.push(StreamEvent::OutputTextDelta {
378                                    item_id: id.clone(),
379                                    output_index,
380                                    delta: text.clone(),
381                                });
382                            }
383                        }
384                    }
385
386                    all_events.push(StreamEvent::OutputItemDone {
387                        output_index,
388                        sequence_number: None,
389                        item: item.clone(),
390                    });
391                }
392
393                let final_event = if response.error.is_some() {
394                    StreamEvent::Failed { response }
395                } else if response.incomplete_details.is_some() {
396                    StreamEvent::Incomplete { response }
397                } else {
398                    StreamEvent::Completed { response }
399                };
400                all_events.push(final_event);
401
402                Ok(futures::stream::iter(all_events.into_iter().map(Ok)).boxed())
403            }
404            Err(error) => {
405                log::error!(
406                    "Failed to parse Copilot non-streaming response: `{}`\nResponse: `{}`",
407                    error,
408                    body,
409                );
410                Err(anyhow!(error))
411            }
412        }
413    }
414}