copilot_chat.rs

  1use std::path::PathBuf;
  2use std::sync::Arc;
  3use std::sync::OnceLock;
  4
  5use anyhow::{Result, anyhow};
  6use chrono::DateTime;
  7use collections::HashSet;
  8use fs::Fs;
  9use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
 10use gpui::{App, AsyncApp, Global, prelude::*};
 11use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 12use paths::home_dir;
 13use serde::{Deserialize, Serialize};
 14use settings::watch_config_dir;
 15use strum::EnumIter;
 16
 17pub const COPILOT_CHAT_COMPLETION_URL: &str = "https://api.githubcopilot.com/chat/completions";
 18pub const COPILOT_CHAT_AUTH_URL: &str = "https://api.github.com/copilot_internal/v2/token";
 19
 20#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 21#[serde(rename_all = "lowercase")]
 22pub enum Role {
 23    User,
 24    Assistant,
 25    System,
 26}
 27
 28#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 29#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 30pub enum Model {
 31    #[default]
 32    #[serde(alias = "gpt-4o", rename = "gpt-4o-2024-05-13")]
 33    Gpt4o,
 34    #[serde(alias = "gpt-4", rename = "gpt-4")]
 35    Gpt4,
 36    #[serde(alias = "gpt-4.1", rename = "gpt-4.1")]
 37    Gpt4_1,
 38    #[serde(alias = "gpt-3.5-turbo", rename = "gpt-3.5-turbo")]
 39    Gpt3_5Turbo,
 40    #[serde(alias = "o1", rename = "o1")]
 41    O1,
 42    #[serde(alias = "o1-mini", rename = "o3-mini")]
 43    O3Mini,
 44    #[serde(alias = "o3", rename = "o3")]
 45    O3,
 46    #[serde(alias = "o4-mini", rename = "o4-mini")]
 47    O4Mini,
 48    #[serde(alias = "claude-3-5-sonnet", rename = "claude-3.5-sonnet")]
 49    Claude3_5Sonnet,
 50    #[serde(alias = "claude-3-7-sonnet", rename = "claude-3.7-sonnet")]
 51    Claude3_7Sonnet,
 52    #[serde(
 53        alias = "claude-3.7-sonnet-thought",
 54        rename = "claude-3.7-sonnet-thought"
 55    )]
 56    Claude3_7SonnetThinking,
 57    #[serde(alias = "gemini-2.0-flash", rename = "gemini-2.0-flash-001")]
 58    Gemini20Flash,
 59    #[serde(alias = "gemini-2.5-pro", rename = "gemini-2.5-pro")]
 60    Gemini25Pro,
 61}
 62
 63impl Model {
 64    pub fn default_fast() -> Self {
 65        Self::Claude3_7Sonnet
 66    }
 67
 68    pub fn uses_streaming(&self) -> bool {
 69        match self {
 70            Self::Gpt4o
 71            | Self::Gpt4
 72            | Self::Gpt4_1
 73            | Self::Gpt3_5Turbo
 74            | Self::O3
 75            | Self::O4Mini
 76            | Self::Claude3_5Sonnet
 77            | Self::Claude3_7Sonnet
 78            | Self::Claude3_7SonnetThinking => true,
 79            Self::O3Mini | Self::O1 | Self::Gemini20Flash | Self::Gemini25Pro => false,
 80        }
 81    }
 82
 83    pub fn from_id(id: &str) -> Result<Self> {
 84        match id {
 85            "gpt-4o" => Ok(Self::Gpt4o),
 86            "gpt-4" => Ok(Self::Gpt4),
 87            "gpt-4.1" => Ok(Self::Gpt4_1),
 88            "gpt-3.5-turbo" => Ok(Self::Gpt3_5Turbo),
 89            "o1" => Ok(Self::O1),
 90            "o3-mini" => Ok(Self::O3Mini),
 91            "o3" => Ok(Self::O3),
 92            "o4-mini" => Ok(Self::O4Mini),
 93            "claude-3-5-sonnet" => Ok(Self::Claude3_5Sonnet),
 94            "claude-3-7-sonnet" => Ok(Self::Claude3_7Sonnet),
 95            "claude-3.7-sonnet-thought" => Ok(Self::Claude3_7SonnetThinking),
 96            "gemini-2.0-flash-001" => Ok(Self::Gemini20Flash),
 97            "gemini-2.5-pro" => Ok(Self::Gemini25Pro),
 98            _ => Err(anyhow!("Invalid model id: {}", id)),
 99        }
100    }
101
102    pub fn id(&self) -> &'static str {
103        match self {
104            Self::Gpt3_5Turbo => "gpt-3.5-turbo",
105            Self::Gpt4 => "gpt-4",
106            Self::Gpt4_1 => "gpt-4.1",
107            Self::Gpt4o => "gpt-4o",
108            Self::O3Mini => "o3-mini",
109            Self::O1 => "o1",
110            Self::O3 => "o3",
111            Self::O4Mini => "o4-mini",
112            Self::Claude3_5Sonnet => "claude-3-5-sonnet",
113            Self::Claude3_7Sonnet => "claude-3-7-sonnet",
114            Self::Claude3_7SonnetThinking => "claude-3.7-sonnet-thought",
115            Self::Gemini20Flash => "gemini-2.0-flash-001",
116            Self::Gemini25Pro => "gemini-2.5-pro",
117        }
118    }
119
120    pub fn display_name(&self) -> &'static str {
121        match self {
122            Self::Gpt3_5Turbo => "GPT-3.5",
123            Self::Gpt4 => "GPT-4",
124            Self::Gpt4_1 => "GPT-4.1",
125            Self::Gpt4o => "GPT-4o",
126            Self::O3Mini => "o3-mini",
127            Self::O1 => "o1",
128            Self::O3 => "o3",
129            Self::O4Mini => "o4-mini",
130            Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
131            Self::Claude3_7Sonnet => "Claude 3.7 Sonnet",
132            Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking",
133            Self::Gemini20Flash => "Gemini 2.0 Flash",
134            Self::Gemini25Pro => "Gemini 2.5 Pro",
135        }
136    }
137
138    pub fn max_token_count(&self) -> usize {
139        match self {
140            Self::Gpt4o => 64_000,
141            Self::Gpt4 => 32_768,
142            Self::Gpt4_1 => 128_000,
143            Self::Gpt3_5Turbo => 12_288,
144            Self::O3Mini => 64_000,
145            Self::O1 => 20_000,
146            Self::O3 => 128_000,
147            Self::O4Mini => 128_000,
148            Self::Claude3_5Sonnet => 200_000,
149            Self::Claude3_7Sonnet => 90_000,
150            Self::Claude3_7SonnetThinking => 90_000,
151            Self::Gemini20Flash => 128_000,
152            Self::Gemini25Pro => 128_000,
153        }
154    }
155}
156
157#[derive(Serialize, Deserialize)]
158pub struct Request {
159    pub intent: bool,
160    pub n: usize,
161    pub stream: bool,
162    pub temperature: f32,
163    pub model: Model,
164    pub messages: Vec<ChatMessage>,
165    #[serde(default, skip_serializing_if = "Vec::is_empty")]
166    pub tools: Vec<Tool>,
167    #[serde(default, skip_serializing_if = "Option::is_none")]
168    pub tool_choice: Option<ToolChoice>,
169}
170
171#[derive(Serialize, Deserialize)]
172pub struct Function {
173    pub name: String,
174    pub description: String,
175    pub parameters: serde_json::Value,
176}
177
178#[derive(Serialize, Deserialize)]
179#[serde(tag = "type", rename_all = "snake_case")]
180pub enum Tool {
181    Function { function: Function },
182}
183
184#[derive(Serialize, Deserialize)]
185#[serde(rename_all = "lowercase")]
186pub enum ToolChoice {
187    Auto,
188    Any,
189    None,
190}
191
192#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
193#[serde(tag = "role", rename_all = "lowercase")]
194pub enum ChatMessage {
195    Assistant {
196        content: Option<String>,
197        #[serde(default, skip_serializing_if = "Vec::is_empty")]
198        tool_calls: Vec<ToolCall>,
199    },
200    User {
201        content: String,
202    },
203    System {
204        content: String,
205    },
206    Tool {
207        content: String,
208        tool_call_id: String,
209    },
210}
211
212#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
213pub struct ToolCall {
214    pub id: String,
215    #[serde(flatten)]
216    pub content: ToolCallContent,
217}
218
219#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
220#[serde(tag = "type", rename_all = "lowercase")]
221pub enum ToolCallContent {
222    Function { function: FunctionContent },
223}
224
225#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
226pub struct FunctionContent {
227    pub name: String,
228    pub arguments: String,
229}
230
231#[derive(Deserialize, Debug)]
232#[serde(tag = "type", rename_all = "snake_case")]
233pub struct ResponseEvent {
234    pub choices: Vec<ResponseChoice>,
235    pub created: u64,
236    pub id: String,
237}
238
239#[derive(Debug, Deserialize)]
240pub struct ResponseChoice {
241    pub index: usize,
242    pub finish_reason: Option<String>,
243    pub delta: Option<ResponseDelta>,
244    pub message: Option<ResponseDelta>,
245}
246
247#[derive(Debug, Deserialize)]
248pub struct ResponseDelta {
249    pub content: Option<String>,
250    pub role: Option<Role>,
251    #[serde(default)]
252    pub tool_calls: Vec<ToolCallChunk>,
253}
254
255#[derive(Deserialize, Debug, Eq, PartialEq)]
256pub struct ToolCallChunk {
257    pub index: usize,
258    pub id: Option<String>,
259    pub function: Option<FunctionChunk>,
260}
261
262#[derive(Deserialize, Debug, Eq, PartialEq)]
263pub struct FunctionChunk {
264    pub name: Option<String>,
265    pub arguments: Option<String>,
266}
267
268#[derive(Deserialize)]
269struct ApiTokenResponse {
270    token: String,
271    expires_at: i64,
272}
273
274#[derive(Clone)]
275struct ApiToken {
276    api_key: String,
277    expires_at: DateTime<chrono::Utc>,
278}
279
280impl ApiToken {
281    pub fn remaining_seconds(&self) -> i64 {
282        self.expires_at
283            .timestamp()
284            .saturating_sub(chrono::Utc::now().timestamp())
285    }
286}
287
288impl TryFrom<ApiTokenResponse> for ApiToken {
289    type Error = anyhow::Error;
290
291    fn try_from(response: ApiTokenResponse) -> Result<Self, Self::Error> {
292        let expires_at = DateTime::from_timestamp(response.expires_at, 0)
293            .ok_or_else(|| anyhow!("invalid expires_at"))?;
294
295        Ok(Self {
296            api_key: response.token,
297            expires_at,
298        })
299    }
300}
301
302struct GlobalCopilotChat(gpui::Entity<CopilotChat>);
303
304impl Global for GlobalCopilotChat {}
305
306pub struct CopilotChat {
307    oauth_token: Option<String>,
308    api_token: Option<ApiToken>,
309    client: Arc<dyn HttpClient>,
310}
311
312pub fn init(fs: Arc<dyn Fs>, client: Arc<dyn HttpClient>, cx: &mut App) {
313    let copilot_chat = cx.new(|cx| CopilotChat::new(fs, client, cx));
314    cx.set_global(GlobalCopilotChat(copilot_chat));
315}
316
317pub fn copilot_chat_config_dir() -> &'static PathBuf {
318    static COPILOT_CHAT_CONFIG_DIR: OnceLock<PathBuf> = OnceLock::new();
319
320    COPILOT_CHAT_CONFIG_DIR.get_or_init(|| {
321        if cfg!(target_os = "windows") {
322            home_dir().join("AppData").join("Local")
323        } else {
324            home_dir().join(".config")
325        }
326        .join("github-copilot")
327    })
328}
329
330fn copilot_chat_config_paths() -> [PathBuf; 2] {
331    let base_dir = copilot_chat_config_dir();
332    [base_dir.join("hosts.json"), base_dir.join("apps.json")]
333}
334
335impl CopilotChat {
336    pub fn global(cx: &App) -> Option<gpui::Entity<Self>> {
337        cx.try_global::<GlobalCopilotChat>()
338            .map(|model| model.0.clone())
339    }
340
341    pub fn new(fs: Arc<dyn Fs>, client: Arc<dyn HttpClient>, cx: &App) -> Self {
342        let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
343        let dir_path = copilot_chat_config_dir();
344
345        cx.spawn(async move |cx| {
346            let mut parent_watch_rx = watch_config_dir(
347                cx.background_executor(),
348                fs.clone(),
349                dir_path.clone(),
350                config_paths,
351            );
352            while let Some(contents) = parent_watch_rx.next().await {
353                let oauth_token = extract_oauth_token(contents);
354                cx.update(|cx| {
355                    if let Some(this) = Self::global(cx).as_ref() {
356                        this.update(cx, |this, cx| {
357                            this.oauth_token = oauth_token;
358                            cx.notify();
359                        });
360                    }
361                })?;
362            }
363            anyhow::Ok(())
364        })
365        .detach_and_log_err(cx);
366
367        Self {
368            oauth_token: None,
369            api_token: None,
370            client,
371        }
372    }
373
374    pub fn is_authenticated(&self) -> bool {
375        self.oauth_token.is_some()
376    }
377
378    pub async fn stream_completion(
379        request: Request,
380        mut cx: AsyncApp,
381    ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
382        let Some(this) = cx.update(|cx| Self::global(cx)).ok().flatten() else {
383            return Err(anyhow!("Copilot chat is not enabled"));
384        };
385
386        let (oauth_token, api_token, client) = this.read_with(&cx, |this, _| {
387            (
388                this.oauth_token.clone(),
389                this.api_token.clone(),
390                this.client.clone(),
391            )
392        })?;
393
394        let oauth_token = oauth_token.ok_or_else(|| anyhow!("No OAuth token available"))?;
395
396        let token = match api_token {
397            Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token.clone(),
398            _ => {
399                let token = request_api_token(&oauth_token, client.clone()).await?;
400                this.update(&mut cx, |this, cx| {
401                    this.api_token = Some(token.clone());
402                    cx.notify();
403                })?;
404                token
405            }
406        };
407
408        stream_completion(client.clone(), token.api_key, request).await
409    }
410}
411
412async fn request_api_token(oauth_token: &str, client: Arc<dyn HttpClient>) -> Result<ApiToken> {
413    let request_builder = HttpRequest::builder()
414        .method(Method::GET)
415        .uri(COPILOT_CHAT_AUTH_URL)
416        .header("Authorization", format!("token {}", oauth_token))
417        .header("Accept", "application/json");
418
419    let request = request_builder.body(AsyncBody::empty())?;
420
421    let mut response = client.send(request).await?;
422
423    if response.status().is_success() {
424        let mut body = Vec::new();
425        response.body_mut().read_to_end(&mut body).await?;
426
427        let body_str = std::str::from_utf8(&body)?;
428
429        let parsed: ApiTokenResponse = serde_json::from_str(body_str)?;
430        ApiToken::try_from(parsed)
431    } else {
432        let mut body = Vec::new();
433        response.body_mut().read_to_end(&mut body).await?;
434
435        let body_str = std::str::from_utf8(&body)?;
436
437        Err(anyhow!("Failed to request API token: {}", body_str))
438    }
439}
440
441fn extract_oauth_token(contents: String) -> Option<String> {
442    serde_json::from_str::<serde_json::Value>(&contents)
443        .map(|v| {
444            v.as_object().and_then(|obj| {
445                obj.iter().find_map(|(key, value)| {
446                    if key.starts_with("github.com") {
447                        value["oauth_token"].as_str().map(|v| v.to_string())
448                    } else {
449                        None
450                    }
451                })
452            })
453        })
454        .ok()
455        .flatten()
456}
457
458async fn stream_completion(
459    client: Arc<dyn HttpClient>,
460    api_key: String,
461    request: Request,
462) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
463    let request_builder = HttpRequest::builder()
464        .method(Method::POST)
465        .uri(COPILOT_CHAT_COMPLETION_URL)
466        .header(
467            "Editor-Version",
468            format!(
469                "Zed/{}",
470                option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
471            ),
472        )
473        .header("Authorization", format!("Bearer {}", api_key))
474        .header("Content-Type", "application/json")
475        .header("Copilot-Integration-Id", "vscode-chat");
476
477    let is_streaming = request.stream;
478
479    let json = serde_json::to_string(&request)?;
480    let request = request_builder.body(AsyncBody::from(json))?;
481    let mut response = client.send(request).await?;
482
483    if !response.status().is_success() {
484        let mut body = Vec::new();
485        response.body_mut().read_to_end(&mut body).await?;
486        let body_str = std::str::from_utf8(&body)?;
487        return Err(anyhow!(
488            "Failed to connect to API: {} {}",
489            response.status(),
490            body_str
491        ));
492    }
493
494    if is_streaming {
495        let reader = BufReader::new(response.into_body());
496        Ok(reader
497            .lines()
498            .filter_map(|line| async move {
499                match line {
500                    Ok(line) => {
501                        let line = line.strip_prefix("data: ")?;
502                        if line.starts_with("[DONE]") {
503                            return None;
504                        }
505
506                        match serde_json::from_str::<ResponseEvent>(line) {
507                            Ok(response) => {
508                                if response.choices.is_empty() {
509                                    None
510                                } else {
511                                    Some(Ok(response))
512                                }
513                            }
514                            Err(error) => Some(Err(anyhow!(error))),
515                        }
516                    }
517                    Err(error) => Some(Err(anyhow!(error))),
518                }
519            })
520            .boxed())
521    } else {
522        let mut body = Vec::new();
523        response.body_mut().read_to_end(&mut body).await?;
524        let body_str = std::str::from_utf8(&body)?;
525        let response: ResponseEvent = serde_json::from_str(body_str)?;
526
527        Ok(futures::stream::once(async move { Ok(response) }).boxed())
528    }
529}