task_store.rs

  1use std::{
  2    path::{Path, PathBuf},
  3    sync::Arc,
  4};
  5
  6use anyhow::Context as _;
  7use collections::HashMap;
  8use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
  9use language::{
 10    ContextProvider as _, LanguageToolchainStore, Location,
 11    proto::{deserialize_anchor, serialize_anchor},
 12};
 13use rpc::{AnyProtoClient, TypedEnvelope, proto};
 14use settings::{InvalidSettingsError, SettingsLocation};
 15use task::{TaskContext, TaskVariables, VariableName};
 16use text::{BufferId, OffsetRangeExt};
 17use util::ResultExt;
 18
 19use crate::{
 20    BasicContextProvider, Inventory, ProjectEnvironment, buffer_store::BufferStore,
 21    worktree_store::WorktreeStore,
 22};
 23
 24#[allow(clippy::large_enum_variant)] // platform-dependent warning
 25pub enum TaskStore {
 26    Functional(StoreState),
 27    Noop,
 28}
 29
 30pub struct StoreState {
 31    mode: StoreMode,
 32    task_inventory: Entity<Inventory>,
 33    buffer_store: WeakEntity<BufferStore>,
 34    worktree_store: Entity<WorktreeStore>,
 35    toolchain_store: Arc<dyn LanguageToolchainStore>,
 36}
 37
 38enum StoreMode {
 39    Local {
 40        downstream_client: Option<(AnyProtoClient, u64)>,
 41        environment: Entity<ProjectEnvironment>,
 42    },
 43    Remote {
 44        upstream_client: AnyProtoClient,
 45        project_id: u64,
 46    },
 47}
 48
 49impl EventEmitter<crate::Event> for TaskStore {}
 50
 51#[derive(Debug)]
 52pub enum TaskSettingsLocation<'a> {
 53    Global(&'a Path),
 54    Worktree(SettingsLocation<'a>),
 55}
 56
 57impl TaskStore {
 58    pub fn init(client: Option<&AnyProtoClient>) {
 59        if let Some(client) = client {
 60            client.add_entity_request_handler(Self::handle_task_context_for_location);
 61        }
 62    }
 63
 64    async fn handle_task_context_for_location(
 65        store: Entity<Self>,
 66        envelope: TypedEnvelope<proto::TaskContextForLocation>,
 67        mut cx: AsyncApp,
 68    ) -> anyhow::Result<proto::TaskContext> {
 69        let location = envelope
 70            .payload
 71            .location
 72            .context("no location given for task context handling")?;
 73        let (buffer_store, is_remote) = store.update(&mut cx, |store, _| {
 74            Ok(match store {
 75                TaskStore::Functional(state) => (
 76                    state.buffer_store.clone(),
 77                    match &state.mode {
 78                        StoreMode::Local { .. } => false,
 79                        StoreMode::Remote { .. } => true,
 80                    },
 81                ),
 82                TaskStore::Noop => {
 83                    anyhow::bail!("empty task store cannot handle task context requests")
 84                }
 85            })
 86        })??;
 87        let buffer_store = buffer_store
 88            .upgrade()
 89            .context("no buffer store when handling task context request")?;
 90
 91        let buffer_id = BufferId::new(location.buffer_id).with_context(|| {
 92            format!(
 93                "cannot handle task context request for invalid buffer id: {}",
 94                location.buffer_id
 95            )
 96        })?;
 97
 98        let start = location
 99            .start
100            .and_then(deserialize_anchor)
101            .context("missing task context location start")?;
102        let end = location
103            .end
104            .and_then(deserialize_anchor)
105            .context("missing task context location end")?;
106        let buffer = buffer_store
107            .update(&mut cx, |buffer_store, cx| {
108                if is_remote {
109                    buffer_store.wait_for_remote_buffer(buffer_id, cx)
110                } else {
111                    Task::ready(
112                        buffer_store
113                            .get(buffer_id)
114                            .with_context(|| format!("no local buffer with id {buffer_id}")),
115                    )
116                }
117            })?
118            .await?;
119
120        let location = Location {
121            buffer,
122            range: start..end,
123        };
124        let context_task = store.update(&mut cx, |store, cx| {
125            let captured_variables = {
126                let mut variables = TaskVariables::from_iter(
127                    envelope
128                        .payload
129                        .task_variables
130                        .into_iter()
131                        .filter_map(|(k, v)| Some((k.parse().log_err()?, v))),
132                );
133
134                let snapshot = location.buffer.read(cx).snapshot();
135                let range = location.range.to_offset(&snapshot);
136
137                for range in snapshot.runnable_ranges(range) {
138                    for (capture_name, value) in range.extra_captures {
139                        variables.insert(VariableName::Custom(capture_name.into()), value);
140                    }
141                }
142                variables
143            };
144            store.task_context_for_location(captured_variables, location, cx)
145        })?;
146        let task_context = context_task.await.unwrap_or_default();
147        Ok(proto::TaskContext {
148            project_env: task_context.project_env.into_iter().collect(),
149            cwd: task_context
150                .cwd
151                .map(|cwd| cwd.to_string_lossy().to_string()),
152            task_variables: task_context
153                .task_variables
154                .into_iter()
155                .map(|(variable_name, variable_value)| (variable_name.to_string(), variable_value))
156                .collect(),
157        })
158    }
159
160    pub fn local(
161        buffer_store: WeakEntity<BufferStore>,
162        worktree_store: Entity<WorktreeStore>,
163        toolchain_store: Arc<dyn LanguageToolchainStore>,
164        environment: Entity<ProjectEnvironment>,
165        cx: &mut Context<Self>,
166    ) -> Self {
167        Self::Functional(StoreState {
168            mode: StoreMode::Local {
169                downstream_client: None,
170                environment,
171            },
172            task_inventory: Inventory::new(cx),
173            buffer_store,
174            toolchain_store,
175            worktree_store,
176        })
177    }
178
179    pub fn remote(
180        buffer_store: WeakEntity<BufferStore>,
181        worktree_store: Entity<WorktreeStore>,
182        toolchain_store: Arc<dyn LanguageToolchainStore>,
183        upstream_client: AnyProtoClient,
184        project_id: u64,
185        cx: &mut Context<Self>,
186    ) -> Self {
187        Self::Functional(StoreState {
188            mode: StoreMode::Remote {
189                upstream_client,
190                project_id,
191            },
192            task_inventory: Inventory::new(cx),
193            buffer_store,
194            toolchain_store,
195            worktree_store,
196        })
197    }
198
199    pub fn task_context_for_location(
200        &self,
201        captured_variables: TaskVariables,
202        location: Location,
203        cx: &mut App,
204    ) -> Task<Option<TaskContext>> {
205        match self {
206            TaskStore::Functional(state) => match &state.mode {
207                StoreMode::Local { environment, .. } => local_task_context_for_location(
208                    state.worktree_store.clone(),
209                    state.toolchain_store.clone(),
210                    environment.clone(),
211                    captured_variables,
212                    location,
213                    cx,
214                ),
215                StoreMode::Remote {
216                    upstream_client,
217                    project_id,
218                } => remote_task_context_for_location(
219                    *project_id,
220                    upstream_client.clone(),
221                    state.worktree_store.clone(),
222                    captured_variables,
223                    location,
224                    state.toolchain_store.clone(),
225                    cx,
226                ),
227            },
228            TaskStore::Noop => Task::ready(None),
229        }
230    }
231
232    pub fn task_inventory(&self) -> Option<&Entity<Inventory>> {
233        match self {
234            TaskStore::Functional(state) => Some(&state.task_inventory),
235            TaskStore::Noop => None,
236        }
237    }
238
239    pub fn shared(&mut self, remote_id: u64, new_downstream_client: AnyProtoClient, _cx: &mut App) {
240        if let Self::Functional(StoreState {
241            mode: StoreMode::Local {
242                downstream_client, ..
243            },
244            ..
245        }) = self
246        {
247            *downstream_client = Some((new_downstream_client, remote_id));
248        }
249    }
250
251    pub fn unshared(&mut self, _: &mut Context<Self>) {
252        if let Self::Functional(StoreState {
253            mode: StoreMode::Local {
254                downstream_client, ..
255            },
256            ..
257        }) = self
258        {
259            *downstream_client = None;
260        }
261    }
262
263    pub(super) fn update_user_tasks(
264        &self,
265        location: TaskSettingsLocation<'_>,
266        raw_tasks_json: Option<&str>,
267        cx: &mut Context<Self>,
268    ) -> Result<(), InvalidSettingsError> {
269        let task_inventory = match self {
270            TaskStore::Functional(state) => &state.task_inventory,
271            TaskStore::Noop => return Ok(()),
272        };
273        let raw_tasks_json = raw_tasks_json
274            .map(|json| json.trim())
275            .filter(|json| !json.is_empty());
276
277        task_inventory.update(cx, |inventory, _| {
278            inventory.update_file_based_tasks(location, raw_tasks_json)
279        })
280    }
281
282    pub(super) fn update_user_debug_scenarios(
283        &self,
284        location: TaskSettingsLocation<'_>,
285        raw_tasks_json: Option<&str>,
286        cx: &mut Context<Self>,
287    ) -> Result<(), InvalidSettingsError> {
288        let task_inventory = match self {
289            TaskStore::Functional(state) => &state.task_inventory,
290            TaskStore::Noop => return Ok(()),
291        };
292        let raw_tasks_json = raw_tasks_json
293            .map(|json| json.trim())
294            .filter(|json| !json.is_empty());
295
296        task_inventory.update(cx, |inventory, _| {
297            inventory.update_file_based_scenarios(location, raw_tasks_json)
298        })
299    }
300}
301
302fn local_task_context_for_location(
303    worktree_store: Entity<WorktreeStore>,
304    toolchain_store: Arc<dyn LanguageToolchainStore>,
305    environment: Entity<ProjectEnvironment>,
306    captured_variables: TaskVariables,
307    location: Location,
308    cx: &App,
309) -> Task<Option<TaskContext>> {
310    let worktree_id = location.buffer.read(cx).file().map(|f| f.worktree_id(cx));
311    let worktree_abs_path = worktree_id
312        .and_then(|worktree_id| worktree_store.read(cx).worktree_for_id(worktree_id, cx))
313        .and_then(|worktree| worktree.read(cx).root_dir());
314
315    cx.spawn(async move |cx| {
316        let project_env = environment
317            .update(cx, |environment, cx| {
318                environment.get_buffer_environment(
319                    location.buffer.clone(),
320                    worktree_store.clone(),
321                    cx,
322                )
323            })
324            .ok()?
325            .await;
326
327        let mut task_variables = cx
328            .update(|cx| {
329                combine_task_variables(
330                    captured_variables,
331                    location,
332                    project_env.clone(),
333                    BasicContextProvider::new(worktree_store),
334                    toolchain_store,
335                    cx,
336                )
337            })
338            .ok()?
339            .await
340            .log_err()?;
341        // Remove all custom entries starting with _, as they're not intended for use by the end user.
342        task_variables.sweep();
343
344        Some(TaskContext {
345            project_env: project_env.unwrap_or_default(),
346            cwd: worktree_abs_path.map(|p| p.to_path_buf()),
347            task_variables,
348        })
349    })
350}
351
352fn remote_task_context_for_location(
353    project_id: u64,
354    upstream_client: AnyProtoClient,
355    worktree_store: Entity<WorktreeStore>,
356    captured_variables: TaskVariables,
357    location: Location,
358    toolchain_store: Arc<dyn LanguageToolchainStore>,
359    cx: &mut App,
360) -> Task<Option<TaskContext>> {
361    cx.spawn(async move |cx| {
362        // We need to gather a client context, as the headless one may lack certain information (e.g. tree-sitter parsing is disabled there, so symbols are not available).
363        let mut remote_context = cx
364            .update(|cx| {
365                BasicContextProvider::new(worktree_store).build_context(
366                    &TaskVariables::default(),
367                    &location,
368                    None,
369                    toolchain_store,
370                    cx,
371                )
372            })
373            .ok()?
374            .await
375            .log_err()
376            .unwrap_or_default();
377        remote_context.extend(captured_variables);
378
379        let buffer_id = cx
380            .update(|cx| location.buffer.read(cx).remote_id().to_proto())
381            .ok()?;
382        let context_task = upstream_client.request(proto::TaskContextForLocation {
383            project_id,
384            location: Some(proto::Location {
385                buffer_id,
386                start: Some(serialize_anchor(&location.range.start)),
387                end: Some(serialize_anchor(&location.range.end)),
388            }),
389            task_variables: remote_context
390                .into_iter()
391                .map(|(k, v)| (k.to_string(), v))
392                .collect(),
393        });
394        let task_context = context_task.await.log_err()?;
395        Some(TaskContext {
396            cwd: task_context.cwd.map(PathBuf::from),
397            task_variables: task_context
398                .task_variables
399                .into_iter()
400                .filter_map(
401                    |(variable_name, variable_value)| match variable_name.parse() {
402                        Ok(variable_name) => Some((variable_name, variable_value)),
403                        Err(()) => {
404                            log::error!("Unknown variable name: {variable_name}");
405                            None
406                        }
407                    },
408                )
409                .collect(),
410            project_env: task_context.project_env.into_iter().collect(),
411        })
412    })
413}
414
415fn combine_task_variables(
416    mut captured_variables: TaskVariables,
417    location: Location,
418    project_env: Option<HashMap<String, String>>,
419    baseline: BasicContextProvider,
420    toolchain_store: Arc<dyn LanguageToolchainStore>,
421    cx: &mut App,
422) -> Task<anyhow::Result<TaskVariables>> {
423    let language_context_provider = location
424        .buffer
425        .read(cx)
426        .language()
427        .and_then(|language| language.context_provider());
428    cx.spawn(async move |cx| {
429        let baseline = cx
430            .update(|cx| {
431                baseline.build_context(
432                    &captured_variables,
433                    &location,
434                    project_env.clone(),
435                    toolchain_store.clone(),
436                    cx,
437                )
438            })?
439            .await
440            .context("building basic default context")?;
441        captured_variables.extend(baseline);
442        if let Some(provider) = language_context_provider {
443            captured_variables.extend(
444                cx.update(|cx| {
445                    provider.build_context(
446                        &captured_variables,
447                        &location,
448                        project_env,
449                        toolchain_store,
450                        cx,
451                    )
452                })?
453                .await
454                .context("building provider context")?,
455            );
456        }
457        Ok(captured_variables)
458    })
459}