task_store.rs

  1use std::{
  2    path::{Path, PathBuf},
  3    sync::Arc,
  4};
  5
  6use anyhow::Context as _;
  7use collections::HashMap;
  8use fs::Fs;
  9use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
 10use language::{
 11    ContextLocation, ContextProvider as _, LanguageToolchainStore, Location,
 12    proto::{deserialize_anchor, serialize_anchor},
 13};
 14use rpc::{AnyProtoClient, TypedEnvelope, proto};
 15use settings::{InvalidSettingsError, SettingsLocation};
 16use task::{TaskContext, TaskVariables, VariableName};
 17use text::{BufferId, OffsetRangeExt};
 18use util::ResultExt;
 19
 20use crate::{
 21    BasicContextProvider, Inventory, ProjectEnvironment, buffer_store::BufferStore,
 22    git_store::GitStore, worktree_store::WorktreeStore,
 23};
 24
 25// platform-dependent warning
 26pub enum TaskStore {
 27    Functional(StoreState),
 28    Noop,
 29}
 30
 31pub struct StoreState {
 32    mode: StoreMode,
 33    task_inventory: Entity<Inventory>,
 34    buffer_store: WeakEntity<BufferStore>,
 35    worktree_store: Entity<WorktreeStore>,
 36    git_store: Entity<GitStore>,
 37    toolchain_store: Arc<dyn LanguageToolchainStore>,
 38}
 39
 40enum StoreMode {
 41    Local {
 42        downstream_client: Option<(AnyProtoClient, u64)>,
 43        environment: Entity<ProjectEnvironment>,
 44    },
 45    Remote {
 46        upstream_client: AnyProtoClient,
 47        project_id: u64,
 48    },
 49}
 50
 51impl EventEmitter<crate::Event> for TaskStore {}
 52
 53#[derive(Debug)]
 54pub enum TaskSettingsLocation<'a> {
 55    Global(&'a Path),
 56    Worktree(SettingsLocation<'a>),
 57}
 58
 59impl TaskStore {
 60    pub fn init(client: Option<&AnyProtoClient>) {
 61        if let Some(client) = client {
 62            client.add_entity_request_handler(Self::handle_task_context_for_location);
 63        }
 64    }
 65
 66    async fn handle_task_context_for_location(
 67        store: Entity<Self>,
 68        envelope: TypedEnvelope<proto::TaskContextForLocation>,
 69        mut cx: AsyncApp,
 70    ) -> anyhow::Result<proto::TaskContext> {
 71        let location = envelope
 72            .payload
 73            .location
 74            .context("no location given for task context handling")?;
 75        let (buffer_store, is_remote) = store.read_with(&cx, |store, _| {
 76            Ok(match store {
 77                TaskStore::Functional(state) => (
 78                    state.buffer_store.clone(),
 79                    match &state.mode {
 80                        StoreMode::Local { .. } => false,
 81                        StoreMode::Remote { .. } => true,
 82                    },
 83                ),
 84                TaskStore::Noop => {
 85                    anyhow::bail!("empty task store cannot handle task context requests")
 86                }
 87            })
 88        })?;
 89        let buffer_store = buffer_store
 90            .upgrade()
 91            .context("no buffer store when handling task context request")?;
 92
 93        let buffer_id = BufferId::new(location.buffer_id).with_context(|| {
 94            format!(
 95                "cannot handle task context request for invalid buffer id: {}",
 96                location.buffer_id
 97            )
 98        })?;
 99
100        let start = location
101            .start
102            .and_then(deserialize_anchor)
103            .context("missing task context location start")?;
104        let end = location
105            .end
106            .and_then(deserialize_anchor)
107            .context("missing task context location end")?;
108        let buffer = buffer_store
109            .update(&mut cx, |buffer_store, cx| {
110                if is_remote {
111                    buffer_store.wait_for_remote_buffer(buffer_id, cx)
112                } else {
113                    Task::ready(
114                        buffer_store
115                            .get(buffer_id)
116                            .with_context(|| format!("no local buffer with id {buffer_id}")),
117                    )
118                }
119            })
120            .await?;
121
122        let location = Location {
123            buffer,
124            range: start..end,
125        };
126        let context_task = store.update(&mut cx, |store, cx| {
127            let captured_variables = {
128                let mut variables = TaskVariables::from_iter(
129                    envelope
130                        .payload
131                        .task_variables
132                        .into_iter()
133                        .filter_map(|(k, v)| Some((k.parse().log_err()?, v))),
134                );
135
136                let snapshot = location.buffer.read(cx).snapshot();
137                let range = location.range.to_offset(&snapshot);
138
139                for range in snapshot.runnable_ranges(range) {
140                    for (capture_name, value) in range.extra_captures {
141                        variables.insert(VariableName::Custom(capture_name.into()), value);
142                    }
143                }
144                variables
145            };
146            store.task_context_for_location(captured_variables, location, cx)
147        });
148        let task_context = context_task.await.unwrap_or_default();
149        Ok(proto::TaskContext {
150            project_env: task_context.project_env.into_iter().collect(),
151            cwd: task_context
152                .cwd
153                .map(|cwd| cwd.to_string_lossy().into_owned()),
154            task_variables: task_context
155                .task_variables
156                .into_iter()
157                .map(|(variable_name, variable_value)| (variable_name.to_string(), variable_value))
158                .collect(),
159        })
160    }
161
162    pub fn local(
163        buffer_store: WeakEntity<BufferStore>,
164        worktree_store: Entity<WorktreeStore>,
165        toolchain_store: Arc<dyn LanguageToolchainStore>,
166        environment: Entity<ProjectEnvironment>,
167        git_store: Entity<GitStore>,
168        cx: &mut Context<Self>,
169    ) -> Self {
170        Self::Functional(StoreState {
171            mode: StoreMode::Local {
172                downstream_client: None,
173                environment,
174            },
175            task_inventory: Inventory::new(cx),
176            buffer_store,
177            git_store,
178            toolchain_store,
179            worktree_store,
180        })
181    }
182
183    pub fn remote(
184        buffer_store: WeakEntity<BufferStore>,
185        worktree_store: Entity<WorktreeStore>,
186        toolchain_store: Arc<dyn LanguageToolchainStore>,
187        upstream_client: AnyProtoClient,
188        project_id: u64,
189        git_store: Entity<GitStore>,
190        cx: &mut Context<Self>,
191    ) -> Self {
192        Self::Functional(StoreState {
193            mode: StoreMode::Remote {
194                upstream_client,
195                project_id,
196            },
197            task_inventory: Inventory::new(cx),
198            buffer_store,
199            git_store,
200            toolchain_store,
201            worktree_store,
202        })
203    }
204
205    pub fn task_context_for_location(
206        &self,
207        captured_variables: TaskVariables,
208        location: Location,
209        cx: &mut App,
210    ) -> Task<Option<TaskContext>> {
211        match self {
212            TaskStore::Functional(state) => match &state.mode {
213                StoreMode::Local { environment, .. } => local_task_context_for_location(
214                    state.worktree_store.clone(),
215                    state.git_store.clone(),
216                    state.toolchain_store.clone(),
217                    environment.clone(),
218                    captured_variables,
219                    location,
220                    cx,
221                ),
222                StoreMode::Remote {
223                    upstream_client,
224                    project_id,
225                } => remote_task_context_for_location(
226                    *project_id,
227                    upstream_client.clone(),
228                    state.worktree_store.clone(),
229                    state.git_store.clone(),
230                    captured_variables,
231                    location,
232                    state.toolchain_store.clone(),
233                    cx,
234                ),
235            },
236            TaskStore::Noop => Task::ready(None),
237        }
238    }
239
240    pub fn task_inventory(&self) -> Option<&Entity<Inventory>> {
241        match self {
242            TaskStore::Functional(state) => Some(&state.task_inventory),
243            TaskStore::Noop => None,
244        }
245    }
246
247    pub fn shared(&mut self, remote_id: u64, new_downstream_client: AnyProtoClient, _cx: &mut App) {
248        if let Self::Functional(StoreState {
249            mode: StoreMode::Local {
250                downstream_client, ..
251            },
252            ..
253        }) = self
254        {
255            *downstream_client = Some((new_downstream_client, remote_id));
256        }
257    }
258
259    pub fn unshared(&mut self, _: &mut Context<Self>) {
260        if let Self::Functional(StoreState {
261            mode: StoreMode::Local {
262                downstream_client, ..
263            },
264            ..
265        }) = self
266        {
267            *downstream_client = None;
268        }
269    }
270
271    pub(super) fn update_user_tasks(
272        &self,
273        location: TaskSettingsLocation<'_>,
274        raw_tasks_json: Option<&str>,
275        cx: &mut Context<Self>,
276    ) -> Result<(), InvalidSettingsError> {
277        let task_inventory = match self {
278            TaskStore::Functional(state) => &state.task_inventory,
279            TaskStore::Noop => return Ok(()),
280        };
281        let raw_tasks_json = raw_tasks_json
282            .map(|json| json.trim())
283            .filter(|json| !json.is_empty());
284
285        task_inventory.update(cx, |inventory, _| {
286            inventory.update_file_based_tasks(location, raw_tasks_json)
287        })
288    }
289
290    pub(super) fn update_user_debug_scenarios(
291        &self,
292        location: TaskSettingsLocation<'_>,
293        raw_tasks_json: Option<&str>,
294        cx: &mut Context<Self>,
295    ) -> Result<(), InvalidSettingsError> {
296        let task_inventory = match self {
297            TaskStore::Functional(state) => &state.task_inventory,
298            TaskStore::Noop => return Ok(()),
299        };
300        let raw_tasks_json = raw_tasks_json
301            .map(|json| json.trim())
302            .filter(|json| !json.is_empty());
303
304        task_inventory.update(cx, |inventory, _| {
305            inventory.update_file_based_scenarios(location, raw_tasks_json)
306        })
307    }
308}
309
310fn local_task_context_for_location(
311    worktree_store: Entity<WorktreeStore>,
312    git_store: Entity<GitStore>,
313    toolchain_store: Arc<dyn LanguageToolchainStore>,
314    environment: Entity<ProjectEnvironment>,
315    captured_variables: TaskVariables,
316    location: Location,
317    cx: &App,
318) -> Task<Option<TaskContext>> {
319    let worktree_id = location.buffer.read(cx).file().map(|f| f.worktree_id(cx));
320    let worktree_abs_path = worktree_id
321        .and_then(|worktree_id| worktree_store.read(cx).worktree_for_id(worktree_id, cx))
322        .and_then(|worktree| worktree.read(cx).root_dir());
323    let fs = worktree_store.read(cx).fs();
324
325    cx.spawn(async move |cx| {
326        let project_env = environment
327            .update(cx, |environment, cx| {
328                environment.buffer_environment(&location.buffer, &worktree_store, cx)
329            })
330            .await;
331
332        let mut task_variables = cx
333            .update(|cx| {
334                combine_task_variables(
335                    captured_variables,
336                    fs,
337                    worktree_store.clone(),
338                    location,
339                    project_env.clone(),
340                    BasicContextProvider::new(worktree_store, git_store),
341                    toolchain_store,
342                    cx,
343                )
344            })
345            .await
346            .log_err()?;
347        // Remove all custom entries starting with _, as they're not intended for use by the end user.
348        task_variables.sweep();
349
350        Some(TaskContext {
351            project_env: project_env.unwrap_or_default(),
352            cwd: worktree_abs_path.map(|p| p.to_path_buf()),
353            task_variables,
354        })
355    })
356}
357
358fn remote_task_context_for_location(
359    project_id: u64,
360    upstream_client: AnyProtoClient,
361    worktree_store: Entity<WorktreeStore>,
362    git_store: Entity<GitStore>,
363    captured_variables: TaskVariables,
364    location: Location,
365    toolchain_store: Arc<dyn LanguageToolchainStore>,
366    cx: &mut App,
367) -> Task<Option<TaskContext>> {
368    cx.spawn(async move |cx| {
369        // We need to gather a client context, as the headless one may lack certain information (e.g. tree-sitter parsing is disabled there, so symbols are not available).
370        let mut remote_context = cx
371            .update(|cx| {
372                let worktree_root = worktree_root(&worktree_store, &location, cx);
373
374                BasicContextProvider::new(worktree_store, git_store).build_context(
375                    &TaskVariables::default(),
376                    ContextLocation {
377                        fs: None,
378                        worktree_root,
379                        file_location: &location,
380                    },
381                    None,
382                    toolchain_store,
383                    cx,
384                )
385            })
386            .await
387            .log_err()
388            .unwrap_or_default();
389        remote_context.extend(captured_variables);
390
391        let buffer_id = cx.update(|cx| location.buffer.read(cx).remote_id().to_proto());
392        let context_task = upstream_client.request(proto::TaskContextForLocation {
393            project_id,
394            location: Some(proto::Location {
395                buffer_id,
396                start: Some(serialize_anchor(&location.range.start)),
397                end: Some(serialize_anchor(&location.range.end)),
398            }),
399            task_variables: remote_context
400                .into_iter()
401                .map(|(k, v)| (k.to_string(), v))
402                .collect(),
403        });
404        let task_context = context_task.await.log_err()?;
405        Some(TaskContext {
406            cwd: task_context.cwd.map(PathBuf::from),
407            task_variables: task_context
408                .task_variables
409                .into_iter()
410                .filter_map(
411                    |(variable_name, variable_value)| match variable_name.parse() {
412                        Ok(variable_name) => Some((variable_name, variable_value)),
413                        Err(()) => {
414                            log::error!("Unknown variable name: {variable_name}");
415                            None
416                        }
417                    },
418                )
419                .collect(),
420            project_env: task_context.project_env.into_iter().collect(),
421        })
422    })
423}
424
425fn worktree_root(
426    worktree_store: &Entity<WorktreeStore>,
427    location: &Location,
428    cx: &mut App,
429) -> Option<PathBuf> {
430    location
431        .buffer
432        .read(cx)
433        .file()
434        .map(|f| f.worktree_id(cx))
435        .and_then(|worktree_id| worktree_store.read(cx).worktree_for_id(worktree_id, cx))
436        .and_then(|worktree| {
437            let worktree = worktree.read(cx);
438            if !worktree.is_visible() {
439                return None;
440            }
441            let root_entry = worktree.root_entry()?;
442            if !root_entry.is_dir() {
443                return None;
444            }
445            Some(worktree.absolutize(&root_entry.path))
446        })
447}
448
449fn combine_task_variables(
450    mut captured_variables: TaskVariables,
451    fs: Option<Arc<dyn Fs>>,
452    worktree_store: Entity<WorktreeStore>,
453    location: Location,
454    project_env: Option<HashMap<String, String>>,
455    baseline: BasicContextProvider,
456    toolchain_store: Arc<dyn LanguageToolchainStore>,
457    cx: &mut App,
458) -> Task<anyhow::Result<TaskVariables>> {
459    let language_context_provider = location
460        .buffer
461        .read(cx)
462        .language()
463        .and_then(|language| language.context_provider());
464    cx.spawn(async move |cx| {
465        let baseline = cx
466            .update(|cx| {
467                let worktree_root = worktree_root(&worktree_store, &location, cx);
468                baseline.build_context(
469                    &captured_variables,
470                    ContextLocation {
471                        fs: fs.clone(),
472                        worktree_root,
473                        file_location: &location,
474                    },
475                    project_env.clone(),
476                    toolchain_store.clone(),
477                    cx,
478                )
479            })
480            .await
481            .context("building basic default context")?;
482        captured_variables.extend(baseline);
483        if let Some(provider) = language_context_provider {
484            captured_variables.extend(
485                cx.update(|cx| {
486                    let worktree_root = worktree_root(&worktree_store, &location, cx);
487                    provider.build_context(
488                        &captured_variables,
489                        ContextLocation {
490                            fs,
491                            worktree_root,
492                            file_location: &location,
493                        },
494                        project_env,
495                        toolchain_store,
496                        cx,
497                    )
498                })
499                .await
500                .context("building provider context")?,
501            );
502        }
503        Ok(captured_variables)
504    })
505}