task_store.rs

  1use std::{path::PathBuf, sync::Arc};
  2
  3use anyhow::Context as _;
  4use collections::HashMap;
  5use fs::Fs;
  6use futures::StreamExt as _;
  7use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
  8use language::{
  9    proto::{deserialize_anchor, serialize_anchor},
 10    ContextProvider as _, LanguageToolchainStore, Location,
 11};
 12use rpc::{proto, AnyProtoClient, TypedEnvelope};
 13use settings::{watch_config_file, SettingsLocation};
 14use task::{TaskContext, TaskVariables, VariableName};
 15use text::{BufferId, OffsetRangeExt};
 16use util::ResultExt;
 17
 18use crate::{
 19    buffer_store::BufferStore, worktree_store::WorktreeStore, BasicContextProvider, Inventory,
 20    ProjectEnvironment,
 21};
 22
 23#[allow(clippy::large_enum_variant)] // platform-dependent warning
 24pub enum TaskStore {
 25    Functional(StoreState),
 26    Noop,
 27}
 28
 29pub struct StoreState {
 30    mode: StoreMode,
 31    task_inventory: Entity<Inventory>,
 32    buffer_store: WeakEntity<BufferStore>,
 33    worktree_store: Entity<WorktreeStore>,
 34    toolchain_store: Arc<dyn LanguageToolchainStore>,
 35    _global_task_config_watcher: Task<()>,
 36}
 37
 38enum StoreMode {
 39    Local {
 40        downstream_client: Option<(AnyProtoClient, u64)>,
 41        environment: Entity<ProjectEnvironment>,
 42    },
 43    Remote {
 44        upstream_client: AnyProtoClient,
 45        project_id: u64,
 46    },
 47}
 48
 49impl EventEmitter<crate::Event> for TaskStore {}
 50
 51impl TaskStore {
 52    pub fn init(client: Option<&AnyProtoClient>) {
 53        if let Some(client) = client {
 54            client.add_entity_request_handler(Self::handle_task_context_for_location);
 55        }
 56    }
 57
 58    async fn handle_task_context_for_location(
 59        store: Entity<Self>,
 60        envelope: TypedEnvelope<proto::TaskContextForLocation>,
 61        mut cx: AsyncApp,
 62    ) -> anyhow::Result<proto::TaskContext> {
 63        let location = envelope
 64            .payload
 65            .location
 66            .context("no location given for task context handling")?;
 67        let (buffer_store, is_remote) = store.update(&mut cx, |store, _| {
 68            Ok(match store {
 69                TaskStore::Functional(state) => (
 70                    state.buffer_store.clone(),
 71                    match &state.mode {
 72                        StoreMode::Local { .. } => false,
 73                        StoreMode::Remote { .. } => true,
 74                    },
 75                ),
 76                TaskStore::Noop => {
 77                    anyhow::bail!("empty task store cannot handle task context requests")
 78                }
 79            })
 80        })??;
 81        let buffer_store = buffer_store
 82            .upgrade()
 83            .context("no buffer store when handling task context request")?;
 84
 85        let buffer_id = BufferId::new(location.buffer_id).with_context(|| {
 86            format!(
 87                "cannot handle task context request for invalid buffer id: {}",
 88                location.buffer_id
 89            )
 90        })?;
 91
 92        let start = location
 93            .start
 94            .and_then(deserialize_anchor)
 95            .context("missing task context location start")?;
 96        let end = location
 97            .end
 98            .and_then(deserialize_anchor)
 99            .context("missing task context location end")?;
100        let buffer = buffer_store
101            .update(&mut cx, |buffer_store, cx| {
102                if is_remote {
103                    buffer_store.wait_for_remote_buffer(buffer_id, cx)
104                } else {
105                    Task::ready(
106                        buffer_store
107                            .get(buffer_id)
108                            .with_context(|| format!("no local buffer with id {buffer_id}")),
109                    )
110                }
111            })?
112            .await?;
113
114        let location = Location {
115            buffer,
116            range: start..end,
117        };
118        let context_task = store.update(&mut cx, |store, cx| {
119            let captured_variables = {
120                let mut variables = TaskVariables::from_iter(
121                    envelope
122                        .payload
123                        .task_variables
124                        .into_iter()
125                        .filter_map(|(k, v)| Some((k.parse().log_err()?, v))),
126                );
127
128                let snapshot = location.buffer.read(cx).snapshot();
129                let range = location.range.to_offset(&snapshot);
130
131                for range in snapshot.runnable_ranges(range) {
132                    for (capture_name, value) in range.extra_captures {
133                        variables.insert(VariableName::Custom(capture_name.into()), value);
134                    }
135                }
136                variables
137            };
138            store.task_context_for_location(captured_variables, location, cx)
139        })?;
140        let task_context = context_task.await.unwrap_or_default();
141        Ok(proto::TaskContext {
142            project_env: task_context.project_env.into_iter().collect(),
143            cwd: task_context
144                .cwd
145                .map(|cwd| cwd.to_string_lossy().to_string()),
146            task_variables: task_context
147                .task_variables
148                .into_iter()
149                .map(|(variable_name, variable_value)| (variable_name.to_string(), variable_value))
150                .collect(),
151        })
152    }
153
154    pub fn local(
155        fs: Arc<dyn Fs>,
156        buffer_store: WeakEntity<BufferStore>,
157        worktree_store: Entity<WorktreeStore>,
158        toolchain_store: Arc<dyn LanguageToolchainStore>,
159        environment: Entity<ProjectEnvironment>,
160        cx: &mut Context<'_, Self>,
161    ) -> Self {
162        Self::Functional(StoreState {
163            mode: StoreMode::Local {
164                downstream_client: None,
165                environment,
166            },
167            task_inventory: Inventory::new(cx),
168            buffer_store,
169            toolchain_store,
170            worktree_store,
171            _global_task_config_watcher: Self::subscribe_to_global_task_file_changes(fs, cx),
172        })
173    }
174
175    pub fn remote(
176        fs: Arc<dyn Fs>,
177        buffer_store: WeakEntity<BufferStore>,
178        worktree_store: Entity<WorktreeStore>,
179        toolchain_store: Arc<dyn LanguageToolchainStore>,
180        upstream_client: AnyProtoClient,
181        project_id: u64,
182        cx: &mut Context<'_, Self>,
183    ) -> Self {
184        Self::Functional(StoreState {
185            mode: StoreMode::Remote {
186                upstream_client,
187                project_id,
188            },
189            task_inventory: Inventory::new(cx),
190            buffer_store,
191            toolchain_store,
192            worktree_store,
193            _global_task_config_watcher: Self::subscribe_to_global_task_file_changes(fs, cx),
194        })
195    }
196
197    pub fn task_context_for_location(
198        &self,
199        captured_variables: TaskVariables,
200        location: Location,
201        cx: &mut App,
202    ) -> Task<Option<TaskContext>> {
203        match self {
204            TaskStore::Functional(state) => match &state.mode {
205                StoreMode::Local { environment, .. } => local_task_context_for_location(
206                    state.worktree_store.clone(),
207                    state.toolchain_store.clone(),
208                    environment.clone(),
209                    captured_variables,
210                    location,
211                    cx,
212                ),
213                StoreMode::Remote {
214                    upstream_client,
215                    project_id,
216                } => remote_task_context_for_location(
217                    *project_id,
218                    upstream_client.clone(),
219                    state.worktree_store.clone(),
220                    captured_variables,
221                    location,
222                    state.toolchain_store.clone(),
223                    cx,
224                ),
225            },
226            TaskStore::Noop => Task::ready(None),
227        }
228    }
229
230    pub fn task_inventory(&self) -> Option<&Entity<Inventory>> {
231        match self {
232            TaskStore::Functional(state) => Some(&state.task_inventory),
233            TaskStore::Noop => None,
234        }
235    }
236
237    pub fn shared(&mut self, remote_id: u64, new_downstream_client: AnyProtoClient, _cx: &mut App) {
238        if let Self::Functional(StoreState {
239            mode: StoreMode::Local {
240                downstream_client, ..
241            },
242            ..
243        }) = self
244        {
245            *downstream_client = Some((new_downstream_client, remote_id));
246        }
247    }
248
249    pub fn unshared(&mut self, _: &mut Context<Self>) {
250        if let Self::Functional(StoreState {
251            mode: StoreMode::Local {
252                downstream_client, ..
253            },
254            ..
255        }) = self
256        {
257            *downstream_client = None;
258        }
259    }
260
261    pub(super) fn update_user_tasks(
262        &self,
263        location: Option<SettingsLocation<'_>>,
264        raw_tasks_json: Option<&str>,
265        cx: &mut Context<'_, Self>,
266    ) -> anyhow::Result<()> {
267        let task_inventory = match self {
268            TaskStore::Functional(state) => &state.task_inventory,
269            TaskStore::Noop => return Ok(()),
270        };
271        let raw_tasks_json = raw_tasks_json
272            .map(|json| json.trim())
273            .filter(|json| !json.is_empty());
274
275        task_inventory.update(cx, |inventory, _| {
276            inventory.update_file_based_tasks(location, raw_tasks_json)
277        })
278    }
279
280    fn subscribe_to_global_task_file_changes(
281        fs: Arc<dyn Fs>,
282        cx: &mut Context<'_, Self>,
283    ) -> Task<()> {
284        let mut user_tasks_file_rx =
285            watch_config_file(&cx.background_executor(), fs, paths::tasks_file().clone());
286        let user_tasks_content = cx.background_executor().block(user_tasks_file_rx.next());
287        cx.spawn(move |task_store, mut cx| async move {
288            if let Some(user_tasks_content) = user_tasks_content {
289                let Ok(_) = task_store.update(&mut cx, |task_store, cx| {
290                    task_store
291                        .update_user_tasks(None, Some(&user_tasks_content), cx)
292                        .log_err();
293                }) else {
294                    return;
295                };
296            }
297            while let Some(user_tasks_content) = user_tasks_file_rx.next().await {
298                let Ok(()) = task_store.update(&mut cx, |task_store, cx| {
299                    let result = task_store.update_user_tasks(None, Some(&user_tasks_content), cx);
300                    if let Err(err) = &result {
301                        log::error!("Failed to load user tasks: {err}");
302                        cx.emit(crate::Event::Toast {
303                            notification_id: "load-user-tasks".into(),
304                            message: format!("Invalid global tasks file\n{err}"),
305                        });
306                    }
307                    cx.refresh_windows();
308                }) else {
309                    break; // App dropped
310                };
311            }
312        })
313    }
314}
315
316fn local_task_context_for_location(
317    worktree_store: Entity<WorktreeStore>,
318    toolchain_store: Arc<dyn LanguageToolchainStore>,
319    environment: Entity<ProjectEnvironment>,
320    captured_variables: TaskVariables,
321    location: Location,
322    cx: &App,
323) -> Task<Option<TaskContext>> {
324    let worktree_id = location.buffer.read(cx).file().map(|f| f.worktree_id(cx));
325    let worktree_abs_path = worktree_id
326        .and_then(|worktree_id| worktree_store.read(cx).worktree_for_id(worktree_id, cx))
327        .and_then(|worktree| worktree.read(cx).root_dir());
328
329    cx.spawn(|mut cx| async move {
330        let worktree_abs_path = worktree_abs_path.clone();
331        let project_env = environment
332            .update(&mut cx, |environment, cx| {
333                environment.get_environment(worktree_id, worktree_abs_path.clone(), cx)
334            })
335            .ok()?
336            .await;
337
338        let mut task_variables = cx
339            .update(|cx| {
340                combine_task_variables(
341                    captured_variables,
342                    location,
343                    project_env.clone(),
344                    BasicContextProvider::new(worktree_store),
345                    toolchain_store,
346                    cx,
347                )
348            })
349            .ok()?
350            .await
351            .log_err()?;
352        // Remove all custom entries starting with _, as they're not intended for use by the end user.
353        task_variables.sweep();
354
355        Some(TaskContext {
356            project_env: project_env.unwrap_or_default(),
357            cwd: worktree_abs_path.map(|p| p.to_path_buf()),
358            task_variables,
359        })
360    })
361}
362
363fn remote_task_context_for_location(
364    project_id: u64,
365    upstream_client: AnyProtoClient,
366    worktree_store: Entity<WorktreeStore>,
367    captured_variables: TaskVariables,
368    location: Location,
369    toolchain_store: Arc<dyn LanguageToolchainStore>,
370    cx: &mut App,
371) -> Task<Option<TaskContext>> {
372    cx.spawn(|cx| async move {
373        // We need to gather a client context, as the headless one may lack certain information (e.g. tree-sitter parsing is disabled there, so symbols are not available).
374        let mut remote_context = cx
375            .update(|cx| {
376                BasicContextProvider::new(worktree_store).build_context(
377                    &TaskVariables::default(),
378                    &location,
379                    None,
380                    toolchain_store,
381                    cx,
382                )
383            })
384            .ok()?
385            .await
386            .log_err()
387            .unwrap_or_default();
388        remote_context.extend(captured_variables);
389
390        let buffer_id = cx
391            .update(|cx| location.buffer.read(cx).remote_id().to_proto())
392            .ok()?;
393        let context_task = upstream_client.request(proto::TaskContextForLocation {
394            project_id,
395            location: Some(proto::Location {
396                buffer_id,
397                start: Some(serialize_anchor(&location.range.start)),
398                end: Some(serialize_anchor(&location.range.end)),
399            }),
400            task_variables: remote_context
401                .into_iter()
402                .map(|(k, v)| (k.to_string(), v))
403                .collect(),
404        });
405        let task_context = context_task.await.log_err()?;
406        Some(TaskContext {
407            cwd: task_context.cwd.map(PathBuf::from),
408            task_variables: task_context
409                .task_variables
410                .into_iter()
411                .filter_map(
412                    |(variable_name, variable_value)| match variable_name.parse() {
413                        Ok(variable_name) => Some((variable_name, variable_value)),
414                        Err(()) => {
415                            log::error!("Unknown variable name: {variable_name}");
416                            None
417                        }
418                    },
419                )
420                .collect(),
421            project_env: task_context.project_env.into_iter().collect(),
422        })
423    })
424}
425
426fn combine_task_variables(
427    mut captured_variables: TaskVariables,
428    location: Location,
429    project_env: Option<HashMap<String, String>>,
430    baseline: BasicContextProvider,
431    toolchain_store: Arc<dyn LanguageToolchainStore>,
432    cx: &mut App,
433) -> Task<anyhow::Result<TaskVariables>> {
434    let language_context_provider = location
435        .buffer
436        .read(cx)
437        .language()
438        .and_then(|language| language.context_provider());
439    cx.spawn(move |cx| async move {
440        let baseline = cx
441            .update(|cx| {
442                baseline.build_context(
443                    &captured_variables,
444                    &location,
445                    project_env.clone(),
446                    toolchain_store.clone(),
447                    cx,
448                )
449            })?
450            .await
451            .context("building basic default context")?;
452        captured_variables.extend(baseline);
453        if let Some(provider) = language_context_provider {
454            captured_variables.extend(
455                cx.update(|cx| {
456                    provider.build_context(
457                        &captured_variables,
458                        &location,
459                        project_env,
460                        toolchain_store,
461                        cx,
462                    )
463                })?
464                .await
465                .context("building provider context")?,
466            );
467        }
468        Ok(captured_variables)
469    })
470}