1use super::*;
2use anyhow::{Result, anyhow};
3use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
4use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7pub use settings::OpenAiReasoningEffort as ReasoningEffort;
8
9#[derive(Serialize, Debug)]
10pub struct Request {
11 pub model: String,
12 pub input: Vec<ResponseInputItem>,
13 #[serde(default)]
14 pub stream: bool,
15 #[serde(skip_serializing_if = "Option::is_none")]
16 pub temperature: Option<f32>,
17 #[serde(skip_serializing_if = "Vec::is_empty")]
18 pub tools: Vec<ToolDefinition>,
19 #[serde(skip_serializing_if = "Option::is_none")]
20 pub tool_choice: Option<ToolChoice>,
21 #[serde(skip_serializing_if = "Option::is_none")]
22 pub reasoning: Option<ReasoningConfig>,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 pub include: Option<Vec<ResponseIncludable>>,
25}
26
27#[derive(Serialize, Deserialize, Debug, Clone)]
28#[serde(rename_all = "snake_case")]
29pub enum ResponseIncludable {
30 #[serde(rename = "reasoning.encrypted_content")]
31 ReasoningEncryptedContent,
32}
33
34#[derive(Serialize, Deserialize, Debug)]
35#[serde(tag = "type", rename_all = "snake_case")]
36pub enum ToolDefinition {
37 Function {
38 name: String,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 description: Option<String>,
41 #[serde(skip_serializing_if = "Option::is_none")]
42 parameters: Option<Value>,
43 #[serde(skip_serializing_if = "Option::is_none")]
44 strict: Option<bool>,
45 },
46}
47
48#[derive(Serialize, Deserialize, Debug)]
49#[serde(rename_all = "lowercase")]
50pub enum ToolChoice {
51 Auto,
52 Any,
53 None,
54 #[serde(untagged)]
55 Other(ToolDefinition),
56}
57
58#[derive(Serialize, Deserialize, Debug)]
59#[serde(rename_all = "lowercase")]
60pub enum ReasoningSummary {
61 Auto,
62 Concise,
63 Detailed,
64}
65
66#[derive(Serialize, Debug)]
67pub struct ReasoningConfig {
68 pub effort: ReasoningEffort,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 pub summary: Option<ReasoningSummary>,
71}
72
73#[derive(Serialize, Deserialize, Debug, Clone, Default)]
74#[serde(rename_all = "snake_case")]
75pub enum ResponseImageDetail {
76 Low,
77 High,
78 #[default]
79 Auto,
80}
81
82#[derive(Serialize, Deserialize, Debug, Clone)]
83#[serde(tag = "type", rename_all = "snake_case")]
84pub enum ResponseInputContent {
85 InputText {
86 text: String,
87 },
88 OutputText {
89 text: String,
90 },
91 InputImage {
92 #[serde(skip_serializing_if = "Option::is_none")]
93 image_url: Option<String>,
94 #[serde(default)]
95 detail: ResponseImageDetail,
96 },
97}
98
99#[derive(Serialize, Deserialize, Debug, Clone)]
100#[serde(rename_all = "snake_case")]
101pub enum ItemStatus {
102 InProgress,
103 Completed,
104 Incomplete,
105}
106
107#[derive(Serialize, Deserialize, Debug, Clone)]
108#[serde(untagged)]
109pub enum ResponseFunctionOutput {
110 Text(String),
111 Content(Vec<ResponseInputContent>),
112}
113
114#[derive(Serialize, Deserialize, Debug, Clone)]
115#[serde(tag = "type", rename_all = "snake_case")]
116pub enum ResponseInputItem {
117 Message {
118 role: String,
119 #[serde(skip_serializing_if = "Option::is_none")]
120 content: Option<Vec<ResponseInputContent>>,
121 #[serde(skip_serializing_if = "Option::is_none")]
122 status: Option<String>,
123 },
124 FunctionCall {
125 call_id: String,
126 name: String,
127 arguments: String,
128 #[serde(skip_serializing_if = "Option::is_none")]
129 status: Option<ItemStatus>,
130 },
131 FunctionCallOutput {
132 call_id: String,
133 output: ResponseFunctionOutput,
134 #[serde(skip_serializing_if = "Option::is_none")]
135 status: Option<ItemStatus>,
136 },
137 Reasoning {
138 #[serde(skip_serializing_if = "Option::is_none")]
139 id: Option<String>,
140 summary: Vec<ResponseReasoningItem>,
141 encrypted_content: String,
142 },
143}
144
145#[derive(Deserialize, Debug, Clone)]
146#[serde(rename_all = "snake_case")]
147pub enum IncompleteReason {
148 #[serde(rename = "max_output_tokens")]
149 MaxOutputTokens,
150 #[serde(rename = "content_filter")]
151 ContentFilter,
152}
153
154#[derive(Deserialize, Debug, Clone)]
155pub struct IncompleteDetails {
156 #[serde(skip_serializing_if = "Option::is_none")]
157 pub reason: Option<IncompleteReason>,
158}
159
160#[derive(Serialize, Deserialize, Debug, Clone)]
161pub struct ResponseReasoningItem {
162 #[serde(rename = "type")]
163 pub kind: String,
164 pub text: String,
165}
166
167#[derive(Deserialize, Debug)]
168#[serde(tag = "type")]
169pub enum StreamEvent {
170 #[serde(rename = "error")]
171 GenericError { error: ResponseError },
172
173 #[serde(rename = "response.created")]
174 Created { response: Response },
175
176 #[serde(rename = "response.output_item.added")]
177 OutputItemAdded {
178 output_index: usize,
179 #[serde(default)]
180 sequence_number: Option<u64>,
181 item: ResponseOutputItem,
182 },
183
184 #[serde(rename = "response.output_text.delta")]
185 OutputTextDelta {
186 item_id: String,
187 output_index: usize,
188 delta: String,
189 },
190
191 #[serde(rename = "response.output_item.done")]
192 OutputItemDone {
193 output_index: usize,
194 #[serde(default)]
195 sequence_number: Option<u64>,
196 item: ResponseOutputItem,
197 },
198
199 #[serde(rename = "response.incomplete")]
200 Incomplete { response: Response },
201
202 #[serde(rename = "response.completed")]
203 Completed { response: Response },
204
205 #[serde(rename = "response.failed")]
206 Failed { response: Response },
207
208 #[serde(other)]
209 Unknown,
210}
211
212#[derive(Deserialize, Debug, Clone)]
213pub struct ResponseError {
214 pub code: String,
215 pub message: String,
216}
217
218#[derive(Deserialize, Debug, Default, Clone)]
219pub struct Response {
220 pub id: Option<String>,
221 pub status: Option<String>,
222 pub usage: Option<ResponseUsage>,
223 pub output: Vec<ResponseOutputItem>,
224 #[serde(skip_serializing_if = "Option::is_none")]
225 pub incomplete_details: Option<IncompleteDetails>,
226 #[serde(skip_serializing_if = "Option::is_none")]
227 pub error: Option<ResponseError>,
228}
229
230#[derive(Deserialize, Debug, Default, Clone)]
231pub struct ResponseUsage {
232 pub input_tokens: Option<u64>,
233 pub output_tokens: Option<u64>,
234 pub total_tokens: Option<u64>,
235}
236
237#[derive(Deserialize, Debug, Clone)]
238#[serde(tag = "type", rename_all = "snake_case")]
239pub enum ResponseOutputItem {
240 Message {
241 id: String,
242 role: String,
243 #[serde(skip_serializing_if = "Option::is_none")]
244 content: Option<Vec<ResponseOutputContent>>,
245 },
246 FunctionCall {
247 #[serde(skip_serializing_if = "Option::is_none")]
248 id: Option<String>,
249 call_id: String,
250 name: String,
251 arguments: String,
252 #[serde(skip_serializing_if = "Option::is_none")]
253 status: Option<ItemStatus>,
254 },
255 Reasoning {
256 id: String,
257 #[serde(skip_serializing_if = "Option::is_none")]
258 summary: Option<Vec<ResponseReasoningItem>>,
259 #[serde(skip_serializing_if = "Option::is_none")]
260 encrypted_content: Option<String>,
261 },
262}
263
264#[derive(Deserialize, Debug, Clone)]
265#[serde(tag = "type", rename_all = "snake_case")]
266pub enum ResponseOutputContent {
267 OutputText { text: String },
268 Refusal { refusal: String },
269}
270
271pub async fn stream_response(
272 client: Arc<dyn HttpClient>,
273 api_key: String,
274 api_url: String,
275 request: Request,
276 is_user_initiated: bool,
277) -> Result<BoxStream<'static, Result<StreamEvent>>> {
278 let is_vision_request = request.input.iter().any(|item| match item {
279 ResponseInputItem::Message {
280 content: Some(parts),
281 ..
282 } => parts
283 .iter()
284 .any(|p| matches!(p, ResponseInputContent::InputImage { .. })),
285 _ => false,
286 });
287
288 let request_initiator = if is_user_initiated { "user" } else { "agent" };
289
290 let request_builder = HttpRequest::builder()
291 .method(Method::POST)
292 .uri(&api_url)
293 .header(
294 "Editor-Version",
295 format!(
296 "Zed/{}",
297 option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
298 ),
299 )
300 .header("Authorization", format!("Bearer {}", api_key))
301 .header("Content-Type", "application/json")
302 .header("Copilot-Integration-Id", "vscode-chat")
303 .header("X-Initiator", request_initiator);
304
305 let request_builder = if is_vision_request {
306 request_builder.header("Copilot-Vision-Request", "true")
307 } else {
308 request_builder
309 };
310
311 let is_streaming = request.stream;
312 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
313 let mut response = client.send(request).await?;
314
315 if !response.status().is_success() {
316 let mut body = String::new();
317 response.body_mut().read_to_string(&mut body).await?;
318 anyhow::bail!("Failed to connect to API: {} {}", response.status(), body);
319 }
320
321 if is_streaming {
322 let reader = BufReader::new(response.into_body());
323 Ok(reader
324 .lines()
325 .filter_map(|line| async move {
326 match line {
327 Ok(line) => {
328 let line = line.strip_prefix("data: ")?;
329 if line.starts_with("[DONE]") || line.is_empty() {
330 return None;
331 }
332
333 match serde_json::from_str::<StreamEvent>(line) {
334 Ok(event) => Some(Ok(event)),
335 Err(error) => {
336 log::error!(
337 "Failed to parse Copilot responses stream event: `{}`\nResponse: `{}`",
338 error,
339 line,
340 );
341 Some(Err(anyhow!(error)))
342 }
343 }
344 }
345 Err(error) => Some(Err(anyhow!(error))),
346 }
347 })
348 .boxed())
349 } else {
350 // Simulate streaming this makes the mapping of this function return more straight-forward to handle if all callers assume it streams.
351 // Removes the need of having a method to map StreamEvent and another to map Response to a LanguageCompletionEvent
352 let mut body = String::new();
353 response.body_mut().read_to_string(&mut body).await?;
354
355 match serde_json::from_str::<Response>(&body) {
356 Ok(response) => {
357 let events = vec![StreamEvent::Created {
358 response: response.clone(),
359 }];
360
361 let mut all_events = events;
362 for (output_index, item) in response.output.iter().enumerate() {
363 all_events.push(StreamEvent::OutputItemAdded {
364 output_index,
365 sequence_number: None,
366 item: item.clone(),
367 });
368
369 if let ResponseOutputItem::Message {
370 id,
371 content: Some(content),
372 ..
373 } = item
374 {
375 for part in content {
376 if let ResponseOutputContent::OutputText { text } = part {
377 all_events.push(StreamEvent::OutputTextDelta {
378 item_id: id.clone(),
379 output_index,
380 delta: text.clone(),
381 });
382 }
383 }
384 }
385
386 all_events.push(StreamEvent::OutputItemDone {
387 output_index,
388 sequence_number: None,
389 item: item.clone(),
390 });
391 }
392
393 let final_event = if response.error.is_some() {
394 StreamEvent::Failed { response }
395 } else if response.incomplete_details.is_some() {
396 StreamEvent::Incomplete { response }
397 } else {
398 StreamEvent::Completed { response }
399 };
400 all_events.push(final_event);
401
402 Ok(futures::stream::iter(all_events.into_iter().map(Ok)).boxed())
403 }
404 Err(error) => {
405 log::error!(
406 "Failed to parse Copilot non-streaming response: `{}`\nResponse: `{}`",
407 error,
408 body,
409 );
410 Err(anyhow!(error))
411 }
412 }
413 }
414}