task_store.rs

  1use std::{
  2    path::{Path, PathBuf},
  3    sync::Arc,
  4};
  5
  6use anyhow::Context as _;
  7use collections::HashMap;
  8use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
  9use language::{
 10    ContextProvider as _, LanguageToolchainStore, Location,
 11    proto::{deserialize_anchor, serialize_anchor},
 12};
 13use rpc::{AnyProtoClient, TypedEnvelope, proto};
 14use settings::{InvalidSettingsError, SettingsLocation};
 15use task::{TaskContext, TaskVariables, VariableName};
 16use text::{BufferId, OffsetRangeExt};
 17use util::ResultExt;
 18
 19use crate::{
 20    BasicContextProvider, Inventory, ProjectEnvironment, buffer_store::BufferStore,
 21    worktree_store::WorktreeStore,
 22};
 23
 24// platform-dependent warning
 25pub enum TaskStore {
 26    Functional(StoreState),
 27    Noop,
 28}
 29
 30pub struct StoreState {
 31    mode: StoreMode,
 32    task_inventory: Entity<Inventory>,
 33    buffer_store: WeakEntity<BufferStore>,
 34    worktree_store: Entity<WorktreeStore>,
 35    toolchain_store: Arc<dyn LanguageToolchainStore>,
 36}
 37
 38enum StoreMode {
 39    Local {
 40        downstream_client: Option<(AnyProtoClient, u64)>,
 41        environment: Entity<ProjectEnvironment>,
 42    },
 43    Remote {
 44        upstream_client: AnyProtoClient,
 45        project_id: u64,
 46    },
 47}
 48
 49impl EventEmitter<crate::Event> for TaskStore {}
 50
 51#[derive(Debug)]
 52pub enum TaskSettingsLocation<'a> {
 53    Global(&'a Path),
 54    Worktree(SettingsLocation<'a>),
 55}
 56
 57impl TaskStore {
 58    pub fn init(client: Option<&AnyProtoClient>) {
 59        if let Some(client) = client {
 60            client.add_entity_request_handler(Self::handle_task_context_for_location);
 61        }
 62    }
 63
 64    async fn handle_task_context_for_location(
 65        store: Entity<Self>,
 66        envelope: TypedEnvelope<proto::TaskContextForLocation>,
 67        mut cx: AsyncApp,
 68    ) -> anyhow::Result<proto::TaskContext> {
 69        let location = envelope
 70            .payload
 71            .location
 72            .context("no location given for task context handling")?;
 73        let (buffer_store, is_remote) = store.read_with(&mut cx, |store, _| {
 74            Ok(match store {
 75                TaskStore::Functional(state) => (
 76                    state.buffer_store.clone(),
 77                    match &state.mode {
 78                        StoreMode::Local { .. } => false,
 79                        StoreMode::Remote { .. } => true,
 80                    },
 81                ),
 82                TaskStore::Noop => {
 83                    anyhow::bail!("empty task store cannot handle task context requests")
 84                }
 85            })
 86        })??;
 87        let buffer_store = buffer_store
 88            .upgrade()
 89            .context("no buffer store when handling task context request")?;
 90
 91        let buffer_id = BufferId::new(location.buffer_id).with_context(|| {
 92            format!(
 93                "cannot handle task context request for invalid buffer id: {}",
 94                location.buffer_id
 95            )
 96        })?;
 97
 98        let start = location
 99            .start
100            .and_then(deserialize_anchor)
101            .context("missing task context location start")?;
102        let end = location
103            .end
104            .and_then(deserialize_anchor)
105            .context("missing task context location end")?;
106        let buffer = buffer_store
107            .update(&mut cx, |buffer_store, cx| {
108                if is_remote {
109                    buffer_store.wait_for_remote_buffer(buffer_id, cx)
110                } else {
111                    Task::ready(
112                        buffer_store
113                            .get(buffer_id)
114                            .with_context(|| format!("no local buffer with id {buffer_id}")),
115                    )
116                }
117            })?
118            .await?;
119
120        let location = Location {
121            buffer,
122            range: start..end,
123        };
124        let context_task = store.update(&mut cx, |store, cx| {
125            let captured_variables = {
126                let mut variables = TaskVariables::from_iter(
127                    envelope
128                        .payload
129                        .task_variables
130                        .into_iter()
131                        .filter_map(|(k, v)| Some((k.parse().log_err()?, v))),
132                );
133
134                let snapshot = location.buffer.read(cx).snapshot();
135                let range = location.range.to_offset(&snapshot);
136
137                for range in snapshot.runnable_ranges(range) {
138                    for (capture_name, value) in range.extra_captures {
139                        variables.insert(VariableName::Custom(capture_name.into()), value);
140                    }
141                }
142                variables
143            };
144            store.task_context_for_location(captured_variables, location, cx)
145        })?;
146        let task_context = context_task.await.unwrap_or_default();
147        Ok(proto::TaskContext {
148            project_env: task_context.project_env.into_iter().collect(),
149            cwd: task_context
150                .cwd
151                .map(|cwd| cwd.to_string_lossy().to_string()),
152            task_variables: task_context
153                .task_variables
154                .into_iter()
155                .map(|(variable_name, variable_value)| (variable_name.to_string(), variable_value))
156                .collect(),
157        })
158    }
159
160    pub fn local(
161        buffer_store: WeakEntity<BufferStore>,
162        worktree_store: Entity<WorktreeStore>,
163        toolchain_store: Arc<dyn LanguageToolchainStore>,
164        environment: Entity<ProjectEnvironment>,
165        cx: &mut Context<Self>,
166    ) -> Self {
167        Self::Functional(StoreState {
168            mode: StoreMode::Local {
169                downstream_client: None,
170                environment,
171            },
172            task_inventory: Inventory::new(cx),
173            buffer_store,
174            toolchain_store,
175            worktree_store,
176        })
177    }
178
179    pub fn remote(
180        buffer_store: WeakEntity<BufferStore>,
181        worktree_store: Entity<WorktreeStore>,
182        toolchain_store: Arc<dyn LanguageToolchainStore>,
183        upstream_client: AnyProtoClient,
184        project_id: u64,
185        cx: &mut Context<Self>,
186    ) -> Self {
187        Self::Functional(StoreState {
188            mode: StoreMode::Remote {
189                upstream_client,
190                project_id,
191            },
192            task_inventory: Inventory::new(cx),
193            buffer_store,
194            toolchain_store,
195            worktree_store,
196        })
197    }
198
199    pub fn task_context_for_location(
200        &self,
201        captured_variables: TaskVariables,
202        location: Location,
203        cx: &mut App,
204    ) -> Task<Option<TaskContext>> {
205        match self {
206            TaskStore::Functional(state) => match &state.mode {
207                StoreMode::Local { environment, .. } => local_task_context_for_location(
208                    state.worktree_store.clone(),
209                    state.toolchain_store.clone(),
210                    environment.clone(),
211                    captured_variables,
212                    location,
213                    cx,
214                ),
215                StoreMode::Remote {
216                    upstream_client,
217                    project_id,
218                } => remote_task_context_for_location(
219                    *project_id,
220                    upstream_client.clone(),
221                    state.worktree_store.clone(),
222                    captured_variables,
223                    location,
224                    state.toolchain_store.clone(),
225                    cx,
226                ),
227            },
228            TaskStore::Noop => Task::ready(None),
229        }
230    }
231
232    pub fn task_inventory(&self) -> Option<&Entity<Inventory>> {
233        match self {
234            TaskStore::Functional(state) => Some(&state.task_inventory),
235            TaskStore::Noop => None,
236        }
237    }
238
239    pub fn shared(&mut self, remote_id: u64, new_downstream_client: AnyProtoClient, _cx: &mut App) {
240        if let Self::Functional(StoreState {
241            mode: StoreMode::Local {
242                downstream_client, ..
243            },
244            ..
245        }) = self
246        {
247            *downstream_client = Some((new_downstream_client, remote_id));
248        }
249    }
250
251    pub fn unshared(&mut self, _: &mut Context<Self>) {
252        if let Self::Functional(StoreState {
253            mode: StoreMode::Local {
254                downstream_client, ..
255            },
256            ..
257        }) = self
258        {
259            *downstream_client = None;
260        }
261    }
262
263    pub(super) fn update_user_tasks(
264        &self,
265        location: TaskSettingsLocation<'_>,
266        raw_tasks_json: Option<&str>,
267        cx: &mut Context<Self>,
268    ) -> Result<(), InvalidSettingsError> {
269        let task_inventory = match self {
270            TaskStore::Functional(state) => &state.task_inventory,
271            TaskStore::Noop => return Ok(()),
272        };
273        let raw_tasks_json = raw_tasks_json
274            .map(|json| json.trim())
275            .filter(|json| !json.is_empty());
276
277        task_inventory.update(cx, |inventory, _| {
278            inventory.update_file_based_tasks(location, raw_tasks_json)
279        })
280    }
281
282    pub(super) fn update_user_debug_scenarios(
283        &self,
284        location: TaskSettingsLocation<'_>,
285        raw_tasks_json: Option<&str>,
286        cx: &mut Context<Self>,
287    ) -> Result<(), InvalidSettingsError> {
288        let task_inventory = match self {
289            TaskStore::Functional(state) => &state.task_inventory,
290            TaskStore::Noop => return Ok(()),
291        };
292        let raw_tasks_json = raw_tasks_json
293            .map(|json| json.trim())
294            .filter(|json| !json.is_empty());
295
296        task_inventory.update(cx, |inventory, _| {
297            inventory.update_file_based_scenarios(location, raw_tasks_json)
298        })
299    }
300}
301
302fn local_task_context_for_location(
303    worktree_store: Entity<WorktreeStore>,
304    toolchain_store: Arc<dyn LanguageToolchainStore>,
305    environment: Entity<ProjectEnvironment>,
306    captured_variables: TaskVariables,
307    location: Location,
308    cx: &App,
309) -> Task<Option<TaskContext>> {
310    let worktree_id = location.buffer.read(cx).file().map(|f| f.worktree_id(cx));
311    let worktree_abs_path = worktree_id
312        .and_then(|worktree_id| worktree_store.read(cx).worktree_for_id(worktree_id, cx))
313        .and_then(|worktree| worktree.read(cx).root_dir());
314
315    cx.spawn(async move |cx| {
316        let project_env = environment
317            .update(cx, |environment, cx| {
318                environment.get_buffer_environment(&location.buffer, &worktree_store, cx)
319            })
320            .ok()?
321            .await;
322
323        let mut task_variables = cx
324            .update(|cx| {
325                combine_task_variables(
326                    captured_variables,
327                    location,
328                    project_env.clone(),
329                    BasicContextProvider::new(worktree_store),
330                    toolchain_store,
331                    cx,
332                )
333            })
334            .ok()?
335            .await
336            .log_err()?;
337        // Remove all custom entries starting with _, as they're not intended for use by the end user.
338        task_variables.sweep();
339
340        Some(TaskContext {
341            project_env: project_env.unwrap_or_default(),
342            cwd: worktree_abs_path.map(|p| p.to_path_buf()),
343            task_variables,
344        })
345    })
346}
347
348fn remote_task_context_for_location(
349    project_id: u64,
350    upstream_client: AnyProtoClient,
351    worktree_store: Entity<WorktreeStore>,
352    captured_variables: TaskVariables,
353    location: Location,
354    toolchain_store: Arc<dyn LanguageToolchainStore>,
355    cx: &mut App,
356) -> Task<Option<TaskContext>> {
357    cx.spawn(async move |cx| {
358        // We need to gather a client context, as the headless one may lack certain information (e.g. tree-sitter parsing is disabled there, so symbols are not available).
359        let mut remote_context = cx
360            .update(|cx| {
361                BasicContextProvider::new(worktree_store).build_context(
362                    &TaskVariables::default(),
363                    &location,
364                    None,
365                    toolchain_store,
366                    cx,
367                )
368            })
369            .ok()?
370            .await
371            .log_err()
372            .unwrap_or_default();
373        remote_context.extend(captured_variables);
374
375        let buffer_id = cx
376            .update(|cx| location.buffer.read(cx).remote_id().to_proto())
377            .ok()?;
378        let context_task = upstream_client.request(proto::TaskContextForLocation {
379            project_id,
380            location: Some(proto::Location {
381                buffer_id,
382                start: Some(serialize_anchor(&location.range.start)),
383                end: Some(serialize_anchor(&location.range.end)),
384            }),
385            task_variables: remote_context
386                .into_iter()
387                .map(|(k, v)| (k.to_string(), v))
388                .collect(),
389        });
390        let task_context = context_task.await.log_err()?;
391        Some(TaskContext {
392            cwd: task_context.cwd.map(PathBuf::from),
393            task_variables: task_context
394                .task_variables
395                .into_iter()
396                .filter_map(
397                    |(variable_name, variable_value)| match variable_name.parse() {
398                        Ok(variable_name) => Some((variable_name, variable_value)),
399                        Err(()) => {
400                            log::error!("Unknown variable name: {variable_name}");
401                            None
402                        }
403                    },
404                )
405                .collect(),
406            project_env: task_context.project_env.into_iter().collect(),
407        })
408    })
409}
410
411fn combine_task_variables(
412    mut captured_variables: TaskVariables,
413    location: Location,
414    project_env: Option<HashMap<String, String>>,
415    baseline: BasicContextProvider,
416    toolchain_store: Arc<dyn LanguageToolchainStore>,
417    cx: &mut App,
418) -> Task<anyhow::Result<TaskVariables>> {
419    let language_context_provider = location
420        .buffer
421        .read(cx)
422        .language()
423        .and_then(|language| language.context_provider());
424    cx.spawn(async move |cx| {
425        let baseline = cx
426            .update(|cx| {
427                baseline.build_context(
428                    &captured_variables,
429                    &location,
430                    project_env.clone(),
431                    toolchain_store.clone(),
432                    cx,
433                )
434            })?
435            .await
436            .context("building basic default context")?;
437        captured_variables.extend(baseline);
438        if let Some(provider) = language_context_provider {
439            captured_variables.extend(
440                cx.update(|cx| {
441                    provider.build_context(
442                        &captured_variables,
443                        &location,
444                        project_env,
445                        toolchain_store,
446                        cx,
447                    )
448                })?
449                .await
450                .context("building provider context")?,
451            );
452        }
453        Ok(captured_variables)
454    })
455}