1use std::path::PathBuf;
2use std::sync::Arc;
3use std::sync::OnceLock;
4
5use anyhow::{Result, anyhow};
6use chrono::DateTime;
7use collections::HashSet;
8use fs::Fs;
9use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
10use gpui::{App, AsyncApp, Global, prelude::*};
11use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
12use paths::home_dir;
13use serde::{Deserialize, Serialize};
14use settings::watch_config_dir;
15use strum::EnumIter;
16
17pub const COPILOT_CHAT_COMPLETION_URL: &str = "https://api.githubcopilot.com/chat/completions";
18pub const COPILOT_CHAT_AUTH_URL: &str = "https://api.github.com/copilot_internal/v2/token";
19
20#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
21#[serde(rename_all = "lowercase")]
22pub enum Role {
23 User,
24 Assistant,
25 System,
26}
27
28#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
29#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
30pub enum Model {
31 #[default]
32 #[serde(alias = "gpt-4o", rename = "gpt-4o-2024-05-13")]
33 Gpt4o,
34 #[serde(alias = "gpt-4", rename = "gpt-4")]
35 Gpt4,
36 #[serde(alias = "gpt-3.5-turbo", rename = "gpt-3.5-turbo")]
37 Gpt3_5Turbo,
38 #[serde(alias = "o1", rename = "o1")]
39 O1,
40 #[serde(alias = "o1-mini", rename = "o3-mini")]
41 O3Mini,
42 #[serde(alias = "claude-3-5-sonnet", rename = "claude-3.5-sonnet")]
43 Claude3_5Sonnet,
44 #[serde(alias = "claude-3-7-sonnet", rename = "claude-3.7-sonnet")]
45 Claude3_7Sonnet,
46 #[serde(
47 alias = "claude-3.7-sonnet-thought",
48 rename = "claude-3.7-sonnet-thought"
49 )]
50 Claude3_7SonnetThinking,
51 #[serde(alias = "gemini-2.0-flash", rename = "gemini-2.0-flash-001")]
52 Gemini20Flash,
53 #[serde(alias = "gemini-2.5-pro", rename = "gemini-2.5-pro")]
54 Gemini25Pro,
55}
56
57impl Model {
58 pub fn uses_streaming(&self) -> bool {
59 match self {
60 Self::Gpt4o
61 | Self::Gpt4
62 | Self::Gpt3_5Turbo
63 | Self::Claude3_5Sonnet
64 | Self::Claude3_7Sonnet
65 | Self::Claude3_7SonnetThinking => true,
66 Self::O3Mini | Self::O1 | Self::Gemini20Flash | Self::Gemini25Pro => false,
67 }
68 }
69
70 pub fn from_id(id: &str) -> Result<Self> {
71 match id {
72 "gpt-4o" => Ok(Self::Gpt4o),
73 "gpt-4" => Ok(Self::Gpt4),
74 "gpt-3.5-turbo" => Ok(Self::Gpt3_5Turbo),
75 "o1" => Ok(Self::O1),
76 "o3-mini" => Ok(Self::O3Mini),
77 "claude-3-5-sonnet" => Ok(Self::Claude3_5Sonnet),
78 "claude-3-7-sonnet" => Ok(Self::Claude3_7Sonnet),
79 "claude-3.7-sonnet-thought" => Ok(Self::Claude3_7SonnetThinking),
80 "gemini-2.0-flash-001" => Ok(Self::Gemini20Flash),
81 "gemini-2.5-pro" => Ok(Self::Gemini25Pro),
82 _ => Err(anyhow!("Invalid model id: {}", id)),
83 }
84 }
85
86 pub fn id(&self) -> &'static str {
87 match self {
88 Self::Gpt3_5Turbo => "gpt-3.5-turbo",
89 Self::Gpt4 => "gpt-4",
90 Self::Gpt4o => "gpt-4o",
91 Self::O3Mini => "o3-mini",
92 Self::O1 => "o1",
93 Self::Claude3_5Sonnet => "claude-3-5-sonnet",
94 Self::Claude3_7Sonnet => "claude-3-7-sonnet",
95 Self::Claude3_7SonnetThinking => "claude-3.7-sonnet-thought",
96 Self::Gemini20Flash => "gemini-2.0-flash-001",
97 Self::Gemini25Pro => "gemini-2.5-pro",
98 }
99 }
100
101 pub fn display_name(&self) -> &'static str {
102 match self {
103 Self::Gpt3_5Turbo => "GPT-3.5",
104 Self::Gpt4 => "GPT-4",
105 Self::Gpt4o => "GPT-4o",
106 Self::O3Mini => "o3-mini",
107 Self::O1 => "o1",
108 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
109 Self::Claude3_7Sonnet => "Claude 3.7 Sonnet",
110 Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking",
111 Self::Gemini20Flash => "Gemini 2.0 Flash",
112 Self::Gemini25Pro => "Gemini 2.5 Pro",
113 }
114 }
115
116 pub fn max_token_count(&self) -> usize {
117 match self {
118 Self::Gpt4o => 64_000,
119 Self::Gpt4 => 32_768,
120 Self::Gpt3_5Turbo => 12_288,
121 Self::O3Mini => 64_000,
122 Self::O1 => 20_000,
123 Self::Claude3_5Sonnet => 200_000,
124 Self::Claude3_7Sonnet => 90_000,
125 Self::Claude3_7SonnetThinking => 90_000,
126 Self::Gemini20Flash => 128_000,
127 Self::Gemini25Pro => 128_000,
128 }
129 }
130}
131
132#[derive(Serialize, Deserialize)]
133pub struct Request {
134 pub intent: bool,
135 pub n: usize,
136 pub stream: bool,
137 pub temperature: f32,
138 pub model: Model,
139 pub messages: Vec<ChatMessage>,
140 #[serde(default, skip_serializing_if = "Vec::is_empty")]
141 pub tools: Vec<Tool>,
142 #[serde(default, skip_serializing_if = "Option::is_none")]
143 pub tool_choice: Option<ToolChoice>,
144}
145
146#[derive(Serialize, Deserialize)]
147pub struct Function {
148 pub name: String,
149 pub description: String,
150 pub parameters: serde_json::Value,
151}
152
153#[derive(Serialize, Deserialize)]
154#[serde(tag = "type", rename_all = "snake_case")]
155pub enum Tool {
156 Function { function: Function },
157}
158
159#[derive(Serialize, Deserialize)]
160#[serde(tag = "type", rename_all = "lowercase")]
161pub enum ToolChoice {
162 Auto,
163 Any,
164 Tool { name: String },
165}
166
167#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
168#[serde(tag = "role", rename_all = "lowercase")]
169pub enum ChatMessage {
170 Assistant {
171 content: Option<String>,
172 #[serde(default, skip_serializing_if = "Vec::is_empty")]
173 tool_calls: Vec<ToolCall>,
174 },
175 User {
176 content: String,
177 },
178 System {
179 content: String,
180 },
181 Tool {
182 content: String,
183 tool_call_id: String,
184 },
185}
186
187#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
188pub struct ToolCall {
189 pub id: String,
190 #[serde(flatten)]
191 pub content: ToolCallContent,
192}
193
194#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
195#[serde(tag = "type", rename_all = "lowercase")]
196pub enum ToolCallContent {
197 Function { function: FunctionContent },
198}
199
200#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
201pub struct FunctionContent {
202 pub name: String,
203 pub arguments: String,
204}
205
206#[derive(Deserialize, Debug)]
207#[serde(tag = "type", rename_all = "snake_case")]
208pub struct ResponseEvent {
209 pub choices: Vec<ResponseChoice>,
210 pub created: u64,
211 pub id: String,
212}
213
214#[derive(Debug, Deserialize)]
215pub struct ResponseChoice {
216 pub index: usize,
217 pub finish_reason: Option<String>,
218 pub delta: Option<ResponseDelta>,
219 pub message: Option<ResponseDelta>,
220}
221
222#[derive(Debug, Deserialize)]
223pub struct ResponseDelta {
224 pub content: Option<String>,
225 pub role: Option<Role>,
226 #[serde(default)]
227 pub tool_calls: Vec<ToolCallChunk>,
228}
229
230#[derive(Deserialize, Debug, Eq, PartialEq)]
231pub struct ToolCallChunk {
232 pub index: usize,
233 pub id: Option<String>,
234 pub function: Option<FunctionChunk>,
235}
236
237#[derive(Deserialize, Debug, Eq, PartialEq)]
238pub struct FunctionChunk {
239 pub name: Option<String>,
240 pub arguments: Option<String>,
241}
242
243#[derive(Deserialize)]
244struct ApiTokenResponse {
245 token: String,
246 expires_at: i64,
247}
248
249#[derive(Clone)]
250struct ApiToken {
251 api_key: String,
252 expires_at: DateTime<chrono::Utc>,
253}
254
255impl ApiToken {
256 pub fn remaining_seconds(&self) -> i64 {
257 self.expires_at
258 .timestamp()
259 .saturating_sub(chrono::Utc::now().timestamp())
260 }
261}
262
263impl TryFrom<ApiTokenResponse> for ApiToken {
264 type Error = anyhow::Error;
265
266 fn try_from(response: ApiTokenResponse) -> Result<Self, Self::Error> {
267 let expires_at = DateTime::from_timestamp(response.expires_at, 0)
268 .ok_or_else(|| anyhow!("invalid expires_at"))?;
269
270 Ok(Self {
271 api_key: response.token,
272 expires_at,
273 })
274 }
275}
276
277struct GlobalCopilotChat(gpui::Entity<CopilotChat>);
278
279impl Global for GlobalCopilotChat {}
280
281pub struct CopilotChat {
282 oauth_token: Option<String>,
283 api_token: Option<ApiToken>,
284 client: Arc<dyn HttpClient>,
285}
286
287pub fn init(fs: Arc<dyn Fs>, client: Arc<dyn HttpClient>, cx: &mut App) {
288 let copilot_chat = cx.new(|cx| CopilotChat::new(fs, client, cx));
289 cx.set_global(GlobalCopilotChat(copilot_chat));
290}
291
292pub fn copilot_chat_config_dir() -> &'static PathBuf {
293 static COPILOT_CHAT_CONFIG_DIR: OnceLock<PathBuf> = OnceLock::new();
294
295 COPILOT_CHAT_CONFIG_DIR.get_or_init(|| {
296 if cfg!(target_os = "windows") {
297 home_dir().join("AppData").join("Local")
298 } else {
299 home_dir().join(".config")
300 }
301 .join("github-copilot")
302 })
303}
304
305fn copilot_chat_config_paths() -> [PathBuf; 2] {
306 let base_dir = copilot_chat_config_dir();
307 [base_dir.join("hosts.json"), base_dir.join("apps.json")]
308}
309
310impl CopilotChat {
311 pub fn global(cx: &App) -> Option<gpui::Entity<Self>> {
312 cx.try_global::<GlobalCopilotChat>()
313 .map(|model| model.0.clone())
314 }
315
316 pub fn new(fs: Arc<dyn Fs>, client: Arc<dyn HttpClient>, cx: &App) -> Self {
317 let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
318 let dir_path = copilot_chat_config_dir();
319
320 cx.spawn(async move |cx| {
321 let mut parent_watch_rx = watch_config_dir(
322 cx.background_executor(),
323 fs.clone(),
324 dir_path.clone(),
325 config_paths,
326 );
327 while let Some(contents) = parent_watch_rx.next().await {
328 let oauth_token = extract_oauth_token(contents);
329 cx.update(|cx| {
330 if let Some(this) = Self::global(cx).as_ref() {
331 this.update(cx, |this, cx| {
332 this.oauth_token = oauth_token;
333 cx.notify();
334 });
335 }
336 })?;
337 }
338 anyhow::Ok(())
339 })
340 .detach_and_log_err(cx);
341
342 Self {
343 oauth_token: None,
344 api_token: None,
345 client,
346 }
347 }
348
349 pub fn is_authenticated(&self) -> bool {
350 self.oauth_token.is_some()
351 }
352
353 pub async fn stream_completion(
354 request: Request,
355 mut cx: AsyncApp,
356 ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
357 let Some(this) = cx.update(|cx| Self::global(cx)).ok().flatten() else {
358 return Err(anyhow!("Copilot chat is not enabled"));
359 };
360
361 let (oauth_token, api_token, client) = this.read_with(&cx, |this, _| {
362 (
363 this.oauth_token.clone(),
364 this.api_token.clone(),
365 this.client.clone(),
366 )
367 })?;
368
369 let oauth_token = oauth_token.ok_or_else(|| anyhow!("No OAuth token available"))?;
370
371 let token = match api_token {
372 Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token.clone(),
373 _ => {
374 let token = request_api_token(&oauth_token, client.clone()).await?;
375 this.update(&mut cx, |this, cx| {
376 this.api_token = Some(token.clone());
377 cx.notify();
378 })?;
379 token
380 }
381 };
382
383 stream_completion(client.clone(), token.api_key, request).await
384 }
385}
386
387async fn request_api_token(oauth_token: &str, client: Arc<dyn HttpClient>) -> Result<ApiToken> {
388 let request_builder = HttpRequest::builder()
389 .method(Method::GET)
390 .uri(COPILOT_CHAT_AUTH_URL)
391 .header("Authorization", format!("token {}", oauth_token))
392 .header("Accept", "application/json");
393
394 let request = request_builder.body(AsyncBody::empty())?;
395
396 let mut response = client.send(request).await?;
397
398 if response.status().is_success() {
399 let mut body = Vec::new();
400 response.body_mut().read_to_end(&mut body).await?;
401
402 let body_str = std::str::from_utf8(&body)?;
403
404 let parsed: ApiTokenResponse = serde_json::from_str(body_str)?;
405 ApiToken::try_from(parsed)
406 } else {
407 let mut body = Vec::new();
408 response.body_mut().read_to_end(&mut body).await?;
409
410 let body_str = std::str::from_utf8(&body)?;
411
412 Err(anyhow!("Failed to request API token: {}", body_str))
413 }
414}
415
416fn extract_oauth_token(contents: String) -> Option<String> {
417 serde_json::from_str::<serde_json::Value>(&contents)
418 .map(|v| {
419 v.as_object().and_then(|obj| {
420 obj.iter().find_map(|(key, value)| {
421 if key.starts_with("github.com") {
422 value["oauth_token"].as_str().map(|v| v.to_string())
423 } else {
424 None
425 }
426 })
427 })
428 })
429 .ok()
430 .flatten()
431}
432
433async fn stream_completion(
434 client: Arc<dyn HttpClient>,
435 api_key: String,
436 request: Request,
437) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
438 let request_builder = HttpRequest::builder()
439 .method(Method::POST)
440 .uri(COPILOT_CHAT_COMPLETION_URL)
441 .header(
442 "Editor-Version",
443 format!(
444 "Zed/{}",
445 option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
446 ),
447 )
448 .header("Authorization", format!("Bearer {}", api_key))
449 .header("Content-Type", "application/json")
450 .header("Copilot-Integration-Id", "vscode-chat");
451
452 let is_streaming = request.stream;
453
454 let json = serde_json::to_string(&request)?;
455 let request = request_builder.body(AsyncBody::from(json))?;
456 let mut response = client.send(request).await?;
457
458 if !response.status().is_success() {
459 let mut body = Vec::new();
460 response.body_mut().read_to_end(&mut body).await?;
461 let body_str = std::str::from_utf8(&body)?;
462 return Err(anyhow!(
463 "Failed to connect to API: {} {}",
464 response.status(),
465 body_str
466 ));
467 }
468
469 if is_streaming {
470 let reader = BufReader::new(response.into_body());
471 Ok(reader
472 .lines()
473 .filter_map(|line| async move {
474 match line {
475 Ok(line) => {
476 let line = line.strip_prefix("data: ")?;
477 if line.starts_with("[DONE]") {
478 return None;
479 }
480
481 match serde_json::from_str::<ResponseEvent>(line) {
482 Ok(response) => {
483 if response.choices.is_empty() {
484 None
485 } else {
486 Some(Ok(response))
487 }
488 }
489 Err(error) => Some(Err(anyhow!(error))),
490 }
491 }
492 Err(error) => Some(Err(anyhow!(error))),
493 }
494 })
495 .boxed())
496 } else {
497 let mut body = Vec::new();
498 response.body_mut().read_to_end(&mut body).await?;
499 let body_str = std::str::from_utf8(&body)?;
500 let response: ResponseEvent = serde_json::from_str(body_str)?;
501
502 Ok(futures::stream::once(async move { Ok(response) }).boxed())
503 }
504}