1use std::path::PathBuf;
2use std::sync::Arc;
3use std::sync::OnceLock;
4
5use anyhow::{Result, anyhow};
6use chrono::DateTime;
7use collections::HashSet;
8use fs::Fs;
9use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
10use gpui::{App, AsyncApp, Global, prelude::*};
11use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
12use paths::home_dir;
13use serde::{Deserialize, Serialize};
14use settings::watch_config_dir;
15use strum::EnumIter;
16
17pub const COPILOT_CHAT_COMPLETION_URL: &str = "https://api.githubcopilot.com/chat/completions";
18pub const COPILOT_CHAT_AUTH_URL: &str = "https://api.github.com/copilot_internal/v2/token";
19
20#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
21#[serde(rename_all = "lowercase")]
22pub enum Role {
23 User,
24 Assistant,
25 System,
26}
27
28#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
29#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
30pub enum Model {
31 #[default]
32 #[serde(alias = "gpt-4o", rename = "gpt-4o-2024-05-13")]
33 Gpt4o,
34 #[serde(alias = "gpt-4", rename = "gpt-4")]
35 Gpt4,
36 #[serde(alias = "gpt-4.1", rename = "gpt-4.1")]
37 Gpt4_1,
38 #[serde(alias = "gpt-3.5-turbo", rename = "gpt-3.5-turbo")]
39 Gpt3_5Turbo,
40 #[serde(alias = "o1", rename = "o1")]
41 O1,
42 #[serde(alias = "o1-mini", rename = "o3-mini")]
43 O3Mini,
44 #[serde(alias = "claude-3-5-sonnet", rename = "claude-3.5-sonnet")]
45 Claude3_5Sonnet,
46 #[serde(alias = "claude-3-7-sonnet", rename = "claude-3.7-sonnet")]
47 Claude3_7Sonnet,
48 #[serde(
49 alias = "claude-3.7-sonnet-thought",
50 rename = "claude-3.7-sonnet-thought"
51 )]
52 Claude3_7SonnetThinking,
53 #[serde(alias = "gemini-2.0-flash", rename = "gemini-2.0-flash-001")]
54 Gemini20Flash,
55}
56
57impl Model {
58 pub fn uses_streaming(&self) -> bool {
59 match self {
60 Self::Gpt4o
61 | Self::Gpt4
62 | Self::Gpt4_1
63 | Self::Gpt3_5Turbo
64 | Self::Claude3_5Sonnet
65 | Self::Claude3_7Sonnet
66 | Self::Claude3_7SonnetThinking => true,
67 Self::O3Mini | Self::O1 | Self::Gemini20Flash => false,
68 }
69 }
70
71 pub fn from_id(id: &str) -> Result<Self> {
72 match id {
73 "gpt-4o" => Ok(Self::Gpt4o),
74 "gpt-4" => Ok(Self::Gpt4),
75 "gpt-4.1" => Ok(Self::Gpt4_1),
76 "gpt-3.5-turbo" => Ok(Self::Gpt3_5Turbo),
77 "o1" => Ok(Self::O1),
78 "o3-mini" => Ok(Self::O3Mini),
79 "claude-3-5-sonnet" => Ok(Self::Claude3_5Sonnet),
80 "claude-3-7-sonnet" => Ok(Self::Claude3_7Sonnet),
81 "claude-3.7-sonnet-thought" => Ok(Self::Claude3_7SonnetThinking),
82 "gemini-2.0-flash-001" => Ok(Self::Gemini20Flash),
83 _ => Err(anyhow!("Invalid model id: {}", id)),
84 }
85 }
86
87 pub fn id(&self) -> &'static str {
88 match self {
89 Self::Gpt3_5Turbo => "gpt-3.5-turbo",
90 Self::Gpt4 => "gpt-4",
91 Self::Gpt4_1 => "gpt-4.1",
92 Self::Gpt4o => "gpt-4o",
93 Self::O3Mini => "o3-mini",
94 Self::O1 => "o1",
95 Self::Claude3_5Sonnet => "claude-3-5-sonnet",
96 Self::Claude3_7Sonnet => "claude-3-7-sonnet",
97 Self::Claude3_7SonnetThinking => "claude-3.7-sonnet-thought",
98 Self::Gemini20Flash => "gemini-2.0-flash-001",
99 }
100 }
101
102 pub fn display_name(&self) -> &'static str {
103 match self {
104 Self::Gpt3_5Turbo => "GPT-3.5",
105 Self::Gpt4 => "GPT-4",
106 Self::Gpt4_1 => "GPT-4.1",
107 Self::Gpt4o => "GPT-4o",
108 Self::O3Mini => "o3-mini",
109 Self::O1 => "o1",
110 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
111 Self::Claude3_7Sonnet => "Claude 3.7 Sonnet",
112 Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking",
113 Self::Gemini20Flash => "Gemini 2.0 Flash",
114 }
115 }
116
117 pub fn max_token_count(&self) -> usize {
118 match self {
119 Self::Gpt4o => 64_000,
120 Self::Gpt4 => 32_768,
121 Self::Gpt4_1 => 1_047_576,
122 Self::Gpt3_5Turbo => 12_288,
123 Self::O3Mini => 64_000,
124 Self::O1 => 20_000,
125 Self::Claude3_5Sonnet => 200_000,
126 Self::Claude3_7Sonnet => 90_000,
127 Self::Claude3_7SonnetThinking => 90_000,
128 Model::Gemini20Flash => 128_000,
129 }
130 }
131}
132
133#[derive(Serialize, Deserialize)]
134pub struct Request {
135 pub intent: bool,
136 pub n: usize,
137 pub stream: bool,
138 pub temperature: f32,
139 pub model: Model,
140 pub messages: Vec<ChatMessage>,
141 #[serde(default, skip_serializing_if = "Vec::is_empty")]
142 pub tools: Vec<Tool>,
143 #[serde(default, skip_serializing_if = "Option::is_none")]
144 pub tool_choice: Option<ToolChoice>,
145}
146
147#[derive(Serialize, Deserialize)]
148pub struct Function {
149 pub name: String,
150 pub description: String,
151 pub parameters: serde_json::Value,
152}
153
154#[derive(Serialize, Deserialize)]
155#[serde(tag = "type", rename_all = "snake_case")]
156pub enum Tool {
157 Function { function: Function },
158}
159
160#[derive(Serialize, Deserialize)]
161#[serde(tag = "type", rename_all = "lowercase")]
162pub enum ToolChoice {
163 Auto,
164 Any,
165 Tool { name: String },
166}
167
168#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
169#[serde(tag = "role", rename_all = "lowercase")]
170pub enum ChatMessage {
171 Assistant {
172 content: Option<String>,
173 #[serde(default, skip_serializing_if = "Vec::is_empty")]
174 tool_calls: Vec<ToolCall>,
175 },
176 User {
177 content: String,
178 },
179 System {
180 content: String,
181 },
182 Tool {
183 content: String,
184 tool_call_id: String,
185 },
186}
187
188#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
189pub struct ToolCall {
190 pub id: String,
191 #[serde(flatten)]
192 pub content: ToolCallContent,
193}
194
195#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
196#[serde(tag = "type", rename_all = "lowercase")]
197pub enum ToolCallContent {
198 Function { function: FunctionContent },
199}
200
201#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
202pub struct FunctionContent {
203 pub name: String,
204 pub arguments: String,
205}
206
207#[derive(Deserialize, Debug)]
208#[serde(tag = "type", rename_all = "snake_case")]
209pub struct ResponseEvent {
210 pub choices: Vec<ResponseChoice>,
211 pub created: u64,
212 pub id: String,
213}
214
215#[derive(Debug, Deserialize)]
216pub struct ResponseChoice {
217 pub index: usize,
218 pub finish_reason: Option<String>,
219 pub delta: Option<ResponseDelta>,
220 pub message: Option<ResponseDelta>,
221}
222
223#[derive(Debug, Deserialize)]
224pub struct ResponseDelta {
225 pub content: Option<String>,
226 pub role: Option<Role>,
227 #[serde(default)]
228 pub tool_calls: Vec<ToolCallChunk>,
229}
230
231#[derive(Deserialize, Debug, Eq, PartialEq)]
232pub struct ToolCallChunk {
233 pub index: usize,
234 pub id: Option<String>,
235 pub function: Option<FunctionChunk>,
236}
237
238#[derive(Deserialize, Debug, Eq, PartialEq)]
239pub struct FunctionChunk {
240 pub name: Option<String>,
241 pub arguments: Option<String>,
242}
243
244#[derive(Deserialize)]
245struct ApiTokenResponse {
246 token: String,
247 expires_at: i64,
248}
249
250#[derive(Clone)]
251struct ApiToken {
252 api_key: String,
253 expires_at: DateTime<chrono::Utc>,
254}
255
256impl ApiToken {
257 pub fn remaining_seconds(&self) -> i64 {
258 self.expires_at
259 .timestamp()
260 .saturating_sub(chrono::Utc::now().timestamp())
261 }
262}
263
264impl TryFrom<ApiTokenResponse> for ApiToken {
265 type Error = anyhow::Error;
266
267 fn try_from(response: ApiTokenResponse) -> Result<Self, Self::Error> {
268 let expires_at = DateTime::from_timestamp(response.expires_at, 0)
269 .ok_or_else(|| anyhow!("invalid expires_at"))?;
270
271 Ok(Self {
272 api_key: response.token,
273 expires_at,
274 })
275 }
276}
277
278struct GlobalCopilotChat(gpui::Entity<CopilotChat>);
279
280impl Global for GlobalCopilotChat {}
281
282pub struct CopilotChat {
283 oauth_token: Option<String>,
284 api_token: Option<ApiToken>,
285 client: Arc<dyn HttpClient>,
286}
287
288pub fn init(fs: Arc<dyn Fs>, client: Arc<dyn HttpClient>, cx: &mut App) {
289 let copilot_chat = cx.new(|cx| CopilotChat::new(fs, client, cx));
290 cx.set_global(GlobalCopilotChat(copilot_chat));
291}
292
293pub fn copilot_chat_config_dir() -> &'static PathBuf {
294 static COPILOT_CHAT_CONFIG_DIR: OnceLock<PathBuf> = OnceLock::new();
295
296 COPILOT_CHAT_CONFIG_DIR.get_or_init(|| {
297 if cfg!(target_os = "windows") {
298 home_dir().join("AppData").join("Local")
299 } else {
300 home_dir().join(".config")
301 }
302 .join("github-copilot")
303 })
304}
305
306fn copilot_chat_config_paths() -> [PathBuf; 2] {
307 let base_dir = copilot_chat_config_dir();
308 [base_dir.join("hosts.json"), base_dir.join("apps.json")]
309}
310
311impl CopilotChat {
312 pub fn global(cx: &App) -> Option<gpui::Entity<Self>> {
313 cx.try_global::<GlobalCopilotChat>()
314 .map(|model| model.0.clone())
315 }
316
317 pub fn new(fs: Arc<dyn Fs>, client: Arc<dyn HttpClient>, cx: &App) -> Self {
318 let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
319 let dir_path = copilot_chat_config_dir();
320
321 cx.spawn(async move |cx| {
322 let mut parent_watch_rx = watch_config_dir(
323 cx.background_executor(),
324 fs.clone(),
325 dir_path.clone(),
326 config_paths,
327 );
328 while let Some(contents) = parent_watch_rx.next().await {
329 let oauth_token = extract_oauth_token(contents);
330 cx.update(|cx| {
331 if let Some(this) = Self::global(cx).as_ref() {
332 this.update(cx, |this, cx| {
333 this.oauth_token = oauth_token;
334 cx.notify();
335 });
336 }
337 })?;
338 }
339 anyhow::Ok(())
340 })
341 .detach_and_log_err(cx);
342
343 Self {
344 oauth_token: None,
345 api_token: None,
346 client,
347 }
348 }
349
350 pub fn is_authenticated(&self) -> bool {
351 self.oauth_token.is_some()
352 }
353
354 pub async fn stream_completion(
355 request: Request,
356 mut cx: AsyncApp,
357 ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
358 let Some(this) = cx.update(|cx| Self::global(cx)).ok().flatten() else {
359 return Err(anyhow!("Copilot chat is not enabled"));
360 };
361
362 let (oauth_token, api_token, client) = this.read_with(&cx, |this, _| {
363 (
364 this.oauth_token.clone(),
365 this.api_token.clone(),
366 this.client.clone(),
367 )
368 })?;
369
370 let oauth_token = oauth_token.ok_or_else(|| anyhow!("No OAuth token available"))?;
371
372 let token = match api_token {
373 Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token.clone(),
374 _ => {
375 let token = request_api_token(&oauth_token, client.clone()).await?;
376 this.update(&mut cx, |this, cx| {
377 this.api_token = Some(token.clone());
378 cx.notify();
379 })?;
380 token
381 }
382 };
383
384 stream_completion(client.clone(), token.api_key, request).await
385 }
386}
387
388async fn request_api_token(oauth_token: &str, client: Arc<dyn HttpClient>) -> Result<ApiToken> {
389 let request_builder = HttpRequest::builder()
390 .method(Method::GET)
391 .uri(COPILOT_CHAT_AUTH_URL)
392 .header("Authorization", format!("token {}", oauth_token))
393 .header("Accept", "application/json");
394
395 let request = request_builder.body(AsyncBody::empty())?;
396
397 let mut response = client.send(request).await?;
398
399 if response.status().is_success() {
400 let mut body = Vec::new();
401 response.body_mut().read_to_end(&mut body).await?;
402
403 let body_str = std::str::from_utf8(&body)?;
404
405 let parsed: ApiTokenResponse = serde_json::from_str(body_str)?;
406 ApiToken::try_from(parsed)
407 } else {
408 let mut body = Vec::new();
409 response.body_mut().read_to_end(&mut body).await?;
410
411 let body_str = std::str::from_utf8(&body)?;
412
413 Err(anyhow!("Failed to request API token: {}", body_str))
414 }
415}
416
417fn extract_oauth_token(contents: String) -> Option<String> {
418 serde_json::from_str::<serde_json::Value>(&contents)
419 .map(|v| {
420 v.as_object().and_then(|obj| {
421 obj.iter().find_map(|(key, value)| {
422 if key.starts_with("github.com") {
423 value["oauth_token"].as_str().map(|v| v.to_string())
424 } else {
425 None
426 }
427 })
428 })
429 })
430 .ok()
431 .flatten()
432}
433
434async fn stream_completion(
435 client: Arc<dyn HttpClient>,
436 api_key: String,
437 request: Request,
438) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
439 let request_builder = HttpRequest::builder()
440 .method(Method::POST)
441 .uri(COPILOT_CHAT_COMPLETION_URL)
442 .header(
443 "Editor-Version",
444 format!(
445 "Zed/{}",
446 option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
447 ),
448 )
449 .header("Authorization", format!("Bearer {}", api_key))
450 .header("Content-Type", "application/json")
451 .header("Copilot-Integration-Id", "vscode-chat");
452
453 let is_streaming = request.stream;
454
455 let json = serde_json::to_string(&request)?;
456 let request = request_builder.body(AsyncBody::from(json))?;
457 let mut response = client.send(request).await?;
458
459 if !response.status().is_success() {
460 let mut body = Vec::new();
461 response.body_mut().read_to_end(&mut body).await?;
462 let body_str = std::str::from_utf8(&body)?;
463 return Err(anyhow!(
464 "Failed to connect to API: {} {}",
465 response.status(),
466 body_str
467 ));
468 }
469
470 if is_streaming {
471 let reader = BufReader::new(response.into_body());
472 Ok(reader
473 .lines()
474 .filter_map(|line| async move {
475 match line {
476 Ok(line) => {
477 let line = line.strip_prefix("data: ")?;
478 if line.starts_with("[DONE]") {
479 return None;
480 }
481
482 match serde_json::from_str::<ResponseEvent>(line) {
483 Ok(response) => {
484 if response.choices.is_empty() {
485 None
486 } else {
487 Some(Ok(response))
488 }
489 }
490 Err(error) => Some(Err(anyhow!(error))),
491 }
492 }
493 Err(error) => Some(Err(anyhow!(error))),
494 }
495 })
496 .boxed())
497 } else {
498 let mut body = Vec::new();
499 response.body_mut().read_to_end(&mut body).await?;
500 let body_str = std::str::from_utf8(&body)?;
501 let response: ResponseEvent = serde_json::from_str(body_str)?;
502
503 Ok(futures::stream::once(async move { Ok(response) }).boxed())
504 }
505}