1use super::*;
2use anyhow::{Result, anyhow};
3use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
4use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7pub use settings::OpenAiReasoningEffort as ReasoningEffort;
8
9#[derive(Serialize, Debug)]
10pub struct Request {
11 pub model: String,
12 pub input: Vec<ResponseInputItem>,
13 #[serde(default)]
14 pub stream: bool,
15 #[serde(skip_serializing_if = "Option::is_none")]
16 pub temperature: Option<f32>,
17 #[serde(skip_serializing_if = "Vec::is_empty")]
18 pub tools: Vec<ToolDefinition>,
19 #[serde(skip_serializing_if = "Option::is_none")]
20 pub tool_choice: Option<ToolChoice>,
21 #[serde(skip_serializing_if = "Option::is_none")]
22 pub reasoning: Option<ReasoningConfig>,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 pub include: Option<Vec<ResponseIncludable>>,
25}
26
27#[derive(Serialize, Deserialize, Debug, Clone)]
28#[serde(rename_all = "snake_case")]
29pub enum ResponseIncludable {
30 #[serde(rename = "reasoning.encrypted_content")]
31 ReasoningEncryptedContent,
32}
33
34#[derive(Serialize, Deserialize, Debug)]
35#[serde(tag = "type", rename_all = "snake_case")]
36pub enum ToolDefinition {
37 Function {
38 name: String,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 description: Option<String>,
41 #[serde(skip_serializing_if = "Option::is_none")]
42 parameters: Option<Value>,
43 #[serde(skip_serializing_if = "Option::is_none")]
44 strict: Option<bool>,
45 },
46}
47
48#[derive(Serialize, Deserialize, Debug)]
49#[serde(rename_all = "lowercase")]
50pub enum ToolChoice {
51 Auto,
52 Any,
53 None,
54 #[serde(untagged)]
55 Other(ToolDefinition),
56}
57
58#[derive(Serialize, Deserialize, Debug)]
59#[serde(rename_all = "lowercase")]
60pub enum ReasoningSummary {
61 Auto,
62 Concise,
63 Detailed,
64}
65
66#[derive(Serialize, Debug)]
67pub struct ReasoningConfig {
68 pub effort: ReasoningEffort,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 pub summary: Option<ReasoningSummary>,
71}
72
73#[derive(Serialize, Deserialize, Debug, Clone, Default)]
74#[serde(rename_all = "snake_case")]
75pub enum ResponseImageDetail {
76 Low,
77 High,
78 #[default]
79 Auto,
80}
81
82#[derive(Serialize, Deserialize, Debug, Clone)]
83#[serde(tag = "type", rename_all = "snake_case")]
84pub enum ResponseInputContent {
85 InputText {
86 text: String,
87 },
88 OutputText {
89 text: String,
90 },
91 InputImage {
92 #[serde(skip_serializing_if = "Option::is_none")]
93 image_url: Option<String>,
94 #[serde(default)]
95 detail: ResponseImageDetail,
96 },
97}
98
99#[derive(Serialize, Deserialize, Debug, Clone)]
100#[serde(rename_all = "snake_case")]
101pub enum ItemStatus {
102 InProgress,
103 Completed,
104 Incomplete,
105}
106
107#[derive(Serialize, Deserialize, Debug, Clone)]
108#[serde(untagged)]
109pub enum ResponseFunctionOutput {
110 Text(String),
111 Content(Vec<ResponseInputContent>),
112}
113
114#[derive(Serialize, Deserialize, Debug, Clone)]
115#[serde(tag = "type", rename_all = "snake_case")]
116pub enum ResponseInputItem {
117 Message {
118 role: String,
119 #[serde(skip_serializing_if = "Option::is_none")]
120 content: Option<Vec<ResponseInputContent>>,
121 #[serde(skip_serializing_if = "Option::is_none")]
122 status: Option<String>,
123 },
124 FunctionCall {
125 call_id: String,
126 name: String,
127 arguments: String,
128 #[serde(skip_serializing_if = "Option::is_none")]
129 status: Option<ItemStatus>,
130 #[serde(default, skip_serializing_if = "Option::is_none")]
131 thought_signature: Option<String>,
132 },
133 FunctionCallOutput {
134 call_id: String,
135 output: ResponseFunctionOutput,
136 #[serde(skip_serializing_if = "Option::is_none")]
137 status: Option<ItemStatus>,
138 },
139 Reasoning {
140 #[serde(skip_serializing_if = "Option::is_none")]
141 id: Option<String>,
142 summary: Vec<ResponseReasoningItem>,
143 encrypted_content: String,
144 },
145}
146
147#[derive(Deserialize, Debug, Clone)]
148#[serde(rename_all = "snake_case")]
149pub enum IncompleteReason {
150 #[serde(rename = "max_output_tokens")]
151 MaxOutputTokens,
152 #[serde(rename = "content_filter")]
153 ContentFilter,
154}
155
156#[derive(Deserialize, Debug, Clone)]
157pub struct IncompleteDetails {
158 #[serde(skip_serializing_if = "Option::is_none")]
159 pub reason: Option<IncompleteReason>,
160}
161
162#[derive(Serialize, Deserialize, Debug, Clone)]
163pub struct ResponseReasoningItem {
164 #[serde(rename = "type")]
165 pub kind: String,
166 pub text: String,
167}
168
169#[derive(Deserialize, Debug)]
170#[serde(tag = "type")]
171pub enum StreamEvent {
172 #[serde(rename = "error")]
173 GenericError { error: ResponseError },
174
175 #[serde(rename = "response.created")]
176 Created { response: Response },
177
178 #[serde(rename = "response.output_item.added")]
179 OutputItemAdded {
180 output_index: usize,
181 #[serde(default)]
182 sequence_number: Option<u64>,
183 item: ResponseOutputItem,
184 },
185
186 #[serde(rename = "response.output_text.delta")]
187 OutputTextDelta {
188 item_id: String,
189 output_index: usize,
190 delta: String,
191 },
192
193 #[serde(rename = "response.output_item.done")]
194 OutputItemDone {
195 output_index: usize,
196 #[serde(default)]
197 sequence_number: Option<u64>,
198 item: ResponseOutputItem,
199 },
200
201 #[serde(rename = "response.incomplete")]
202 Incomplete { response: Response },
203
204 #[serde(rename = "response.completed")]
205 Completed { response: Response },
206
207 #[serde(rename = "response.failed")]
208 Failed { response: Response },
209
210 #[serde(other)]
211 Unknown,
212}
213
214#[derive(Deserialize, Debug, Clone)]
215pub struct ResponseError {
216 pub code: String,
217 pub message: String,
218}
219
220#[derive(Deserialize, Debug, Default, Clone)]
221pub struct Response {
222 pub id: Option<String>,
223 pub status: Option<String>,
224 pub usage: Option<ResponseUsage>,
225 pub output: Vec<ResponseOutputItem>,
226 #[serde(skip_serializing_if = "Option::is_none")]
227 pub incomplete_details: Option<IncompleteDetails>,
228 #[serde(skip_serializing_if = "Option::is_none")]
229 pub error: Option<ResponseError>,
230}
231
232#[derive(Deserialize, Debug, Default, Clone)]
233pub struct ResponseUsage {
234 pub input_tokens: Option<u64>,
235 pub output_tokens: Option<u64>,
236 pub total_tokens: Option<u64>,
237}
238
239#[derive(Deserialize, Debug, Clone)]
240#[serde(tag = "type", rename_all = "snake_case")]
241pub enum ResponseOutputItem {
242 Message {
243 id: String,
244 role: String,
245 #[serde(skip_serializing_if = "Option::is_none")]
246 content: Option<Vec<ResponseOutputContent>>,
247 },
248 FunctionCall {
249 #[serde(skip_serializing_if = "Option::is_none")]
250 id: Option<String>,
251 call_id: String,
252 name: String,
253 arguments: String,
254 #[serde(skip_serializing_if = "Option::is_none")]
255 status: Option<ItemStatus>,
256 #[serde(default, skip_serializing_if = "Option::is_none")]
257 thought_signature: Option<String>,
258 },
259 Reasoning {
260 id: String,
261 #[serde(skip_serializing_if = "Option::is_none")]
262 summary: Option<Vec<ResponseReasoningItem>>,
263 #[serde(skip_serializing_if = "Option::is_none")]
264 encrypted_content: Option<String>,
265 },
266}
267
268#[derive(Deserialize, Debug, Clone)]
269#[serde(tag = "type", rename_all = "snake_case")]
270pub enum ResponseOutputContent {
271 OutputText { text: String },
272 Refusal { refusal: String },
273}
274
275pub async fn stream_response(
276 client: Arc<dyn HttpClient>,
277 api_key: String,
278 api_url: String,
279 request: Request,
280 is_user_initiated: bool,
281) -> Result<BoxStream<'static, Result<StreamEvent>>> {
282 let is_vision_request = request.input.iter().any(|item| match item {
283 ResponseInputItem::Message {
284 content: Some(parts),
285 ..
286 } => parts
287 .iter()
288 .any(|p| matches!(p, ResponseInputContent::InputImage { .. })),
289 _ => false,
290 });
291
292 let request_initiator = if is_user_initiated { "user" } else { "agent" };
293
294 let request_builder = HttpRequest::builder()
295 .method(Method::POST)
296 .uri(&api_url)
297 .header(
298 "Editor-Version",
299 format!(
300 "Zed/{}",
301 option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
302 ),
303 )
304 .header("Authorization", format!("Bearer {}", api_key))
305 .header("Content-Type", "application/json")
306 .header("Copilot-Integration-Id", "vscode-chat")
307 .header("X-Initiator", request_initiator);
308
309 let request_builder = if is_vision_request {
310 request_builder.header("Copilot-Vision-Request", "true")
311 } else {
312 request_builder
313 };
314
315 let is_streaming = request.stream;
316 let json = serde_json::to_string(&request)?;
317 let request = request_builder.body(AsyncBody::from(json))?;
318 let mut response = client.send(request).await?;
319
320 if !response.status().is_success() {
321 let mut body = String::new();
322 response.body_mut().read_to_string(&mut body).await?;
323 anyhow::bail!("Failed to connect to API: {} {}", response.status(), body);
324 }
325
326 if is_streaming {
327 let reader = BufReader::new(response.into_body());
328 Ok(reader
329 .lines()
330 .filter_map(|line| async move {
331 match line {
332 Ok(line) => {
333 let line = line.strip_prefix("data: ")?;
334 if line.starts_with("[DONE]") || line.is_empty() {
335 return None;
336 }
337
338 match serde_json::from_str::<StreamEvent>(line) {
339 Ok(event) => Some(Ok(event)),
340 Err(error) => {
341 log::error!(
342 "Failed to parse Copilot responses stream event: `{}`\nResponse: `{}`",
343 error,
344 line,
345 );
346 Some(Err(anyhow!(error)))
347 }
348 }
349 }
350 Err(error) => Some(Err(anyhow!(error))),
351 }
352 })
353 .boxed())
354 } else {
355 // Simulate streaming this makes the mapping of this function return more straight-forward to handle if all callers assume it streams.
356 // Removes the need of having a method to map StreamEvent and another to map Response to a LanguageCompletionEvent
357 let mut body = String::new();
358 response.body_mut().read_to_string(&mut body).await?;
359
360 match serde_json::from_str::<Response>(&body) {
361 Ok(response) => {
362 let events = vec![StreamEvent::Created {
363 response: response.clone(),
364 }];
365
366 let mut all_events = events;
367 for (output_index, item) in response.output.iter().enumerate() {
368 all_events.push(StreamEvent::OutputItemAdded {
369 output_index,
370 sequence_number: None,
371 item: item.clone(),
372 });
373
374 if let ResponseOutputItem::Message {
375 id,
376 content: Some(content),
377 ..
378 } = item
379 {
380 for part in content {
381 if let ResponseOutputContent::OutputText { text } = part {
382 all_events.push(StreamEvent::OutputTextDelta {
383 item_id: id.clone(),
384 output_index,
385 delta: text.clone(),
386 });
387 }
388 }
389 }
390
391 all_events.push(StreamEvent::OutputItemDone {
392 output_index,
393 sequence_number: None,
394 item: item.clone(),
395 });
396 }
397
398 let final_event = if response.error.is_some() {
399 StreamEvent::Failed { response }
400 } else if response.incomplete_details.is_some() {
401 StreamEvent::Incomplete { response }
402 } else {
403 StreamEvent::Completed { response }
404 };
405 all_events.push(final_event);
406
407 Ok(futures::stream::iter(all_events.into_iter().map(Ok)).boxed())
408 }
409 Err(error) => {
410 log::error!(
411 "Failed to parse Copilot non-streaming response: `{}`\nResponse: `{}`",
412 error,
413 body,
414 );
415 Err(anyhow!(error))
416 }
417 }
418 }
419}