task_store.rs

  1use std::{
  2    path::{Path, PathBuf},
  3    sync::Arc,
  4};
  5
  6use anyhow::Context as _;
  7use collections::HashMap;
  8use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
  9use language::{
 10    ContextProvider as _, LanguageToolchainStore, Location,
 11    proto::{deserialize_anchor, serialize_anchor},
 12};
 13use rpc::{AnyProtoClient, TypedEnvelope, proto};
 14use settings::{InvalidSettingsError, SettingsLocation, TaskKind};
 15use task::{TaskContext, TaskVariables, VariableName};
 16use text::{BufferId, OffsetRangeExt};
 17use util::ResultExt;
 18
 19use crate::{
 20    BasicContextProvider, Inventory, ProjectEnvironment, buffer_store::BufferStore,
 21    worktree_store::WorktreeStore,
 22};
 23
 24#[allow(clippy::large_enum_variant)] // platform-dependent warning
 25pub enum TaskStore {
 26    Functional(StoreState),
 27    Noop,
 28}
 29
 30pub struct StoreState {
 31    mode: StoreMode,
 32    task_inventory: Entity<Inventory>,
 33    buffer_store: WeakEntity<BufferStore>,
 34    worktree_store: Entity<WorktreeStore>,
 35    toolchain_store: Arc<dyn LanguageToolchainStore>,
 36}
 37
 38enum StoreMode {
 39    Local {
 40        downstream_client: Option<(AnyProtoClient, u64)>,
 41        environment: Entity<ProjectEnvironment>,
 42    },
 43    Remote {
 44        upstream_client: AnyProtoClient,
 45        project_id: u64,
 46    },
 47}
 48
 49impl EventEmitter<crate::Event> for TaskStore {}
 50
 51#[derive(Debug)]
 52pub enum TaskSettingsLocation<'a> {
 53    Global(&'a Path),
 54    Worktree(SettingsLocation<'a>),
 55}
 56
 57impl TaskStore {
 58    pub fn init(client: Option<&AnyProtoClient>) {
 59        if let Some(client) = client {
 60            client.add_entity_request_handler(Self::handle_task_context_for_location);
 61        }
 62    }
 63
 64    async fn handle_task_context_for_location(
 65        store: Entity<Self>,
 66        envelope: TypedEnvelope<proto::TaskContextForLocation>,
 67        mut cx: AsyncApp,
 68    ) -> anyhow::Result<proto::TaskContext> {
 69        let location = envelope
 70            .payload
 71            .location
 72            .context("no location given for task context handling")?;
 73        let (buffer_store, is_remote) = store.update(&mut cx, |store, _| {
 74            Ok(match store {
 75                TaskStore::Functional(state) => (
 76                    state.buffer_store.clone(),
 77                    match &state.mode {
 78                        StoreMode::Local { .. } => false,
 79                        StoreMode::Remote { .. } => true,
 80                    },
 81                ),
 82                TaskStore::Noop => {
 83                    anyhow::bail!("empty task store cannot handle task context requests")
 84                }
 85            })
 86        })??;
 87        let buffer_store = buffer_store
 88            .upgrade()
 89            .context("no buffer store when handling task context request")?;
 90
 91        let buffer_id = BufferId::new(location.buffer_id).with_context(|| {
 92            format!(
 93                "cannot handle task context request for invalid buffer id: {}",
 94                location.buffer_id
 95            )
 96        })?;
 97
 98        let start = location
 99            .start
100            .and_then(deserialize_anchor)
101            .context("missing task context location start")?;
102        let end = location
103            .end
104            .and_then(deserialize_anchor)
105            .context("missing task context location end")?;
106        let buffer = buffer_store
107            .update(&mut cx, |buffer_store, cx| {
108                if is_remote {
109                    buffer_store.wait_for_remote_buffer(buffer_id, cx)
110                } else {
111                    Task::ready(
112                        buffer_store
113                            .get(buffer_id)
114                            .with_context(|| format!("no local buffer with id {buffer_id}")),
115                    )
116                }
117            })?
118            .await?;
119
120        let location = Location {
121            buffer,
122            range: start..end,
123        };
124        let context_task = store.update(&mut cx, |store, cx| {
125            let captured_variables = {
126                let mut variables = TaskVariables::from_iter(
127                    envelope
128                        .payload
129                        .task_variables
130                        .into_iter()
131                        .filter_map(|(k, v)| Some((k.parse().log_err()?, v))),
132                );
133
134                let snapshot = location.buffer.read(cx).snapshot();
135                let range = location.range.to_offset(&snapshot);
136
137                for range in snapshot.runnable_ranges(range) {
138                    for (capture_name, value) in range.extra_captures {
139                        variables.insert(VariableName::Custom(capture_name.into()), value);
140                    }
141                }
142                variables
143            };
144            store.task_context_for_location(captured_variables, location, cx)
145        })?;
146        let task_context = context_task.await.unwrap_or_default();
147        Ok(proto::TaskContext {
148            project_env: task_context.project_env.into_iter().collect(),
149            cwd: task_context
150                .cwd
151                .map(|cwd| cwd.to_string_lossy().to_string()),
152            task_variables: task_context
153                .task_variables
154                .into_iter()
155                .map(|(variable_name, variable_value)| (variable_name.to_string(), variable_value))
156                .collect(),
157        })
158    }
159
160    pub fn local(
161        buffer_store: WeakEntity<BufferStore>,
162        worktree_store: Entity<WorktreeStore>,
163        toolchain_store: Arc<dyn LanguageToolchainStore>,
164        environment: Entity<ProjectEnvironment>,
165        cx: &mut Context<Self>,
166    ) -> Self {
167        Self::Functional(StoreState {
168            mode: StoreMode::Local {
169                downstream_client: None,
170                environment,
171            },
172            task_inventory: Inventory::new(cx),
173            buffer_store,
174            toolchain_store,
175            worktree_store,
176        })
177    }
178
179    pub fn remote(
180        buffer_store: WeakEntity<BufferStore>,
181        worktree_store: Entity<WorktreeStore>,
182        toolchain_store: Arc<dyn LanguageToolchainStore>,
183        upstream_client: AnyProtoClient,
184        project_id: u64,
185        cx: &mut Context<Self>,
186    ) -> Self {
187        Self::Functional(StoreState {
188            mode: StoreMode::Remote {
189                upstream_client,
190                project_id,
191            },
192            task_inventory: Inventory::new(cx),
193            buffer_store,
194            toolchain_store,
195            worktree_store,
196        })
197    }
198
199    pub fn task_context_for_location(
200        &self,
201        captured_variables: TaskVariables,
202        location: Location,
203        cx: &mut App,
204    ) -> Task<Option<TaskContext>> {
205        match self {
206            TaskStore::Functional(state) => match &state.mode {
207                StoreMode::Local { environment, .. } => local_task_context_for_location(
208                    state.worktree_store.clone(),
209                    state.toolchain_store.clone(),
210                    environment.clone(),
211                    captured_variables,
212                    location,
213                    cx,
214                ),
215                StoreMode::Remote {
216                    upstream_client,
217                    project_id,
218                } => remote_task_context_for_location(
219                    *project_id,
220                    upstream_client.clone(),
221                    state.worktree_store.clone(),
222                    captured_variables,
223                    location,
224                    state.toolchain_store.clone(),
225                    cx,
226                ),
227            },
228            TaskStore::Noop => Task::ready(None),
229        }
230    }
231
232    pub fn task_inventory(&self) -> Option<&Entity<Inventory>> {
233        match self {
234            TaskStore::Functional(state) => Some(&state.task_inventory),
235            TaskStore::Noop => None,
236        }
237    }
238
239    pub fn shared(&mut self, remote_id: u64, new_downstream_client: AnyProtoClient, _cx: &mut App) {
240        if let Self::Functional(StoreState {
241            mode: StoreMode::Local {
242                downstream_client, ..
243            },
244            ..
245        }) = self
246        {
247            *downstream_client = Some((new_downstream_client, remote_id));
248        }
249    }
250
251    pub fn unshared(&mut self, _: &mut Context<Self>) {
252        if let Self::Functional(StoreState {
253            mode: StoreMode::Local {
254                downstream_client, ..
255            },
256            ..
257        }) = self
258        {
259            *downstream_client = None;
260        }
261    }
262
263    pub(super) fn update_user_tasks(
264        &self,
265        location: TaskSettingsLocation<'_>,
266        raw_tasks_json: Option<&str>,
267        task_type: TaskKind,
268        cx: &mut Context<Self>,
269    ) -> Result<(), InvalidSettingsError> {
270        let task_inventory = match self {
271            TaskStore::Functional(state) => &state.task_inventory,
272            TaskStore::Noop => return Ok(()),
273        };
274        let raw_tasks_json = raw_tasks_json
275            .map(|json| json.trim())
276            .filter(|json| !json.is_empty());
277
278        task_inventory.update(cx, |inventory, _| {
279            inventory.update_file_based_tasks(location, raw_tasks_json, task_type)
280        })
281    }
282}
283
284fn local_task_context_for_location(
285    worktree_store: Entity<WorktreeStore>,
286    toolchain_store: Arc<dyn LanguageToolchainStore>,
287    environment: Entity<ProjectEnvironment>,
288    captured_variables: TaskVariables,
289    location: Location,
290    cx: &App,
291) -> Task<Option<TaskContext>> {
292    let worktree_id = location.buffer.read(cx).file().map(|f| f.worktree_id(cx));
293    let worktree_abs_path = worktree_id
294        .and_then(|worktree_id| worktree_store.read(cx).worktree_for_id(worktree_id, cx))
295        .and_then(|worktree| worktree.read(cx).root_dir());
296
297    cx.spawn(async move |cx| {
298        let worktree_abs_path = worktree_abs_path.clone();
299        let project_env = environment
300            .update(cx, |environment, cx| {
301                environment.get_environment(worktree_abs_path.clone(), cx)
302            })
303            .ok()?
304            .await;
305
306        let mut task_variables = cx
307            .update(|cx| {
308                combine_task_variables(
309                    captured_variables,
310                    location,
311                    project_env.clone(),
312                    BasicContextProvider::new(worktree_store),
313                    toolchain_store,
314                    cx,
315                )
316            })
317            .ok()?
318            .await
319            .log_err()?;
320        // Remove all custom entries starting with _, as they're not intended for use by the end user.
321        task_variables.sweep();
322
323        Some(TaskContext {
324            project_env: project_env.unwrap_or_default(),
325            cwd: worktree_abs_path.map(|p| p.to_path_buf()),
326            task_variables,
327        })
328    })
329}
330
331fn remote_task_context_for_location(
332    project_id: u64,
333    upstream_client: AnyProtoClient,
334    worktree_store: Entity<WorktreeStore>,
335    captured_variables: TaskVariables,
336    location: Location,
337    toolchain_store: Arc<dyn LanguageToolchainStore>,
338    cx: &mut App,
339) -> Task<Option<TaskContext>> {
340    cx.spawn(async move |cx| {
341        // We need to gather a client context, as the headless one may lack certain information (e.g. tree-sitter parsing is disabled there, so symbols are not available).
342        let mut remote_context = cx
343            .update(|cx| {
344                BasicContextProvider::new(worktree_store).build_context(
345                    &TaskVariables::default(),
346                    &location,
347                    None,
348                    toolchain_store,
349                    cx,
350                )
351            })
352            .ok()?
353            .await
354            .log_err()
355            .unwrap_or_default();
356        remote_context.extend(captured_variables);
357
358        let buffer_id = cx
359            .update(|cx| location.buffer.read(cx).remote_id().to_proto())
360            .ok()?;
361        let context_task = upstream_client.request(proto::TaskContextForLocation {
362            project_id,
363            location: Some(proto::Location {
364                buffer_id,
365                start: Some(serialize_anchor(&location.range.start)),
366                end: Some(serialize_anchor(&location.range.end)),
367            }),
368            task_variables: remote_context
369                .into_iter()
370                .map(|(k, v)| (k.to_string(), v))
371                .collect(),
372        });
373        let task_context = context_task.await.log_err()?;
374        Some(TaskContext {
375            cwd: task_context.cwd.map(PathBuf::from),
376            task_variables: task_context
377                .task_variables
378                .into_iter()
379                .filter_map(
380                    |(variable_name, variable_value)| match variable_name.parse() {
381                        Ok(variable_name) => Some((variable_name, variable_value)),
382                        Err(()) => {
383                            log::error!("Unknown variable name: {variable_name}");
384                            None
385                        }
386                    },
387                )
388                .collect(),
389            project_env: task_context.project_env.into_iter().collect(),
390        })
391    })
392}
393
394fn combine_task_variables(
395    mut captured_variables: TaskVariables,
396    location: Location,
397    project_env: Option<HashMap<String, String>>,
398    baseline: BasicContextProvider,
399    toolchain_store: Arc<dyn LanguageToolchainStore>,
400    cx: &mut App,
401) -> Task<anyhow::Result<TaskVariables>> {
402    let language_context_provider = location
403        .buffer
404        .read(cx)
405        .language()
406        .and_then(|language| language.context_provider());
407    cx.spawn(async move |cx| {
408        let baseline = cx
409            .update(|cx| {
410                baseline.build_context(
411                    &captured_variables,
412                    &location,
413                    project_env.clone(),
414                    toolchain_store.clone(),
415                    cx,
416                )
417            })?
418            .await
419            .context("building basic default context")?;
420        captured_variables.extend(baseline);
421        if let Some(provider) = language_context_provider {
422            captured_variables.extend(
423                cx.update(|cx| {
424                    provider.build_context(
425                        &captured_variables,
426                        &location,
427                        project_env,
428                        toolchain_store,
429                        cx,
430                    )
431                })?
432                .await
433                .context("building provider context")?,
434            );
435        }
436        Ok(captured_variables)
437    })
438}