1use std::path::PathBuf;
2use std::sync::Arc;
3use std::sync::OnceLock;
4
5use anyhow::{Result, anyhow};
6use chrono::DateTime;
7use collections::HashSet;
8use fs::Fs;
9use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
10use gpui::{App, AsyncApp, Global, prelude::*};
11use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
12use paths::home_dir;
13use serde::{Deserialize, Serialize};
14use settings::watch_config_dir;
15use strum::EnumIter;
16
17pub const COPILOT_CHAT_COMPLETION_URL: &str = "https://api.githubcopilot.com/chat/completions";
18pub const COPILOT_CHAT_AUTH_URL: &str = "https://api.github.com/copilot_internal/v2/token";
19
20#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
21#[serde(rename_all = "lowercase")]
22pub enum Role {
23 User,
24 Assistant,
25 System,
26}
27
28#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
29#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
30pub enum Model {
31 #[default]
32 #[serde(alias = "gpt-4o", rename = "gpt-4o-2024-05-13")]
33 Gpt4o,
34 #[serde(alias = "gpt-4", rename = "gpt-4")]
35 Gpt4,
36 #[serde(alias = "gpt-3.5-turbo", rename = "gpt-3.5-turbo")]
37 Gpt3_5Turbo,
38 #[serde(alias = "o1", rename = "o1")]
39 O1,
40 #[serde(alias = "o1-mini", rename = "o3-mini")]
41 O3Mini,
42 #[serde(alias = "claude-3-5-sonnet", rename = "claude-3.5-sonnet")]
43 Claude3_5Sonnet,
44 #[serde(alias = "claude-3-7-sonnet", rename = "claude-3.7-sonnet")]
45 Claude3_7Sonnet,
46 #[serde(
47 alias = "claude-3.7-sonnet-thought",
48 rename = "claude-3.7-sonnet-thought"
49 )]
50 Claude3_7SonnetThinking,
51 #[serde(alias = "gemini-2.0-flash", rename = "gemini-2.0-flash-001")]
52 Gemini20Flash,
53}
54
55impl Model {
56 pub fn uses_streaming(&self) -> bool {
57 match self {
58 Self::Gpt4o
59 | Self::Gpt4
60 | Self::Gpt3_5Turbo
61 | Self::Claude3_5Sonnet
62 | Self::Claude3_7Sonnet
63 | Self::Claude3_7SonnetThinking => true,
64 Self::O3Mini | Self::O1 | Self::Gemini20Flash => false,
65 }
66 }
67
68 pub fn from_id(id: &str) -> Result<Self> {
69 match id {
70 "gpt-4o" => Ok(Self::Gpt4o),
71 "gpt-4" => Ok(Self::Gpt4),
72 "gpt-3.5-turbo" => Ok(Self::Gpt3_5Turbo),
73 "o1" => Ok(Self::O1),
74 "o3-mini" => Ok(Self::O3Mini),
75 "claude-3-5-sonnet" => Ok(Self::Claude3_5Sonnet),
76 "claude-3-7-sonnet" => Ok(Self::Claude3_7Sonnet),
77 "claude-3.7-sonnet-thought" => Ok(Self::Claude3_7SonnetThinking),
78 "gemini-2.0-flash-001" => Ok(Self::Gemini20Flash),
79 _ => Err(anyhow!("Invalid model id: {}", id)),
80 }
81 }
82
83 pub fn id(&self) -> &'static str {
84 match self {
85 Self::Gpt3_5Turbo => "gpt-3.5-turbo",
86 Self::Gpt4 => "gpt-4",
87 Self::Gpt4o => "gpt-4o",
88 Self::O3Mini => "o3-mini",
89 Self::O1 => "o1",
90 Self::Claude3_5Sonnet => "claude-3-5-sonnet",
91 Self::Claude3_7Sonnet => "claude-3-7-sonnet",
92 Self::Claude3_7SonnetThinking => "claude-3.7-sonnet-thought",
93 Self::Gemini20Flash => "gemini-2.0-flash-001",
94 }
95 }
96
97 pub fn display_name(&self) -> &'static str {
98 match self {
99 Self::Gpt3_5Turbo => "GPT-3.5",
100 Self::Gpt4 => "GPT-4",
101 Self::Gpt4o => "GPT-4o",
102 Self::O3Mini => "o3-mini",
103 Self::O1 => "o1",
104 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
105 Self::Claude3_7Sonnet => "Claude 3.7 Sonnet",
106 Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking",
107 Self::Gemini20Flash => "Gemini 2.0 Flash",
108 }
109 }
110
111 pub fn max_token_count(&self) -> usize {
112 match self {
113 Self::Gpt4o => 64_000,
114 Self::Gpt4 => 32_768,
115 Self::Gpt3_5Turbo => 12_288,
116 Self::O3Mini => 64_000,
117 Self::O1 => 20_000,
118 Self::Claude3_5Sonnet => 200_000,
119 Self::Claude3_7Sonnet => 90_000,
120 Self::Claude3_7SonnetThinking => 90_000,
121 Model::Gemini20Flash => 128_000,
122 }
123 }
124}
125
126#[derive(Serialize, Deserialize)]
127pub struct Request {
128 pub intent: bool,
129 pub n: usize,
130 pub stream: bool,
131 pub temperature: f32,
132 pub model: Model,
133 pub messages: Vec<ChatMessage>,
134 #[serde(default, skip_serializing_if = "Vec::is_empty")]
135 pub tools: Vec<Tool>,
136 #[serde(default, skip_serializing_if = "Option::is_none")]
137 pub tool_choice: Option<ToolChoice>,
138}
139
140#[derive(Serialize, Deserialize)]
141pub struct Function {
142 pub name: String,
143 pub description: String,
144 pub parameters: serde_json::Value,
145}
146
147#[derive(Serialize, Deserialize)]
148#[serde(tag = "type", rename_all = "snake_case")]
149pub enum Tool {
150 Function { function: Function },
151}
152
153#[derive(Serialize, Deserialize)]
154#[serde(tag = "type", rename_all = "lowercase")]
155pub enum ToolChoice {
156 Auto,
157 Any,
158 Tool { name: String },
159}
160
161#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
162#[serde(tag = "role", rename_all = "lowercase")]
163pub enum ChatMessage {
164 Assistant {
165 content: Option<String>,
166 #[serde(default, skip_serializing_if = "Vec::is_empty")]
167 tool_calls: Vec<ToolCall>,
168 },
169 User {
170 content: String,
171 },
172 System {
173 content: String,
174 },
175 Tool {
176 content: String,
177 tool_call_id: String,
178 },
179}
180
181#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
182pub struct ToolCall {
183 pub id: String,
184 #[serde(flatten)]
185 pub content: ToolCallContent,
186}
187
188#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
189#[serde(tag = "type", rename_all = "lowercase")]
190pub enum ToolCallContent {
191 Function { function: FunctionContent },
192}
193
194#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
195pub struct FunctionContent {
196 pub name: String,
197 pub arguments: String,
198}
199
200#[derive(Deserialize, Debug)]
201#[serde(tag = "type", rename_all = "snake_case")]
202pub struct ResponseEvent {
203 pub choices: Vec<ResponseChoice>,
204 pub created: u64,
205 pub id: String,
206}
207
208#[derive(Debug, Deserialize)]
209pub struct ResponseChoice {
210 pub index: usize,
211 pub finish_reason: Option<String>,
212 pub delta: Option<ResponseDelta>,
213 pub message: Option<ResponseDelta>,
214}
215
216#[derive(Debug, Deserialize)]
217pub struct ResponseDelta {
218 pub content: Option<String>,
219 pub role: Option<Role>,
220 #[serde(default)]
221 pub tool_calls: Vec<ToolCallChunk>,
222}
223
224#[derive(Deserialize, Debug, Eq, PartialEq)]
225pub struct ToolCallChunk {
226 pub index: usize,
227 pub id: Option<String>,
228 pub function: Option<FunctionChunk>,
229}
230
231#[derive(Deserialize, Debug, Eq, PartialEq)]
232pub struct FunctionChunk {
233 pub name: Option<String>,
234 pub arguments: Option<String>,
235}
236
237#[derive(Deserialize)]
238struct ApiTokenResponse {
239 token: String,
240 expires_at: i64,
241}
242
243#[derive(Clone)]
244struct ApiToken {
245 api_key: String,
246 expires_at: DateTime<chrono::Utc>,
247}
248
249impl ApiToken {
250 pub fn remaining_seconds(&self) -> i64 {
251 self.expires_at
252 .timestamp()
253 .saturating_sub(chrono::Utc::now().timestamp())
254 }
255}
256
257impl TryFrom<ApiTokenResponse> for ApiToken {
258 type Error = anyhow::Error;
259
260 fn try_from(response: ApiTokenResponse) -> Result<Self, Self::Error> {
261 let expires_at = DateTime::from_timestamp(response.expires_at, 0)
262 .ok_or_else(|| anyhow!("invalid expires_at"))?;
263
264 Ok(Self {
265 api_key: response.token,
266 expires_at,
267 })
268 }
269}
270
271struct GlobalCopilotChat(gpui::Entity<CopilotChat>);
272
273impl Global for GlobalCopilotChat {}
274
275pub struct CopilotChat {
276 oauth_token: Option<String>,
277 api_token: Option<ApiToken>,
278 client: Arc<dyn HttpClient>,
279}
280
281pub fn init(fs: Arc<dyn Fs>, client: Arc<dyn HttpClient>, cx: &mut App) {
282 let copilot_chat = cx.new(|cx| CopilotChat::new(fs, client, cx));
283 cx.set_global(GlobalCopilotChat(copilot_chat));
284}
285
286pub fn copilot_chat_config_dir() -> &'static PathBuf {
287 static COPILOT_CHAT_CONFIG_DIR: OnceLock<PathBuf> = OnceLock::new();
288
289 COPILOT_CHAT_CONFIG_DIR.get_or_init(|| {
290 if cfg!(target_os = "windows") {
291 home_dir().join("AppData").join("Local")
292 } else {
293 home_dir().join(".config")
294 }
295 .join("github-copilot")
296 })
297}
298
299fn copilot_chat_config_paths() -> [PathBuf; 2] {
300 let base_dir = copilot_chat_config_dir();
301 [base_dir.join("hosts.json"), base_dir.join("apps.json")]
302}
303
304impl CopilotChat {
305 pub fn global(cx: &App) -> Option<gpui::Entity<Self>> {
306 cx.try_global::<GlobalCopilotChat>()
307 .map(|model| model.0.clone())
308 }
309
310 pub fn new(fs: Arc<dyn Fs>, client: Arc<dyn HttpClient>, cx: &App) -> Self {
311 let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
312 let dir_path = copilot_chat_config_dir();
313
314 cx.spawn(async move |cx| {
315 let mut parent_watch_rx = watch_config_dir(
316 cx.background_executor(),
317 fs.clone(),
318 dir_path.clone(),
319 config_paths,
320 );
321 while let Some(contents) = parent_watch_rx.next().await {
322 let oauth_token = extract_oauth_token(contents);
323 cx.update(|cx| {
324 if let Some(this) = Self::global(cx).as_ref() {
325 this.update(cx, |this, cx| {
326 this.oauth_token = oauth_token;
327 cx.notify();
328 });
329 }
330 })?;
331 }
332 anyhow::Ok(())
333 })
334 .detach_and_log_err(cx);
335
336 Self {
337 oauth_token: None,
338 api_token: None,
339 client,
340 }
341 }
342
343 pub fn is_authenticated(&self) -> bool {
344 self.oauth_token.is_some()
345 }
346
347 pub async fn stream_completion(
348 request: Request,
349 mut cx: AsyncApp,
350 ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
351 let Some(this) = cx.update(|cx| Self::global(cx)).ok().flatten() else {
352 return Err(anyhow!("Copilot chat is not enabled"));
353 };
354
355 let (oauth_token, api_token, client) = this.read_with(&cx, |this, _| {
356 (
357 this.oauth_token.clone(),
358 this.api_token.clone(),
359 this.client.clone(),
360 )
361 })?;
362
363 let oauth_token = oauth_token.ok_or_else(|| anyhow!("No OAuth token available"))?;
364
365 let token = match api_token {
366 Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token.clone(),
367 _ => {
368 let token = request_api_token(&oauth_token, client.clone()).await?;
369 this.update(&mut cx, |this, cx| {
370 this.api_token = Some(token.clone());
371 cx.notify();
372 })?;
373 token
374 }
375 };
376
377 stream_completion(client.clone(), token.api_key, request).await
378 }
379}
380
381async fn request_api_token(oauth_token: &str, client: Arc<dyn HttpClient>) -> Result<ApiToken> {
382 let request_builder = HttpRequest::builder()
383 .method(Method::GET)
384 .uri(COPILOT_CHAT_AUTH_URL)
385 .header("Authorization", format!("token {}", oauth_token))
386 .header("Accept", "application/json");
387
388 let request = request_builder.body(AsyncBody::empty())?;
389
390 let mut response = client.send(request).await?;
391
392 if response.status().is_success() {
393 let mut body = Vec::new();
394 response.body_mut().read_to_end(&mut body).await?;
395
396 let body_str = std::str::from_utf8(&body)?;
397
398 let parsed: ApiTokenResponse = serde_json::from_str(body_str)?;
399 ApiToken::try_from(parsed)
400 } else {
401 let mut body = Vec::new();
402 response.body_mut().read_to_end(&mut body).await?;
403
404 let body_str = std::str::from_utf8(&body)?;
405
406 Err(anyhow!("Failed to request API token: {}", body_str))
407 }
408}
409
410fn extract_oauth_token(contents: String) -> Option<String> {
411 serde_json::from_str::<serde_json::Value>(&contents)
412 .map(|v| {
413 v.as_object().and_then(|obj| {
414 obj.iter().find_map(|(key, value)| {
415 if key.starts_with("github.com") {
416 value["oauth_token"].as_str().map(|v| v.to_string())
417 } else {
418 None
419 }
420 })
421 })
422 })
423 .ok()
424 .flatten()
425}
426
427async fn stream_completion(
428 client: Arc<dyn HttpClient>,
429 api_key: String,
430 request: Request,
431) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
432 let request_builder = HttpRequest::builder()
433 .method(Method::POST)
434 .uri(COPILOT_CHAT_COMPLETION_URL)
435 .header(
436 "Editor-Version",
437 format!(
438 "Zed/{}",
439 option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
440 ),
441 )
442 .header("Authorization", format!("Bearer {}", api_key))
443 .header("Content-Type", "application/json")
444 .header("Copilot-Integration-Id", "vscode-chat");
445
446 let is_streaming = request.stream;
447
448 let json = serde_json::to_string(&request)?;
449 let request = request_builder.body(AsyncBody::from(json))?;
450 let mut response = client.send(request).await?;
451
452 if !response.status().is_success() {
453 let mut body = Vec::new();
454 response.body_mut().read_to_end(&mut body).await?;
455 let body_str = std::str::from_utf8(&body)?;
456 return Err(anyhow!(
457 "Failed to connect to API: {} {}",
458 response.status(),
459 body_str
460 ));
461 }
462
463 if is_streaming {
464 let reader = BufReader::new(response.into_body());
465 Ok(reader
466 .lines()
467 .filter_map(|line| async move {
468 match line {
469 Ok(line) => {
470 let line = line.strip_prefix("data: ")?;
471 if line.starts_with("[DONE]") {
472 return None;
473 }
474
475 match serde_json::from_str::<ResponseEvent>(line) {
476 Ok(response) => {
477 if response.choices.is_empty() {
478 None
479 } else {
480 Some(Ok(response))
481 }
482 }
483 Err(error) => Some(Err(anyhow!(error))),
484 }
485 }
486 Err(error) => Some(Err(anyhow!(error))),
487 }
488 })
489 .boxed())
490 } else {
491 let mut body = Vec::new();
492 response.body_mut().read_to_end(&mut body).await?;
493 let body_str = std::str::from_utf8(&body)?;
494 let response: ResponseEvent = serde_json::from_str(body_str)?;
495
496 Ok(futures::stream::once(async move { Ok(response) }).boxed())
497 }
498}