copilot_chat.rs

  1use std::{sync::Arc, time::Duration};
  2
  3use anyhow::{anyhow, Result};
  4use chrono::DateTime;
  5use fs::Fs;
  6use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
  7use gpui::{AppContext, AsyncAppContext, Global};
  8use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  9use isahc::config::Configurable;
 10use serde::{Deserialize, Serialize};
 11use settings::watch_config_file;
 12use strum::EnumIter;
 13use ui::Context;
 14
 15pub const COPILOT_CHAT_COMPLETION_URL: &'static str =
 16    "https://api.githubcopilot.com/chat/completions";
 17pub const COPILOT_CHAT_AUTH_URL: &'static str = "https://api.github.com/copilot_internal/v2/token";
 18
 19#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 20#[serde(rename_all = "lowercase")]
 21pub enum Role {
 22    User,
 23    Assistant,
 24    System,
 25}
 26
 27#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 28#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 29pub enum Model {
 30    #[default]
 31    #[serde(alias = "gpt-4", rename = "gpt-4")]
 32    Gpt4,
 33    #[serde(alias = "gpt-3.5-turbo", rename = "gpt-3.5-turbo")]
 34    Gpt3_5Turbo,
 35}
 36
 37impl Model {
 38    pub fn from_id(id: &str) -> Result<Self> {
 39        match id {
 40            "gpt-4" => Ok(Self::Gpt4),
 41            "gpt-3.5-turbo" => Ok(Self::Gpt3_5Turbo),
 42            _ => Err(anyhow!("Invalid model id: {}", id)),
 43        }
 44    }
 45
 46    pub fn id(&self) -> &'static str {
 47        match self {
 48            Self::Gpt3_5Turbo => "gpt-3.5-turbo",
 49            Self::Gpt4 => "gpt-4",
 50        }
 51    }
 52
 53    pub fn display_name(&self) -> &'static str {
 54        match self {
 55            Self::Gpt3_5Turbo => "GPT-3.5",
 56            Self::Gpt4 => "GPT-4",
 57        }
 58    }
 59
 60    pub fn max_token_count(&self) -> usize {
 61        match self {
 62            Self::Gpt4 => 8192,
 63            Self::Gpt3_5Turbo => 16385,
 64        }
 65    }
 66}
 67
 68#[derive(Serialize, Deserialize)]
 69pub struct Request {
 70    pub intent: bool,
 71    pub n: usize,
 72    pub stream: bool,
 73    pub temperature: f32,
 74    pub model: Model,
 75    pub messages: Vec<ChatMessage>,
 76}
 77
 78impl Request {
 79    pub fn new(model: Model, messages: Vec<ChatMessage>) -> Self {
 80        Self {
 81            intent: true,
 82            n: 1,
 83            stream: true,
 84            temperature: 0.1,
 85            model,
 86            messages,
 87        }
 88    }
 89}
 90
 91#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 92pub struct ChatMessage {
 93    pub role: Role,
 94    pub content: String,
 95}
 96
 97#[derive(Deserialize, Debug)]
 98#[serde(tag = "type", rename_all = "snake_case")]
 99pub struct ResponseEvent {
100    pub choices: Vec<ResponseChoice>,
101    pub created: u64,
102    pub id: String,
103}
104
105#[derive(Debug, Deserialize)]
106pub struct ResponseChoice {
107    pub index: usize,
108    pub finish_reason: Option<String>,
109    pub delta: ResponseDelta,
110}
111
112#[derive(Debug, Deserialize)]
113pub struct ResponseDelta {
114    pub content: Option<String>,
115    pub role: Option<Role>,
116}
117
118#[derive(Deserialize)]
119struct ApiTokenResponse {
120    token: String,
121    expires_at: i64,
122}
123
124#[derive(Clone)]
125struct ApiToken {
126    api_key: String,
127    expires_at: DateTime<chrono::Utc>,
128}
129
130impl ApiToken {
131    pub fn remaining_seconds(&self) -> i64 {
132        self.expires_at
133            .timestamp()
134            .saturating_sub(chrono::Utc::now().timestamp())
135    }
136}
137
138impl TryFrom<ApiTokenResponse> for ApiToken {
139    type Error = anyhow::Error;
140
141    fn try_from(response: ApiTokenResponse) -> Result<Self, Self::Error> {
142        let expires_at = DateTime::from_timestamp(response.expires_at, 0)
143            .ok_or_else(|| anyhow!("invalid expires_at"))?;
144
145        Ok(Self {
146            api_key: response.token,
147            expires_at,
148        })
149    }
150}
151
152struct GlobalCopilotChat(gpui::Model<CopilotChat>);
153
154impl Global for GlobalCopilotChat {}
155
156pub struct CopilotChat {
157    oauth_token: Option<String>,
158    api_token: Option<ApiToken>,
159    client: Arc<dyn HttpClient>,
160}
161
162pub fn init(fs: Arc<dyn Fs>, client: Arc<dyn HttpClient>, cx: &mut AppContext) {
163    let copilot_chat = cx.new_model(|cx| CopilotChat::new(fs, client, cx));
164    cx.set_global(GlobalCopilotChat(copilot_chat));
165}
166
167impl CopilotChat {
168    pub fn global(cx: &AppContext) -> Option<gpui::Model<Self>> {
169        cx.try_global::<GlobalCopilotChat>()
170            .map(|model| model.0.clone())
171    }
172
173    pub fn new(fs: Arc<dyn Fs>, client: Arc<dyn HttpClient>, cx: &AppContext) -> Self {
174        let mut config_file_rx = watch_config_file(
175            cx.background_executor(),
176            fs,
177            paths::copilot_chat_config_path().clone(),
178        );
179
180        cx.spawn(|cx| async move {
181            while let Some(contents) = config_file_rx.next().await {
182                let oauth_token = extract_oauth_token(contents);
183
184                cx.update(|cx| {
185                    if let Some(this) = Self::global(cx).as_ref() {
186                        this.update(cx, |this, cx| {
187                            this.oauth_token = oauth_token;
188                            cx.notify();
189                        });
190                    }
191                })?;
192            }
193            anyhow::Ok(())
194        })
195        .detach_and_log_err(cx);
196
197        Self {
198            oauth_token: None,
199            api_token: None,
200            client,
201        }
202    }
203
204    pub fn is_authenticated(&self) -> bool {
205        self.oauth_token.is_some()
206    }
207
208    pub async fn stream_completion(
209        request: Request,
210        low_speed_timeout: Option<Duration>,
211        mut cx: AsyncAppContext,
212    ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
213        let Some(this) = cx.update(|cx| Self::global(cx)).ok().flatten() else {
214            return Err(anyhow!("Copilot chat is not enabled"));
215        };
216
217        let (oauth_token, api_token, client) = this.read_with(&cx, |this, _| {
218            (
219                this.oauth_token.clone(),
220                this.api_token.clone(),
221                this.client.clone(),
222            )
223        })?;
224
225        let oauth_token = oauth_token.ok_or_else(|| anyhow!("No OAuth token available"))?;
226
227        let token = match api_token {
228            Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token.clone(),
229            _ => {
230                let token =
231                    request_api_token(&oauth_token, client.clone(), low_speed_timeout).await?;
232                this.update(&mut cx, |this, cx| {
233                    this.api_token = Some(token.clone());
234                    cx.notify();
235                })?;
236                token
237            }
238        };
239
240        stream_completion(client.clone(), token.api_key, request, low_speed_timeout).await
241    }
242}
243
244async fn request_api_token(
245    oauth_token: &str,
246    client: Arc<dyn HttpClient>,
247    low_speed_timeout: Option<Duration>,
248) -> Result<ApiToken> {
249    let mut request_builder = HttpRequest::builder()
250        .method(Method::GET)
251        .uri(COPILOT_CHAT_AUTH_URL)
252        .header("Authorization", format!("token {}", oauth_token))
253        .header("Accept", "application/json");
254
255    if let Some(low_speed_timeout) = low_speed_timeout {
256        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
257    }
258
259    let request = request_builder.body(AsyncBody::empty())?;
260
261    let mut response = client.send(request).await?;
262
263    if response.status().is_success() {
264        let mut body = Vec::new();
265        response.body_mut().read_to_end(&mut body).await?;
266
267        let body_str = std::str::from_utf8(&body)?;
268
269        let parsed: ApiTokenResponse = serde_json::from_str(body_str)?;
270        ApiToken::try_from(parsed)
271    } else {
272        let mut body = Vec::new();
273        response.body_mut().read_to_end(&mut body).await?;
274
275        let body_str = std::str::from_utf8(&body)?;
276
277        Err(anyhow!("Failed to request API token: {}", body_str))
278    }
279}
280
281fn extract_oauth_token(contents: String) -> Option<String> {
282    serde_json::from_str::<serde_json::Value>(&contents)
283        .map(|v| {
284            v["github.com"]["oauth_token"]
285                .as_str()
286                .map(|v| v.to_string())
287        })
288        .ok()
289        .flatten()
290}
291
292async fn stream_completion(
293    client: Arc<dyn HttpClient>,
294    api_key: String,
295    request: Request,
296    low_speed_timeout: Option<Duration>,
297) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
298    let mut request_builder = HttpRequest::builder()
299        .method(Method::POST)
300        .uri(COPILOT_CHAT_COMPLETION_URL)
301        .header(
302            "Editor-Version",
303            format!(
304                "Zed/{}",
305                option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
306            ),
307        )
308        .header("Authorization", format!("Bearer {}", api_key))
309        .header("Content-Type", "application/json")
310        .header("Copilot-Integration-Id", "vscode-chat");
311
312    if let Some(low_speed_timeout) = low_speed_timeout {
313        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
314    }
315    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
316    let mut response = client.send(request).await?;
317    if response.status().is_success() {
318        let reader = BufReader::new(response.into_body());
319        Ok(reader
320            .lines()
321            .filter_map(|line| async move {
322                match line {
323                    Ok(line) => {
324                        let line = line.strip_prefix("data: ")?;
325                        if line.starts_with("[DONE]") {
326                            return None;
327                        }
328
329                        match serde_json::from_str::<ResponseEvent>(line) {
330                            Ok(response) => {
331                                if response.choices.first().is_none()
332                                    || response.choices.first().unwrap().finish_reason.is_some()
333                                {
334                                    None
335                                } else {
336                                    Some(Ok(response))
337                                }
338                            }
339                            Err(error) => Some(Err(anyhow!(error))),
340                        }
341                    }
342                    Err(error) => Some(Err(anyhow!(error))),
343                }
344            })
345            .boxed())
346    } else {
347        let mut body = Vec::new();
348        response.body_mut().read_to_end(&mut body).await?;
349
350        let body_str = std::str::from_utf8(&body)?;
351
352        match serde_json::from_str::<ResponseEvent>(body_str) {
353            Ok(_) => Err(anyhow!(
354                "Unexpected success response while expecting an error: {}",
355                body_str,
356            )),
357            Err(_) => Err(anyhow!(
358                "Failed to connect to API: {} {}",
359                response.status(),
360                body_str,
361            )),
362        }
363    }
364}