1use std::path::PathBuf;
2use std::sync::Arc;
3use std::sync::OnceLock;
4
5use anyhow::{Result, anyhow};
6use chrono::DateTime;
7use collections::HashSet;
8use fs::Fs;
9use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
10use gpui::{App, AsyncApp, Global, prelude::*};
11use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
12use paths::home_dir;
13use serde::{Deserialize, Serialize};
14use settings::watch_config_dir;
15use strum::EnumIter;
16
17pub const COPILOT_CHAT_COMPLETION_URL: &str = "https://api.githubcopilot.com/chat/completions";
18pub const COPILOT_CHAT_AUTH_URL: &str = "https://api.github.com/copilot_internal/v2/token";
19
20#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
21#[serde(rename_all = "lowercase")]
22pub enum Role {
23 User,
24 Assistant,
25 System,
26}
27
28#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
29#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
30pub enum Model {
31 #[default]
32 #[serde(alias = "gpt-4o", rename = "gpt-4o-2024-05-13")]
33 Gpt4o,
34 #[serde(alias = "gpt-4", rename = "gpt-4")]
35 Gpt4,
36 #[serde(alias = "gpt-4.1", rename = "gpt-4.1")]
37 Gpt4_1,
38 #[serde(alias = "gpt-3.5-turbo", rename = "gpt-3.5-turbo")]
39 Gpt3_5Turbo,
40 #[serde(alias = "o1", rename = "o1")]
41 O1,
42 #[serde(alias = "o1-mini", rename = "o3-mini")]
43 O3Mini,
44 #[serde(alias = "o3", rename = "o3")]
45 O3,
46 #[serde(alias = "o4-mini", rename = "o4-mini")]
47 O4Mini,
48 #[serde(alias = "claude-3-5-sonnet", rename = "claude-3.5-sonnet")]
49 Claude3_5Sonnet,
50 #[serde(alias = "claude-3-7-sonnet", rename = "claude-3.7-sonnet")]
51 Claude3_7Sonnet,
52 #[serde(
53 alias = "claude-3.7-sonnet-thought",
54 rename = "claude-3.7-sonnet-thought"
55 )]
56 Claude3_7SonnetThinking,
57 #[serde(alias = "gemini-2.0-flash", rename = "gemini-2.0-flash-001")]
58 Gemini20Flash,
59 #[serde(alias = "gemini-2.5-pro", rename = "gemini-2.5-pro")]
60 Gemini25Pro,
61}
62
63impl Model {
64 pub fn uses_streaming(&self) -> bool {
65 match self {
66 Self::Gpt4o
67 | Self::Gpt4
68 | Self::Gpt4_1
69 | Self::Gpt3_5Turbo
70 | Self::O3
71 | Self::O4Mini
72 | Self::Claude3_5Sonnet
73 | Self::Claude3_7Sonnet
74 | Self::Claude3_7SonnetThinking => true,
75 Self::O3Mini | Self::O1 | Self::Gemini20Flash | Self::Gemini25Pro => false,
76 }
77 }
78
79 pub fn from_id(id: &str) -> Result<Self> {
80 match id {
81 "gpt-4o" => Ok(Self::Gpt4o),
82 "gpt-4" => Ok(Self::Gpt4),
83 "gpt-4.1" => Ok(Self::Gpt4_1),
84 "gpt-3.5-turbo" => Ok(Self::Gpt3_5Turbo),
85 "o1" => Ok(Self::O1),
86 "o3-mini" => Ok(Self::O3Mini),
87 "o3" => Ok(Self::O3),
88 "o4-mini" => Ok(Self::O4Mini),
89 "claude-3-5-sonnet" => Ok(Self::Claude3_5Sonnet),
90 "claude-3-7-sonnet" => Ok(Self::Claude3_7Sonnet),
91 "claude-3.7-sonnet-thought" => Ok(Self::Claude3_7SonnetThinking),
92 "gemini-2.0-flash-001" => Ok(Self::Gemini20Flash),
93 "gemini-2.5-pro" => Ok(Self::Gemini25Pro),
94 _ => Err(anyhow!("Invalid model id: {}", id)),
95 }
96 }
97
98 pub fn id(&self) -> &'static str {
99 match self {
100 Self::Gpt3_5Turbo => "gpt-3.5-turbo",
101 Self::Gpt4 => "gpt-4",
102 Self::Gpt4_1 => "gpt-4.1",
103 Self::Gpt4o => "gpt-4o",
104 Self::O3Mini => "o3-mini",
105 Self::O1 => "o1",
106 Self::O3 => "o3",
107 Self::O4Mini => "o4-mini",
108 Self::Claude3_5Sonnet => "claude-3-5-sonnet",
109 Self::Claude3_7Sonnet => "claude-3-7-sonnet",
110 Self::Claude3_7SonnetThinking => "claude-3.7-sonnet-thought",
111 Self::Gemini20Flash => "gemini-2.0-flash-001",
112 Self::Gemini25Pro => "gemini-2.5-pro",
113 }
114 }
115
116 pub fn display_name(&self) -> &'static str {
117 match self {
118 Self::Gpt3_5Turbo => "GPT-3.5",
119 Self::Gpt4 => "GPT-4",
120 Self::Gpt4_1 => "GPT-4.1",
121 Self::Gpt4o => "GPT-4o",
122 Self::O3Mini => "o3-mini",
123 Self::O1 => "o1",
124 Self::O3 => "o3",
125 Self::O4Mini => "o4-mini",
126 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
127 Self::Claude3_7Sonnet => "Claude 3.7 Sonnet",
128 Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking",
129 Self::Gemini20Flash => "Gemini 2.0 Flash",
130 Self::Gemini25Pro => "Gemini 2.5 Pro",
131 }
132 }
133
134 pub fn max_token_count(&self) -> usize {
135 match self {
136 Self::Gpt4o => 64_000,
137 Self::Gpt4 => 32_768,
138 Self::Gpt4_1 => 128_000,
139 Self::Gpt3_5Turbo => 12_288,
140 Self::O3Mini => 64_000,
141 Self::O1 => 20_000,
142 Self::O3 => 128_000,
143 Self::O4Mini => 128_000,
144 Self::Claude3_5Sonnet => 200_000,
145 Self::Claude3_7Sonnet => 90_000,
146 Self::Claude3_7SonnetThinking => 90_000,
147 Self::Gemini20Flash => 128_000,
148 Self::Gemini25Pro => 128_000,
149 }
150 }
151}
152
153#[derive(Serialize, Deserialize)]
154pub struct Request {
155 pub intent: bool,
156 pub n: usize,
157 pub stream: bool,
158 pub temperature: f32,
159 pub model: Model,
160 pub messages: Vec<ChatMessage>,
161 #[serde(default, skip_serializing_if = "Vec::is_empty")]
162 pub tools: Vec<Tool>,
163 #[serde(default, skip_serializing_if = "Option::is_none")]
164 pub tool_choice: Option<ToolChoice>,
165}
166
167#[derive(Serialize, Deserialize)]
168pub struct Function {
169 pub name: String,
170 pub description: String,
171 pub parameters: serde_json::Value,
172}
173
174#[derive(Serialize, Deserialize)]
175#[serde(tag = "type", rename_all = "snake_case")]
176pub enum Tool {
177 Function { function: Function },
178}
179
180#[derive(Serialize, Deserialize)]
181#[serde(tag = "type", rename_all = "lowercase")]
182pub enum ToolChoice {
183 Auto,
184 Any,
185 Tool { name: String },
186}
187
188#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
189#[serde(tag = "role", rename_all = "lowercase")]
190pub enum ChatMessage {
191 Assistant {
192 content: Option<String>,
193 #[serde(default, skip_serializing_if = "Vec::is_empty")]
194 tool_calls: Vec<ToolCall>,
195 },
196 User {
197 content: String,
198 },
199 System {
200 content: String,
201 },
202 Tool {
203 content: String,
204 tool_call_id: String,
205 },
206}
207
208#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
209pub struct ToolCall {
210 pub id: String,
211 #[serde(flatten)]
212 pub content: ToolCallContent,
213}
214
215#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
216#[serde(tag = "type", rename_all = "lowercase")]
217pub enum ToolCallContent {
218 Function { function: FunctionContent },
219}
220
221#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
222pub struct FunctionContent {
223 pub name: String,
224 pub arguments: String,
225}
226
227#[derive(Deserialize, Debug)]
228#[serde(tag = "type", rename_all = "snake_case")]
229pub struct ResponseEvent {
230 pub choices: Vec<ResponseChoice>,
231 pub created: u64,
232 pub id: String,
233}
234
235#[derive(Debug, Deserialize)]
236pub struct ResponseChoice {
237 pub index: usize,
238 pub finish_reason: Option<String>,
239 pub delta: Option<ResponseDelta>,
240 pub message: Option<ResponseDelta>,
241}
242
243#[derive(Debug, Deserialize)]
244pub struct ResponseDelta {
245 pub content: Option<String>,
246 pub role: Option<Role>,
247 #[serde(default)]
248 pub tool_calls: Vec<ToolCallChunk>,
249}
250
251#[derive(Deserialize, Debug, Eq, PartialEq)]
252pub struct ToolCallChunk {
253 pub index: usize,
254 pub id: Option<String>,
255 pub function: Option<FunctionChunk>,
256}
257
258#[derive(Deserialize, Debug, Eq, PartialEq)]
259pub struct FunctionChunk {
260 pub name: Option<String>,
261 pub arguments: Option<String>,
262}
263
264#[derive(Deserialize)]
265struct ApiTokenResponse {
266 token: String,
267 expires_at: i64,
268}
269
270#[derive(Clone)]
271struct ApiToken {
272 api_key: String,
273 expires_at: DateTime<chrono::Utc>,
274}
275
276impl ApiToken {
277 pub fn remaining_seconds(&self) -> i64 {
278 self.expires_at
279 .timestamp()
280 .saturating_sub(chrono::Utc::now().timestamp())
281 }
282}
283
284impl TryFrom<ApiTokenResponse> for ApiToken {
285 type Error = anyhow::Error;
286
287 fn try_from(response: ApiTokenResponse) -> Result<Self, Self::Error> {
288 let expires_at = DateTime::from_timestamp(response.expires_at, 0)
289 .ok_or_else(|| anyhow!("invalid expires_at"))?;
290
291 Ok(Self {
292 api_key: response.token,
293 expires_at,
294 })
295 }
296}
297
298struct GlobalCopilotChat(gpui::Entity<CopilotChat>);
299
300impl Global for GlobalCopilotChat {}
301
302pub struct CopilotChat {
303 oauth_token: Option<String>,
304 api_token: Option<ApiToken>,
305 client: Arc<dyn HttpClient>,
306}
307
308pub fn init(fs: Arc<dyn Fs>, client: Arc<dyn HttpClient>, cx: &mut App) {
309 let copilot_chat = cx.new(|cx| CopilotChat::new(fs, client, cx));
310 cx.set_global(GlobalCopilotChat(copilot_chat));
311}
312
313pub fn copilot_chat_config_dir() -> &'static PathBuf {
314 static COPILOT_CHAT_CONFIG_DIR: OnceLock<PathBuf> = OnceLock::new();
315
316 COPILOT_CHAT_CONFIG_DIR.get_or_init(|| {
317 if cfg!(target_os = "windows") {
318 home_dir().join("AppData").join("Local")
319 } else {
320 home_dir().join(".config")
321 }
322 .join("github-copilot")
323 })
324}
325
326fn copilot_chat_config_paths() -> [PathBuf; 2] {
327 let base_dir = copilot_chat_config_dir();
328 [base_dir.join("hosts.json"), base_dir.join("apps.json")]
329}
330
331impl CopilotChat {
332 pub fn global(cx: &App) -> Option<gpui::Entity<Self>> {
333 cx.try_global::<GlobalCopilotChat>()
334 .map(|model| model.0.clone())
335 }
336
337 pub fn new(fs: Arc<dyn Fs>, client: Arc<dyn HttpClient>, cx: &App) -> Self {
338 let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
339 let dir_path = copilot_chat_config_dir();
340
341 cx.spawn(async move |cx| {
342 let mut parent_watch_rx = watch_config_dir(
343 cx.background_executor(),
344 fs.clone(),
345 dir_path.clone(),
346 config_paths,
347 );
348 while let Some(contents) = parent_watch_rx.next().await {
349 let oauth_token = extract_oauth_token(contents);
350 cx.update(|cx| {
351 if let Some(this) = Self::global(cx).as_ref() {
352 this.update(cx, |this, cx| {
353 this.oauth_token = oauth_token;
354 cx.notify();
355 });
356 }
357 })?;
358 }
359 anyhow::Ok(())
360 })
361 .detach_and_log_err(cx);
362
363 Self {
364 oauth_token: None,
365 api_token: None,
366 client,
367 }
368 }
369
370 pub fn is_authenticated(&self) -> bool {
371 self.oauth_token.is_some()
372 }
373
374 pub async fn stream_completion(
375 request: Request,
376 mut cx: AsyncApp,
377 ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
378 let Some(this) = cx.update(|cx| Self::global(cx)).ok().flatten() else {
379 return Err(anyhow!("Copilot chat is not enabled"));
380 };
381
382 let (oauth_token, api_token, client) = this.read_with(&cx, |this, _| {
383 (
384 this.oauth_token.clone(),
385 this.api_token.clone(),
386 this.client.clone(),
387 )
388 })?;
389
390 let oauth_token = oauth_token.ok_or_else(|| anyhow!("No OAuth token available"))?;
391
392 let token = match api_token {
393 Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token.clone(),
394 _ => {
395 let token = request_api_token(&oauth_token, client.clone()).await?;
396 this.update(&mut cx, |this, cx| {
397 this.api_token = Some(token.clone());
398 cx.notify();
399 })?;
400 token
401 }
402 };
403
404 stream_completion(client.clone(), token.api_key, request).await
405 }
406}
407
408async fn request_api_token(oauth_token: &str, client: Arc<dyn HttpClient>) -> Result<ApiToken> {
409 let request_builder = HttpRequest::builder()
410 .method(Method::GET)
411 .uri(COPILOT_CHAT_AUTH_URL)
412 .header("Authorization", format!("token {}", oauth_token))
413 .header("Accept", "application/json");
414
415 let request = request_builder.body(AsyncBody::empty())?;
416
417 let mut response = client.send(request).await?;
418
419 if response.status().is_success() {
420 let mut body = Vec::new();
421 response.body_mut().read_to_end(&mut body).await?;
422
423 let body_str = std::str::from_utf8(&body)?;
424
425 let parsed: ApiTokenResponse = serde_json::from_str(body_str)?;
426 ApiToken::try_from(parsed)
427 } else {
428 let mut body = Vec::new();
429 response.body_mut().read_to_end(&mut body).await?;
430
431 let body_str = std::str::from_utf8(&body)?;
432
433 Err(anyhow!("Failed to request API token: {}", body_str))
434 }
435}
436
437fn extract_oauth_token(contents: String) -> Option<String> {
438 serde_json::from_str::<serde_json::Value>(&contents)
439 .map(|v| {
440 v.as_object().and_then(|obj| {
441 obj.iter().find_map(|(key, value)| {
442 if key.starts_with("github.com") {
443 value["oauth_token"].as_str().map(|v| v.to_string())
444 } else {
445 None
446 }
447 })
448 })
449 })
450 .ok()
451 .flatten()
452}
453
454async fn stream_completion(
455 client: Arc<dyn HttpClient>,
456 api_key: String,
457 request: Request,
458) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
459 let request_builder = HttpRequest::builder()
460 .method(Method::POST)
461 .uri(COPILOT_CHAT_COMPLETION_URL)
462 .header(
463 "Editor-Version",
464 format!(
465 "Zed/{}",
466 option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
467 ),
468 )
469 .header("Authorization", format!("Bearer {}", api_key))
470 .header("Content-Type", "application/json")
471 .header("Copilot-Integration-Id", "vscode-chat");
472
473 let is_streaming = request.stream;
474
475 let json = serde_json::to_string(&request)?;
476 let request = request_builder.body(AsyncBody::from(json))?;
477 let mut response = client.send(request).await?;
478
479 if !response.status().is_success() {
480 let mut body = Vec::new();
481 response.body_mut().read_to_end(&mut body).await?;
482 let body_str = std::str::from_utf8(&body)?;
483 return Err(anyhow!(
484 "Failed to connect to API: {} {}",
485 response.status(),
486 body_str
487 ));
488 }
489
490 if is_streaming {
491 let reader = BufReader::new(response.into_body());
492 Ok(reader
493 .lines()
494 .filter_map(|line| async move {
495 match line {
496 Ok(line) => {
497 let line = line.strip_prefix("data: ")?;
498 if line.starts_with("[DONE]") {
499 return None;
500 }
501
502 match serde_json::from_str::<ResponseEvent>(line) {
503 Ok(response) => {
504 if response.choices.is_empty() {
505 None
506 } else {
507 Some(Ok(response))
508 }
509 }
510 Err(error) => Some(Err(anyhow!(error))),
511 }
512 }
513 Err(error) => Some(Err(anyhow!(error))),
514 }
515 })
516 .boxed())
517 } else {
518 let mut body = Vec::new();
519 response.body_mut().read_to_end(&mut body).await?;
520 let body_str = std::str::from_utf8(&body)?;
521 let response: ResponseEvent = serde_json::from_str(body_str)?;
522
523 Ok(futures::stream::once(async move { Ok(response) }).boxed())
524 }
525}