1use std::path::PathBuf;
2use std::sync::Arc;
3use std::sync::OnceLock;
4
5use anyhow::{Result, anyhow};
6use chrono::DateTime;
7use collections::HashSet;
8use fs::Fs;
9use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
10use gpui::{App, AsyncApp, Global, prelude::*};
11use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
12use paths::home_dir;
13use serde::{Deserialize, Serialize};
14use settings::watch_config_dir;
15use strum::EnumIter;
16
17pub const COPILOT_CHAT_COMPLETION_URL: &str = "https://api.githubcopilot.com/chat/completions";
18pub const COPILOT_CHAT_AUTH_URL: &str = "https://api.github.com/copilot_internal/v2/token";
19
20#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
21#[serde(rename_all = "lowercase")]
22pub enum Role {
23 User,
24 Assistant,
25 System,
26}
27
28#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
29#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
30pub enum Model {
31 #[default]
32 #[serde(alias = "gpt-4o", rename = "gpt-4o-2024-05-13")]
33 Gpt4o,
34 #[serde(alias = "gpt-4", rename = "gpt-4")]
35 Gpt4,
36 #[serde(alias = "gpt-3.5-turbo", rename = "gpt-3.5-turbo")]
37 Gpt3_5Turbo,
38 #[serde(alias = "o1", rename = "o1")]
39 O1,
40 #[serde(alias = "o1-mini", rename = "o3-mini")]
41 O3Mini,
42 #[serde(alias = "claude-3-5-sonnet", rename = "claude-3.5-sonnet")]
43 Claude3_5Sonnet,
44 #[serde(alias = "claude-3-7-sonnet", rename = "claude-3.7-sonnet")]
45 Claude3_7Sonnet,
46 #[serde(
47 alias = "claude-3.7-sonnet-thought",
48 rename = "claude-3.7-sonnet-thought"
49 )]
50 Claude3_7SonnetThinking,
51 #[serde(alias = "gemini-2.0-flash", rename = "gemini-2.0-flash-001")]
52 Gemini20Flash,
53}
54
55impl Model {
56 pub fn uses_streaming(&self) -> bool {
57 match self {
58 Self::Gpt4o
59 | Self::Gpt4
60 | Self::Gpt3_5Turbo
61 | Self::Claude3_5Sonnet
62 | Self::Claude3_7Sonnet
63 | Self::Claude3_7SonnetThinking => true,
64 Self::O3Mini | Self::O1 | Self::Gemini20Flash => false,
65 }
66 }
67
68 pub fn from_id(id: &str) -> Result<Self> {
69 match id {
70 "gpt-4o" => Ok(Self::Gpt4o),
71 "gpt-4" => Ok(Self::Gpt4),
72 "gpt-3.5-turbo" => Ok(Self::Gpt3_5Turbo),
73 "o1" => Ok(Self::O1),
74 "o3-mini" => Ok(Self::O3Mini),
75 "claude-3-5-sonnet" => Ok(Self::Claude3_5Sonnet),
76 "claude-3-7-sonnet" => Ok(Self::Claude3_7Sonnet),
77 "claude-3.7-sonnet-thought" => Ok(Self::Claude3_7SonnetThinking),
78 "gemini-2.0-flash-001" => Ok(Self::Gemini20Flash),
79 _ => Err(anyhow!("Invalid model id: {}", id)),
80 }
81 }
82
83 pub fn id(&self) -> &'static str {
84 match self {
85 Self::Gpt3_5Turbo => "gpt-3.5-turbo",
86 Self::Gpt4 => "gpt-4",
87 Self::Gpt4o => "gpt-4o",
88 Self::O3Mini => "o3-mini",
89 Self::O1 => "o1",
90 Self::Claude3_5Sonnet => "claude-3-5-sonnet",
91 Self::Claude3_7Sonnet => "claude-3-7-sonnet",
92 Self::Claude3_7SonnetThinking => "claude-3.7-sonnet-thought",
93 Self::Gemini20Flash => "gemini-2.0-flash-001",
94 }
95 }
96
97 pub fn display_name(&self) -> &'static str {
98 match self {
99 Self::Gpt3_5Turbo => "GPT-3.5",
100 Self::Gpt4 => "GPT-4",
101 Self::Gpt4o => "GPT-4o",
102 Self::O3Mini => "o3-mini",
103 Self::O1 => "o1",
104 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
105 Self::Claude3_7Sonnet => "Claude 3.7 Sonnet",
106 Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking",
107 Self::Gemini20Flash => "Gemini 2.0 Flash",
108 }
109 }
110
111 pub fn max_token_count(&self) -> usize {
112 match self {
113 Self::Gpt4o => 64_000,
114 Self::Gpt4 => 32_768,
115 Self::Gpt3_5Turbo => 12_288,
116 Self::O3Mini => 64_000,
117 Self::O1 => 20_000,
118 Self::Claude3_5Sonnet => 200_000,
119 Self::Claude3_7Sonnet => 90_000,
120 Self::Claude3_7SonnetThinking => 90_000,
121 Model::Gemini20Flash => 128_000,
122 }
123 }
124}
125
126#[derive(Serialize, Deserialize)]
127pub struct Request {
128 pub intent: bool,
129 pub n: usize,
130 pub stream: bool,
131 pub temperature: f32,
132 pub model: Model,
133 pub messages: Vec<ChatMessage>,
134}
135
136impl Request {
137 pub fn new(model: Model, messages: Vec<ChatMessage>) -> Self {
138 Self {
139 intent: true,
140 n: 1,
141 stream: model.uses_streaming(),
142 temperature: 0.1,
143 model,
144 messages,
145 }
146 }
147}
148
149#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
150pub struct ChatMessage {
151 pub role: Role,
152 pub content: String,
153}
154
155#[derive(Deserialize, Debug)]
156#[serde(tag = "type", rename_all = "snake_case")]
157pub struct ResponseEvent {
158 pub choices: Vec<ResponseChoice>,
159 pub created: u64,
160 pub id: String,
161}
162
163#[derive(Debug, Deserialize)]
164pub struct ResponseChoice {
165 pub index: usize,
166 pub finish_reason: Option<String>,
167 pub delta: Option<ResponseDelta>,
168 pub message: Option<ResponseDelta>,
169}
170
171#[derive(Debug, Deserialize)]
172pub struct ResponseDelta {
173 pub content: Option<String>,
174 pub role: Option<Role>,
175}
176
177#[derive(Deserialize)]
178struct ApiTokenResponse {
179 token: String,
180 expires_at: i64,
181}
182
183#[derive(Clone)]
184struct ApiToken {
185 api_key: String,
186 expires_at: DateTime<chrono::Utc>,
187}
188
189impl ApiToken {
190 pub fn remaining_seconds(&self) -> i64 {
191 self.expires_at
192 .timestamp()
193 .saturating_sub(chrono::Utc::now().timestamp())
194 }
195}
196
197impl TryFrom<ApiTokenResponse> for ApiToken {
198 type Error = anyhow::Error;
199
200 fn try_from(response: ApiTokenResponse) -> Result<Self, Self::Error> {
201 let expires_at = DateTime::from_timestamp(response.expires_at, 0)
202 .ok_or_else(|| anyhow!("invalid expires_at"))?;
203
204 Ok(Self {
205 api_key: response.token,
206 expires_at,
207 })
208 }
209}
210
211struct GlobalCopilotChat(gpui::Entity<CopilotChat>);
212
213impl Global for GlobalCopilotChat {}
214
215pub struct CopilotChat {
216 oauth_token: Option<String>,
217 api_token: Option<ApiToken>,
218 client: Arc<dyn HttpClient>,
219}
220
221pub fn init(fs: Arc<dyn Fs>, client: Arc<dyn HttpClient>, cx: &mut App) {
222 let copilot_chat = cx.new(|cx| CopilotChat::new(fs, client, cx));
223 cx.set_global(GlobalCopilotChat(copilot_chat));
224}
225
226pub fn copilot_chat_config_dir() -> &'static PathBuf {
227 static COPILOT_CHAT_CONFIG_DIR: OnceLock<PathBuf> = OnceLock::new();
228
229 COPILOT_CHAT_CONFIG_DIR.get_or_init(|| {
230 if cfg!(target_os = "windows") {
231 home_dir().join("AppData").join("Local")
232 } else {
233 home_dir().join(".config")
234 }
235 .join("github-copilot")
236 })
237}
238
239fn copilot_chat_config_paths() -> [PathBuf; 2] {
240 let base_dir = copilot_chat_config_dir();
241 [base_dir.join("hosts.json"), base_dir.join("apps.json")]
242}
243
244impl CopilotChat {
245 pub fn global(cx: &App) -> Option<gpui::Entity<Self>> {
246 cx.try_global::<GlobalCopilotChat>()
247 .map(|model| model.0.clone())
248 }
249
250 pub fn new(fs: Arc<dyn Fs>, client: Arc<dyn HttpClient>, cx: &App) -> Self {
251 let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
252 let dir_path = copilot_chat_config_dir();
253
254 cx.spawn(async move |cx| {
255 let mut parent_watch_rx = watch_config_dir(
256 cx.background_executor(),
257 fs.clone(),
258 dir_path.clone(),
259 config_paths,
260 );
261 while let Some(contents) = parent_watch_rx.next().await {
262 let oauth_token = extract_oauth_token(contents);
263 cx.update(|cx| {
264 if let Some(this) = Self::global(cx).as_ref() {
265 this.update(cx, |this, cx| {
266 this.oauth_token = oauth_token;
267 cx.notify();
268 });
269 }
270 })?;
271 }
272 anyhow::Ok(())
273 })
274 .detach_and_log_err(cx);
275
276 Self {
277 oauth_token: None,
278 api_token: None,
279 client,
280 }
281 }
282
283 pub fn is_authenticated(&self) -> bool {
284 self.oauth_token.is_some()
285 }
286
287 pub async fn stream_completion(
288 request: Request,
289 mut cx: AsyncApp,
290 ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
291 let Some(this) = cx.update(|cx| Self::global(cx)).ok().flatten() else {
292 return Err(anyhow!("Copilot chat is not enabled"));
293 };
294
295 let (oauth_token, api_token, client) = this.read_with(&cx, |this, _| {
296 (
297 this.oauth_token.clone(),
298 this.api_token.clone(),
299 this.client.clone(),
300 )
301 })?;
302
303 let oauth_token = oauth_token.ok_or_else(|| anyhow!("No OAuth token available"))?;
304
305 let token = match api_token {
306 Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token.clone(),
307 _ => {
308 let token = request_api_token(&oauth_token, client.clone()).await?;
309 this.update(&mut cx, |this, cx| {
310 this.api_token = Some(token.clone());
311 cx.notify();
312 })?;
313 token
314 }
315 };
316
317 stream_completion(client.clone(), token.api_key, request).await
318 }
319}
320
321async fn request_api_token(oauth_token: &str, client: Arc<dyn HttpClient>) -> Result<ApiToken> {
322 let request_builder = HttpRequest::builder()
323 .method(Method::GET)
324 .uri(COPILOT_CHAT_AUTH_URL)
325 .header("Authorization", format!("token {}", oauth_token))
326 .header("Accept", "application/json");
327
328 let request = request_builder.body(AsyncBody::empty())?;
329
330 let mut response = client.send(request).await?;
331
332 if response.status().is_success() {
333 let mut body = Vec::new();
334 response.body_mut().read_to_end(&mut body).await?;
335
336 let body_str = std::str::from_utf8(&body)?;
337
338 let parsed: ApiTokenResponse = serde_json::from_str(body_str)?;
339 ApiToken::try_from(parsed)
340 } else {
341 let mut body = Vec::new();
342 response.body_mut().read_to_end(&mut body).await?;
343
344 let body_str = std::str::from_utf8(&body)?;
345
346 Err(anyhow!("Failed to request API token: {}", body_str))
347 }
348}
349
350fn extract_oauth_token(contents: String) -> Option<String> {
351 serde_json::from_str::<serde_json::Value>(&contents)
352 .map(|v| {
353 v.as_object().and_then(|obj| {
354 obj.iter().find_map(|(key, value)| {
355 if key.starts_with("github.com") {
356 value["oauth_token"].as_str().map(|v| v.to_string())
357 } else {
358 None
359 }
360 })
361 })
362 })
363 .ok()
364 .flatten()
365}
366
367async fn stream_completion(
368 client: Arc<dyn HttpClient>,
369 api_key: String,
370 request: Request,
371) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
372 let request_builder = HttpRequest::builder()
373 .method(Method::POST)
374 .uri(COPILOT_CHAT_COMPLETION_URL)
375 .header(
376 "Editor-Version",
377 format!(
378 "Zed/{}",
379 option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
380 ),
381 )
382 .header("Authorization", format!("Bearer {}", api_key))
383 .header("Content-Type", "application/json")
384 .header("Copilot-Integration-Id", "vscode-chat");
385
386 let is_streaming = request.stream;
387
388 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
389 let mut response = client.send(request).await?;
390
391 if !response.status().is_success() {
392 let mut body = Vec::new();
393 response.body_mut().read_to_end(&mut body).await?;
394 let body_str = std::str::from_utf8(&body)?;
395 return Err(anyhow!(
396 "Failed to connect to API: {} {}",
397 response.status(),
398 body_str
399 ));
400 }
401
402 if is_streaming {
403 let reader = BufReader::new(response.into_body());
404 Ok(reader
405 .lines()
406 .filter_map(|line| async move {
407 match line {
408 Ok(line) => {
409 let line = line.strip_prefix("data: ")?;
410 if line.starts_with("[DONE]") {
411 return None;
412 }
413
414 match serde_json::from_str::<ResponseEvent>(line) {
415 Ok(response) => {
416 if response.choices.is_empty()
417 || response.choices.first().unwrap().finish_reason.is_some()
418 {
419 None
420 } else {
421 Some(Ok(response))
422 }
423 }
424 Err(error) => Some(Err(anyhow!(error))),
425 }
426 }
427 Err(error) => Some(Err(anyhow!(error))),
428 }
429 })
430 .boxed())
431 } else {
432 let mut body = Vec::new();
433 response.body_mut().read_to_end(&mut body).await?;
434 let body_str = std::str::from_utf8(&body)?;
435 let response: ResponseEvent = serde_json::from_str(body_str)?;
436
437 Ok(futures::stream::once(async move { Ok(response) }).boxed())
438 }
439}