1use std::sync::Arc;
2
3use super::copilot_request_headers;
4use anyhow::{Result, anyhow};
5use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
6use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9pub use settings::OpenAiReasoningEffort as ReasoningEffort;
10
11#[derive(Serialize, Debug)]
12pub struct Request {
13 pub model: String,
14 pub input: Vec<ResponseInputItem>,
15 #[serde(default)]
16 pub stream: bool,
17 #[serde(skip_serializing_if = "Option::is_none")]
18 pub temperature: Option<f32>,
19 #[serde(skip_serializing_if = "Vec::is_empty")]
20 pub tools: Vec<ToolDefinition>,
21 #[serde(skip_serializing_if = "Option::is_none")]
22 pub tool_choice: Option<ToolChoice>,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 pub reasoning: Option<ReasoningConfig>,
25 #[serde(skip_serializing_if = "Option::is_none")]
26 pub include: Option<Vec<ResponseIncludable>>,
27}
28
29#[derive(Serialize, Deserialize, Debug, Clone)]
30#[serde(rename_all = "snake_case")]
31pub enum ResponseIncludable {
32 #[serde(rename = "reasoning.encrypted_content")]
33 ReasoningEncryptedContent,
34}
35
36#[derive(Serialize, Deserialize, Debug)]
37#[serde(tag = "type", rename_all = "snake_case")]
38pub enum ToolDefinition {
39 Function {
40 name: String,
41 #[serde(skip_serializing_if = "Option::is_none")]
42 description: Option<String>,
43 #[serde(skip_serializing_if = "Option::is_none")]
44 parameters: Option<Value>,
45 #[serde(skip_serializing_if = "Option::is_none")]
46 strict: Option<bool>,
47 },
48}
49
50#[derive(Serialize, Deserialize, Debug)]
51#[serde(rename_all = "lowercase")]
52pub enum ToolChoice {
53 Auto,
54 Any,
55 None,
56 #[serde(untagged)]
57 Other(ToolDefinition),
58}
59
60#[derive(Serialize, Deserialize, Debug)]
61#[serde(rename_all = "lowercase")]
62pub enum ReasoningSummary {
63 Auto,
64 Concise,
65 Detailed,
66}
67
68#[derive(Serialize, Debug)]
69pub struct ReasoningConfig {
70 pub effort: ReasoningEffort,
71 #[serde(skip_serializing_if = "Option::is_none")]
72 pub summary: Option<ReasoningSummary>,
73}
74
75#[derive(Serialize, Deserialize, Debug, Clone, Default)]
76#[serde(rename_all = "snake_case")]
77pub enum ResponseImageDetail {
78 Low,
79 High,
80 #[default]
81 Auto,
82}
83
84#[derive(Serialize, Deserialize, Debug, Clone)]
85#[serde(tag = "type", rename_all = "snake_case")]
86pub enum ResponseInputContent {
87 InputText {
88 text: String,
89 },
90 OutputText {
91 text: String,
92 },
93 InputImage {
94 #[serde(skip_serializing_if = "Option::is_none")]
95 image_url: Option<String>,
96 #[serde(default)]
97 detail: ResponseImageDetail,
98 },
99}
100
101#[derive(Serialize, Deserialize, Debug, Clone)]
102#[serde(rename_all = "snake_case")]
103pub enum ItemStatus {
104 InProgress,
105 Completed,
106 Incomplete,
107}
108
109#[derive(Serialize, Deserialize, Debug, Clone)]
110#[serde(untagged)]
111pub enum ResponseFunctionOutput {
112 Text(String),
113 Content(Vec<ResponseInputContent>),
114}
115
116#[derive(Serialize, Deserialize, Debug, Clone)]
117#[serde(tag = "type", rename_all = "snake_case")]
118pub enum ResponseInputItem {
119 Message {
120 role: String,
121 #[serde(skip_serializing_if = "Option::is_none")]
122 content: Option<Vec<ResponseInputContent>>,
123 #[serde(skip_serializing_if = "Option::is_none")]
124 status: Option<String>,
125 },
126 FunctionCall {
127 call_id: String,
128 name: String,
129 arguments: String,
130 #[serde(skip_serializing_if = "Option::is_none")]
131 status: Option<ItemStatus>,
132 #[serde(default, skip_serializing_if = "Option::is_none")]
133 thought_signature: Option<String>,
134 },
135 FunctionCallOutput {
136 call_id: String,
137 output: ResponseFunctionOutput,
138 #[serde(skip_serializing_if = "Option::is_none")]
139 status: Option<ItemStatus>,
140 },
141 Reasoning {
142 #[serde(skip_serializing_if = "Option::is_none")]
143 id: Option<String>,
144 summary: Vec<ResponseReasoningItem>,
145 encrypted_content: String,
146 },
147}
148
149#[derive(Deserialize, Debug, Clone)]
150#[serde(rename_all = "snake_case")]
151pub enum IncompleteReason {
152 #[serde(rename = "max_output_tokens")]
153 MaxOutputTokens,
154 #[serde(rename = "content_filter")]
155 ContentFilter,
156}
157
158#[derive(Deserialize, Debug, Clone)]
159pub struct IncompleteDetails {
160 #[serde(skip_serializing_if = "Option::is_none")]
161 pub reason: Option<IncompleteReason>,
162}
163
164#[derive(Serialize, Deserialize, Debug, Clone)]
165pub struct ResponseReasoningItem {
166 #[serde(rename = "type")]
167 pub kind: String,
168 pub text: String,
169}
170
171#[derive(Deserialize, Debug)]
172#[serde(tag = "type")]
173pub enum StreamEvent {
174 #[serde(rename = "error")]
175 GenericError { error: ResponseError },
176
177 #[serde(rename = "response.created")]
178 Created { response: Response },
179
180 #[serde(rename = "response.output_item.added")]
181 OutputItemAdded {
182 output_index: usize,
183 #[serde(default)]
184 sequence_number: Option<u64>,
185 item: ResponseOutputItem,
186 },
187
188 #[serde(rename = "response.output_text.delta")]
189 OutputTextDelta {
190 item_id: String,
191 output_index: usize,
192 delta: String,
193 },
194
195 #[serde(rename = "response.output_item.done")]
196 OutputItemDone {
197 output_index: usize,
198 #[serde(default)]
199 sequence_number: Option<u64>,
200 item: ResponseOutputItem,
201 },
202
203 #[serde(rename = "response.incomplete")]
204 Incomplete { response: Response },
205
206 #[serde(rename = "response.completed")]
207 Completed { response: Response },
208
209 #[serde(rename = "response.failed")]
210 Failed { response: Response },
211
212 #[serde(other)]
213 Unknown,
214}
215
216#[derive(Deserialize, Debug, Clone)]
217pub struct ResponseError {
218 pub code: String,
219 pub message: String,
220}
221
222#[derive(Deserialize, Debug, Default, Clone)]
223pub struct Response {
224 pub id: Option<String>,
225 pub status: Option<String>,
226 pub usage: Option<ResponseUsage>,
227 pub output: Vec<ResponseOutputItem>,
228 #[serde(skip_serializing_if = "Option::is_none")]
229 pub incomplete_details: Option<IncompleteDetails>,
230 #[serde(skip_serializing_if = "Option::is_none")]
231 pub error: Option<ResponseError>,
232}
233
234#[derive(Deserialize, Debug, Default, Clone)]
235pub struct ResponseUsage {
236 pub input_tokens: Option<u64>,
237 pub output_tokens: Option<u64>,
238 pub total_tokens: Option<u64>,
239}
240
241#[derive(Deserialize, Debug, Clone)]
242#[serde(tag = "type", rename_all = "snake_case")]
243pub enum ResponseOutputItem {
244 Message {
245 id: String,
246 role: String,
247 #[serde(skip_serializing_if = "Option::is_none")]
248 content: Option<Vec<ResponseOutputContent>>,
249 },
250 FunctionCall {
251 #[serde(skip_serializing_if = "Option::is_none")]
252 id: Option<String>,
253 call_id: String,
254 name: String,
255 arguments: String,
256 #[serde(skip_serializing_if = "Option::is_none")]
257 status: Option<ItemStatus>,
258 #[serde(default, skip_serializing_if = "Option::is_none")]
259 thought_signature: Option<String>,
260 },
261 Reasoning {
262 id: String,
263 #[serde(skip_serializing_if = "Option::is_none")]
264 summary: Option<Vec<ResponseReasoningItem>>,
265 #[serde(skip_serializing_if = "Option::is_none")]
266 encrypted_content: Option<String>,
267 },
268}
269
270#[derive(Deserialize, Debug, Clone)]
271#[serde(tag = "type", rename_all = "snake_case")]
272pub enum ResponseOutputContent {
273 OutputText { text: String },
274 Refusal { refusal: String },
275}
276
277pub async fn stream_response(
278 client: Arc<dyn HttpClient>,
279 oauth_token: String,
280 api_url: String,
281 request: Request,
282 is_user_initiated: bool,
283) -> Result<BoxStream<'static, Result<StreamEvent>>> {
284 let is_vision_request = request.input.iter().any(|item| match item {
285 ResponseInputItem::Message {
286 content: Some(parts),
287 ..
288 } => parts
289 .iter()
290 .any(|p| matches!(p, ResponseInputContent::InputImage { .. })),
291 _ => false,
292 });
293
294 let request_builder = copilot_request_headers(
295 HttpRequest::builder().method(Method::POST).uri(&api_url),
296 &oauth_token,
297 Some(is_user_initiated),
298 );
299
300 let request_builder = if is_vision_request {
301 request_builder.header("Copilot-Vision-Request", "true")
302 } else {
303 request_builder
304 };
305
306 let is_streaming = request.stream;
307 let json = serde_json::to_string(&request)?;
308 let request = request_builder.body(AsyncBody::from(json))?;
309 let mut response = client.send(request).await?;
310
311 if !response.status().is_success() {
312 let mut body = String::new();
313 response.body_mut().read_to_string(&mut body).await?;
314 anyhow::bail!("Failed to connect to API: {} {}", response.status(), body);
315 }
316
317 if is_streaming {
318 let reader = BufReader::new(response.into_body());
319 Ok(reader
320 .lines()
321 .filter_map(|line| async move {
322 match line {
323 Ok(line) => {
324 let line = line.strip_prefix("data: ")?;
325 if line.starts_with("[DONE]") || line.is_empty() {
326 return None;
327 }
328
329 match serde_json::from_str::<StreamEvent>(line) {
330 Ok(event) => Some(Ok(event)),
331 Err(error) => {
332 log::error!(
333 "Failed to parse Copilot responses stream event: `{}`\nResponse: `{}`",
334 error,
335 line,
336 );
337 Some(Err(anyhow!(error)))
338 }
339 }
340 }
341 Err(error) => Some(Err(anyhow!(error))),
342 }
343 })
344 .boxed())
345 } else {
346 // Simulate streaming this makes the mapping of this function return more straight-forward to handle if all callers assume it streams.
347 // Removes the need of having a method to map StreamEvent and another to map Response to a LanguageCompletionEvent
348 let mut body = String::new();
349 response.body_mut().read_to_string(&mut body).await?;
350
351 match serde_json::from_str::<Response>(&body) {
352 Ok(response) => {
353 let events = vec![StreamEvent::Created {
354 response: response.clone(),
355 }];
356
357 let mut all_events = events;
358 for (output_index, item) in response.output.iter().enumerate() {
359 all_events.push(StreamEvent::OutputItemAdded {
360 output_index,
361 sequence_number: None,
362 item: item.clone(),
363 });
364
365 if let ResponseOutputItem::Message {
366 id,
367 content: Some(content),
368 ..
369 } = item
370 {
371 for part in content {
372 if let ResponseOutputContent::OutputText { text } = part {
373 all_events.push(StreamEvent::OutputTextDelta {
374 item_id: id.clone(),
375 output_index,
376 delta: text.clone(),
377 });
378 }
379 }
380 }
381
382 all_events.push(StreamEvent::OutputItemDone {
383 output_index,
384 sequence_number: None,
385 item: item.clone(),
386 });
387 }
388
389 let final_event = if response.error.is_some() {
390 StreamEvent::Failed { response }
391 } else if response.incomplete_details.is_some() {
392 StreamEvent::Incomplete { response }
393 } else {
394 StreamEvent::Completed { response }
395 };
396 all_events.push(final_event);
397
398 Ok(futures::stream::iter(all_events.into_iter().map(Ok)).boxed())
399 }
400 Err(error) => {
401 log::error!(
402 "Failed to parse Copilot non-streaming response: `{}`\nResponse: `{}`",
403 error,
404 body,
405 );
406 Err(anyhow!(error))
407 }
408 }
409 }
410}