copilot_responses.rs

  1use super::*;
  2use anyhow::{Result, anyhow};
  3use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  4use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  5use serde::{Deserialize, Serialize};
  6use serde_json::Value;
  7pub use settings::OpenAiReasoningEffort as ReasoningEffort;
  8
  9#[derive(Serialize, Debug)]
 10pub struct Request {
 11    pub model: String,
 12    pub input: Vec<ResponseInputItem>,
 13    #[serde(default)]
 14    pub stream: bool,
 15    #[serde(skip_serializing_if = "Option::is_none")]
 16    pub temperature: Option<f32>,
 17    #[serde(skip_serializing_if = "Vec::is_empty")]
 18    pub tools: Vec<ToolDefinition>,
 19    #[serde(skip_serializing_if = "Option::is_none")]
 20    pub tool_choice: Option<ToolChoice>,
 21    #[serde(skip_serializing_if = "Option::is_none")]
 22    pub reasoning: Option<ReasoningConfig>,
 23    #[serde(skip_serializing_if = "Option::is_none")]
 24    pub include: Option<Vec<ResponseIncludable>>,
 25}
 26
 27#[derive(Serialize, Deserialize, Debug, Clone)]
 28#[serde(rename_all = "snake_case")]
 29pub enum ResponseIncludable {
 30    #[serde(rename = "reasoning.encrypted_content")]
 31    ReasoningEncryptedContent,
 32}
 33
 34#[derive(Serialize, Deserialize, Debug)]
 35#[serde(tag = "type", rename_all = "snake_case")]
 36pub enum ToolDefinition {
 37    Function {
 38        name: String,
 39        #[serde(skip_serializing_if = "Option::is_none")]
 40        description: Option<String>,
 41        #[serde(skip_serializing_if = "Option::is_none")]
 42        parameters: Option<Value>,
 43        #[serde(skip_serializing_if = "Option::is_none")]
 44        strict: Option<bool>,
 45    },
 46}
 47
 48#[derive(Serialize, Deserialize, Debug)]
 49#[serde(rename_all = "lowercase")]
 50pub enum ToolChoice {
 51    Auto,
 52    Any,
 53    None,
 54    #[serde(untagged)]
 55    Other(ToolDefinition),
 56}
 57
 58#[derive(Serialize, Deserialize, Debug)]
 59#[serde(rename_all = "lowercase")]
 60pub enum ReasoningSummary {
 61    Auto,
 62    Concise,
 63    Detailed,
 64}
 65
 66#[derive(Serialize, Debug)]
 67pub struct ReasoningConfig {
 68    pub effort: ReasoningEffort,
 69    #[serde(skip_serializing_if = "Option::is_none")]
 70    pub summary: Option<ReasoningSummary>,
 71}
 72
 73#[derive(Serialize, Deserialize, Debug, Clone, Default)]
 74#[serde(rename_all = "snake_case")]
 75pub enum ResponseImageDetail {
 76    Low,
 77    High,
 78    #[default]
 79    Auto,
 80}
 81
 82#[derive(Serialize, Deserialize, Debug, Clone)]
 83#[serde(tag = "type", rename_all = "snake_case")]
 84pub enum ResponseInputContent {
 85    InputText {
 86        text: String,
 87    },
 88    OutputText {
 89        text: String,
 90    },
 91    InputImage {
 92        #[serde(skip_serializing_if = "Option::is_none")]
 93        image_url: Option<String>,
 94        #[serde(default)]
 95        detail: ResponseImageDetail,
 96    },
 97}
 98
 99#[derive(Serialize, Deserialize, Debug, Clone)]
