copilot_responses.rs

  1use super::*;
  2use anyhow::{Result, anyhow};
  3use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  4use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  5use serde::{Deserialize, Serialize};
  6use serde_json::Value;
  7pub use settings::OpenAiReasoningEffort as ReasoningEffort;
  8
  9#[derive(Serialize, Debug)]
 10pub struct Request {
 11    pub model: String,
 12    pub input: Vec<ResponseInputItem>,
 13    #[serde(default)]
 14    pub stream: bool,
 15    #[serde(skip_serializing_if = "Option::is_none")]
 16    pub temperature: Option<f32>,
 17    #[serde(skip_serializing_if = "Vec::is_empty")]
 18    pub tools: Vec<ToolDefinition>,
 19    #[serde(skip_serializing_if = "Option::is_none")]
 20    pub tool_choice: Option<ToolChoice>,
 21    #[serde(skip_serializing_if = "Option::is_none")]
 22    pub reasoning: Option<ReasoningConfig>,
 23    #[serde(skip_serializing_if = "Option::is_none")]
 24    pub include: Option<Vec<ResponseIncludable>>,
 25}
 26
 27#[derive(Serialize, Deserialize, Debug, Clone)]
 28#[serde(rename_all = "snake_case")]
 29pub enum ResponseIncludable {
 30    #[serde(rename = "reasoning.encrypted_content")]
 31    ReasoningEncryptedContent,
 32}
 33
 34#[derive(Serialize, Deserialize, Debug)]
 35#[serde(tag = "type", rename_all = "snake_case")]
 36pub enum ToolDefinition {
 37    Function {
 38        name: String,
 39        #[serde(skip_serializing_if = "Option::is_none")]
 40        description: Option<String>,
 41        #[serde(skip_serializing_if = "Option::is_none")]
 42        parameters: Option<Value>,
 43        #[serde(skip_serializing_if = "Option::is_none")]
 44        strict: Option<bool>,
 45    },
 46}
 47
 48#[derive(Serialize, Deserialize, Debug)]
 49#[serde(rename_all = "lowercase")]
 50pub enum ToolChoice {
 51    Auto,
 52    Any,
 53    None,
 54    #[serde(untagged)]
 55    Other(ToolDefinition),
 56}
 57
 58#[derive(Serialize, Deserialize, Debug)]
 59#[serde(rename_all = "lowercase")]
 60pub enum ReasoningSummary {
 61    Auto,
 62    Concise,
 63    Detailed,
 64}
 65
 66#[derive(Serialize, Debug)]
 67pub struct ReasoningConfig {
 68    pub effort: ReasoningEffort,
 69    #[serde(skip_serializing_if = "Option::is_none")]
 70    pub summary: Option<ReasoningSummary>,
 71}
 72
 73#[derive(Serialize, Deserialize, Debug, Clone, Default)]
 74#[serde(rename_all = "snake_case")]
 75pub enum ResponseImageDetail {
 76    Low,
 77    High,
 78    #[default]
 79    Auto,
 80}
 81
 82#[derive(Serialize, Deserialize, Debug, Clone)]
 83#[serde(tag = "type", rename_all = "snake_case")]
 84pub enum ResponseInputContent {
 85    InputText {
 86        text: String,
 87    },
 88    OutputText {
 89        text: String,
 90    },
 91    InputImage {
 92        #[serde(skip_serializing_if = "Option::is_none")]
 93        image_url: Option<String>,
 94        #[serde(default)]
 95        detail: ResponseImageDetail,
 96    },
 97}
 98
 99#[derive(Serialize, Deserialize, Debug, Clone)]
