1use std::{
2 path::{Path, PathBuf},
3 sync::Arc,
4};
5
6use anyhow::Context as _;
7use collections::HashMap;
8use fs::Fs;
9use futures::StreamExt as _;
10use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
11use language::{
12 proto::{deserialize_anchor, serialize_anchor},
13 ContextProvider as _, LanguageToolchainStore, Location,
14};
15use rpc::{proto, AnyProtoClient, TypedEnvelope};
16use settings::{watch_config_file, SettingsLocation, TaskKind};
17use task::{TaskContext, TaskVariables, VariableName};
18use text::{BufferId, OffsetRangeExt};
19use util::ResultExt;
20
21use crate::{
22 buffer_store::BufferStore, worktree_store::WorktreeStore, BasicContextProvider, Inventory,
23 ProjectEnvironment,
24};
25
26#[allow(clippy::large_enum_variant)] // platform-dependent warning
27pub enum TaskStore {
28 Functional(StoreState),
29 Noop,
30}
31
32pub struct StoreState {
33 mode: StoreMode,
34 task_inventory: Entity<Inventory>,
35 buffer_store: WeakEntity<BufferStore>,
36 worktree_store: Entity<WorktreeStore>,
37 toolchain_store: Arc<dyn LanguageToolchainStore>,
38 _global_task_config_watchers: (Task<()>, Task<()>),
39}
40
41enum StoreMode {
42 Local {
43 downstream_client: Option<(AnyProtoClient, u64)>,
44 environment: Entity<ProjectEnvironment>,
45 },
46 Remote {
47 upstream_client: AnyProtoClient,
48 project_id: u64,
49 },
50}
51
52impl EventEmitter<crate::Event> for TaskStore {}
53
54#[derive(Debug)]
55pub enum TaskSettingsLocation<'a> {
56 Global(&'a Path),
57 Worktree(SettingsLocation<'a>),
58}
59
60impl TaskStore {
61 pub fn init(client: Option<&AnyProtoClient>) {
62 if let Some(client) = client {
63 client.add_entity_request_handler(Self::handle_task_context_for_location);
64 }
65 }
66
67 async fn handle_task_context_for_location(
68 store: Entity<Self>,
69 envelope: TypedEnvelope<proto::TaskContextForLocation>,
70 mut cx: AsyncApp,
71 ) -> anyhow::Result<proto::TaskContext> {
72 let location = envelope
73 .payload
74 .location
75 .context("no location given for task context handling")?;
76 let (buffer_store, is_remote) = store.update(&mut cx, |store, _| {
77 Ok(match store {
78 TaskStore::Functional(state) => (
79 state.buffer_store.clone(),
80 match &state.mode {
81 StoreMode::Local { .. } => false,
82 StoreMode::Remote { .. } => true,
83 },
84 ),
85 TaskStore::Noop => {
86 anyhow::bail!("empty task store cannot handle task context requests")
87 }
88 })
89 })??;
90 let buffer_store = buffer_store
91 .upgrade()
92 .context("no buffer store when handling task context request")?;
93
94 let buffer_id = BufferId::new(location.buffer_id).with_context(|| {
95 format!(
96 "cannot handle task context request for invalid buffer id: {}",
97 location.buffer_id
98 )
99 })?;
100
101 let start = location
102 .start
103 .and_then(deserialize_anchor)
104 .context("missing task context location start")?;
105 let end = location
106 .end
107 .and_then(deserialize_anchor)
108 .context("missing task context location end")?;
109 let buffer = buffer_store
110 .update(&mut cx, |buffer_store, cx| {
111 if is_remote {
112 buffer_store.wait_for_remote_buffer(buffer_id, cx)
113 } else {
114 Task::ready(
115 buffer_store
116 .get(buffer_id)
117 .with_context(|| format!("no local buffer with id {buffer_id}")),
118 )
119 }
120 })?
121 .await?;
122
123 let location = Location {
124 buffer,
125 range: start..end,
126 };
127 let context_task = store.update(&mut cx, |store, cx| {
128 let captured_variables = {
129 let mut variables = TaskVariables::from_iter(
130 envelope
131 .payload
132 .task_variables
133 .into_iter()
134 .filter_map(|(k, v)| Some((k.parse().log_err()?, v))),
135 );
136
137 let snapshot = location.buffer.read(cx).snapshot();
138 let range = location.range.to_offset(&snapshot);
139
140 for range in snapshot.runnable_ranges(range) {
141 for (capture_name, value) in range.extra_captures {
142 variables.insert(VariableName::Custom(capture_name.into()), value);
143 }
144 }
145 variables
146 };
147 store.task_context_for_location(captured_variables, location, cx)
148 })?;
149 let task_context = context_task.await.unwrap_or_default();
150 Ok(proto::TaskContext {
151 project_env: task_context.project_env.into_iter().collect(),
152 cwd: task_context
153 .cwd
154 .map(|cwd| cwd.to_string_lossy().to_string()),
155 task_variables: task_context
156 .task_variables
157 .into_iter()
158 .map(|(variable_name, variable_value)| (variable_name.to_string(), variable_value))
159 .collect(),
160 })
161 }
162
163 pub fn local(
164 fs: Arc<dyn Fs>,
165 buffer_store: WeakEntity<BufferStore>,
166 worktree_store: Entity<WorktreeStore>,
167 toolchain_store: Arc<dyn LanguageToolchainStore>,
168 environment: Entity<ProjectEnvironment>,
169 cx: &mut Context<'_, Self>,
170 ) -> Self {
171 Self::Functional(StoreState {
172 mode: StoreMode::Local {
173 downstream_client: None,
174 environment,
175 },
176 task_inventory: Inventory::new(cx),
177 buffer_store,
178 toolchain_store,
179 worktree_store,
180 _global_task_config_watchers: (
181 Self::subscribe_to_global_task_file_changes(
182 fs.clone(),
183 TaskKind::Script,
184 paths::tasks_file().clone(),
185 cx,
186 ),
187 Self::subscribe_to_global_task_file_changes(
188 fs.clone(),
189 TaskKind::Debug,
190 paths::debug_tasks_file().clone(),
191 cx,
192 ),
193 ),
194 })
195 }
196
197 pub fn remote(
198 fs: Arc<dyn Fs>,
199 buffer_store: WeakEntity<BufferStore>,
200 worktree_store: Entity<WorktreeStore>,
201 toolchain_store: Arc<dyn LanguageToolchainStore>,
202 upstream_client: AnyProtoClient,
203 project_id: u64,
204 cx: &mut Context<'_, Self>,
205 ) -> Self {
206 Self::Functional(StoreState {
207 mode: StoreMode::Remote {
208 upstream_client,
209 project_id,
210 },
211 task_inventory: Inventory::new(cx),
212 buffer_store,
213 toolchain_store,
214 worktree_store,
215 _global_task_config_watchers: (
216 Self::subscribe_to_global_task_file_changes(
217 fs.clone(),
218 TaskKind::Script,
219 paths::tasks_file().clone(),
220 cx,
221 ),
222 Self::subscribe_to_global_task_file_changes(
223 fs.clone(),
224 TaskKind::Debug,
225 paths::debug_tasks_file().clone(),
226 cx,
227 ),
228 ),
229 })
230 }
231
232 pub fn task_context_for_location(
233 &self,
234 captured_variables: TaskVariables,
235 location: Location,
236 cx: &mut App,
237 ) -> Task<Option<TaskContext>> {
238 match self {
239 TaskStore::Functional(state) => match &state.mode {
240 StoreMode::Local { environment, .. } => local_task_context_for_location(
241 state.worktree_store.clone(),
242 state.toolchain_store.clone(),
243 environment.clone(),
244 captured_variables,
245 location,
246 cx,
247 ),
248 StoreMode::Remote {
249 upstream_client,
250 project_id,
251 } => remote_task_context_for_location(
252 *project_id,
253 upstream_client.clone(),
254 state.worktree_store.clone(),
255 captured_variables,
256 location,
257 state.toolchain_store.clone(),
258 cx,
259 ),
260 },
261 TaskStore::Noop => Task::ready(None),
262 }
263 }
264
265 pub fn task_inventory(&self) -> Option<&Entity<Inventory>> {
266 match self {
267 TaskStore::Functional(state) => Some(&state.task_inventory),
268 TaskStore::Noop => None,
269 }
270 }
271
272 pub fn shared(&mut self, remote_id: u64, new_downstream_client: AnyProtoClient, _cx: &mut App) {
273 if let Self::Functional(StoreState {
274 mode: StoreMode::Local {
275 downstream_client, ..
276 },
277 ..
278 }) = self
279 {
280 *downstream_client = Some((new_downstream_client, remote_id));
281 }
282 }
283
284 pub fn unshared(&mut self, _: &mut Context<Self>) {
285 if let Self::Functional(StoreState {
286 mode: StoreMode::Local {
287 downstream_client, ..
288 },
289 ..
290 }) = self
291 {
292 *downstream_client = None;
293 }
294 }
295
296 pub(super) fn update_user_tasks(
297 &self,
298 location: TaskSettingsLocation<'_>,
299 raw_tasks_json: Option<&str>,
300 task_type: TaskKind,
301 cx: &mut Context<'_, Self>,
302 ) -> anyhow::Result<()> {
303 let task_inventory = match self {
304 TaskStore::Functional(state) => &state.task_inventory,
305 TaskStore::Noop => return Ok(()),
306 };
307 let raw_tasks_json = raw_tasks_json
308 .map(|json| json.trim())
309 .filter(|json| !json.is_empty());
310
311 task_inventory.update(cx, |inventory, _| {
312 inventory.update_file_based_tasks(location, raw_tasks_json, task_type)
313 })
314 }
315
316 fn subscribe_to_global_task_file_changes(
317 fs: Arc<dyn Fs>,
318 task_kind: TaskKind,
319 file_path: PathBuf,
320 cx: &mut Context<'_, Self>,
321 ) -> Task<()> {
322 let mut user_tasks_file_rx =
323 watch_config_file(&cx.background_executor(), fs, file_path.clone());
324 let user_tasks_content = cx.background_executor().block(user_tasks_file_rx.next());
325 cx.spawn(async move |task_store, cx| {
326 if let Some(user_tasks_content) = user_tasks_content {
327 let Ok(_) = task_store.update(cx, |task_store, cx| {
328 task_store
329 .update_user_tasks(
330 TaskSettingsLocation::Global(&file_path),
331 Some(&user_tasks_content),
332 task_kind,
333 cx,
334 )
335 .log_err();
336 }) else {
337 return;
338 };
339 }
340 while let Some(user_tasks_content) = user_tasks_file_rx.next().await {
341 let Ok(()) = task_store.update(cx, |task_store, cx| {
342 let result = task_store.update_user_tasks(
343 TaskSettingsLocation::Global(&file_path),
344 Some(&user_tasks_content),
345 task_kind,
346 cx,
347 );
348 if let Err(err) = &result {
349 log::error!("Failed to load user {:?} tasks: {err}", task_kind);
350 cx.emit(crate::Event::Toast {
351 notification_id: format!("load-user-{:?}-tasks", task_kind).into(),
352 message: format!("Invalid global {:?} tasks file\n{err}", task_kind),
353 });
354 }
355 cx.refresh_windows();
356 }) else {
357 break; // App dropped
358 };
359 }
360 })
361 }
362}
363
364fn local_task_context_for_location(
365 worktree_store: Entity<WorktreeStore>,
366 toolchain_store: Arc<dyn LanguageToolchainStore>,
367 environment: Entity<ProjectEnvironment>,
368 captured_variables: TaskVariables,
369 location: Location,
370 cx: &App,
371) -> Task<Option<TaskContext>> {
372 let worktree_id = location.buffer.read(cx).file().map(|f| f.worktree_id(cx));
373 let worktree_abs_path = worktree_id
374 .and_then(|worktree_id| worktree_store.read(cx).worktree_for_id(worktree_id, cx))
375 .and_then(|worktree| worktree.read(cx).root_dir());
376
377 cx.spawn(async move |cx| {
378 let worktree_abs_path = worktree_abs_path.clone();
379 let project_env = environment
380 .update(cx, |environment, cx| {
381 environment.get_environment(worktree_id, worktree_abs_path.clone(), cx)
382 })
383 .ok()?
384 .await;
385
386 let mut task_variables = cx
387 .update(|cx| {
388 combine_task_variables(
389 captured_variables,
390 location,
391 project_env.clone(),
392 BasicContextProvider::new(worktree_store),
393 toolchain_store,
394 cx,
395 )
396 })
397 .ok()?
398 .await
399 .log_err()?;
400 // Remove all custom entries starting with _, as they're not intended for use by the end user.
401 task_variables.sweep();
402
403 Some(TaskContext {
404 project_env: project_env.unwrap_or_default(),
405 cwd: worktree_abs_path.map(|p| p.to_path_buf()),
406 task_variables,
407 })
408 })
409}
410
411fn remote_task_context_for_location(
412 project_id: u64,
413 upstream_client: AnyProtoClient,
414 worktree_store: Entity<WorktreeStore>,
415 captured_variables: TaskVariables,
416 location: Location,
417 toolchain_store: Arc<dyn LanguageToolchainStore>,
418 cx: &mut App,
419) -> Task<Option<TaskContext>> {
420 cx.spawn(async move |cx| {
421 // We need to gather a client context, as the headless one may lack certain information (e.g. tree-sitter parsing is disabled there, so symbols are not available).
422 let mut remote_context = cx
423 .update(|cx| {
424 BasicContextProvider::new(worktree_store).build_context(
425 &TaskVariables::default(),
426 &location,
427 None,
428 toolchain_store,
429 cx,
430 )
431 })
432 .ok()?
433 .await
434 .log_err()
435 .unwrap_or_default();
436 remote_context.extend(captured_variables);
437
438 let buffer_id = cx
439 .update(|cx| location.buffer.read(cx).remote_id().to_proto())
440 .ok()?;
441 let context_task = upstream_client.request(proto::TaskContextForLocation {
442 project_id,
443 location: Some(proto::Location {
444 buffer_id,
445 start: Some(serialize_anchor(&location.range.start)),
446 end: Some(serialize_anchor(&location.range.end)),
447 }),
448 task_variables: remote_context
449 .into_iter()
450 .map(|(k, v)| (k.to_string(), v))
451 .collect(),
452 });
453 let task_context = context_task.await.log_err()?;
454 Some(TaskContext {
455 cwd: task_context.cwd.map(PathBuf::from),
456 task_variables: task_context
457 .task_variables
458 .into_iter()
459 .filter_map(
460 |(variable_name, variable_value)| match variable_name.parse() {
461 Ok(variable_name) => Some((variable_name, variable_value)),
462 Err(()) => {
463 log::error!("Unknown variable name: {variable_name}");
464 None
465 }
466 },
467 )
468 .collect(),
469 project_env: task_context.project_env.into_iter().collect(),
470 })
471 })
472}
473
474fn combine_task_variables(
475 mut captured_variables: TaskVariables,
476 location: Location,
477 project_env: Option<HashMap<String, String>>,
478 baseline: BasicContextProvider,
479 toolchain_store: Arc<dyn LanguageToolchainStore>,
480 cx: &mut App,
481) -> Task<anyhow::Result<TaskVariables>> {
482 let language_context_provider = location
483 .buffer
484 .read(cx)
485 .language()
486 .and_then(|language| language.context_provider());
487 cx.spawn(async move |cx| {
488 let baseline = cx
489 .update(|cx| {
490 baseline.build_context(
491 &captured_variables,
492 &location,
493 project_env.clone(),
494 toolchain_store.clone(),
495 cx,
496 )
497 })?
498 .await
499 .context("building basic default context")?;
500 captured_variables.extend(baseline);
501 if let Some(provider) = language_context_provider {
502 captured_variables.extend(
503 cx.update(|cx| {
504 provider.build_context(
505 &captured_variables,
506 &location,
507 project_env,
508 toolchain_store,
509 cx,
510 )
511 })?
512 .await
513 .context("building provider context")?,
514 );
515 }
516 Ok(captured_variables)
517 })
518}