task_store.rs

  1use std::{
  2    path::{Path, PathBuf},
  3    sync::Arc,
  4};
  5
  6use anyhow::Context as _;
  7use collections::HashMap;
  8use fs::Fs;
  9use futures::StreamExt as _;
 10use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
 11use language::{
 12    proto::{deserialize_anchor, serialize_anchor},
 13    ContextProvider as _, LanguageToolchainStore, Location,
 14};
 15use rpc::{proto, AnyProtoClient, TypedEnvelope};
 16use settings::{watch_config_file, SettingsLocation, TaskKind};
 17use task::{TaskContext, TaskVariables, VariableName};
 18use text::{BufferId, OffsetRangeExt};
 19use util::ResultExt;
 20
 21use crate::{
 22    buffer_store::BufferStore, worktree_store::WorktreeStore, BasicContextProvider, Inventory,
 23    ProjectEnvironment,
 24};
 25
 26#[allow(clippy::large_enum_variant)] // platform-dependent warning
 27pub enum TaskStore {
 28    Functional(StoreState),
 29    Noop,
 30}
 31
 32pub struct StoreState {
 33    mode: StoreMode,
 34    task_inventory: Entity<Inventory>,
 35    buffer_store: WeakEntity<BufferStore>,
 36    worktree_store: Entity<WorktreeStore>,
 37    toolchain_store: Arc<dyn LanguageToolchainStore>,
 38    _global_task_config_watchers: (Task<()>, Task<()>),
 39}
 40
 41enum StoreMode {
 42    Local {
 43        downstream_client: Option<(AnyProtoClient, u64)>,
 44        environment: Entity<ProjectEnvironment>,
 45    },
 46    Remote {
 47        upstream_client: AnyProtoClient,
 48        project_id: u64,
 49    },
 50}
 51
 52impl EventEmitter<crate::Event> for TaskStore {}
 53
 54#[derive(Debug)]
 55pub enum TaskSettingsLocation<'a> {
 56    Global(&'a Path),
 57    Worktree(SettingsLocation<'a>),
 58}
 59
 60impl TaskStore {
 61    pub fn init(client: Option<&AnyProtoClient>) {
 62        if let Some(client) = client {
 63            client.add_entity_request_handler(Self::handle_task_context_for_location);
 64        }
 65    }
 66
 67    async fn handle_task_context_for_location(
 68        store: Entity<Self>,
 69        envelope: TypedEnvelope<proto::TaskContextForLocation>,
 70        mut cx: AsyncApp,
 71    ) -> anyhow::Result<proto::TaskContext> {
 72        let location = envelope
 73            .payload
 74            .location
 75            .context("no location given for task context handling")?;
 76        let (buffer_store, is_remote) = store.update(&mut cx, |store, _| {
 77            Ok(match store {
 78                TaskStore::Functional(state) => (
 79                    state.buffer_store.clone(),
 80                    match &state.mode {
 81                        StoreMode::Local { .. } => false,
 82                        StoreMode::Remote { .. } => true,
 83                    },
 84                ),
 85                TaskStore::Noop => {
 86                    anyhow::bail!("empty task store cannot handle task context requests")
 87                }
 88            })
 89        })??;
 90        let buffer_store = buffer_store
 91            .upgrade()
 92            .context("no buffer store when handling task context request")?;
 93
 94        let buffer_id = BufferId::new(location.buffer_id).with_context(|| {
 95            format!(
 96                "cannot handle task context request for invalid buffer id: {}",
 97                location.buffer_id
 98            )
 99        })?;
100
101        let start = location
102            .start
103            .and_then(deserialize_anchor)
104            .context("missing task context location start")?;
105        let end = location
106            .end
107            .and_then(deserialize_anchor)
108            .context("missing task context location end")?;
109        let buffer = buffer_store
110            .update(&mut cx, |buffer_store, cx| {
111                if is_remote {
112                    buffer_store.wait_for_remote_buffer(buffer_id, cx)
113                } else {
114                    Task::ready(
115                        buffer_store
116                            .get(buffer_id)
117                            .with_context(|| format!("no local buffer with id {buffer_id}")),
118                    )
119                }
120            })?
121            .await?;
122
123        let location = Location {
124            buffer,
125            range: start..end,
126        };
127        let context_task = store.update(&mut cx, |store, cx| {
128            let captured_variables = {
129                let mut variables = TaskVariables::from_iter(
130                    envelope
131                        .payload
132                        .task_variables
133                        .into_iter()
134                        .filter_map(|(k, v)| Some((k.parse().log_err()?, v))),
135                );
136
137                let snapshot = location.buffer.read(cx).snapshot();
138                let range = location.range.to_offset(&snapshot);
139
140                for range in snapshot.runnable_ranges(range) {
141                    for (capture_name, value) in range.extra_captures {
142                        variables.insert(VariableName::Custom(capture_name.into()), value);
143                    }
144                }
145                variables
146            };
147            store.task_context_for_location(captured_variables, location, cx)
148        })?;
149        let task_context = context_task.await.unwrap_or_default();
150        Ok(proto::TaskContext {
151            project_env: task_context.project_env.into_iter().collect(),
152            cwd: task_context
153                .cwd
154                .map(|cwd| cwd.to_string_lossy().to_string()),
155            task_variables: task_context
156                .task_variables
157                .into_iter()
158                .map(|(variable_name, variable_value)| (variable_name.to_string(), variable_value))
159                .collect(),
160        })
161    }
162
163    pub fn local(
164        fs: Arc<dyn Fs>,
165        buffer_store: WeakEntity<BufferStore>,
166        worktree_store: Entity<WorktreeStore>,
167        toolchain_store: Arc<dyn LanguageToolchainStore>,
168        environment: Entity<ProjectEnvironment>,
169        cx: &mut Context<'_, Self>,
170    ) -> Self {
171        Self::Functional(StoreState {
172            mode: StoreMode::Local {
173                downstream_client: None,
174                environment,
175            },
176            task_inventory: Inventory::new(cx),
177            buffer_store,
178            toolchain_store,
179            worktree_store,
180            _global_task_config_watchers: (
181                Self::subscribe_to_global_task_file_changes(
182                    fs.clone(),
183                    TaskKind::Script,
184                    paths::tasks_file().clone(),
185                    cx,
186                ),
187                Self::subscribe_to_global_task_file_changes(
188                    fs.clone(),
189                    TaskKind::Debug,
190                    paths::debug_tasks_file().clone(),
191                    cx,
192                ),
193            ),
194        })
195    }
196
197    pub fn remote(
198        fs: Arc<dyn Fs>,
199        buffer_store: WeakEntity<BufferStore>,
200        worktree_store: Entity<WorktreeStore>,
201        toolchain_store: Arc<dyn LanguageToolchainStore>,
202        upstream_client: AnyProtoClient,
203        project_id: u64,
204        cx: &mut Context<'_, Self>,
205    ) -> Self {
206        Self::Functional(StoreState {
207            mode: StoreMode::Remote {
208                upstream_client,
209                project_id,
210            },
211            task_inventory: Inventory::new(cx),
212            buffer_store,
213            toolchain_store,
214            worktree_store,
215            _global_task_config_watchers: (
216                Self::subscribe_to_global_task_file_changes(
217                    fs.clone(),
218                    TaskKind::Script,
219                    paths::tasks_file().clone(),
220                    cx,
221                ),
222                Self::subscribe_to_global_task_file_changes(
223                    fs.clone(),
224                    TaskKind::Debug,
225                    paths::debug_tasks_file().clone(),
226                    cx,
227                ),
228            ),
229        })
230    }
231
232    pub fn task_context_for_location(
233        &self,
234        captured_variables: TaskVariables,
235        location: Location,
236        cx: &mut App,
237    ) -> Task<Option<TaskContext>> {
238        match self {
239            TaskStore::Functional(state) => match &state.mode {
240                StoreMode::Local { environment, .. } => local_task_context_for_location(
241                    state.worktree_store.clone(),
242                    state.toolchain_store.clone(),
243                    environment.clone(),
244                    captured_variables,
245                    location,
246                    cx,
247                ),
248                StoreMode::Remote {
249                    upstream_client,
250                    project_id,
251                } => remote_task_context_for_location(
252                    *project_id,
253                    upstream_client.clone(),
254                    state.worktree_store.clone(),
255                    captured_variables,
256                    location,
257                    state.toolchain_store.clone(),
258                    cx,
259                ),
260            },
261            TaskStore::Noop => Task::ready(None),
262        }
263    }
264
265    pub fn task_inventory(&self) -> Option<&Entity<Inventory>> {
266        match self {
267            TaskStore::Functional(state) => Some(&state.task_inventory),
268            TaskStore::Noop => None,
269        }
270    }
271
272    pub fn shared(&mut self, remote_id: u64, new_downstream_client: AnyProtoClient, _cx: &mut App) {
273        if let Self::Functional(StoreState {
274            mode: StoreMode::Local {
275                downstream_client, ..
276            },
277            ..
278        }) = self
279        {
280            *downstream_client = Some((new_downstream_client, remote_id));
281        }
282    }
283
284    pub fn unshared(&mut self, _: &mut Context<Self>) {
285        if let Self::Functional(StoreState {
286            mode: StoreMode::Local {
287                downstream_client, ..
288            },
289            ..
290        }) = self
291        {
292            *downstream_client = None;
293        }
294    }
295
296    pub(super) fn update_user_tasks(
297        &self,
298        location: TaskSettingsLocation<'_>,
299        raw_tasks_json: Option<&str>,
300        task_type: TaskKind,
301        cx: &mut Context<'_, Self>,
302    ) -> anyhow::Result<()> {
303        let task_inventory = match self {
304            TaskStore::Functional(state) => &state.task_inventory,
305            TaskStore::Noop => return Ok(()),
306        };
307        let raw_tasks_json = raw_tasks_json
308            .map(|json| json.trim())
309            .filter(|json| !json.is_empty());
310
311        task_inventory.update(cx, |inventory, _| {
312            inventory.update_file_based_tasks(location, raw_tasks_json, task_type)
313        })
314    }
315
316    fn subscribe_to_global_task_file_changes(
317        fs: Arc<dyn Fs>,
318        task_kind: TaskKind,
319        file_path: PathBuf,
320        cx: &mut Context<'_, Self>,
321    ) -> Task<()> {
322        let mut user_tasks_file_rx =
323            watch_config_file(&cx.background_executor(), fs, file_path.clone());
324        let user_tasks_content = cx.background_executor().block(user_tasks_file_rx.next());
325        cx.spawn(async move |task_store, cx| {
326            if let Some(user_tasks_content) = user_tasks_content {
327                let Ok(_) = task_store.update(cx, |task_store, cx| {
328                    task_store
329                        .update_user_tasks(
330                            TaskSettingsLocation::Global(&file_path),
331                            Some(&user_tasks_content),
332                            task_kind,
333                            cx,
334                        )
335                        .log_err();
336                }) else {
337                    return;
338                };
339            }
340            while let Some(user_tasks_content) = user_tasks_file_rx.next().await {
341                let Ok(()) = task_store.update(cx, |task_store, cx| {
342                    let result = task_store.update_user_tasks(
343                        TaskSettingsLocation::Global(&file_path),
344                        Some(&user_tasks_content),
345                        task_kind,
346                        cx,
347                    );
348                    if let Err(err) = &result {
349                        log::error!("Failed to load user {:?} tasks: {err}", task_kind);
350                        cx.emit(crate::Event::Toast {
351                            notification_id: format!("load-user-{:?}-tasks", task_kind).into(),
352                            message: format!("Invalid global {:?} tasks file\n{err}", task_kind),
353                        });
354                    }
355                    cx.refresh_windows();
356                }) else {
357                    break; // App dropped
358                };
359            }
360        })
361    }
362}
363
364fn local_task_context_for_location(
365    worktree_store: Entity<WorktreeStore>,
366    toolchain_store: Arc<dyn LanguageToolchainStore>,
367    environment: Entity<ProjectEnvironment>,
368    captured_variables: TaskVariables,
369    location: Location,
370    cx: &App,
371) -> Task<Option<TaskContext>> {
372    let worktree_id = location.buffer.read(cx).file().map(|f| f.worktree_id(cx));
373    let worktree_abs_path = worktree_id
374        .and_then(|worktree_id| worktree_store.read(cx).worktree_for_id(worktree_id, cx))
375        .and_then(|worktree| worktree.read(cx).root_dir());
376
377    cx.spawn(async move |cx| {
378        let worktree_abs_path = worktree_abs_path.clone();
379        let project_env = environment
380            .update(cx, |environment, cx| {
381                environment.get_environment(worktree_id, worktree_abs_path.clone(), cx)
382            })
383            .ok()?
384            .await;
385
386        let mut task_variables = cx
387            .update(|cx| {
388                combine_task_variables(
389                    captured_variables,
390                    location,
391                    project_env.clone(),
392                    BasicContextProvider::new(worktree_store),
393                    toolchain_store,
394                    cx,
395                )
396            })
397            .ok()?
398            .await
399            .log_err()?;
400        // Remove all custom entries starting with _, as they're not intended for use by the end user.
401        task_variables.sweep();
402
403        Some(TaskContext {
404            project_env: project_env.unwrap_or_default(),
405            cwd: worktree_abs_path.map(|p| p.to_path_buf()),
406            task_variables,
407        })
408    })
409}
410
411fn remote_task_context_for_location(
412    project_id: u64,
413    upstream_client: AnyProtoClient,
414    worktree_store: Entity<WorktreeStore>,
415    captured_variables: TaskVariables,
416    location: Location,
417    toolchain_store: Arc<dyn LanguageToolchainStore>,
418    cx: &mut App,
419) -> Task<Option<TaskContext>> {
420    cx.spawn(async move |cx| {
421        // We need to gather a client context, as the headless one may lack certain information (e.g. tree-sitter parsing is disabled there, so symbols are not available).
422        let mut remote_context = cx
423            .update(|cx| {
424                BasicContextProvider::new(worktree_store).build_context(
425                    &TaskVariables::default(),
426                    &location,
427                    None,
428                    toolchain_store,
429                    cx,
430                )
431            })
432            .ok()?
433            .await
434            .log_err()
435            .unwrap_or_default();
436        remote_context.extend(captured_variables);
437
438        let buffer_id = cx
439            .update(|cx| location.buffer.read(cx).remote_id().to_proto())
440            .ok()?;
441        let context_task = upstream_client.request(proto::TaskContextForLocation {
442            project_id,
443            location: Some(proto::Location {
444                buffer_id,
445                start: Some(serialize_anchor(&location.range.start)),
446                end: Some(serialize_anchor(&location.range.end)),
447            }),
448            task_variables: remote_context
449                .into_iter()
450                .map(|(k, v)| (k.to_string(), v))
451                .collect(),
452        });
453        let task_context = context_task.await.log_err()?;
454        Some(TaskContext {
455            cwd: task_context.cwd.map(PathBuf::from),
456            task_variables: task_context
457                .task_variables
458                .into_iter()
459                .filter_map(
460                    |(variable_name, variable_value)| match variable_name.parse() {
461                        Ok(variable_name) => Some((variable_name, variable_value)),
462                        Err(()) => {
463                            log::error!("Unknown variable name: {variable_name}");
464                            None
465                        }
466                    },
467                )
468                .collect(),
469            project_env: task_context.project_env.into_iter().collect(),
470        })
471    })
472}
473
474fn combine_task_variables(
475    mut captured_variables: TaskVariables,
476    location: Location,
477    project_env: Option<HashMap<String, String>>,
478    baseline: BasicContextProvider,
479    toolchain_store: Arc<dyn LanguageToolchainStore>,
480    cx: &mut App,
481) -> Task<anyhow::Result<TaskVariables>> {
482    let language_context_provider = location
483        .buffer
484        .read(cx)
485        .language()
486        .and_then(|language| language.context_provider());
487    cx.spawn(async move |cx| {
488        let baseline = cx
489            .update(|cx| {
490                baseline.build_context(
491                    &captured_variables,
492                    &location,
493                    project_env.clone(),
494                    toolchain_store.clone(),
495                    cx,
496                )
497            })?
498            .await
499            .context("building basic default context")?;
500        captured_variables.extend(baseline);
501        if let Some(provider) = language_context_provider {
502            captured_variables.extend(
503                cx.update(|cx| {
504                    provider.build_context(
505                        &captured_variables,
506                        &location,
507                        project_env,
508                        toolchain_store,
509                        cx,
510                    )
511                })?
512                .await
513                .context("building provider context")?,
514            );
515        }
516        Ok(captured_variables)
517    })
518}