1use std::sync::Arc;
2
3use super::{ChatLocation, copilot_request_headers};
4use anyhow::{Result, anyhow};
5use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
6use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9pub use settings::OpenAiReasoningEffort as ReasoningEffort;
10
11#[derive(Serialize, Debug)]
12pub struct Request {
13 pub model: String,
14 pub input: Vec<ResponseInputItem>,
15 #[serde(default)]
16 pub stream: bool,
17 #[serde(skip_serializing_if = "Option::is_none")]
18 pub temperature: Option<f32>,
19 #[serde(skip_serializing_if = "Vec::is_empty")]
20 pub tools: Vec<ToolDefinition>,
21 #[serde(skip_serializing_if = "Option::is_none")]
22 pub tool_choice: Option<ToolChoice>,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 pub reasoning: Option<ReasoningConfig>,
25 #[serde(skip_serializing_if = "Option::is_none")]
26 pub include: Option<Vec<ResponseIncludable>>,
27 pub store: bool,
28}
29
30#[derive(Serialize, Deserialize, Debug, Clone)]
31#[serde(rename_all = "snake_case")]
32pub enum ResponseIncludable {
33 #[serde(rename = "reasoning.encrypted_content")]
34 ReasoningEncryptedContent,
35}
36
37#[derive(Serialize, Deserialize, Debug)]
38#[serde(tag = "type", rename_all = "snake_case")]
39pub enum ToolDefinition {
40 Function {
41 name: String,
42 #[serde(skip_serializing_if = "Option::is_none")]
43 description: Option<String>,
44 #[serde(skip_serializing_if = "Option::is_none")]
45 parameters: Option<Value>,
46 #[serde(skip_serializing_if = "Option::is_none")]
47 strict: Option<bool>,
48 },
49}
50
51#[derive(Serialize, Deserialize, Debug)]
52#[serde(rename_all = "lowercase")]
53pub enum ToolChoice {
54 Auto,
55 Any,
56 None,
57 #[serde(untagged)]
58 Other(ToolDefinition),
59}
60
61#[derive(Serialize, Deserialize, Debug)]
62#[serde(rename_all = "lowercase")]
63pub enum ReasoningSummary {
64 Auto,
65 Concise,
66 Detailed,
67}
68
69#[derive(Serialize, Debug)]
70pub struct ReasoningConfig {
71 pub effort: ReasoningEffort,
72 #[serde(skip_serializing_if = "Option::is_none")]
73 pub summary: Option<ReasoningSummary>,
74}
75
76#[derive(Serialize, Deserialize, Debug, Clone, Default)]
77#[serde(rename_all = "snake_case")]
78pub enum ResponseImageDetail {
79 Low,
80 High,
81 #[default]
82 Auto,
83}
84
85#[derive(Serialize, Deserialize, Debug, Clone)]
86#[serde(tag = "type", rename_all = "snake_case")]
87pub enum ResponseInputContent {
88 InputText {
89 text: String,
90 },
91 OutputText {
92 text: String,
93 },
94 InputImage {
95 #[serde(skip_serializing_if = "Option::is_none")]
96 image_url: Option<String>,
97 #[serde(default)]
98 detail: ResponseImageDetail,
99 },
100}
101
102#[derive(Serialize, Deserialize, Debug, Clone)]
103#[serde(rename_all = "snake_case")]
104pub enum ItemStatus {
105 InProgress,
106 Completed,
107 Incomplete,
108}
109
110#[derive(Serialize, Deserialize, Debug, Clone)]
111#[serde(untagged)]
112pub enum ResponseFunctionOutput {
113 Text(String),
114 Content(Vec<ResponseInputContent>),
115}
116
117#[derive(Serialize, Deserialize, Debug, Clone)]
118#[serde(tag = "type", rename_all = "snake_case")]
119pub enum ResponseInputItem {
120 Message {
121 role: String,
122 #[serde(skip_serializing_if = "Option::is_none")]
123 content: Option<Vec<ResponseInputContent>>,
124 #[serde(skip_serializing_if = "Option::is_none")]
125 status: Option<String>,
126 },
127 FunctionCall {
128 call_id: String,
129 name: String,
130 arguments: String,
131 #[serde(skip_serializing_if = "Option::is_none")]
132 status: Option<ItemStatus>,
133 #[serde(default, skip_serializing_if = "Option::is_none")]
134 thought_signature: Option<String>,
135 },
136 FunctionCallOutput {
137 call_id: String,
138 output: ResponseFunctionOutput,
139 #[serde(skip_serializing_if = "Option::is_none")]
140 status: Option<ItemStatus>,
141 },
142 Reasoning {
143 #[serde(skip_serializing_if = "Option::is_none")]
144 id: Option<String>,
145 summary: Vec<ResponseReasoningItem>,
146 encrypted_content: String,
147 },
148}
149
150#[derive(Deserialize, Debug, Clone)]
151#[serde(rename_all = "snake_case")]
152pub enum IncompleteReason {
153 #[serde(rename = "max_output_tokens")]
154 MaxOutputTokens,
155 #[serde(rename = "content_filter")]
156 ContentFilter,
157}
158
159#[derive(Deserialize, Debug, Clone)]
160pub struct IncompleteDetails {
161 #[serde(skip_serializing_if = "Option::is_none")]
162 pub reason: Option<IncompleteReason>,
163}
164
165#[derive(Serialize, Deserialize, Debug, Clone)]
166pub struct ResponseReasoningItem {
167 #[serde(rename = "type")]
168 pub kind: String,
169 pub text: String,
170}
171
172#[derive(Deserialize, Debug)]
173#[serde(tag = "type")]
174pub enum StreamEvent {
175 #[serde(rename = "error")]
176 GenericError { error: ResponseError },
177
178 #[serde(rename = "response.created")]
179 Created { response: Response },
180
181 #[serde(rename = "response.output_item.added")]
182 OutputItemAdded {
183 output_index: usize,
184 #[serde(default)]
185 sequence_number: Option<u64>,
186 item: ResponseOutputItem,
187 },
188
189 #[serde(rename = "response.output_text.delta")]
190 OutputTextDelta {
191 item_id: String,
192 output_index: usize,
193 delta: String,
194 },
195
196 #[serde(rename = "response.output_item.done")]
197 OutputItemDone {
198 output_index: usize,
199 #[serde(default)]
200 sequence_number: Option<u64>,
201 item: ResponseOutputItem,
202 },
203
204 #[serde(rename = "response.incomplete")]
205 Incomplete { response: Response },
206
207 #[serde(rename = "response.completed")]
208 Completed { response: Response },
209
210 #[serde(rename = "response.failed")]
211 Failed { response: Response },
212
213 #[serde(other)]
214 Unknown,
215}
216
217#[derive(Deserialize, Debug, Clone)]
218pub struct ResponseError {
219 pub code: String,
220 pub message: String,
221}
222
223#[derive(Deserialize, Debug, Default, Clone)]
224pub struct Response {
225 pub id: Option<String>,
226 pub status: Option<String>,
227 pub usage: Option<ResponseUsage>,
228 pub output: Vec<ResponseOutputItem>,
229 #[serde(skip_serializing_if = "Option::is_none")]
230 pub incomplete_details: Option<IncompleteDetails>,
231 #[serde(skip_serializing_if = "Option::is_none")]
232 pub error: Option<ResponseError>,
233}
234
235#[derive(Deserialize, Debug, Default, Clone)]
236pub struct ResponseUsage {
237 pub input_tokens: Option<u64>,
238 pub output_tokens: Option<u64>,
239 pub total_tokens: Option<u64>,
240}
241
242#[derive(Deserialize, Debug, Clone)]
243#[serde(tag = "type", rename_all = "snake_case")]
244pub enum ResponseOutputItem {
245 Message {
246 id: String,
247 role: String,
248 #[serde(skip_serializing_if = "Option::is_none")]
249 content: Option<Vec<ResponseOutputContent>>,
250 },
251 FunctionCall {
252 #[serde(skip_serializing_if = "Option::is_none")]
253 id: Option<String>,
254 call_id: String,
255 name: String,
256 arguments: String,
257 #[serde(skip_serializing_if = "Option::is_none")]
258 status: Option<ItemStatus>,
259 #[serde(default, skip_serializing_if = "Option::is_none")]
260 thought_signature: Option<String>,
261 },
262 Reasoning {
263 id: String,
264 #[serde(skip_serializing_if = "Option::is_none")]
265 summary: Option<Vec<ResponseReasoningItem>>,
266 #[serde(skip_serializing_if = "Option::is_none")]
267 encrypted_content: Option<String>,
268 },
269}
270
271#[derive(Deserialize, Debug, Clone)]
272#[serde(tag = "type", rename_all = "snake_case")]
273pub enum ResponseOutputContent {
274 OutputText { text: String },
275 Refusal { refusal: String },
276}
277
278pub async fn stream_response(
279 client: Arc<dyn HttpClient>,
280 oauth_token: String,
281 api_url: String,
282 request: Request,
283 is_user_initiated: bool,
284 location: ChatLocation,
285) -> Result<BoxStream<'static, Result<StreamEvent>>> {
286 let is_vision_request = request.input.iter().any(|item| match item {
287 ResponseInputItem::Message {
288 content: Some(parts),
289 ..
290 } => parts
291 .iter()
292 .any(|p| matches!(p, ResponseInputContent::InputImage { .. })),
293 _ => false,
294 });
295
296 let request_builder = copilot_request_headers(
297 HttpRequest::builder().method(Method::POST).uri(&api_url),
298 &oauth_token,
299 Some(is_user_initiated),
300 Some(location),
301 )
302 .when(is_vision_request, |builder| {
303 builder.header("Copilot-Vision-Request", "true")
304 });
305
306 let is_streaming = request.stream;
307 let json = serde_json::to_string(&request)?;
308 let request = request_builder.body(AsyncBody::from(json))?;
309 let mut response = client.send(request).await?;
310
311 if !response.status().is_success() {
312 let mut body = String::new();
313 response.body_mut().read_to_string(&mut body).await?;
314 anyhow::bail!("Failed to connect to API: {} {}", response.status(), body);
315 }
316
317 if is_streaming {
318 let reader = BufReader::new(response.into_body());
319 Ok(reader
320 .lines()
321 .filter_map(|line| async move {
322 match line {
323 Ok(line) => {
324 let line = line.strip_prefix("data: ")?;
325 if line.starts_with("[DONE]") || line.is_empty() {
326 return None;
327 }
328
329 match serde_json::from_str::<StreamEvent>(line) {
330 Ok(event) => Some(Ok(event)),
331 Err(error) => {
332 log::error!(
333 "Failed to parse Copilot responses stream event: `{}`\nResponse: `{}`",
334 error,
335 line,
336 );
337 Some(Err(anyhow!(error)))
338 }
339 }
340 }
341 Err(error) => Some(Err(anyhow!(error))),
342 }
343 })
344 .boxed())
345 } else {
346 // Simulate streaming this makes the mapping of this function return more straight-forward to handle if all callers assume it streams.
347 // Removes the need of having a method to map StreamEvent and another to map Response to a LanguageCompletionEvent
348 let mut body = String::new();
349 response.body_mut().read_to_string(&mut body).await?;
350
351 match serde_json::from_str::<Response>(&body) {
352 Ok(response) => {
353 let events = vec![StreamEvent::Created {
354 response: response.clone(),
355 }];
356
357 let mut all_events = events;
358 for (output_index, item) in response.output.iter().enumerate() {
359 all_events.push(StreamEvent::OutputItemAdded {
360 output_index,
361 sequence_number: None,
362 item: item.clone(),
363 });
364
365 if let ResponseOutputItem::Message {
366 id,
367 content: Some(content),
368 ..
369 } = item
370 {
371 for part in content {
372 if let ResponseOutputContent::OutputText { text } = part {
373 all_events.push(StreamEvent::OutputTextDelta {
374 item_id: id.clone(),
375 output_index,
376 delta: text.clone(),
377 });
378 }
379 }
380 }
381
382 all_events.push(StreamEvent::OutputItemDone {
383 output_index,
384 sequence_number: None,
385 item: item.clone(),
386 });
387 }
388
389 let final_event = if response.error.is_some() {
390 StreamEvent::Failed { response }
391 } else if response.incomplete_details.is_some() {
392 StreamEvent::Incomplete { response }
393 } else {
394 StreamEvent::Completed { response }
395 };
396 all_events.push(final_event);
397
398 Ok(futures::stream::iter(all_events.into_iter().map(Ok)).boxed())
399 }
400 Err(error) => {
401 log::error!(
402 "Failed to parse Copilot non-streaming response: `{}`\nResponse: `{}`",
403 error,
404 body,
405 );
406 Err(anyhow!(error))
407 }
408 }
409 }
410}