1use std::{sync::Arc, time::Duration};
2
3use anyhow::{anyhow, Result};
4use chrono::DateTime;
5use fs::Fs;
6use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
7use gpui::{AppContext, AsyncAppContext, Global};
8use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
9use isahc::config::Configurable;
10use serde::{Deserialize, Serialize};
11use settings::watch_config_file;
12use strum::EnumIter;
13use ui::Context;
14
15pub const COPILOT_CHAT_COMPLETION_URL: &'static str =
16 "https://api.githubcopilot.com/chat/completions";
17pub const COPILOT_CHAT_AUTH_URL: &'static str = "https://api.github.com/copilot_internal/v2/token";
18
19#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
20#[serde(rename_all = "lowercase")]
21pub enum Role {
22 User,
23 Assistant,
24 System,
25}
26
27#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
28#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
29pub enum Model {
30 #[default]
31 #[serde(alias = "gpt-4", rename = "gpt-4")]
32 Gpt4,
33 #[serde(alias = "gpt-3.5-turbo", rename = "gpt-3.5-turbo")]
34 Gpt3_5Turbo,
35}
36
37impl Model {
38 pub fn from_id(id: &str) -> Result<Self> {
39 match id {
40 "gpt-4" => Ok(Self::Gpt4),
41 "gpt-3.5-turbo" => Ok(Self::Gpt3_5Turbo),
42 _ => Err(anyhow!("Invalid model id: {}", id)),
43 }
44 }
45
46 pub fn id(&self) -> &'static str {
47 match self {
48 Self::Gpt3_5Turbo => "gpt-3.5-turbo",
49 Self::Gpt4 => "gpt-4",
50 }
51 }
52
53 pub fn display_name(&self) -> &'static str {
54 match self {
55 Self::Gpt3_5Turbo => "GPT-3.5",
56 Self::Gpt4 => "GPT-4",
57 }
58 }
59
60 pub fn max_token_count(&self) -> usize {
61 match self {
62 Self::Gpt4 => 8192,
63 Self::Gpt3_5Turbo => 16385,
64 }
65 }
66}
67
68#[derive(Serialize, Deserialize)]
69pub struct Request {
70 pub intent: bool,
71 pub n: usize,
72 pub stream: bool,
73 pub temperature: f32,
74 pub model: Model,
75 pub messages: Vec<ChatMessage>,
76}
77
78impl Request {
79 pub fn new(model: Model, messages: Vec<ChatMessage>) -> Self {
80 Self {
81 intent: true,
82 n: 1,
83 stream: true,
84 temperature: 0.1,
85 model,
86 messages,
87 }
88 }
89}
90
91#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
92pub struct ChatMessage {
93 pub role: Role,
94 pub content: String,
95}
96
97#[derive(Deserialize, Debug)]
98#[serde(tag = "type", rename_all = "snake_case")]
99pub struct ResponseEvent {
100 pub choices: Vec<ResponseChoice>,
101 pub created: u64,
102 pub id: String,
103}
104
105#[derive(Debug, Deserialize)]
106pub struct ResponseChoice {
107 pub index: usize,
108 pub finish_reason: Option<String>,
109 pub delta: ResponseDelta,
110}
111
112#[derive(Debug, Deserialize)]
113pub struct ResponseDelta {
114 pub content: Option<String>,
115 pub role: Option<Role>,
116}
117
118#[derive(Deserialize)]
119struct ApiTokenResponse {
120 token: String,
121 expires_at: i64,
122}
123
124#[derive(Clone)]
125struct ApiToken {
126 api_key: String,
127 expires_at: DateTime<chrono::Utc>,
128}
129
130impl ApiToken {
131 pub fn remaining_seconds(&self) -> i64 {
132 self.expires_at
133 .timestamp()
134 .saturating_sub(chrono::Utc::now().timestamp())
135 }
136}
137
138impl TryFrom<ApiTokenResponse> for ApiToken {
139 type Error = anyhow::Error;
140
141 fn try_from(response: ApiTokenResponse) -> Result<Self, Self::Error> {
142 let expires_at = DateTime::from_timestamp(response.expires_at, 0)
143 .ok_or_else(|| anyhow!("invalid expires_at"))?;
144
145 Ok(Self {
146 api_key: response.token,
147 expires_at,
148 })
149 }
150}
151
152struct GlobalCopilotChat(gpui::Model<CopilotChat>);
153
154impl Global for GlobalCopilotChat {}
155
156pub struct CopilotChat {
157 oauth_token: Option<String>,
158 api_token: Option<ApiToken>,
159 client: Arc<dyn HttpClient>,
160}
161
162pub fn init(fs: Arc<dyn Fs>, client: Arc<dyn HttpClient>, cx: &mut AppContext) {
163 let copilot_chat = cx.new_model(|cx| CopilotChat::new(fs, client, cx));
164 cx.set_global(GlobalCopilotChat(copilot_chat));
165}
166
167impl CopilotChat {
168 pub fn global(cx: &AppContext) -> Option<gpui::Model<Self>> {
169 cx.try_global::<GlobalCopilotChat>()
170 .map(|model| model.0.clone())
171 }
172
173 pub fn new(fs: Arc<dyn Fs>, client: Arc<dyn HttpClient>, cx: &AppContext) -> Self {
174 let mut config_file_rx = watch_config_file(
175 cx.background_executor(),
176 fs,
177 paths::copilot_chat_config_path().clone(),
178 );
179
180 cx.spawn(|cx| async move {
181 while let Some(contents) = config_file_rx.next().await {
182 let oauth_token = extract_oauth_token(contents);
183
184 cx.update(|cx| {
185 if let Some(this) = Self::global(cx).as_ref() {
186 this.update(cx, |this, cx| {
187 this.oauth_token = oauth_token;
188 cx.notify();
189 });
190 }
191 })?;
192 }
193 anyhow::Ok(())
194 })
195 .detach_and_log_err(cx);
196
197 Self {
198 oauth_token: None,
199 api_token: None,
200 client,
201 }
202 }
203
204 pub fn is_authenticated(&self) -> bool {
205 self.oauth_token.is_some()
206 }
207
208 pub async fn stream_completion(
209 request: Request,
210 low_speed_timeout: Option<Duration>,
211 mut cx: AsyncAppContext,
212 ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
213 let Some(this) = cx.update(|cx| Self::global(cx)).ok().flatten() else {
214 return Err(anyhow!("Copilot chat is not enabled"));
215 };
216
217 let (oauth_token, api_token, client) = this.read_with(&cx, |this, _| {
218 (
219 this.oauth_token.clone(),
220 this.api_token.clone(),
221 this.client.clone(),
222 )
223 })?;
224
225 let oauth_token = oauth_token.ok_or_else(|| anyhow!("No OAuth token available"))?;
226
227 let token = match api_token {
228 Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token.clone(),
229 _ => {
230 let token =
231 request_api_token(&oauth_token, client.clone(), low_speed_timeout).await?;
232 this.update(&mut cx, |this, cx| {
233 this.api_token = Some(token.clone());
234 cx.notify();
235 })?;
236 token
237 }
238 };
239
240 stream_completion(client.clone(), token.api_key, request, low_speed_timeout).await
241 }
242}
243
244async fn request_api_token(
245 oauth_token: &str,
246 client: Arc<dyn HttpClient>,
247 low_speed_timeout: Option<Duration>,
248) -> Result<ApiToken> {
249 let mut request_builder = HttpRequest::builder()
250 .method(Method::GET)
251 .uri(COPILOT_CHAT_AUTH_URL)
252 .header("Authorization", format!("token {}", oauth_token))
253 .header("Accept", "application/json");
254
255 if let Some(low_speed_timeout) = low_speed_timeout {
256 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
257 }
258
259 let request = request_builder.body(AsyncBody::empty())?;
260
261 let mut response = client.send(request).await?;
262
263 if response.status().is_success() {
264 let mut body = Vec::new();
265 response.body_mut().read_to_end(&mut body).await?;
266
267 let body_str = std::str::from_utf8(&body)?;
268
269 let parsed: ApiTokenResponse = serde_json::from_str(body_str)?;
270 ApiToken::try_from(parsed)
271 } else {
272 let mut body = Vec::new();
273 response.body_mut().read_to_end(&mut body).await?;
274
275 let body_str = std::str::from_utf8(&body)?;
276
277 Err(anyhow!("Failed to request API token: {}", body_str))
278 }
279}
280
281fn extract_oauth_token(contents: String) -> Option<String> {
282 serde_json::from_str::<serde_json::Value>(&contents)
283 .map(|v| {
284 v["github.com"]["oauth_token"]
285 .as_str()
286 .map(|v| v.to_string())
287 })
288 .ok()
289 .flatten()
290}
291
292async fn stream_completion(
293 client: Arc<dyn HttpClient>,
294 api_key: String,
295 request: Request,
296 low_speed_timeout: Option<Duration>,
297) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
298 let mut request_builder = HttpRequest::builder()
299 .method(Method::POST)
300 .uri(COPILOT_CHAT_COMPLETION_URL)
301 .header(
302 "Editor-Version",
303 format!(
304 "Zed/{}",
305 option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
306 ),
307 )
308 .header("Authorization", format!("Bearer {}", api_key))
309 .header("Content-Type", "application/json")
310 .header("Copilot-Integration-Id", "vscode-chat");
311
312 if let Some(low_speed_timeout) = low_speed_timeout {
313 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
314 }
315 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
316 let mut response = client.send(request).await?;
317 if response.status().is_success() {
318 let reader = BufReader::new(response.into_body());
319 Ok(reader
320 .lines()
321 .filter_map(|line| async move {
322 match line {
323 Ok(line) => {
324 let line = line.strip_prefix("data: ")?;
325 if line.starts_with("[DONE]") {
326 return None;
327 }
328
329 match serde_json::from_str::<ResponseEvent>(line) {
330 Ok(response) => {
331 if response.choices.first().is_none()
332 || response.choices.first().unwrap().finish_reason.is_some()
333 {
334 None
335 } else {
336 Some(Ok(response))
337 }
338 }
339 Err(error) => Some(Err(anyhow!(error))),
340 }
341 }
342 Err(error) => Some(Err(anyhow!(error))),
343 }
344 })
345 .boxed())
346 } else {
347 let mut body = Vec::new();
348 response.body_mut().read_to_end(&mut body).await?;
349
350 let body_str = std::str::from_utf8(&body)?;
351
352 match serde_json::from_str::<ResponseEvent>(body_str) {
353 Ok(_) => Err(anyhow!(
354 "Unexpected success response while expecting an error: {}",
355 body_str,
356 )),
357 Err(_) => Err(anyhow!(
358 "Failed to connect to API: {} {}",
359 response.status(),
360 body_str,
361 )),
362 }
363 }
364}