1mod request;
2
3use anyhow::{anyhow, Result};
4use async_compression::futures::bufread::GzipDecoder;
5use client::Client;
6use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task};
7use language::{Buffer, ToPointUtf16};
8use lsp::LanguageServer;
9use smol::{fs, io::BufReader, stream::StreamExt};
10use std::{
11 env::consts,
12 path::{Path, PathBuf},
13 sync::Arc,
14};
15use util::{
16 fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
17};
18
19actions!(copilot, [SignIn, SignOut]);
20
21pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
22 let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), cx));
23 cx.set_global(copilot);
24 cx.add_global_action(|_: &SignIn, cx: &mut MutableAppContext| {
25 if let Some(copilot) = Copilot::global(cx) {
26 copilot
27 .update(cx, |copilot, cx| copilot.sign_in(cx))
28 .detach_and_log_err(cx);
29 }
30 });
31 cx.add_global_action(|_: &SignOut, cx: &mut MutableAppContext| {
32 if let Some(copilot) = Copilot::global(cx) {
33 copilot
34 .update(cx, |copilot, cx| copilot.sign_out(cx))
35 .detach_and_log_err(cx);
36 }
37 });
38}
39
40enum CopilotServer {
41 Downloading,
42 Error(Arc<str>),
43 Started {
44 server: Arc<LanguageServer>,
45 status: SignInStatus,
46 },
47}
48
49#[derive(Clone, Debug, PartialEq, Eq)]
50enum SignInStatus {
51 Authorized { user: String },
52 Unauthorized { user: String },
53 SignedOut,
54}
55
56pub enum Event {
57 PromptUserDeviceFlow {
58 user_code: String,
59 verification_uri: String,
60 },
61}
62
63#[derive(Debug)]
64pub enum Status {
65 Downloading,
66 Error(Arc<str>),
67 SignedOut,
68 Unauthorized,
69 Authorized,
70}
71
72impl Status {
73 fn is_authorized(&self) -> bool {
74 matches!(self, Status::Authorized)
75 }
76}
77
78struct Copilot {
79 server: CopilotServer,
80}
81
82impl Entity for Copilot {
83 type Event = Event;
84}
85
86impl Copilot {
87 fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
88 if cx.has_global::<ModelHandle<Self>>() {
89 let copilot = cx.global::<ModelHandle<Self>>().clone();
90 if copilot.read(cx).status().is_authorized() {
91 Some(copilot)
92 } else {
93 None
94 }
95 } else {
96 None
97 }
98 }
99
100 fn start(http: Arc<dyn HttpClient>, cx: &mut ModelContext<Self>) -> Self {
101 cx.spawn(|this, mut cx| async move {
102 let start_language_server = async {
103 let server_path = get_lsp_binary(http).await?;
104 let server =
105 LanguageServer::new(0, &server_path, &["--stdio"], Path::new("/"), cx.clone())?;
106 let server = server.initialize(Default::default()).await?;
107 let status = server
108 .request::<request::CheckStatus>(request::CheckStatusParams {
109 local_checks_only: false,
110 })
111 .await?;
112 anyhow::Ok((server, status))
113 };
114
115 let server = start_language_server.await;
116 this.update(&mut cx, |this, cx| {
117 cx.notify();
118 match server {
119 Ok((server, status)) => {
120 this.server = CopilotServer::Started {
121 server,
122 status: SignInStatus::SignedOut,
123 };
124 this.update_sign_in_status(status, cx);
125 }
126 Err(error) => {
127 this.server = CopilotServer::Error(error.to_string().into());
128 }
129 }
130 })
131 })
132 .detach();
133 Self {
134 server: CopilotServer::Downloading,
135 }
136 }
137
138 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
139 if let CopilotServer::Started { server, .. } = &self.server {
140 let server = server.clone();
141 cx.spawn(|this, mut cx| async move {
142 let sign_in = server
143 .request::<request::SignInInitiate>(request::SignInInitiateParams {})
144 .await?;
145 if let request::SignInInitiateResult::PromptUserDeviceFlow(flow) = sign_in {
146 this.update(&mut cx, |_, cx| {
147 cx.emit(Event::PromptUserDeviceFlow {
148 user_code: flow.user_code.clone(),
149 verification_uri: flow.verification_uri,
150 });
151 });
152 let response = server
153 .request::<request::SignInConfirm>(request::SignInConfirmParams {
154 user_code: flow.user_code,
155 })
156 .await?;
157 this.update(&mut cx, |this, cx| this.update_sign_in_status(response, cx));
158 }
159 anyhow::Ok(())
160 })
161 } else {
162 Task::ready(Err(anyhow!("copilot hasn't started yet")))
163 }
164 }
165
166 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
167 if let CopilotServer::Started { server, .. } = &self.server {
168 let server = server.clone();
169 cx.spawn(|this, mut cx| async move {
170 server
171 .request::<request::SignOut>(request::SignOutParams {})
172 .await?;
173 this.update(&mut cx, |this, cx| {
174 if let CopilotServer::Started { status, .. } = &mut this.server {
175 *status = SignInStatus::SignedOut;
176 cx.notify();
177 }
178 });
179
180 anyhow::Ok(())
181 })
182 } else {
183 Task::ready(Err(anyhow!("copilot hasn't started yet")))
184 }
185 }
186
187 pub fn completions<T>(
188 &self,
189 buffer: &ModelHandle<Buffer>,
190 position: T,
191 cx: &mut ModelContext<Self>,
192 ) -> Task<Result<()>>
193 where
194 T: ToPointUtf16,
195 {
196 let server = match self.authenticated_server() {
197 Ok(server) => server,
198 Err(error) => return Task::ready(Err(error)),
199 };
200
201 cx.spawn(|this, cx| async move { anyhow::Ok(()) })
202 }
203
204 pub fn status(&self) -> Status {
205 match &self.server {
206 CopilotServer::Downloading => Status::Downloading,
207 CopilotServer::Error(error) => Status::Error(error.clone()),
208 CopilotServer::Started { status, .. } => match status {
209 SignInStatus::Authorized { .. } => Status::Authorized,
210 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
211 SignInStatus::SignedOut => Status::SignedOut,
212 },
213 }
214 }
215
216 fn update_sign_in_status(
217 &mut self,
218 lsp_status: request::SignInStatus,
219 cx: &mut ModelContext<Self>,
220 ) {
221 if let CopilotServer::Started { status, .. } = &mut self.server {
222 *status = match lsp_status {
223 request::SignInStatus::Ok { user } | request::SignInStatus::MaybeOk { user } => {
224 SignInStatus::Authorized { user }
225 }
226 request::SignInStatus::NotAuthorized { user } => {
227 SignInStatus::Unauthorized { user }
228 }
229 _ => SignInStatus::SignedOut,
230 };
231 cx.notify();
232 }
233 }
234
235 fn authenticated_server(&self) -> Result<Arc<LanguageServer>> {
236 match &self.server {
237 CopilotServer::Downloading => Err(anyhow!("copilot is still downloading")),
238 CopilotServer::Error(error) => Err(anyhow!(
239 "copilot was not started because of an error: {}",
240 error
241 )),
242 CopilotServer::Started { server, status } => {
243 if matches!(status, SignInStatus::Authorized { .. }) {
244 Ok(server.clone())
245 } else {
246 Err(anyhow!("must sign in before using copilot"))
247 }
248 }
249 }
250 }
251}
252
253async fn get_lsp_binary(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
254 ///Check for the latest copilot language server and download it if we haven't already
255 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
256 let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
257 let asset_name = format!("copilot-darwin-{}.gz", consts::ARCH);
258 let asset = release
259 .assets
260 .iter()
261 .find(|asset| asset.name == asset_name)
262 .ok_or_else(|| anyhow!("no asset found matching {:?}", asset_name))?;
263
264 fs::create_dir_all(&*paths::COPILOT_DIR).await?;
265 let destination_path =
266 paths::COPILOT_DIR.join(format!("copilot-{}-{}", release.name, consts::ARCH));
267
268 if fs::metadata(&destination_path).await.is_err() {
269 let mut response = http
270 .get(&asset.browser_download_url, Default::default(), true)
271 .await
272 .map_err(|err| anyhow!("error downloading release: {}", err))?;
273 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
274 let mut file = fs::File::create(&destination_path).await?;
275 futures::io::copy(decompressed_bytes, &mut file).await?;
276 fs::set_permissions(
277 &destination_path,
278 <fs::Permissions as fs::unix::PermissionsExt>::from_mode(0o755),
279 )
280 .await?;
281
282 remove_matching(&paths::COPILOT_DIR, |entry| entry != destination_path).await;
283 }
284
285 Ok(destination_path)
286 }
287
288 match fetch_latest(http).await {
289 ok @ Result::Ok(..) => ok,
290 e @ Err(..) => {
291 e.log_err();
292 // Fetch a cached binary, if it exists
293 (|| async move {
294 let mut last = None;
295 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
296 while let Some(entry) = entries.next().await {
297 last = Some(entry?.path());
298 }
299 last.ok_or_else(|| anyhow!("no cached binary"))
300 })()
301 .await
302 }
303 }
304}