task_store.rs

  1use std::{path::PathBuf, sync::Arc};
  2
  3use anyhow::Context as _;
  4use collections::HashMap;
  5use fs::Fs;
  6use futures::StreamExt as _;
  7use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
  8use language::{
  9    proto::{deserialize_anchor, serialize_anchor},
 10    ContextProvider as _, LanguageToolchainStore, Location,
 11};
 12use rpc::{proto, AnyProtoClient, TypedEnvelope};
 13use settings::{watch_config_file, SettingsLocation, TaskKind};
 14use task::{TaskContext, TaskVariables, VariableName};
 15use text::{BufferId, OffsetRangeExt};
 16use util::ResultExt;
 17
 18use crate::{
 19    buffer_store::BufferStore, worktree_store::WorktreeStore, BasicContextProvider, Inventory,
 20    ProjectEnvironment,
 21};
 22
 23#[allow(clippy::large_enum_variant)] // platform-dependent warning
 24pub enum TaskStore {
 25    Functional(StoreState),
 26    Noop,
 27}
 28
 29pub struct StoreState {
 30    mode: StoreMode,
 31    task_inventory: Entity<Inventory>,
 32    buffer_store: WeakEntity<BufferStore>,
 33    worktree_store: Entity<WorktreeStore>,
 34    toolchain_store: Arc<dyn LanguageToolchainStore>,
 35    _global_task_config_watchers: (Task<()>, Task<()>),
 36}
 37
 38enum StoreMode {
 39    Local {
 40        downstream_client: Option<(AnyProtoClient, u64)>,
 41        environment: Entity<ProjectEnvironment>,
 42    },
 43    Remote {
 44        upstream_client: AnyProtoClient,
 45        project_id: u64,
 46    },
 47}
 48
 49impl EventEmitter<crate::Event> for TaskStore {}
 50
 51impl TaskStore {
 52    pub fn init(client: Option<&AnyProtoClient>) {
 53        if let Some(client) = client {
 54            client.add_entity_request_handler(Self::handle_task_context_for_location);
 55        }
 56    }
 57
 58    async fn handle_task_context_for_location(
 59        store: Entity<Self>,
 60        envelope: TypedEnvelope<proto::TaskContextForLocation>,
 61        mut cx: AsyncApp,
 62    ) -> anyhow::Result<proto::TaskContext> {
 63        let location = envelope
 64            .payload
 65            .location
 66            .context("no location given for task context handling")?;
 67        let (buffer_store, is_remote) = store.update(&mut cx, |store, _| {
 68            Ok(match store {
 69                TaskStore::Functional(state) => (
 70                    state.buffer_store.clone(),
 71                    match &state.mode {
 72                        StoreMode::Local { .. } => false,
 73                        StoreMode::Remote { .. } => true,
 74                    },
 75                ),
 76                TaskStore::Noop => {
 77                    anyhow::bail!("empty task store cannot handle task context requests")
 78                }
 79            })
 80        })??;
 81        let buffer_store = buffer_store
 82            .upgrade()
 83            .context("no buffer store when handling task context request")?;
 84
 85        let buffer_id = BufferId::new(location.buffer_id).with_context(|| {
 86            format!(
 87                "cannot handle task context request for invalid buffer id: {}",
 88                location.buffer_id
 89            )
 90        })?;
 91
 92        let start = location
 93            .start
 94            .and_then(deserialize_anchor)
 95            .context("missing task context location start")?;
 96        let end = location
 97            .end
 98            .and_then(deserialize_anchor)
 99            .context("missing task context location end")?;
100        let buffer = buffer_store
101            .update(&mut cx, |buffer_store, cx| {
102                if is_remote {
103                    buffer_store.wait_for_remote_buffer(buffer_id, cx)
104                } else {
105                    Task::ready(
106                        buffer_store
107                            .get(buffer_id)
108                            .with_context(|| format!("no local buffer with id {buffer_id}")),
109                    )
110                }
111            })?
112            .await?;
113
114        let location = Location {
115            buffer,
116            range: start..end,
117        };
118        let context_task = store.update(&mut cx, |store, cx| {
119            let captured_variables = {
120                let mut variables = TaskVariables::from_iter(
121                    envelope
122                        .payload
123                        .task_variables
124                        .into_iter()
125                        .filter_map(|(k, v)| Some((k.parse().log_err()?, v))),
126                );
127
128                let snapshot = location.buffer.read(cx).snapshot();
129                let range = location.range.to_offset(&snapshot);
130
131                for range in snapshot.runnable_ranges(range) {
132                    for (capture_name, value) in range.extra_captures {
133                        variables.insert(VariableName::Custom(capture_name.into()), value);
134                    }
135                }
136                variables
137            };
138            store.task_context_for_location(captured_variables, location, cx)
139        })?;
140        let task_context = context_task.await.unwrap_or_default();
141        Ok(proto::TaskContext {
142            project_env: task_context.project_env.into_iter().collect(),
143            cwd: task_context
144                .cwd
145                .map(|cwd| cwd.to_string_lossy().to_string()),
146            task_variables: task_context
147                .task_variables
148                .into_iter()
149                .map(|(variable_name, variable_value)| (variable_name.to_string(), variable_value))
150                .collect(),
151        })
152    }
153
154    pub fn local(
155        fs: Arc<dyn Fs>,
156        buffer_store: WeakEntity<BufferStore>,
157        worktree_store: Entity<WorktreeStore>,
158        toolchain_store: Arc<dyn LanguageToolchainStore>,
159        environment: Entity<ProjectEnvironment>,
160        cx: &mut Context<'_, Self>,
161    ) -> Self {
162        Self::Functional(StoreState {
163            mode: StoreMode::Local {
164                downstream_client: None,
165                environment,
166            },
167            task_inventory: Inventory::new(cx),
168            buffer_store,
169            toolchain_store,
170            worktree_store,
171            _global_task_config_watchers: (
172                Self::subscribe_to_global_task_file_changes(
173                    fs.clone(),
174                    TaskKind::Script,
175                    paths::tasks_file().clone(),
176                    cx,
177                ),
178                Self::subscribe_to_global_task_file_changes(
179                    fs.clone(),
180                    TaskKind::Debug,
181                    paths::debug_tasks_file().clone(),
182                    cx,
183                ),
184            ),
185        })
186    }
187
188    pub fn remote(
189        fs: Arc<dyn Fs>,
190        buffer_store: WeakEntity<BufferStore>,
191        worktree_store: Entity<WorktreeStore>,
192        toolchain_store: Arc<dyn LanguageToolchainStore>,
193        upstream_client: AnyProtoClient,
194        project_id: u64,
195        cx: &mut Context<'_, Self>,
196    ) -> Self {
197        Self::Functional(StoreState {
198            mode: StoreMode::Remote {
199                upstream_client,
200                project_id,
201            },
202            task_inventory: Inventory::new(cx),
203            buffer_store,
204            toolchain_store,
205            worktree_store,
206            _global_task_config_watchers: (
207                Self::subscribe_to_global_task_file_changes(
208                    fs.clone(),
209                    TaskKind::Script,
210                    paths::tasks_file().clone(),
211                    cx,
212                ),
213                Self::subscribe_to_global_task_file_changes(
214                    fs.clone(),
215                    TaskKind::Debug,
216                    paths::debug_tasks_file().clone(),
217                    cx,
218                ),
219            ),
220        })
221    }
222
223    pub fn task_context_for_location(
224        &self,
225        captured_variables: TaskVariables,
226        location: Location,
227        cx: &mut App,
228    ) -> Task<Option<TaskContext>> {
229        match self {
230            TaskStore::Functional(state) => match &state.mode {
231                StoreMode::Local { environment, .. } => local_task_context_for_location(
232                    state.worktree_store.clone(),
233                    state.toolchain_store.clone(),
234                    environment.clone(),
235                    captured_variables,
236                    location,
237                    cx,
238                ),
239                StoreMode::Remote {
240                    upstream_client,
241                    project_id,
242                } => remote_task_context_for_location(
243                    *project_id,
244                    upstream_client.clone(),
245                    state.worktree_store.clone(),
246                    captured_variables,
247                    location,
248                    state.toolchain_store.clone(),
249                    cx,
250                ),
251            },
252            TaskStore::Noop => Task::ready(None),
253        }
254    }
255
256    pub fn task_inventory(&self) -> Option<&Entity<Inventory>> {
257        match self {
258            TaskStore::Functional(state) => Some(&state.task_inventory),
259            TaskStore::Noop => None,
260        }
261    }
262
263    pub fn shared(&mut self, remote_id: u64, new_downstream_client: AnyProtoClient, _cx: &mut App) {
264        if let Self::Functional(StoreState {
265            mode: StoreMode::Local {
266                downstream_client, ..
267            },
268            ..
269        }) = self
270        {
271            *downstream_client = Some((new_downstream_client, remote_id));
272        }
273    }
274
275    pub fn unshared(&mut self, _: &mut Context<Self>) {
276        if let Self::Functional(StoreState {
277            mode: StoreMode::Local {
278                downstream_client, ..
279            },
280            ..
281        }) = self
282        {
283            *downstream_client = None;
284        }
285    }
286
287    pub(super) fn update_user_tasks(
288        &self,
289        location: Option<SettingsLocation<'_>>,
290        raw_tasks_json: Option<&str>,
291        task_type: TaskKind,
292        cx: &mut Context<'_, Self>,
293    ) -> anyhow::Result<()> {
294        let task_inventory = match self {
295            TaskStore::Functional(state) => &state.task_inventory,
296            TaskStore::Noop => return Ok(()),
297        };
298        let raw_tasks_json = raw_tasks_json
299            .map(|json| json.trim())
300            .filter(|json| !json.is_empty());
301
302        task_inventory.update(cx, |inventory, _| {
303            inventory.update_file_based_tasks(location, raw_tasks_json, task_type)
304        })
305    }
306
307    fn subscribe_to_global_task_file_changes(
308        fs: Arc<dyn Fs>,
309        task_kind: TaskKind,
310        file_path: PathBuf,
311        cx: &mut Context<'_, Self>,
312    ) -> Task<()> {
313        let mut user_tasks_file_rx = watch_config_file(&cx.background_executor(), fs, file_path);
314        let user_tasks_content = cx.background_executor().block(user_tasks_file_rx.next());
315        cx.spawn(move |task_store, mut cx| async move {
316            if let Some(user_tasks_content) = user_tasks_content {
317                let Ok(_) = task_store.update(&mut cx, |task_store, cx| {
318                    task_store
319                        .update_user_tasks(None, Some(&user_tasks_content), task_kind, cx)
320                        .log_err();
321                }) else {
322                    return;
323                };
324            }
325            while let Some(user_tasks_content) = user_tasks_file_rx.next().await {
326                let Ok(()) = task_store.update(&mut cx, |task_store, cx| {
327                    let result = task_store.update_user_tasks(
328                        None,
329                        Some(&user_tasks_content),
330                        task_kind,
331                        cx,
332                    );
333                    if let Err(err) = &result {
334                        log::error!("Failed to load user {:?} tasks: {err}", task_kind);
335                        cx.emit(crate::Event::Toast {
336                            notification_id: format!("load-user-{:?}-tasks", task_kind).into(),
337                            message: format!("Invalid global {:?} tasks file\n{err}", task_kind),
338                        });
339                    }
340                    cx.refresh_windows();
341                }) else {
342                    break; // App dropped
343                };
344            }
345        })
346    }
347}
348
349fn local_task_context_for_location(
350    worktree_store: Entity<WorktreeStore>,
351    toolchain_store: Arc<dyn LanguageToolchainStore>,
352    environment: Entity<ProjectEnvironment>,
353    captured_variables: TaskVariables,
354    location: Location,
355    cx: &App,
356) -> Task<Option<TaskContext>> {
357    let worktree_id = location.buffer.read(cx).file().map(|f| f.worktree_id(cx));
358    let worktree_abs_path = worktree_id
359        .and_then(|worktree_id| worktree_store.read(cx).worktree_for_id(worktree_id, cx))
360        .and_then(|worktree| worktree.read(cx).root_dir());
361
362    cx.spawn(|mut cx| async move {
363        let worktree_abs_path = worktree_abs_path.clone();
364        let project_env = environment
365            .update(&mut cx, |environment, cx| {
366                environment.get_environment(worktree_id, worktree_abs_path.clone(), cx)
367            })
368            .ok()?
369            .await;
370
371        let mut task_variables = cx
372            .update(|cx| {
373                combine_task_variables(
374                    captured_variables,
375                    location,
376                    project_env.clone(),
377                    BasicContextProvider::new(worktree_store),
378                    toolchain_store,
379                    cx,
380                )
381            })
382            .ok()?
383            .await
384            .log_err()?;
385        // Remove all custom entries starting with _, as they're not intended for use by the end user.
386        task_variables.sweep();
387
388        Some(TaskContext {
389            project_env: project_env.unwrap_or_default(),
390            cwd: worktree_abs_path.map(|p| p.to_path_buf()),
391            task_variables,
392        })
393    })
394}
395
396fn remote_task_context_for_location(
397    project_id: u64,
398    upstream_client: AnyProtoClient,
399    worktree_store: Entity<WorktreeStore>,
400    captured_variables: TaskVariables,
401    location: Location,
402    toolchain_store: Arc<dyn LanguageToolchainStore>,
403    cx: &mut App,
404) -> Task<Option<TaskContext>> {
405    cx.spawn(|cx| async move {
406        // We need to gather a client context, as the headless one may lack certain information (e.g. tree-sitter parsing is disabled there, so symbols are not available).
407        let mut remote_context = cx
408            .update(|cx| {
409                BasicContextProvider::new(worktree_store).build_context(
410                    &TaskVariables::default(),
411                    &location,
412                    None,
413                    toolchain_store,
414                    cx,
415                )
416            })
417            .ok()?
418            .await
419            .log_err()
420            .unwrap_or_default();
421        remote_context.extend(captured_variables);
422
423        let buffer_id = cx
424            .update(|cx| location.buffer.read(cx).remote_id().to_proto())
425            .ok()?;
426        let context_task = upstream_client.request(proto::TaskContextForLocation {
427            project_id,
428            location: Some(proto::Location {
429                buffer_id,
430                start: Some(serialize_anchor(&location.range.start)),
431                end: Some(serialize_anchor(&location.range.end)),
432            }),
433            task_variables: remote_context
434                .into_iter()
435                .map(|(k, v)| (k.to_string(), v))
436                .collect(),
437        });
438        let task_context = context_task.await.log_err()?;
439        Some(TaskContext {
440            cwd: task_context.cwd.map(PathBuf::from),
441            task_variables: task_context
442                .task_variables
443                .into_iter()
444                .filter_map(
445                    |(variable_name, variable_value)| match variable_name.parse() {
446                        Ok(variable_name) => Some((variable_name, variable_value)),
447                        Err(()) => {
448                            log::error!("Unknown variable name: {variable_name}");
449                            None
450                        }
451                    },
452                )
453                .collect(),
454            project_env: task_context.project_env.into_iter().collect(),
455        })
456    })
457}
458
459fn combine_task_variables(
460    mut captured_variables: TaskVariables,
461    location: Location,
462    project_env: Option<HashMap<String, String>>,
463    baseline: BasicContextProvider,
464    toolchain_store: Arc<dyn LanguageToolchainStore>,
465    cx: &mut App,
466) -> Task<anyhow::Result<TaskVariables>> {
467    let language_context_provider = location
468        .buffer
469        .read(cx)
470        .language()
471        .and_then(|language| language.context_provider());
472    cx.spawn(move |cx| async move {
473        let baseline = cx
474            .update(|cx| {
475                baseline.build_context(
476                    &captured_variables,
477                    &location,
478                    project_env.clone(),
479                    toolchain_store.clone(),
480                    cx,
481                )
482            })?
483            .await
484            .context("building basic default context")?;
485        captured_variables.extend(baseline);
486        if let Some(provider) = language_context_provider {
487            captured_variables.extend(
488                cx.update(|cx| {
489                    provider.build_context(
490                        &captured_variables,
491                        &location,
492                        project_env,
493                        toolchain_store,
494                        cx,
495                    )
496                })?
497                .await
498                .context("building provider context")?,
499            );
500        }
501        Ok(captured_variables)
502    })
503}