1use std::path::PathBuf;
2use std::sync::Arc;
3use std::sync::OnceLock;
4
5use anyhow::{Result, anyhow};
6use chrono::DateTime;
7use collections::HashSet;
8use fs::Fs;
9use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
10use gpui::{App, AsyncApp, Global, prelude::*};
11use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
12use paths::home_dir;
13use serde::{Deserialize, Serialize};
14use settings::watch_config_dir;
15use strum::EnumIter;
16
17pub const COPILOT_CHAT_COMPLETION_URL: &str = "https://api.githubcopilot.com/chat/completions";
18pub const COPILOT_CHAT_AUTH_URL: &str = "https://api.github.com/copilot_internal/v2/token";
19
20#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
21#[serde(rename_all = "lowercase")]
22pub enum Role {
23 User,
24 Assistant,
25 System,
26}
27
28#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
29#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
30pub enum Model {
31 #[default]
32 #[serde(alias = "gpt-4o", rename = "gpt-4o-2024-05-13")]
33 Gpt4o,
34 #[serde(alias = "gpt-4", rename = "gpt-4")]
35 Gpt4,
36 #[serde(alias = "gpt-4.1", rename = "gpt-4.1")]
37 Gpt4_1,
38 #[serde(alias = "gpt-3.5-turbo", rename = "gpt-3.5-turbo")]
39 Gpt3_5Turbo,
40 #[serde(alias = "o1", rename = "o1")]
41 O1,
42 #[serde(alias = "o1-mini", rename = "o3-mini")]
43 O3Mini,
44 #[serde(alias = "o3", rename = "o3")]
45 O3,
46 #[serde(alias = "o4-mini", rename = "o4-mini")]
47 O4Mini,
48 #[serde(alias = "claude-3-5-sonnet", rename = "claude-3.5-sonnet")]
49 Claude3_5Sonnet,
50 #[serde(alias = "claude-3-7-sonnet", rename = "claude-3.7-sonnet")]
51 Claude3_7Sonnet,
52 #[serde(
53 alias = "claude-3.7-sonnet-thought",
54 rename = "claude-3.7-sonnet-thought"
55 )]
56 Claude3_7SonnetThinking,
57 #[serde(alias = "gemini-2.0-flash", rename = "gemini-2.0-flash-001")]
58 Gemini20Flash,
59 #[serde(alias = "gemini-2.5-pro", rename = "gemini-2.5-pro")]
60 Gemini25Pro,
61}
62
63impl Model {
64 pub fn default_fast() -> Self {
65 Self::Claude3_7Sonnet
66 }
67
68 pub fn uses_streaming(&self) -> bool {
69 match self {
70 Self::Gpt4o
71 | Self::Gpt4
72 | Self::Gpt4_1
73 | Self::Gpt3_5Turbo
74 | Self::O3
75 | Self::O4Mini
76 | Self::Claude3_5Sonnet
77 | Self::Claude3_7Sonnet
78 | Self::Claude3_7SonnetThinking => true,
79 Self::O3Mini | Self::O1 | Self::Gemini20Flash | Self::Gemini25Pro => false,
80 }
81 }
82
83 pub fn from_id(id: &str) -> Result<Self> {
84 match id {
85 "gpt-4o" => Ok(Self::Gpt4o),
86 "gpt-4" => Ok(Self::Gpt4),
87 "gpt-4.1" => Ok(Self::Gpt4_1),
88 "gpt-3.5-turbo" => Ok(Self::Gpt3_5Turbo),
89 "o1" => Ok(Self::O1),
90 "o3-mini" => Ok(Self::O3Mini),
91 "o3" => Ok(Self::O3),
92 "o4-mini" => Ok(Self::O4Mini),
93 "claude-3-5-sonnet" => Ok(Self::Claude3_5Sonnet),
94 "claude-3-7-sonnet" => Ok(Self::Claude3_7Sonnet),
95 "claude-3.7-sonnet-thought" => Ok(Self::Claude3_7SonnetThinking),
96 "gemini-2.0-flash-001" => Ok(Self::Gemini20Flash),
97 "gemini-2.5-pro" => Ok(Self::Gemini25Pro),
98 _ => Err(anyhow!("Invalid model id: {}", id)),
99 }
100 }
101
102 pub fn id(&self) -> &'static str {
103 match self {
104 Self::Gpt3_5Turbo => "gpt-3.5-turbo",
105 Self::Gpt4 => "gpt-4",
106 Self::Gpt4_1 => "gpt-4.1",
107 Self::Gpt4o => "gpt-4o",
108 Self::O3Mini => "o3-mini",
109 Self::O1 => "o1",
110 Self::O3 => "o3",
111 Self::O4Mini => "o4-mini",
112 Self::Claude3_5Sonnet => "claude-3-5-sonnet",
113 Self::Claude3_7Sonnet => "claude-3-7-sonnet",
114 Self::Claude3_7SonnetThinking => "claude-3.7-sonnet-thought",
115 Self::Gemini20Flash => "gemini-2.0-flash-001",
116 Self::Gemini25Pro => "gemini-2.5-pro",
117 }
118 }
119
120 pub fn display_name(&self) -> &'static str {
121 match self {
122 Self::Gpt3_5Turbo => "GPT-3.5",
123 Self::Gpt4 => "GPT-4",
124 Self::Gpt4_1 => "GPT-4.1",
125 Self::Gpt4o => "GPT-4o",
126 Self::O3Mini => "o3-mini",
127 Self::O1 => "o1",
128 Self::O3 => "o3",
129 Self::O4Mini => "o4-mini",
130 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
131 Self::Claude3_7Sonnet => "Claude 3.7 Sonnet",
132 Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking",
133 Self::Gemini20Flash => "Gemini 2.0 Flash",
134 Self::Gemini25Pro => "Gemini 2.5 Pro",
135 }
136 }
137
138 pub fn max_token_count(&self) -> usize {
139 match self {
140 Self::Gpt4o => 64_000,
141 Self::Gpt4 => 32_768,
142 Self::Gpt4_1 => 128_000,
143 Self::Gpt3_5Turbo => 12_288,
144 Self::O3Mini => 64_000,
145 Self::O1 => 20_000,
146 Self::O3 => 128_000,
147 Self::O4Mini => 128_000,
148 Self::Claude3_5Sonnet => 200_000,
149 Self::Claude3_7Sonnet => 90_000,
150 Self::Claude3_7SonnetThinking => 90_000,
151 Self::Gemini20Flash => 128_000,
152 Self::Gemini25Pro => 128_000,
153 }
154 }
155}
156
157#[derive(Serialize, Deserialize)]
158pub struct Request {
159 pub intent: bool,
160 pub n: usize,
161 pub stream: bool,
162 pub temperature: f32,
163 pub model: Model,
164 pub messages: Vec<ChatMessage>,
165 #[serde(default, skip_serializing_if = "Vec::is_empty")]
166 pub tools: Vec<Tool>,
167 #[serde(default, skip_serializing_if = "Option::is_none")]
168 pub tool_choice: Option<ToolChoice>,
169}
170
171#[derive(Serialize, Deserialize)]
172pub struct Function {
173 pub name: String,
174 pub description: String,
175 pub parameters: serde_json::Value,
176}
177
178#[derive(Serialize, Deserialize)]
179#[serde(tag = "type", rename_all = "snake_case")]
180pub enum Tool {
181 Function { function: Function },
182}
183
184#[derive(Serialize, Deserialize)]
185#[serde(rename_all = "lowercase")]
186pub enum ToolChoice {
187 Auto,
188 Any,
189 None,
190}
191
192#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
193#[serde(tag = "role", rename_all = "lowercase")]
194pub enum ChatMessage {
195 Assistant {
196 content: Option<String>,
197 #[serde(default, skip_serializing_if = "Vec::is_empty")]
198 tool_calls: Vec<ToolCall>,
199 },
200 User {
201 content: String,
202 },
203 System {
204 content: String,
205 },
206 Tool {
207 content: String,
208 tool_call_id: String,
209 },
210}
211
212#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
213pub struct ToolCall {
214 pub id: String,
215 #[serde(flatten)]
216 pub content: ToolCallContent,
217}
218
219#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
220#[serde(tag = "type", rename_all = "lowercase")]
221pub enum ToolCallContent {
222 Function { function: FunctionContent },
223}
224
225#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
226pub struct FunctionContent {
227 pub name: String,
228 pub arguments: String,
229}
230
231#[derive(Deserialize, Debug)]
232#[serde(tag = "type", rename_all = "snake_case")]
233pub struct ResponseEvent {
234 pub choices: Vec<ResponseChoice>,
235 pub created: u64,
236 pub id: String,
237}
238
239#[derive(Debug, Deserialize)]
240pub struct ResponseChoice {
241 pub index: usize,
242 pub finish_reason: Option<String>,
243 pub delta: Option<ResponseDelta>,
244 pub message: Option<ResponseDelta>,
245}
246
247#[derive(Debug, Deserialize)]
248pub struct ResponseDelta {
249 pub content: Option<String>,
250 pub role: Option<Role>,
251 #[serde(default)]
252 pub tool_calls: Vec<ToolCallChunk>,
253}
254
255#[derive(Deserialize, Debug, Eq, PartialEq)]
256pub struct ToolCallChunk {
257 pub index: usize,
258 pub id: Option<String>,
259 pub function: Option<FunctionChunk>,
260}
261
262#[derive(Deserialize, Debug, Eq, PartialEq)]
263pub struct FunctionChunk {
264 pub name: Option<String>,
265 pub arguments: Option<String>,
266}
267
268#[derive(Deserialize)]
269struct ApiTokenResponse {
270 token: String,
271 expires_at: i64,
272}
273
274#[derive(Clone)]
275struct ApiToken {
276 api_key: String,
277 expires_at: DateTime<chrono::Utc>,
278}
279
280impl ApiToken {
281 pub fn remaining_seconds(&self) -> i64 {
282 self.expires_at
283 .timestamp()
284 .saturating_sub(chrono::Utc::now().timestamp())
285 }
286}
287
288impl TryFrom<ApiTokenResponse> for ApiToken {
289 type Error = anyhow::Error;
290
291 fn try_from(response: ApiTokenResponse) -> Result<Self, Self::Error> {
292 let expires_at = DateTime::from_timestamp(response.expires_at, 0)
293 .ok_or_else(|| anyhow!("invalid expires_at"))?;
294
295 Ok(Self {
296 api_key: response.token,
297 expires_at,
298 })
299 }
300}
301
302struct GlobalCopilotChat(gpui::Entity<CopilotChat>);
303
304impl Global for GlobalCopilotChat {}
305
306pub struct CopilotChat {
307 oauth_token: Option<String>,
308 api_token: Option<ApiToken>,
309 client: Arc<dyn HttpClient>,
310}
311
312pub fn init(fs: Arc<dyn Fs>, client: Arc<dyn HttpClient>, cx: &mut App) {
313 let copilot_chat = cx.new(|cx| CopilotChat::new(fs, client, cx));
314 cx.set_global(GlobalCopilotChat(copilot_chat));
315}
316
317pub fn copilot_chat_config_dir() -> &'static PathBuf {
318 static COPILOT_CHAT_CONFIG_DIR: OnceLock<PathBuf> = OnceLock::new();
319
320 COPILOT_CHAT_CONFIG_DIR.get_or_init(|| {
321 if cfg!(target_os = "windows") {
322 home_dir().join("AppData").join("Local")
323 } else {
324 home_dir().join(".config")
325 }
326 .join("github-copilot")
327 })
328}
329
330fn copilot_chat_config_paths() -> [PathBuf; 2] {
331 let base_dir = copilot_chat_config_dir();
332 [base_dir.join("hosts.json"), base_dir.join("apps.json")]
333}
334
335impl CopilotChat {
336 pub fn global(cx: &App) -> Option<gpui::Entity<Self>> {
337 cx.try_global::<GlobalCopilotChat>()
338 .map(|model| model.0.clone())
339 }
340
341 pub fn new(fs: Arc<dyn Fs>, client: Arc<dyn HttpClient>, cx: &App) -> Self {
342 let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
343 let dir_path = copilot_chat_config_dir();
344
345 cx.spawn(async move |cx| {
346 let mut parent_watch_rx = watch_config_dir(
347 cx.background_executor(),
348 fs.clone(),
349 dir_path.clone(),
350 config_paths,
351 );
352 while let Some(contents) = parent_watch_rx.next().await {
353 let oauth_token = extract_oauth_token(contents);
354 cx.update(|cx| {
355 if let Some(this) = Self::global(cx).as_ref() {
356 this.update(cx, |this, cx| {
357 this.oauth_token = oauth_token;
358 cx.notify();
359 });
360 }
361 })?;
362 }
363 anyhow::Ok(())
364 })
365 .detach_and_log_err(cx);
366
367 Self {
368 oauth_token: None,
369 api_token: None,
370 client,
371 }
372 }
373
374 pub fn is_authenticated(&self) -> bool {
375 self.oauth_token.is_some()
376 }
377
378 pub async fn stream_completion(
379 request: Request,
380 mut cx: AsyncApp,
381 ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
382 let Some(this) = cx.update(|cx| Self::global(cx)).ok().flatten() else {
383 return Err(anyhow!("Copilot chat is not enabled"));
384 };
385
386 let (oauth_token, api_token, client) = this.read_with(&cx, |this, _| {
387 (
388 this.oauth_token.clone(),
389 this.api_token.clone(),
390 this.client.clone(),
391 )
392 })?;
393
394 let oauth_token = oauth_token.ok_or_else(|| anyhow!("No OAuth token available"))?;
395
396 let token = match api_token {
397 Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token.clone(),
398 _ => {
399 let token = request_api_token(&oauth_token, client.clone()).await?;
400 this.update(&mut cx, |this, cx| {
401 this.api_token = Some(token.clone());
402 cx.notify();
403 })?;
404 token
405 }
406 };
407
408 stream_completion(client.clone(), token.api_key, request).await
409 }
410}
411
412async fn request_api_token(oauth_token: &str, client: Arc<dyn HttpClient>) -> Result<ApiToken> {
413 let request_builder = HttpRequest::builder()
414 .method(Method::GET)
415 .uri(COPILOT_CHAT_AUTH_URL)
416 .header("Authorization", format!("token {}", oauth_token))
417 .header("Accept", "application/json");
418
419 let request = request_builder.body(AsyncBody::empty())?;
420
421 let mut response = client.send(request).await?;
422
423 if response.status().is_success() {
424 let mut body = Vec::new();
425 response.body_mut().read_to_end(&mut body).await?;
426
427 let body_str = std::str::from_utf8(&body)?;
428
429 let parsed: ApiTokenResponse = serde_json::from_str(body_str)?;
430 ApiToken::try_from(parsed)
431 } else {
432 let mut body = Vec::new();
433 response.body_mut().read_to_end(&mut body).await?;
434
435 let body_str = std::str::from_utf8(&body)?;
436
437 Err(anyhow!("Failed to request API token: {}", body_str))
438 }
439}
440
441fn extract_oauth_token(contents: String) -> Option<String> {
442 serde_json::from_str::<serde_json::Value>(&contents)
443 .map(|v| {
444 v.as_object().and_then(|obj| {
445 obj.iter().find_map(|(key, value)| {
446 if key.starts_with("github.com") {
447 value["oauth_token"].as_str().map(|v| v.to_string())
448 } else {
449 None
450 }
451 })
452 })
453 })
454 .ok()
455 .flatten()
456}
457
458async fn stream_completion(
459 client: Arc<dyn HttpClient>,
460 api_key: String,
461 request: Request,
462) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
463 let request_builder = HttpRequest::builder()
464 .method(Method::POST)
465 .uri(COPILOT_CHAT_COMPLETION_URL)
466 .header(
467 "Editor-Version",
468 format!(
469 "Zed/{}",
470 option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
471 ),
472 )
473 .header("Authorization", format!("Bearer {}", api_key))
474 .header("Content-Type", "application/json")
475 .header("Copilot-Integration-Id", "vscode-chat");
476
477 let is_streaming = request.stream;
478
479 let json = serde_json::to_string(&request)?;
480 let request = request_builder.body(AsyncBody::from(json))?;
481 let mut response = client.send(request).await?;
482
483 if !response.status().is_success() {
484 let mut body = Vec::new();
485 response.body_mut().read_to_end(&mut body).await?;
486 let body_str = std::str::from_utf8(&body)?;
487 return Err(anyhow!(
488 "Failed to connect to API: {} {}",
489 response.status(),
490 body_str
491 ));
492 }
493
494 if is_streaming {
495 let reader = BufReader::new(response.into_body());
496 Ok(reader
497 .lines()
498 .filter_map(|line| async move {
499 match line {
500 Ok(line) => {
501 let line = line.strip_prefix("data: ")?;
502 if line.starts_with("[DONE]") {
503 return None;
504 }
505
506 match serde_json::from_str::<ResponseEvent>(line) {
507 Ok(response) => {
508 if response.choices.is_empty() {
509 None
510 } else {
511 Some(Ok(response))
512 }
513 }
514 Err(error) => Some(Err(anyhow!(error))),
515 }
516 }
517 Err(error) => Some(Err(anyhow!(error))),
518 }
519 })
520 .boxed())
521 } else {
522 let mut body = Vec::new();
523 response.body_mut().read_to_end(&mut body).await?;
524 let body_str = std::str::from_utf8(&body)?;
525 let response: ResponseEvent = serde_json::from_str(body_str)?;
526
527 Ok(futures::stream::once(async move { Ok(response) }).boxed())
528 }
529}