100#[serde(rename_all = "snake_case")]
101pub enum ItemStatus {
102    InProgress,
103    Completed,
104    Incomplete,
105}
106
107#[derive(Serialize, Deserialize, Debug, Clone)]
108#[serde(untagged)]
109pub enum ResponseFunctionOutput {
110    Text(String),
111    Content(Vec<ResponseInputContent>),
112}
113
114#[derive(Serialize, Deserialize, Debug, Clone)]
115#[serde(tag = "type", rename_all = "snake_case")]
116pub enum ResponseInputItem {
117    Message {
118        role: String,
119        #[serde(skip_serializing_if = "Option::is_none")]
120        content: Option<Vec<ResponseInputContent>>,
121        #[serde(skip_serializing_if = "Option::is_none")]
122        status: Option<String>,
123    },
124    FunctionCall {
125        call_id: String,
126        name: String,
127        arguments: String,
128        #[serde(skip_serializing_if = "Option::is_none")]
129        status: Option<ItemStatus>,
130        #[serde(default, skip_serializing_if = "Option::is_none")]
131        thought_signature: Option<String>,
132    },
133    FunctionCallOutput {
134        call_id: String,
135        output: ResponseFunctionOutput,
136        #[serde(skip_serializing_if = "Option::is_none")]
137        status: Option<ItemStatus>,
138    },
139    Reasoning {
140        #[serde(skip_serializing_if = "Option::is_none")]
141        id: Option<String>,
142        summary: Vec<ResponseReasoningItem>,
143        encrypted_content: String,
144    },
145}
146
147#[derive(Deserialize, Debug, Clone)]
148#[serde(rename_all = "snake_case")]
149pub enum IncompleteReason {
150    #[serde(rename = "max_output_tokens")]
151    MaxOutputTokens,
152    #[serde(rename = "content_filter")]
153    ContentFilter,
154}
155
156#[derive(Deserialize, Debug, Clone)]
157pub struct IncompleteDetails {
158    #[serde(skip_serializing_if = "Option::is_none")]
159    pub reason: Option<IncompleteReason>,
160}
161
162#[derive(Serialize, Deserialize, Debug, Clone)]
163pub struct ResponseReasoningItem {
164    #[serde(rename = "type")]
165    pub kind: String,
166    pub text: String,
167}
168
169#[derive(Deserialize, Debug)]
170#[serde(tag = "type")]
171pub enum StreamEvent {
172    #[serde(rename = "error")]
173    GenericError { error: ResponseError },
174
175    #[serde(rename = "response.created")]
176    Created { response: Response },
177
178    #[serde(rename = "response.output_item.added")]
179    OutputItemAdded {
180        output_index: usize,
181        #[serde(default)]
182        sequence_number: Option<u64>,
183        item: ResponseOutputItem,
184    },
185
186    #[serde(rename = "response.output_text.delta")]
187    OutputTextDelta {
188        item_id: String,
189        output_index: usize,
190        delta: String,
191    },
192
193    #[serde(rename = "response.output_item.done")]
194    OutputItemDone {
195        output_index: usize,
196        #[serde(default)]
197        sequence_number: Option<u64>,
198        item: ResponseOutputItem,
199    },
200
201    #[serde(rename = "response.incomplete")]
202    Incomplete { response: Response },
203
204    #[serde(rename = "response.completed")]
205    Completed { response: Response },
206
207    #[serde(rename = "response.failed")]
208    Failed { response: Response },
209
210    #[serde(other)]
211    Unknown,
212}
213
214#[derive(Deserialize, Debug, Clone)]
215pub struct ResponseError {
216    pub code: String,
217    pub message: String,
218}
219
220#[derive(Deserialize, Debug, Default, Clone)]
221pub struct Response {
222    pub id: Option<String>,
223    pub status: Option<String>,
224    pub usage: Option<ResponseUsage>,
225    pub output: Vec<ResponseOutputItem>,
226    #[serde(skip_serializing_if = "Option::is_none")]
227    pub incomplete_details: Option<IncompleteDetails>,
228    #[serde(skip_serializing_if = "Option::is_none")]
229    pub error: Option<ResponseError>,
230}
231
232#[derive(Deserialize, Debug, Default, Clone)]
233pub struct ResponseUsage {
234    pub input_tokens: Option<u64>,
235    pub output_tokens: Option<u64>,
236    pub total_tokens: Option<u64>,
237}
238
239#[derive(Deserialize, Debug, Clone)]
240#[serde(tag = "type", rename_all = "snake_case")]
241pub enum ResponseOutputItem {
242    Message {
243        id: String,
244        role: String,
245        #[serde(skip_serializing_if = "Option::is_none")]
246        content: Option<Vec<ResponseOutputContent>>,
247    },
248    FunctionCall {
249        #[serde(skip_serializing_if = "Option::is_none")]
250        id: Option<String>,
251        call_id: String,
252        name: String,
253        arguments: String,
254        #[serde(skip_serializing_if = "Option::is_none")]
255        status: Option<ItemStatus>,
256        #[serde(default, skip_serializing_if = "Option::is_none")]
257        thought_signature: Option<String>,
258    },
259    Reasoning {
260        id: String,
261        #[serde(skip_serializing_if = "Option::is_none")]
262        summary: Option<Vec<ResponseReasoningItem>>,
263        #[serde(skip_serializing_if = "Option::is_none")]
264        encrypted_content: Option<String>,
265    },
266}
267
268#[derive(Deserialize, Debug, Clone)]
269#[serde(tag = "type", rename_all = "snake_case")]
270pub enum ResponseOutputContent {
271    OutputText { text: String },
272    Refusal { refusal: String },
273}
274
275pub async fn stream_response(
276    client: Arc<dyn HttpClient>,
277    api_key: String,
278    api_url: String,
279    request: Request,
280    is_user_initiated: bool,
281) -> Result<BoxStream<'static, Result<StreamEvent>>> {
282    let is_vision_request = request.input.iter().any(|item| match item {
283        ResponseInputItem::Message {
284            content: Some(parts),
285            ..
286        } => parts
287            .iter()
288            .any(|p| matches!(p, ResponseInputContent::InputImage { .. })),
289        _ => false,
290    });
291
292    let request_initiator = if is_user_initiated { "user" } else { "agent" };
293
294    let request_builder = HttpRequest::builder()
295        .method(Method::POST)
296        .uri(&api_url)
297        .header(
298            "Editor-Version",
299            format!(
300                "Zed/{}",
301                option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
302            ),
303        )
304        .header("Authorization", format!("Bearer {}", api_key))
305        .header("Content-Type", "application/json")
306        .header("Copilot-Integration-Id", "vscode-chat")
307        .header("X-Initiator", request_initiator);
308
309    let request_builder = if is_vision_request {
310        request_builder.header("Copilot-Vision-Request", "true")
311    } else {
312        request_builder
313    };
314
315    let is_streaming = request.stream;
316    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
317    let mut response = client.send(request).await?;
318
319    if !response.status().is_success() {
320        let mut body = String::new();
321        response.body_mut().read_to_string(&mut body).await?;
322        anyhow::bail!("Failed to connect to API: {} {}", response.status(), body);
323    }
324
325    if is_streaming {
326        let reader = BufReader::new(response.into_body());
327        Ok(reader
328            .lines()
329            .filter_map(|line| async move {
330                match line {
331                    Ok(line) => {
332                        let line = line.strip_prefix("data: ")?;
333                        if line.starts_with("[DONE]") || line.is_empty() {
334                            return None;
335                        }
336
337                        match serde_json::from_str::<StreamEvent>(line) {
338                            Ok(event) => Some(Ok(event)),
339                            Err(error) => {
340                                log::error!(
341                                    "Failed to parse Copilot responses stream event: `{}`\nResponse: `{}`",
342                                    error,
343                                    line,
344                                );
345                                Some(Err(anyhow!(error)))
346                            }
347                        }
348                    }
349                    Err(error) => Some(Err(anyhow!(error))),
350                }
351            })
352            .boxed())
353    } else {
354        // Simulate streaming this makes the mapping of this function return more straight-forward to handle if all callers assume it streams.
355        // Removes the need of having a method to map StreamEvent and another to map Response to a LanguageCompletionEvent
356        let mut body = String::new();
357        response.body_mut().read_to_string(&mut body).await?;
358
359        match serde_json::from_str::<Response>(&body) {
360            Ok(response) => {
361                let events = vec![StreamEvent::Created {
362                    response: response.clone(),
363                }];
364
365                let mut all_events = events;
366                for (output_index, item) in response.output.iter().enumerate() {
367                    all_events.push(StreamEvent::OutputItemAdded {
368                        output_index,
369                        sequence_number: None,
370                        item: item.clone(),
371                    });
372
373                    if let ResponseOutputItem::Message {
374                        id,
375                        content: Some(content),
376                        ..
377                    } = item
378                    {
379                        for part in content {
380                            if let ResponseOutputContent::OutputText { text } = part {
381                                all_events.push(StreamEvent::OutputTextDelta {
382                                    item_id: id.clone(),
383                                    output_index,
384                                    delta: text.clone(),
385                                });
386                            }
387                        }
388                    }
389
390                    all_events.push(StreamEvent::OutputItemDone {
391                        output_index,
392                        sequence_number: None,
393                        item: item.clone(),
394                    });
395                }
396
397                let final_event = if response.error.is_some() {
398                    StreamEvent::Failed { response }
399                } else if response.incomplete_details.is_some() {
400                    StreamEvent::Incomplete { response }
401                } else {
402                    StreamEvent::Completed { response }
403                };
404                all_events.push(final_event);
405
406                Ok(futures::stream::iter(all_events.into_iter().map(Ok)).boxed())
407            }
408            Err(error) => {
409                log::error!(
410                    "Failed to parse Copilot non-streaming response: `{}`\nResponse: `{}`",
411                    error,
412                    body,
413                );
414                Err(anyhow!(error))
415            }
416        }
417    }
418}