100#[serde(rename_all = "snake_case")]
101pub enum ItemStatus {
102    InProgress,
103    Completed,
104    Incomplete,
105}
106
107#[derive(Serialize, Deserialize, Debug, Clone)]
108#[serde(untagged)]
109pub enum ResponseFunctionOutput {
110    Text(String),
111    Content(Vec<ResponseInputContent>),
112}
113
114#[derive(Serialize, Deserialize, Debug, Clone)]
115#[serde(tag = "type", rename_all = "snake_case")]
116pub enum ResponseInputItem {
117    Message {
118        role: String,
119        #[serde(skip_serializing_if = "Option::is_none")]
120        content: Option<Vec<ResponseInputContent>>,
121        #[serde(skip_serializing_if = "Option::is_none")]
122        status: Option<String>,
123    },
124    FunctionCall {
125        call_id: String,
126        name: String,
127        arguments: String,
128        #[serde(skip_serializing_if = "Option::is_none")]
129        status: Option<ItemStatus>,
130        #[serde(default, skip_serializing_if = "Option::is_none")]
131        thought_signature: Option<String>,
132    },
133    FunctionCallOutput {
134        call_id: String,
135        output: ResponseFunctionOutput,
136        #[serde(skip_serializing_if = "Option::is_none")]
137        status: Option<ItemStatus>,
138    },
139    Reasoning {
140        #[serde(skip_serializing_if = "Option::is_none")]
141        id: Option<String>,
142        summary: Vec<ResponseReasoningItem>,
143        encrypted_content: String,
144    },
145}
146
147#[derive(Deserialize, Debug, Clone)]
148#[serde(rename_all = "snake_case")]
149pub enum IncompleteReason {
150    #[serde(rename = "max_output_tokens")]
151    MaxOutputTokens,
152    #[serde(rename = "content_filter")]
153    ContentFilter,
154}
155
156#[derive(Deserialize, Debug, Clone)]
157pub struct IncompleteDetails {
158    #[serde(skip_serializing_if = "Option::is_none")]
159    pub reason: Option<IncompleteReason>,
160}
161
162#[derive(Serialize, Deserialize, Debug, Clone)]
163pub struct ResponseReasoningItem {
164    #[serde(rename = "type")]
165    pub kind: String,
166    pub text: String,
167}
168
169#[derive(Deserialize, Debug)]
170#[serde(tag = "type")]
171pub enum StreamEvent {
172    #[serde(rename = "error")]
173    GenericError { error: ResponseError },
174
175    #[serde(rename = "response.created")]
176    Created { response: Response },
177
178    #[serde(rename = "response.output_item.added")]
179    OutputItemAdded {
180        output_index: usize,
181        #[serde(default)]
182        sequence_number: Option<u64>,
183        item: ResponseOutputItem,
184    },
185
186    #[serde(rename = "response.output_text.delta")]
187    OutputTextDelta {
188        item_id: String,
189        output_index: usize,
190        delta: String,
191    },
192
193    #[serde(rename = "response.output_item.done")]
194    OutputItemDone {
195        output_index: usize,
196        #[serde(default)]
197        sequence_number: Option<u64>,
198        item: ResponseOutputItem,
199    },
200
201    #[serde(rename = "response.incomplete")]
202    Incomplete { response: Response },
203
204    #[serde(rename = "response.completed")]
205    Completed { response: Response },
206
207    #[serde(rename = "response.failed")]
208    Failed { response: Response },
209
210    #[serde(other)]
211    Unknown,
212}
213
214#[derive(Deserialize, Debug, Clone)]
215pub struct ResponseError {
216    pub code: String,
217    pub message: String,
218}
219
220#[derive(Deserialize, Debug, Default, Clone)]
221pub struct Response {
222    pub id: Option<String>,
223    pub status: Option<String>,
224    pub usage: Option<ResponseUsage>,
225    pub output: Vec<ResponseOutputItem>,
226    #[serde(skip_serializing_if = "Option::is_none")]
227    pub incomplete_details: Option<IncompleteDetails>,
228    #[serde(skip_serializing_if = "Option::is_none")]
229    pub error: Option<ResponseError>,
230}
231
232#[derive(Deserialize, Debug, Default, Clone)]
233pub struct ResponseUsage {
234    pub input_tokens: Option<u64>,
235    pub output_tokens: Option<u64>,
236    pub total_tokens: Option<u64>,
237}
238
239#[derive(Deserialize, Debug, Clone)]
240#[serde(tag = "type", rename_all = "snake_case")]
241pub enum ResponseOutputItem {
242    Message {
243        id: String,
244        role: String,
245        #[serde(skip_serializing_if = "Option::is_none")]
246        content: Option<Vec<ResponseOutputContent>>,
247    },
248    FunctionCall {
249        #[serde(skip_serializing_if = "Option::is_none")]
250        id: Option<String>,
251        call_id: String,
252        name: String,
253        arguments: String,
254        #[serde(skip_serializing_if = "Option::is_none")]
255        status: Option<ItemStatus>,
256        #[serde(default, skip_serializing_if = "Option::is_none")]
257        thought_signature: Option<String>,
258    },
259    Reasoning {
260        id: String,
261        #[serde(skip_serializing_if = "Option::is_none")]
262        summary: Option<Vec<ResponseReasoningItem>>,
263        #[serde(skip_serializing_if = "Option::is_none")]
264        encrypted_content: Option<String>,
265    },
266}
267
268#[derive(Deserialize, Debug, Clone)]
269#[serde(tag = "type", rename_all = "snake_case")]
270pub enum ResponseOutputContent {
271    OutputText { text: String },
272    Refusal { refusal: String },
273}
274
275pub async fn stream_response(
276    client: Arc<dyn HttpClient>,
277    api_key: String,
278    api_url: String,
279    request: Request,
280    is_user_initiated: bool,
281) -> Result<BoxStream<'static, Result<StreamEvent>>> {
282    let is_vision_request = request.input.iter().any(|item| match item {
283        ResponseInputItem::Message {
284            content: Some(parts),
285            ..
286        } => parts
287            .iter()
288            .any(|p| matches!(p, ResponseInputContent::InputImage { .. })),
289        _ => false,
290    });
291
292    let request_initiator = if is_user_initiated { "user" } else { "agent" };
293
294    let request_builder = HttpRequest::builder()
295        .method(Method::POST)
296        .uri(&api_url)
297        .header(
298            "Editor-Version",
299            format!(
300                "Zed/{}",
301                option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
302            ),
303        )
304        .header("Authorization", format!("Bearer {}", api_key))
305        .header("Content-Type", "application/json")
306        .header("Copilot-Integration-Id", "vscode-chat")
307        .header("X-Initiator", request_initiator);
308
309    let request_builder = if is_vision_request {
310        request_builder.header("Copilot-Vision-Request", "true")
311    } else {
312        request_builder
313    };
314
315    let is_streaming = request.stream;
316    let json = serde_json::to_string(&request)?;
317    let request = request_builder.body(AsyncBody::from(json))?;
318    let mut response = client.send(request).await?;
319
320    if !response.status().is_success() {
321        let mut body = String::new();
322        response.body_mut().read_to_string(&mut body).await?;
323        anyhow::bail!("Failed to connect to API: {} {}", response.status(), body);
324    }
325
326    if is_streaming {
327        let reader = BufReader::new(response.into_body());
328        Ok(reader
329            .lines()
330            .filter_map(|line| async move {
331                match line {
332                    Ok(line) => {
333                        let line = line.strip_prefix("data: ")?;
334                        if line.starts_with("[DONE]") || line.is_empty() {
335                            return None;
336                        }
337
338                        match serde_json::from_str::<StreamEvent>(line) {
339                            Ok(event) => Some(Ok(event)),
340                            Err(error) => {
341                                log::error!(
342                                    "Failed to parse Copilot responses stream event: `{}`\nResponse: `{}`",
343                                    error,
344                                    line,
345                                );
346                                Some(Err(anyhow!(error)))
347                            }
348                        }
349                    }
350                    Err(error) => Some(Err(anyhow!(error))),
351                }
352            })
353            .boxed())
354    } else {
355        // Simulate streaming this makes the mapping of this function return more straight-forward to handle if all callers assume it streams.
356        // Removes the need of having a method to map StreamEvent and another to map Response to a LanguageCompletionEvent
357        let mut body = String::new();
358        response.body_mut().read_to_string(&mut body).await?;
359
360        match serde_json::from_str::<Response>(&body) {
361            Ok(response) => {
362                let events = vec![StreamEvent::Created {
363                    response: response.clone(),
364                }];
365
366                let mut all_events = events;
367                for (output_index, item) in response.output.iter().enumerate() {
368                    all_events.push(StreamEvent::OutputItemAdded {
369                        output_index,
370                        sequence_number: None,
371                        item: item.clone(),
372                    });
373
374                    if let ResponseOutputItem::Message {
375                        id,
376                        content: Some(content),
377                        ..
378                    } = item
379                    {
380                        for part in content {
381                            if let ResponseOutputContent::OutputText { text } = part {
382                                all_events.push(StreamEvent::OutputTextDelta {
383                                    item_id: id.clone(),
384                                    output_index,
385                                    delta: text.clone(),
386                                });
387                            }
388                        }
389                    }
390
391                    all_events.push(StreamEvent::OutputItemDone {
392                        output_index,
393                        sequence_number: None,
394                        item: item.clone(),
395                    });
396                }
397
398                let final_event = if response.error.is_some() {
399                    StreamEvent::Failed { response }
400                } else if response.incomplete_details.is_some() {
401                    StreamEvent::Incomplete { response }
402                } else {
403                    StreamEvent::Completed { response }
404                };
405                all_events.push(final_event);
406
407                Ok(futures::stream::iter(all_events.into_iter().map(Ok)).boxed())
408            }
409            Err(error) => {
410                log::error!(
411                    "Failed to parse Copilot non-streaming response: `{}`\nResponse: `{}`",
412                    error,
413                    body,
414                );
415                Err(anyhow!(error))
416            }
417        }
418    }
419}