task_store.rs

  1use std::{
  2    path::{Path, PathBuf},
  3    sync::Arc,
  4};
  5
  6use anyhow::Context as _;
  7use collections::HashMap;
  8use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
  9use language::{
 10    ContextProvider as _, LanguageToolchainStore, Location,
 11    proto::{deserialize_anchor, serialize_anchor},
 12};
 13use rpc::{AnyProtoClient, TypedEnvelope, proto};
 14use settings::{InvalidSettingsError, SettingsLocation, TaskKind};
 15use task::{TaskContext, TaskVariables, VariableName};
 16use text::{BufferId, OffsetRangeExt};
 17use util::ResultExt;
 18
 19use crate::{
 20    BasicContextProvider, Inventory, ProjectEnvironment, buffer_store::BufferStore,
 21    worktree_store::WorktreeStore,
 22};
 23
 24#[allow(clippy::large_enum_variant)] // platform-dependent warning
 25pub enum TaskStore {
 26    Functional(StoreState),
 27    Noop,
 28}
 29
 30pub struct StoreState {
 31    mode: StoreMode,
 32    task_inventory: Entity<Inventory>,
 33    buffer_store: WeakEntity<BufferStore>,
 34    worktree_store: Entity<WorktreeStore>,
 35    toolchain_store: Arc<dyn LanguageToolchainStore>,
 36}
 37
 38enum StoreMode {
 39    Local {
 40        downstream_client: Option<(AnyProtoClient, u64)>,
 41        environment: Entity<ProjectEnvironment>,
 42    },
 43    Remote {
 44        upstream_client: AnyProtoClient,
 45        project_id: u64,
 46    },
 47}
 48
 49impl EventEmitter<crate::Event> for TaskStore {}
 50
 51#[derive(Debug)]
 52pub enum TaskSettingsLocation<'a> {
 53    Global(&'a Path),
 54    Worktree(SettingsLocation<'a>),
 55}
 56
 57impl TaskStore {
 58    pub fn init(client: Option<&AnyProtoClient>) {
 59        if let Some(client) = client {
 60            client.add_entity_request_handler(Self::handle_task_context_for_location);
 61        }
 62    }
 63
 64    async fn handle_task_context_for_location(
 65        store: Entity<Self>,
 66        envelope: TypedEnvelope<proto::TaskContextForLocation>,
 67        mut cx: AsyncApp,
 68    ) -> anyhow::Result<proto::TaskContext> {
 69        let location = envelope
 70            .payload
 71            .location
 72            .context("no location given for task context handling")?;
 73        let (buffer_store, is_remote) = store.update(&mut cx, |store, _| {
 74            Ok(match store {
 75                TaskStore::Functional(state) => (
 76                    state.buffer_store.clone(),
 77                    match &state.mode {
 78                        StoreMode::Local { .. } => false,
 79                        StoreMode::Remote { .. } => true,
 80                    },
 81                ),
 82                TaskStore::Noop => {
 83                    anyhow::bail!("empty task store cannot handle task context requests")
 84                }
 85            })
 86        })??;
 87        let buffer_store = buffer_store
 88            .upgrade()
 89            .context("no buffer store when handling task context request")?;
 90
 91        let buffer_id = BufferId::new(location.buffer_id).with_context(|| {
 92            format!(
 93                "cannot handle task context request for invalid buffer id: {}",
 94                location.buffer_id
 95            )
 96        })?;
 97
 98        let start = location
 99            .start
100            .and_then(deserialize_anchor)
101            .context("missing task context location start")?;
102        let end = location
103            .end
104            .and_then(deserialize_anchor)
105            .context("missing task context location end")?;
106        let buffer = buffer_store
107            .update(&mut cx, |buffer_store, cx| {
108                if is_remote {
109                    buffer_store.wait_for_remote_buffer(buffer_id, cx)
110                } else {
111                    Task::ready(
112                        buffer_store
113                            .get(buffer_id)
114                            .with_context(|| format!("no local buffer with id {buffer_id}")),
115                    )
116                }
117            })?
118            .await?;
119
120        let location = Location {
121            buffer,
122            range: start..end,
123        };
124        let context_task = store.update(&mut cx, |store, cx| {
125            let captured_variables = {
126                let mut variables = TaskVariables::from_iter(
127                    envelope
128                        .payload
129                        .task_variables
130                        .into_iter()
131                        .filter_map(|(k, v)| Some((k.parse().log_err()?, v))),
132                );
133
134                let snapshot = location.buffer.read(cx).snapshot();
135                let range = location.range.to_offset(&snapshot);
136
137                for range in snapshot.runnable_ranges(range) {
138                    for (capture_name, value) in range.extra_captures {
139                        variables.insert(VariableName::Custom(capture_name.into()), value);
140                    }
141                }
142                variables
143            };
144            store.task_context_for_location(captured_variables, location, cx)
145        })?;
146        let task_context = context_task.await.unwrap_or_default();
147        Ok(proto::TaskContext {
148            project_env: task_context.project_env.into_iter().collect(),
149            cwd: task_context
150                .cwd
151                .map(|cwd| cwd.to_string_lossy().to_string()),
152            task_variables: task_context
153                .task_variables
154                .into_iter()
155                .map(|(variable_name, variable_value)| (variable_name.to_string(), variable_value))
156                .collect(),
157        })
158    }
159
160    pub fn local(
161        buffer_store: WeakEntity<BufferStore>,
162        worktree_store: Entity<WorktreeStore>,
163        toolchain_store: Arc<dyn LanguageToolchainStore>,
164        environment: Entity<ProjectEnvironment>,
165        cx: &mut Context<Self>,
166    ) -> Self {
167        Self::Functional(StoreState {
168            mode: StoreMode::Local {
169                downstream_client: None,
170                environment,
171            },
172            task_inventory: Inventory::new(cx),
173            buffer_store,
174            toolchain_store,
175            worktree_store,
176        })
177    }
178
179    pub fn remote(
180        buffer_store: WeakEntity<BufferStore>,
181        worktree_store: Entity<WorktreeStore>,
182        toolchain_store: Arc<dyn LanguageToolchainStore>,
183        upstream_client: AnyProtoClient,
184        project_id: u64,
185        cx: &mut Context<Self>,
186    ) -> Self {
187        Self::Functional(StoreState {
188            mode: StoreMode::Remote {
189                upstream_client,
190                project_id,
191            },
192            task_inventory: Inventory::new(cx),
193            buffer_store,
194            toolchain_store,
195            worktree_store,
196        })
197    }
198
199    pub fn task_context_for_location(
200        &self,
201        captured_variables: TaskVariables,
202        location: Location,
203        cx: &mut App,
204    ) -> Task<Option<TaskContext>> {
205        match self {
206            TaskStore::Functional(state) => match &state.mode {
207                StoreMode::Local { environment, .. } => local_task_context_for_location(
208                    state.worktree_store.clone(),
209                    state.toolchain_store.clone(),
210                    environment.clone(),
211                    captured_variables,
212                    location,
213                    cx,
214                ),
215                StoreMode::Remote {
216                    upstream_client,
217                    project_id,
218                } => remote_task_context_for_location(
219                    *project_id,
220                    upstream_client.clone(),
221                    state.worktree_store.clone(),
222                    captured_variables,
223                    location,
224                    state.toolchain_store.clone(),
225                    cx,
226                ),
227            },
228            TaskStore::Noop => Task::ready(None),
229        }
230    }
231
232    pub fn task_inventory(&self) -> Option<&Entity<Inventory>> {
233        match self {
234            TaskStore::Functional(state) => Some(&state.task_inventory),
235            TaskStore::Noop => None,
236        }
237    }
238
239    pub fn shared(&mut self, remote_id: u64, new_downstream_client: AnyProtoClient, _cx: &mut App) {
240        if let Self::Functional(StoreState {
241            mode: StoreMode::Local {
242                downstream_client, ..
243            },
244            ..
245        }) = self
246        {
247            *downstream_client = Some((new_downstream_client, remote_id));
248        }
249    }
250
251    pub fn unshared(&mut self, _: &mut Context<Self>) {
252        if let Self::Functional(StoreState {
253            mode: StoreMode::Local {
254                downstream_client, ..
255            },
256            ..
257        }) = self
258        {
259            *downstream_client = None;
260        }
261    }
262
263    pub(super) fn update_user_tasks(
264        &self,
265        location: TaskSettingsLocation<'_>,
266        raw_tasks_json: Option<&str>,
267        task_type: TaskKind,
268        cx: &mut Context<Self>,
269    ) -> Result<(), InvalidSettingsError> {
270        let task_inventory = match self {
271            TaskStore::Functional(state) => &state.task_inventory,
272            TaskStore::Noop => return Ok(()),
273        };
274        let raw_tasks_json = raw_tasks_json
275            .map(|json| json.trim())
276            .filter(|json| !json.is_empty());
277
278        task_inventory.update(cx, |inventory, _| {
279            inventory.update_file_based_tasks(location, raw_tasks_json, task_type)
280        })
281    }
282}
283
284fn local_task_context_for_location(
285    worktree_store: Entity<WorktreeStore>,
286    toolchain_store: Arc<dyn LanguageToolchainStore>,
287    environment: Entity<ProjectEnvironment>,
288    captured_variables: TaskVariables,
289    location: Location,
290    cx: &App,
291) -> Task<Option<TaskContext>> {
292    let worktree_id = location.buffer.read(cx).file().map(|f| f.worktree_id(cx));
293    let worktree_abs_path = worktree_id
294        .and_then(|worktree_id| worktree_store.read(cx).worktree_for_id(worktree_id, cx))
295        .and_then(|worktree| worktree.read(cx).root_dir());
296
297    cx.spawn(async move |cx| {
298        let project_env = environment
299            .update(cx, |environment, cx| {
300                environment.get_buffer_environment(
301                    location.buffer.clone(),
302                    worktree_store.clone(),
303                    cx,
304                )
305            })
306            .ok()?
307            .await;
308
309        let mut task_variables = cx
310            .update(|cx| {
311                combine_task_variables(
312                    captured_variables,
313                    location,
314                    project_env.clone(),
315                    BasicContextProvider::new(worktree_store),
316                    toolchain_store,
317                    cx,
318                )
319            })
320            .ok()?
321            .await
322            .log_err()?;
323        // Remove all custom entries starting with _, as they're not intended for use by the end user.
324        task_variables.sweep();
325
326        Some(TaskContext {
327            project_env: project_env.unwrap_or_default(),
328            cwd: worktree_abs_path.map(|p| p.to_path_buf()),
329            task_variables,
330        })
331    })
332}
333
334fn remote_task_context_for_location(
335    project_id: u64,
336    upstream_client: AnyProtoClient,
337    worktree_store: Entity<WorktreeStore>,
338    captured_variables: TaskVariables,
339    location: Location,
340    toolchain_store: Arc<dyn LanguageToolchainStore>,
341    cx: &mut App,
342) -> Task<Option<TaskContext>> {
343    cx.spawn(async move |cx| {
344        // We need to gather a client context, as the headless one may lack certain information (e.g. tree-sitter parsing is disabled there, so symbols are not available).
345        let mut remote_context = cx
346            .update(|cx| {
347                BasicContextProvider::new(worktree_store).build_context(
348                    &TaskVariables::default(),
349                    &location,
350                    None,
351                    toolchain_store,
352                    cx,
353                )
354            })
355            .ok()?
356            .await
357            .log_err()
358            .unwrap_or_default();
359        remote_context.extend(captured_variables);
360
361        let buffer_id = cx
362            .update(|cx| location.buffer.read(cx).remote_id().to_proto())
363            .ok()?;
364        let context_task = upstream_client.request(proto::TaskContextForLocation {
365            project_id,
366            location: Some(proto::Location {
367                buffer_id,
368                start: Some(serialize_anchor(&location.range.start)),
369                end: Some(serialize_anchor(&location.range.end)),
370            }),
371            task_variables: remote_context
372                .into_iter()
373                .map(|(k, v)| (k.to_string(), v))
374                .collect(),
375        });
376        let task_context = context_task.await.log_err()?;
377        Some(TaskContext {
378            cwd: task_context.cwd.map(PathBuf::from),
379            task_variables: task_context
380                .task_variables
381                .into_iter()
382                .filter_map(
383                    |(variable_name, variable_value)| match variable_name.parse() {
384                        Ok(variable_name) => Some((variable_name, variable_value)),
385                        Err(()) => {
386                            log::error!("Unknown variable name: {variable_name}");
387                            None
388                        }
389                    },
390                )
391                .collect(),
392            project_env: task_context.project_env.into_iter().collect(),
393        })
394    })
395}
396
397fn combine_task_variables(
398    mut captured_variables: TaskVariables,
399    location: Location,
400    project_env: Option<HashMap<String, String>>,
401    baseline: BasicContextProvider,
402    toolchain_store: Arc<dyn LanguageToolchainStore>,
403    cx: &mut App,
404) -> Task<anyhow::Result<TaskVariables>> {
405    let language_context_provider = location
406        .buffer
407        .read(cx)
408        .language()
409        .and_then(|language| language.context_provider());
410    cx.spawn(async move |cx| {
411        let baseline = cx
412            .update(|cx| {
413                baseline.build_context(
414                    &captured_variables,
415                    &location,
416                    project_env.clone(),
417                    toolchain_store.clone(),
418                    cx,
419                )
420            })?
421            .await
422            .context("building basic default context")?;
423        captured_variables.extend(baseline);
424        if let Some(provider) = language_context_provider {
425            captured_variables.extend(
426                cx.update(|cx| {
427                    provider.build_context(
428                        &captured_variables,
429                        &location,
430                        project_env,
431                        toolchain_store,
432                        cx,
433                    )
434                })?
435                .await
436                .context("building provider context")?,
437            );
438        }
439        Ok(captured_variables)
440    })
441}