task_store.rs

  1use std::{path::PathBuf, sync::Arc};
  2
  3use anyhow::Context as _;
  4use collections::HashMap;
  5use fs::Fs;
  6use futures::StreamExt as _;
  7use gpui::{AppContext, AsyncAppContext, EventEmitter, Model, ModelContext, Task, WeakModel};
  8use language::{
  9    proto::{deserialize_anchor, serialize_anchor},
 10    ContextProvider as _, LanguageToolchainStore, Location,
 11};
 12use rpc::{proto, AnyProtoClient, TypedEnvelope};
 13use settings::{watch_config_file, SettingsLocation};
 14use task::{TaskContext, TaskVariables, VariableName};
 15use text::BufferId;
 16use util::ResultExt;
 17
 18use crate::{
 19    buffer_store::BufferStore, worktree_store::WorktreeStore, BasicContextProvider, Inventory,
 20    ProjectEnvironment,
 21};
 22
 23#[allow(clippy::large_enum_variant)] // platform-dependent warning
 24pub enum TaskStore {
 25    Functional(StoreState),
 26    Noop,
 27}
 28
 29pub struct StoreState {
 30    mode: StoreMode,
 31    task_inventory: Model<Inventory>,
 32    buffer_store: WeakModel<BufferStore>,
 33    worktree_store: Model<WorktreeStore>,
 34    toolchain_store: Arc<dyn LanguageToolchainStore>,
 35    _global_task_config_watcher: Task<()>,
 36}
 37
 38enum StoreMode {
 39    Local {
 40        downstream_client: Option<(AnyProtoClient, u64)>,
 41        environment: Model<ProjectEnvironment>,
 42    },
 43    Remote {
 44        upstream_client: AnyProtoClient,
 45        project_id: u64,
 46    },
 47}
 48
 49impl EventEmitter<crate::Event> for TaskStore {}
 50
 51impl TaskStore {
 52    pub fn init(client: Option<&AnyProtoClient>) {
 53        if let Some(client) = client {
 54            client.add_model_request_handler(Self::handle_task_context_for_location);
 55        }
 56    }
 57
 58    async fn handle_task_context_for_location(
 59        store: Model<Self>,
 60        envelope: TypedEnvelope<proto::TaskContextForLocation>,
 61        mut cx: AsyncAppContext,
 62    ) -> anyhow::Result<proto::TaskContext> {
 63        let location = envelope
 64            .payload
 65            .location
 66            .context("no location given for task context handling")?;
 67        let (buffer_store, is_remote) = store.update(&mut cx, |store, _| {
 68            Ok(match store {
 69                TaskStore::Functional(state) => (
 70                    state.buffer_store.clone(),
 71                    match &state.mode {
 72                        StoreMode::Local { .. } => false,
 73                        StoreMode::Remote { .. } => true,
 74                    },
 75                ),
 76                TaskStore::Noop => {
 77                    anyhow::bail!("empty task store cannot handle task context requests")
 78                }
 79            })
 80        })??;
 81        let buffer_store = buffer_store
 82            .upgrade()
 83            .context("no buffer store when handling task context request")?;
 84
 85        let buffer_id = BufferId::new(location.buffer_id).with_context(|| {
 86            format!(
 87                "cannot handle task context request for invalid buffer id: {}",
 88                location.buffer_id
 89            )
 90        })?;
 91
 92        let start = location
 93            .start
 94            .and_then(deserialize_anchor)
 95            .context("missing task context location start")?;
 96        let end = location
 97            .end
 98            .and_then(deserialize_anchor)
 99            .context("missing task context location end")?;
100        let buffer = buffer_store
101            .update(&mut cx, |buffer_store, cx| {
102                if is_remote {
103                    buffer_store.wait_for_remote_buffer(buffer_id, cx)
104                } else {
105                    Task::ready(
106                        buffer_store
107                            .get(buffer_id)
108                            .with_context(|| format!("no local buffer with id {buffer_id}")),
109                    )
110                }
111            })?
112            .await?;
113
114        let location = Location {
115            buffer,
116            range: start..end,
117        };
118        let context_task = store.update(&mut cx, |store, cx| {
119            let captured_variables = {
120                let mut variables = TaskVariables::from_iter(
121                    envelope
122                        .payload
123                        .task_variables
124                        .into_iter()
125                        .filter_map(|(k, v)| Some((k.parse().log_err()?, v))),
126                );
127
128                for range in location
129                    .buffer
130                    .read(cx)
131                    .snapshot()
132                    .runnable_ranges(location.range.clone())
133                {
134                    for (capture_name, value) in range.extra_captures {
135                        variables.insert(VariableName::Custom(capture_name.into()), value);
136                    }
137                }
138                variables
139            };
140            store.task_context_for_location(captured_variables, location, cx)
141        })?;
142        let task_context = context_task.await.unwrap_or_default();
143        Ok(proto::TaskContext {
144            project_env: task_context.project_env.into_iter().collect(),
145            cwd: task_context
146                .cwd
147                .map(|cwd| cwd.to_string_lossy().to_string()),
148            task_variables: task_context
149                .task_variables
150                .into_iter()
151                .map(|(variable_name, variable_value)| (variable_name.to_string(), variable_value))
152                .collect(),
153        })
154    }
155
156    pub fn local(
157        fs: Arc<dyn Fs>,
158        buffer_store: WeakModel<BufferStore>,
159        worktree_store: Model<WorktreeStore>,
160        toolchain_store: Arc<dyn LanguageToolchainStore>,
161        environment: Model<ProjectEnvironment>,
162        cx: &mut ModelContext<'_, Self>,
163    ) -> Self {
164        Self::Functional(StoreState {
165            mode: StoreMode::Local {
166                downstream_client: None,
167                environment,
168            },
169            task_inventory: Inventory::new(cx),
170            buffer_store,
171            toolchain_store,
172            worktree_store,
173            _global_task_config_watcher: Self::subscribe_to_global_task_file_changes(fs, cx),
174        })
175    }
176
177    pub fn remote(
178        fs: Arc<dyn Fs>,
179        buffer_store: WeakModel<BufferStore>,
180        worktree_store: Model<WorktreeStore>,
181        toolchain_store: Arc<dyn LanguageToolchainStore>,
182        upstream_client: AnyProtoClient,
183        project_id: u64,
184        cx: &mut ModelContext<'_, Self>,
185    ) -> Self {
186        Self::Functional(StoreState {
187            mode: StoreMode::Remote {
188                upstream_client,
189                project_id,
190            },
191            task_inventory: Inventory::new(cx),
192            buffer_store,
193            toolchain_store,
194            worktree_store,
195            _global_task_config_watcher: Self::subscribe_to_global_task_file_changes(fs, cx),
196        })
197    }
198
199    pub fn task_context_for_location(
200        &self,
201        captured_variables: TaskVariables,
202        location: Location,
203        cx: &mut AppContext,
204    ) -> Task<Option<TaskContext>> {
205        match self {
206            TaskStore::Functional(state) => match &state.mode {
207                StoreMode::Local { environment, .. } => local_task_context_for_location(
208                    state.worktree_store.clone(),
209                    state.toolchain_store.clone(),
210                    environment.clone(),
211                    captured_variables,
212                    location,
213                    cx,
214                ),
215                StoreMode::Remote {
216                    upstream_client,
217                    project_id,
218                } => remote_task_context_for_location(
219                    *project_id,
220                    upstream_client.clone(),
221                    state.worktree_store.clone(),
222                    captured_variables,
223                    location,
224                    state.toolchain_store.clone(),
225                    cx,
226                ),
227            },
228            TaskStore::Noop => Task::ready(None),
229        }
230    }
231
232    pub fn task_inventory(&self) -> Option<&Model<Inventory>> {
233        match self {
234            TaskStore::Functional(state) => Some(&state.task_inventory),
235            TaskStore::Noop => None,
236        }
237    }
238
239    pub fn shared(
240        &mut self,
241        remote_id: u64,
242        new_downstream_client: AnyProtoClient,
243        _cx: &mut AppContext,
244    ) {
245        if let Self::Functional(StoreState {
246            mode: StoreMode::Local {
247                downstream_client, ..
248            },
249            ..
250        }) = self
251        {
252            *downstream_client = Some((new_downstream_client, remote_id));
253        }
254    }
255
256    pub fn unshared(&mut self, _: &mut ModelContext<Self>) {
257        if let Self::Functional(StoreState {
258            mode: StoreMode::Local {
259                downstream_client, ..
260            },
261            ..
262        }) = self
263        {
264            *downstream_client = None;
265        }
266    }
267
268    pub(super) fn update_user_tasks(
269        &self,
270        location: Option<SettingsLocation<'_>>,
271        raw_tasks_json: Option<&str>,
272        cx: &mut ModelContext<'_, Self>,
273    ) -> anyhow::Result<()> {
274        let task_inventory = match self {
275            TaskStore::Functional(state) => &state.task_inventory,
276            TaskStore::Noop => return Ok(()),
277        };
278        let raw_tasks_json = raw_tasks_json
279            .map(|json| json.trim())
280            .filter(|json| !json.is_empty());
281
282        task_inventory.update(cx, |inventory, _| {
283            inventory.update_file_based_tasks(location, raw_tasks_json)
284        })
285    }
286
287    fn subscribe_to_global_task_file_changes(
288        fs: Arc<dyn Fs>,
289        cx: &mut ModelContext<'_, Self>,
290    ) -> Task<()> {
291        let mut user_tasks_file_rx =
292            watch_config_file(&cx.background_executor(), fs, paths::tasks_file().clone());
293        let user_tasks_content = cx.background_executor().block(user_tasks_file_rx.next());
294        cx.spawn(move |task_store, mut cx| async move {
295            if let Some(user_tasks_content) = user_tasks_content {
296                let Ok(_) = task_store.update(&mut cx, |task_store, cx| {
297                    task_store
298                        .update_user_tasks(None, Some(&user_tasks_content), cx)
299                        .log_err();
300                }) else {
301                    return;
302                };
303            }
304            while let Some(user_tasks_content) = user_tasks_file_rx.next().await {
305                let Ok(()) = task_store.update(&mut cx, |task_store, cx| {
306                    let result = task_store.update_user_tasks(None, Some(&user_tasks_content), cx);
307                    if let Err(err) = &result {
308                        log::error!("Failed to load user tasks: {err}");
309                        cx.emit(crate::Event::Toast {
310                            notification_id: "load-user-tasks".into(),
311                            message: format!("Invalid global tasks file\n{err}"),
312                        });
313                    }
314                    cx.refresh();
315                }) else {
316                    break; // App dropped
317                };
318            }
319        })
320    }
321}
322
323fn local_task_context_for_location(
324    worktree_store: Model<WorktreeStore>,
325    toolchain_store: Arc<dyn LanguageToolchainStore>,
326    environment: Model<ProjectEnvironment>,
327    captured_variables: TaskVariables,
328    location: Location,
329    cx: &AppContext,
330) -> Task<Option<TaskContext>> {
331    let worktree_id = location.buffer.read(cx).file().map(|f| f.worktree_id(cx));
332    let worktree_abs_path = worktree_id
333        .and_then(|worktree_id| worktree_store.read(cx).worktree_for_id(worktree_id, cx))
334        .and_then(|worktree| worktree.read(cx).root_dir());
335
336    cx.spawn(|mut cx| async move {
337        let worktree_abs_path = worktree_abs_path.clone();
338        let project_env = environment
339            .update(&mut cx, |environment, cx| {
340                environment.get_environment(worktree_id, worktree_abs_path.clone(), cx)
341            })
342            .ok()?
343            .await;
344
345        let mut task_variables = cx
346            .update(|cx| {
347                combine_task_variables(
348                    captured_variables,
349                    location,
350                    project_env.clone(),
351                    BasicContextProvider::new(worktree_store),
352                    toolchain_store,
353                    cx,
354                )
355            })
356            .ok()?
357            .await
358            .log_err()?;
359        // Remove all custom entries starting with _, as they're not intended for use by the end user.
360        task_variables.sweep();
361
362        Some(TaskContext {
363            project_env: project_env.unwrap_or_default(),
364            cwd: worktree_abs_path.map(|p| p.to_path_buf()),
365            task_variables,
366        })
367    })
368}
369
370fn remote_task_context_for_location(
371    project_id: u64,
372    upstream_client: AnyProtoClient,
373    worktree_store: Model<WorktreeStore>,
374    captured_variables: TaskVariables,
375    location: Location,
376    toolchain_store: Arc<dyn LanguageToolchainStore>,
377    cx: &mut AppContext,
378) -> Task<Option<TaskContext>> {
379    cx.spawn(|cx| async move {
380        // We need to gather a client context, as the headless one may lack certain information (e.g. tree-sitter parsing is disabled there, so symbols are not available).
381        let mut remote_context = cx
382            .update(|cx| {
383                BasicContextProvider::new(worktree_store).build_context(
384                    &TaskVariables::default(),
385                    &location,
386                    None,
387                    toolchain_store,
388                    cx,
389                )
390            })
391            .ok()?
392            .await
393            .log_err()
394            .unwrap_or_default();
395        remote_context.extend(captured_variables);
396
397        let buffer_id = cx
398            .update(|cx| location.buffer.read(cx).remote_id().to_proto())
399            .ok()?;
400        let context_task = upstream_client.request(proto::TaskContextForLocation {
401            project_id,
402            location: Some(proto::Location {
403                buffer_id,
404                start: Some(serialize_anchor(&location.range.start)),
405                end: Some(serialize_anchor(&location.range.end)),
406            }),
407            task_variables: remote_context
408                .into_iter()
409                .map(|(k, v)| (k.to_string(), v))
410                .collect(),
411        });
412        let task_context = context_task.await.log_err()?;
413        Some(TaskContext {
414            cwd: task_context.cwd.map(PathBuf::from),
415            task_variables: task_context
416                .task_variables
417                .into_iter()
418                .filter_map(
419                    |(variable_name, variable_value)| match variable_name.parse() {
420                        Ok(variable_name) => Some((variable_name, variable_value)),
421                        Err(()) => {
422                            log::error!("Unknown variable name: {variable_name}");
423                            None
424                        }
425                    },
426                )
427                .collect(),
428            project_env: task_context.project_env.into_iter().collect(),
429        })
430    })
431}
432
433fn combine_task_variables(
434    mut captured_variables: TaskVariables,
435    location: Location,
436    project_env: Option<HashMap<String, String>>,
437    baseline: BasicContextProvider,
438    toolchain_store: Arc<dyn LanguageToolchainStore>,
439    cx: &mut AppContext,
440) -> Task<anyhow::Result<TaskVariables>> {
441    let language_context_provider = location
442        .buffer
443        .read(cx)
444        .language()
445        .and_then(|language| language.context_provider());
446    cx.spawn(move |cx| async move {
447        let baseline = cx
448            .update(|cx| {
449                baseline.build_context(
450                    &captured_variables,
451                    &location,
452                    project_env.clone(),
453                    toolchain_store.clone(),
454                    cx,
455                )
456            })?
457            .await
458            .context("building basic default context")?;
459        captured_variables.extend(baseline);
460        if let Some(provider) = language_context_provider {
461            captured_variables.extend(
462                cx.update(|cx| {
463                    provider.build_context(
464                        &captured_variables,
465                        &location,
466                        project_env,
467                        toolchain_store,
468                        cx,
469                    )
470                })?
471                .await
472                .context("building provider context")?,
473            );
474        }
475        Ok(captured_variables)
476    })
477}