1use std::sync::Arc;
2
3use anyhow::{Result, anyhow};
4use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
5use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8pub use settings::OpenAiReasoningEffort as ReasoningEffort;
9
10#[derive(Serialize, Debug)]
11pub struct Request {
12 pub model: String,
13 pub input: Vec<ResponseInputItem>,
14 #[serde(default)]
15 pub stream: bool,
16 #[serde(skip_serializing_if = "Option::is_none")]
17 pub temperature: Option<f32>,
18 #[serde(skip_serializing_if = "Vec::is_empty")]
19 pub tools: Vec<ToolDefinition>,
20 #[serde(skip_serializing_if = "Option::is_none")]
21 pub tool_choice: Option<ToolChoice>,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 pub reasoning: Option<ReasoningConfig>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub include: Option<Vec<ResponseIncludable>>,
26}
27
28#[derive(Serialize, Deserialize, Debug, Clone)]
29#[serde(rename_all = "snake_case")]
30pub enum ResponseIncludable {
31 #[serde(rename = "reasoning.encrypted_content")]
32 ReasoningEncryptedContent,
33}
34
35#[derive(Serialize, Deserialize, Debug)]
36#[serde(tag = "type", rename_all = "snake_case")]
37pub enum ToolDefinition {
38 Function {
39 name: String,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 description: Option<String>,
42 #[serde(skip_serializing_if = "Option::is_none")]
43 parameters: Option<Value>,
44 #[serde(skip_serializing_if = "Option::is_none")]
45 strict: Option<bool>,
46 },
47}
48
49#[derive(Serialize, Deserialize, Debug)]
50#[serde(rename_all = "lowercase")]
51pub enum ToolChoice {
52 Auto,
53 Any,
54 None,
55 #[serde(untagged)]
56 Other(ToolDefinition),
57}
58
59#[derive(Serialize, Deserialize, Debug)]
60#[serde(rename_all = "lowercase")]
61pub enum ReasoningSummary {
62 Auto,
63 Concise,
64 Detailed,
65}
66
67#[derive(Serialize, Debug)]
68pub struct ReasoningConfig {
69 pub effort: ReasoningEffort,
70 #[serde(skip_serializing_if = "Option::is_none")]
71 pub summary: Option<ReasoningSummary>,
72}
73
74#[derive(Serialize, Deserialize, Debug, Clone, Default)]
75#[serde(rename_all = "snake_case")]
76pub enum ResponseImageDetail {
77 Low,
78 High,
79 #[default]
80 Auto,
81}
82
83#[derive(Serialize, Deserialize, Debug, Clone)]
84#[serde(tag = "type", rename_all = "snake_case")]
85pub enum ResponseInputContent {
86 InputText {
87 text: String,
88 },
89 OutputText {
90 text: String,
91 },
92 InputImage {
93 #[serde(skip_serializing_if = "Option::is_none")]
94 image_url: Option<String>,
95 #[serde(default)]
96 detail: ResponseImageDetail,
97 },
98}
99
100#[derive(Serialize, Deserialize, Debug, Clone)]
101#[serde(rename_all = "snake_case")]
102pub enum ItemStatus {
103 InProgress,
104 Completed,
105 Incomplete,
106}
107
108#[derive(Serialize, Deserialize, Debug, Clone)]
109#[serde(untagged)]
110pub enum ResponseFunctionOutput {
111 Text(String),
112 Content(Vec<ResponseInputContent>),
113}
114
115#[derive(Serialize, Deserialize, Debug, Clone)]
116#[serde(tag = "type", rename_all = "snake_case")]
117pub enum ResponseInputItem {
118 Message {
119 role: String,
120 #[serde(skip_serializing_if = "Option::is_none")]
121 content: Option<Vec<ResponseInputContent>>,
122 #[serde(skip_serializing_if = "Option::is_none")]
123 status: Option<String>,
124 },
125 FunctionCall {
126 call_id: String,
127 name: String,
128 arguments: String,
129 #[serde(skip_serializing_if = "Option::is_none")]
130 status: Option<ItemStatus>,
131 #[serde(default, skip_serializing_if = "Option::is_none")]
132 thought_signature: Option<String>,
133 },
134 FunctionCallOutput {
135 call_id: String,
136 output: ResponseFunctionOutput,
137 #[serde(skip_serializing_if = "Option::is_none")]
138 status: Option<ItemStatus>,
139 },
140 Reasoning {
141 #[serde(skip_serializing_if = "Option::is_none")]
142 id: Option<String>,
143 summary: Vec<ResponseReasoningItem>,
144 encrypted_content: String,
145 },
146}
147
148#[derive(Deserialize, Debug, Clone)]
149#[serde(rename_all = "snake_case")]
150pub enum IncompleteReason {
151 #[serde(rename = "max_output_tokens")]
152 MaxOutputTokens,
153 #[serde(rename = "content_filter")]
154 ContentFilter,
155}
156
157#[derive(Deserialize, Debug, Clone)]
158pub struct IncompleteDetails {
159 #[serde(skip_serializing_if = "Option::is_none")]
160 pub reason: Option<IncompleteReason>,
161}
162
163#[derive(Serialize, Deserialize, Debug, Clone)]
164pub struct ResponseReasoningItem {
165 #[serde(rename = "type")]
166 pub kind: String,
167 pub text: String,
168}
169
170#[derive(Deserialize, Debug)]
171#[serde(tag = "type")]
172pub enum StreamEvent {
173 #[serde(rename = "error")]
174 GenericError { error: ResponseError },
175
176 #[serde(rename = "response.created")]
177 Created { response: Response },
178
179 #[serde(rename = "response.output_item.added")]
180 OutputItemAdded {
181 output_index: usize,
182 #[serde(default)]
183 sequence_number: Option<u64>,
184 item: ResponseOutputItem,
185 },
186
187 #[serde(rename = "response.output_text.delta")]
188 OutputTextDelta {
189 item_id: String,
190 output_index: usize,
191 delta: String,
192 },
193
194 #[serde(rename = "response.output_item.done")]
195 OutputItemDone {
196 output_index: usize,
197 #[serde(default)]
198 sequence_number: Option<u64>,
199 item: ResponseOutputItem,
200 },
201
202 #[serde(rename = "response.incomplete")]
203 Incomplete { response: Response },
204
205 #[serde(rename = "response.completed")]
206 Completed { response: Response },
207
208 #[serde(rename = "response.failed")]
209 Failed { response: Response },
210
211 #[serde(other)]
212 Unknown,
213}
214
215#[derive(Deserialize, Debug, Clone)]
216pub struct ResponseError {
217 pub code: String,
218 pub message: String,
219}
220
221#[derive(Deserialize, Debug, Default, Clone)]
222pub struct Response {
223 pub id: Option<String>,
224 pub status: Option<String>,
225 pub usage: Option<ResponseUsage>,
226 pub output: Vec<ResponseOutputItem>,
227 #[serde(skip_serializing_if = "Option::is_none")]
228 pub incomplete_details: Option<IncompleteDetails>,
229 #[serde(skip_serializing_if = "Option::is_none")]
230 pub error: Option<ResponseError>,
231}
232
233#[derive(Deserialize, Debug, Default, Clone)]
234pub struct ResponseUsage {
235 pub input_tokens: Option<u64>,
236 pub output_tokens: Option<u64>,
237 pub total_tokens: Option<u64>,
238}
239
240#[derive(Deserialize, Debug, Clone)]
241#[serde(tag = "type", rename_all = "snake_case")]
242pub enum ResponseOutputItem {
243 Message {
244 id: String,
245 role: String,
246 #[serde(skip_serializing_if = "Option::is_none")]
247 content: Option<Vec<ResponseOutputContent>>,
248 },
249 FunctionCall {
250 #[serde(skip_serializing_if = "Option::is_none")]
251 id: Option<String>,
252 call_id: String,
253 name: String,
254 arguments: String,
255 #[serde(skip_serializing_if = "Option::is_none")]
256 status: Option<ItemStatus>,
257 #[serde(default, skip_serializing_if = "Option::is_none")]
258 thought_signature: Option<String>,
259 },
260 Reasoning {
261 id: String,
262 #[serde(skip_serializing_if = "Option::is_none")]
263 summary: Option<Vec<ResponseReasoningItem>>,
264 #[serde(skip_serializing_if = "Option::is_none")]
265 encrypted_content: Option<String>,
266 },
267}
268
269#[derive(Deserialize, Debug, Clone)]
270#[serde(tag = "type", rename_all = "snake_case")]
271pub enum ResponseOutputContent {
272 OutputText { text: String },
273 Refusal { refusal: String },
274}
275
276pub async fn stream_response(
277 client: Arc<dyn HttpClient>,
278 api_key: String,
279 api_url: String,
280 request: Request,
281 is_user_initiated: bool,
282) -> Result<BoxStream<'static, Result<StreamEvent>>> {
283 let is_vision_request = request.input.iter().any(|item| match item {
284 ResponseInputItem::Message {
285 content: Some(parts),
286 ..
287 } => parts
288 .iter()
289 .any(|p| matches!(p, ResponseInputContent::InputImage { .. })),
290 _ => false,
291 });
292
293 let request_initiator = if is_user_initiated { "user" } else { "agent" };
294
295 let request_builder = HttpRequest::builder()
296 .method(Method::POST)
297 .uri(&api_url)
298 .header(
299 "Editor-Version",
300 format!(
301 "Zed/{}",
302 option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
303 ),
304 )
305 .header("Authorization", format!("Bearer {}", api_key))
306 .header("Content-Type", "application/json")
307 .header("Copilot-Integration-Id", "vscode-chat")
308 .header("X-Initiator", request_initiator);
309
310 let request_builder = if is_vision_request {
311 request_builder.header("Copilot-Vision-Request", "true")
312 } else {
313 request_builder
314 };
315
316 let is_streaming = request.stream;
317 let json = serde_json::to_string(&request)?;
318 let request = request_builder.body(AsyncBody::from(json))?;
319 let mut response = client.send(request).await?;
320
321 if !response.status().is_success() {
322 let mut body = String::new();
323 response.body_mut().read_to_string(&mut body).await?;
324 anyhow::bail!("Failed to connect to API: {} {}", response.status(), body);
325 }
326
327 if is_streaming {
328 let reader = BufReader::new(response.into_body());
329 Ok(reader
330 .lines()
331 .filter_map(|line| async move {
332 match line {
333 Ok(line) => {
334 let line = line.strip_prefix("data: ")?;
335 if line.starts_with("[DONE]") || line.is_empty() {
336 return None;
337 }
338
339 match serde_json::from_str::<StreamEvent>(line) {
340 Ok(event) => Some(Ok(event)),
341 Err(error) => {
342 log::error!(
343 "Failed to parse Copilot responses stream event: `{}`\nResponse: `{}`",
344 error,
345 line,
346 );
347 Some(Err(anyhow!(error)))
348 }
349 }
350 }
351 Err(error) => Some(Err(anyhow!(error))),
352 }
353 })
354 .boxed())
355 } else {
356 // Simulate streaming this makes the mapping of this function return more straight-forward to handle if all callers assume it streams.
357 // Removes the need of having a method to map StreamEvent and another to map Response to a LanguageCompletionEvent
358 let mut body = String::new();
359 response.body_mut().read_to_string(&mut body).await?;
360
361 match serde_json::from_str::<Response>(&body) {
362 Ok(response) => {
363 let events = vec![StreamEvent::Created {
364 response: response.clone(),
365 }];
366
367 let mut all_events = events;
368 for (output_index, item) in response.output.iter().enumerate() {
369 all_events.push(StreamEvent::OutputItemAdded {
370 output_index,
371 sequence_number: None,
372 item: item.clone(),
373 });
374
375 if let ResponseOutputItem::Message {
376 id,
377 content: Some(content),
378 ..
379 } = item
380 {
381 for part in content {
382 if let ResponseOutputContent::OutputText { text } = part {
383 all_events.push(StreamEvent::OutputTextDelta {
384 item_id: id.clone(),
385 output_index,
386 delta: text.clone(),
387 });
388 }
389 }
390 }
391
392 all_events.push(StreamEvent::OutputItemDone {
393 output_index,
394 sequence_number: None,
395 item: item.clone(),
396 });
397 }
398
399 let final_event = if response.error.is_some() {
400 StreamEvent::Failed { response }
401 } else if response.incomplete_details.is_some() {
402 StreamEvent::Incomplete { response }
403 } else {
404 StreamEvent::Completed { response }
405 };
406 all_events.push(final_event);
407
408 Ok(futures::stream::iter(all_events.into_iter().map(Ok)).boxed())
409 }
410 Err(error) => {
411 log::error!(
412 "Failed to parse Copilot non-streaming response: `{}`\nResponse: `{}`",
413 error,
414 body,
415 );
416 Err(anyhow!(error))
417 }
418 }
419 }
420}