1use std::path::PathBuf;
2use std::sync::Arc;
3use std::sync::OnceLock;
4
5use anyhow::{Result, anyhow};
6use chrono::DateTime;
7use collections::HashSet;
8use fs::Fs;
9use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
10use gpui::{App, AsyncApp, Global, prelude::*};
11use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
12use paths::home_dir;
13use serde::{Deserialize, Serialize};
14use settings::watch_config_dir;
15use strum::EnumIter;
16
17pub const COPILOT_CHAT_COMPLETION_URL: &str = "https://api.githubcopilot.com/chat/completions";
18pub const COPILOT_CHAT_AUTH_URL: &str = "https://api.github.com/copilot_internal/v2/token";
19
20#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
21#[serde(rename_all = "lowercase")]
22pub enum Role {
23 User,
24 Assistant,
25 System,
26}
27
28#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
29#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
30pub enum Model {
31 #[default]
32 #[serde(alias = "gpt-4o", rename = "gpt-4o-2024-05-13")]
33 Gpt4o,
34 #[serde(alias = "gpt-4", rename = "gpt-4")]
35 Gpt4,
36 #[serde(alias = "gpt-4.1", rename = "gpt-4.1")]
37 Gpt4_1,
38 #[serde(alias = "gpt-3.5-turbo", rename = "gpt-3.5-turbo")]
39 Gpt3_5Turbo,
40 #[serde(alias = "o1", rename = "o1")]
41 O1,
42 #[serde(alias = "o1-mini", rename = "o3-mini")]
43 O3Mini,
44 #[serde(alias = "claude-3-5-sonnet", rename = "claude-3.5-sonnet")]
45 Claude3_5Sonnet,
46 #[serde(alias = "claude-3-7-sonnet", rename = "claude-3.7-sonnet")]
47 Claude3_7Sonnet,
48 #[serde(
49 alias = "claude-3.7-sonnet-thought",
50 rename = "claude-3.7-sonnet-thought"
51 )]
52 Claude3_7SonnetThinking,
53 #[serde(alias = "gemini-2.0-flash", rename = "gemini-2.0-flash-001")]
54 Gemini20Flash,
55 #[serde(alias = "gemini-2.5-pro", rename = "gemini-2.5-pro")]
56 Gemini25Pro,
57}
58
59impl Model {
60 pub fn uses_streaming(&self) -> bool {
61 match self {
62 Self::Gpt4o
63 | Self::Gpt4
64 | Self::Gpt4_1
65 | Self::Gpt3_5Turbo
66 | Self::Claude3_5Sonnet
67 | Self::Claude3_7Sonnet
68 | Self::Claude3_7SonnetThinking => true,
69 Self::O3Mini | Self::O1 | Self::Gemini20Flash | Self::Gemini25Pro => false,
70 }
71 }
72
73 pub fn from_id(id: &str) -> Result<Self> {
74 match id {
75 "gpt-4o" => Ok(Self::Gpt4o),
76 "gpt-4" => Ok(Self::Gpt4),
77 "gpt-4.1" => Ok(Self::Gpt4_1),
78 "gpt-3.5-turbo" => Ok(Self::Gpt3_5Turbo),
79 "o1" => Ok(Self::O1),
80 "o3-mini" => Ok(Self::O3Mini),
81 "claude-3-5-sonnet" => Ok(Self::Claude3_5Sonnet),
82 "claude-3-7-sonnet" => Ok(Self::Claude3_7Sonnet),
83 "claude-3.7-sonnet-thought" => Ok(Self::Claude3_7SonnetThinking),
84 "gemini-2.0-flash-001" => Ok(Self::Gemini20Flash),
85 "gemini-2.5-pro" => Ok(Self::Gemini25Pro),
86 _ => Err(anyhow!("Invalid model id: {}", id)),
87 }
88 }
89
90 pub fn id(&self) -> &'static str {
91 match self {
92 Self::Gpt3_5Turbo => "gpt-3.5-turbo",
93 Self::Gpt4 => "gpt-4",
94 Self::Gpt4_1 => "gpt-4.1",
95 Self::Gpt4o => "gpt-4o",
96 Self::O3Mini => "o3-mini",
97 Self::O1 => "o1",
98 Self::Claude3_5Sonnet => "claude-3-5-sonnet",
99 Self::Claude3_7Sonnet => "claude-3-7-sonnet",
100 Self::Claude3_7SonnetThinking => "claude-3.7-sonnet-thought",
101 Self::Gemini20Flash => "gemini-2.0-flash-001",
102 Self::Gemini25Pro => "gemini-2.5-pro",
103 }
104 }
105
106 pub fn display_name(&self) -> &'static str {
107 match self {
108 Self::Gpt3_5Turbo => "GPT-3.5",
109 Self::Gpt4 => "GPT-4",
110 Self::Gpt4_1 => "GPT-4.1",
111 Self::Gpt4o => "GPT-4o",
112 Self::O3Mini => "o3-mini",
113 Self::O1 => "o1",
114 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
115 Self::Claude3_7Sonnet => "Claude 3.7 Sonnet",
116 Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking",
117 Self::Gemini20Flash => "Gemini 2.0 Flash",
118 Self::Gemini25Pro => "Gemini 2.5 Pro",
119 }
120 }
121
122 pub fn max_token_count(&self) -> usize {
123 match self {
124 Self::Gpt4o => 64_000,
125 Self::Gpt4 => 32_768,
126 Self::Gpt4_1 => 1_047_576,
127 Self::Gpt3_5Turbo => 12_288,
128 Self::O3Mini => 64_000,
129 Self::O1 => 20_000,
130 Self::Claude3_5Sonnet => 200_000,
131 Self::Claude3_7Sonnet => 90_000,
132 Self::Claude3_7SonnetThinking => 90_000,
133 Self::Gemini20Flash => 128_000,
134 Self::Gemini25Pro => 128_000,
135 }
136 }
137}
138
139#[derive(Serialize, Deserialize)]
140pub struct Request {
141 pub intent: bool,
142 pub n: usize,
143 pub stream: bool,
144 pub temperature: f32,
145 pub model: Model,
146 pub messages: Vec<ChatMessage>,
147 #[serde(default, skip_serializing_if = "Vec::is_empty")]
148 pub tools: Vec<Tool>,
149 #[serde(default, skip_serializing_if = "Option::is_none")]
150 pub tool_choice: Option<ToolChoice>,
151}
152
153#[derive(Serialize, Deserialize)]
154pub struct Function {
155 pub name: String,
156 pub description: String,
157 pub parameters: serde_json::Value,
158}
159
160#[derive(Serialize, Deserialize)]
161#[serde(tag = "type", rename_all = "snake_case")]
162pub enum Tool {
163 Function { function: Function },
164}
165
166#[derive(Serialize, Deserialize)]
167#[serde(tag = "type", rename_all = "lowercase")]
168pub enum ToolChoice {
169 Auto,
170 Any,
171 Tool { name: String },
172}
173
174#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
175#[serde(tag = "role", rename_all = "lowercase")]
176pub enum ChatMessage {
177 Assistant {
178 content: Option<String>,
179 #[serde(default, skip_serializing_if = "Vec::is_empty")]
180 tool_calls: Vec<ToolCall>,
181 },
182 User {
183 content: String,
184 },
185 System {
186 content: String,
187 },
188 Tool {
189 content: String,
190 tool_call_id: String,
191 },
192}
193
194#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
195pub struct ToolCall {
196 pub id: String,
197 #[serde(flatten)]
198 pub content: ToolCallContent,
199}
200
201#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
202#[serde(tag = "type", rename_all = "lowercase")]
203pub enum ToolCallContent {
204 Function { function: FunctionContent },
205}
206
207#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
208pub struct FunctionContent {
209 pub name: String,
210 pub arguments: String,
211}
212
213#[derive(Deserialize, Debug)]
214#[serde(tag = "type", rename_all = "snake_case")]
215pub struct ResponseEvent {
216 pub choices: Vec<ResponseChoice>,
217 pub created: u64,
218 pub id: String,
219}
220
221#[derive(Debug, Deserialize)]
222pub struct ResponseChoice {
223 pub index: usize,
224 pub finish_reason: Option<String>,
225 pub delta: Option<ResponseDelta>,
226 pub message: Option<ResponseDelta>,
227}
228
229#[derive(Debug, Deserialize)]
230pub struct ResponseDelta {
231 pub content: Option<String>,
232 pub role: Option<Role>,
233 #[serde(default)]
234 pub tool_calls: Vec<ToolCallChunk>,
235}
236
237#[derive(Deserialize, Debug, Eq, PartialEq)]
238pub struct ToolCallChunk {
239 pub index: usize,
240 pub id: Option<String>,
241 pub function: Option<FunctionChunk>,
242}
243
244#[derive(Deserialize, Debug, Eq, PartialEq)]
245pub struct FunctionChunk {
246 pub name: Option<String>,
247 pub arguments: Option<String>,
248}
249
250#[derive(Deserialize)]
251struct ApiTokenResponse {
252 token: String,
253 expires_at: i64,
254}
255
256#[derive(Clone)]
257struct ApiToken {
258 api_key: String,
259 expires_at: DateTime<chrono::Utc>,
260}
261
262impl ApiToken {
263 pub fn remaining_seconds(&self) -> i64 {
264 self.expires_at
265 .timestamp()
266 .saturating_sub(chrono::Utc::now().timestamp())
267 }
268}
269
270impl TryFrom<ApiTokenResponse> for ApiToken {
271 type Error = anyhow::Error;
272
273 fn try_from(response: ApiTokenResponse) -> Result<Self, Self::Error> {
274 let expires_at = DateTime::from_timestamp(response.expires_at, 0)
275 .ok_or_else(|| anyhow!("invalid expires_at"))?;
276
277 Ok(Self {
278 api_key: response.token,
279 expires_at,
280 })
281 }
282}
283
284struct GlobalCopilotChat(gpui::Entity<CopilotChat>);
285
286impl Global for GlobalCopilotChat {}
287
288pub struct CopilotChat {
289 oauth_token: Option<String>,
290 api_token: Option<ApiToken>,
291 client: Arc<dyn HttpClient>,
292}
293
294pub fn init(fs: Arc<dyn Fs>, client: Arc<dyn HttpClient>, cx: &mut App) {
295 let copilot_chat = cx.new(|cx| CopilotChat::new(fs, client, cx));
296 cx.set_global(GlobalCopilotChat(copilot_chat));
297}
298
299pub fn copilot_chat_config_dir() -> &'static PathBuf {
300 static COPILOT_CHAT_CONFIG_DIR: OnceLock<PathBuf> = OnceLock::new();
301
302 COPILOT_CHAT_CONFIG_DIR.get_or_init(|| {
303 if cfg!(target_os = "windows") {
304 home_dir().join("AppData").join("Local")
305 } else {
306 home_dir().join(".config")
307 }
308 .join("github-copilot")
309 })
310}
311
312fn copilot_chat_config_paths() -> [PathBuf; 2] {
313 let base_dir = copilot_chat_config_dir();
314 [base_dir.join("hosts.json"), base_dir.join("apps.json")]
315}
316
317impl CopilotChat {
318 pub fn global(cx: &App) -> Option<gpui::Entity<Self>> {
319 cx.try_global::<GlobalCopilotChat>()
320 .map(|model| model.0.clone())
321 }
322
323 pub fn new(fs: Arc<dyn Fs>, client: Arc<dyn HttpClient>, cx: &App) -> Self {
324 let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
325 let dir_path = copilot_chat_config_dir();
326
327 cx.spawn(async move |cx| {
328 let mut parent_watch_rx = watch_config_dir(
329 cx.background_executor(),
330 fs.clone(),
331 dir_path.clone(),
332 config_paths,
333 );
334 while let Some(contents) = parent_watch_rx.next().await {
335 let oauth_token = extract_oauth_token(contents);
336 cx.update(|cx| {
337 if let Some(this) = Self::global(cx).as_ref() {
338 this.update(cx, |this, cx| {
339 this.oauth_token = oauth_token;
340 cx.notify();
341 });
342 }
343 })?;
344 }
345 anyhow::Ok(())
346 })
347 .detach_and_log_err(cx);
348
349 Self {
350 oauth_token: None,
351 api_token: None,
352 client,
353 }
354 }
355
356 pub fn is_authenticated(&self) -> bool {
357 self.oauth_token.is_some()
358 }
359
360 pub async fn stream_completion(
361 request: Request,
362 mut cx: AsyncApp,
363 ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
364 let Some(this) = cx.update(|cx| Self::global(cx)).ok().flatten() else {
365 return Err(anyhow!("Copilot chat is not enabled"));
366 };
367
368 let (oauth_token, api_token, client) = this.read_with(&cx, |this, _| {
369 (
370 this.oauth_token.clone(),
371 this.api_token.clone(),
372 this.client.clone(),
373 )
374 })?;
375
376 let oauth_token = oauth_token.ok_or_else(|| anyhow!("No OAuth token available"))?;
377
378 let token = match api_token {
379 Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token.clone(),
380 _ => {
381 let token = request_api_token(&oauth_token, client.clone()).await?;
382 this.update(&mut cx, |this, cx| {
383 this.api_token = Some(token.clone());
384 cx.notify();
385 })?;
386 token
387 }
388 };
389
390 stream_completion(client.clone(), token.api_key, request).await
391 }
392}
393
394async fn request_api_token(oauth_token: &str, client: Arc<dyn HttpClient>) -> Result<ApiToken> {
395 let request_builder = HttpRequest::builder()
396 .method(Method::GET)
397 .uri(COPILOT_CHAT_AUTH_URL)
398 .header("Authorization", format!("token {}", oauth_token))
399 .header("Accept", "application/json");
400
401 let request = request_builder.body(AsyncBody::empty())?;
402
403 let mut response = client.send(request).await?;
404
405 if response.status().is_success() {
406 let mut body = Vec::new();
407 response.body_mut().read_to_end(&mut body).await?;
408
409 let body_str = std::str::from_utf8(&body)?;
410
411 let parsed: ApiTokenResponse = serde_json::from_str(body_str)?;
412 ApiToken::try_from(parsed)
413 } else {
414 let mut body = Vec::new();
415 response.body_mut().read_to_end(&mut body).await?;
416
417 let body_str = std::str::from_utf8(&body)?;
418
419 Err(anyhow!("Failed to request API token: {}", body_str))
420 }
421}
422
423fn extract_oauth_token(contents: String) -> Option<String> {
424 serde_json::from_str::<serde_json::Value>(&contents)
425 .map(|v| {
426 v.as_object().and_then(|obj| {
427 obj.iter().find_map(|(key, value)| {
428 if key.starts_with("github.com") {
429 value["oauth_token"].as_str().map(|v| v.to_string())
430 } else {
431 None
432 }
433 })
434 })
435 })
436 .ok()
437 .flatten()
438}
439
440async fn stream_completion(
441 client: Arc<dyn HttpClient>,
442 api_key: String,
443 request: Request,
444) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
445 let request_builder = HttpRequest::builder()
446 .method(Method::POST)
447 .uri(COPILOT_CHAT_COMPLETION_URL)
448 .header(
449 "Editor-Version",
450 format!(
451 "Zed/{}",
452 option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
453 ),
454 )
455 .header("Authorization", format!("Bearer {}", api_key))
456 .header("Content-Type", "application/json")
457 .header("Copilot-Integration-Id", "vscode-chat");
458
459 let is_streaming = request.stream;
460
461 let json = serde_json::to_string(&request)?;
462 let request = request_builder.body(AsyncBody::from(json))?;
463 let mut response = client.send(request).await?;
464
465 if !response.status().is_success() {
466 let mut body = Vec::new();
467 response.body_mut().read_to_end(&mut body).await?;
468 let body_str = std::str::from_utf8(&body)?;
469 return Err(anyhow!(
470 "Failed to connect to API: {} {}",
471 response.status(),
472 body_str
473 ));
474 }
475
476 if is_streaming {
477 let reader = BufReader::new(response.into_body());
478 Ok(reader
479 .lines()
480 .filter_map(|line| async move {
481 match line {
482 Ok(line) => {
483 let line = line.strip_prefix("data: ")?;
484 if line.starts_with("[DONE]") {
485 return None;
486 }
487
488 match serde_json::from_str::<ResponseEvent>(line) {
489 Ok(response) => {
490 if response.choices.is_empty() {
491 None
492 } else {
493 Some(Ok(response))
494 }
495 }
496 Err(error) => Some(Err(anyhow!(error))),
497 }
498 }
499 Err(error) => Some(Err(anyhow!(error))),
500 }
501 })
502 .boxed())
503 } else {
504 let mut body = Vec::new();
505 response.body_mut().read_to_end(&mut body).await?;
506 let body_str = std::str::from_utf8(&body)?;
507 let response: ResponseEvent = serde_json::from_str(body_str)?;
508
509 Ok(futures::stream::once(async move { Ok(response) }).boxed())
510 }
511}