task_store.rs

  1use std::{path::PathBuf, sync::Arc};
  2
  3use anyhow::Context as _;
  4use collections::HashMap;
  5use fs::Fs;
  6use futures::StreamExt as _;
  7use gpui::{AppContext, AsyncAppContext, EventEmitter, Model, ModelContext, Task, WeakModel};
  8use language::{
  9    proto::{deserialize_anchor, serialize_anchor},
 10    ContextProvider as _, LanguageToolchainStore, Location,
 11};
 12use rpc::{proto, AnyProtoClient, TypedEnvelope};
 13use settings::{watch_config_file, SettingsLocation};
 14use task::{TaskContext, TaskVariables, VariableName};
 15use text::{BufferId, OffsetRangeExt};
 16use util::ResultExt;
 17
 18use crate::{
 19    buffer_store::BufferStore, worktree_store::WorktreeStore, BasicContextProvider, Inventory,
 20    ProjectEnvironment,
 21};
 22
 23#[allow(clippy::large_enum_variant)] // platform-dependent warning
 24pub enum TaskStore {
 25    Functional(StoreState),
 26    Noop,
 27}
 28
 29pub struct StoreState {
 30    mode: StoreMode,
 31    task_inventory: Model<Inventory>,
 32    buffer_store: WeakModel<BufferStore>,
 33    worktree_store: Model<WorktreeStore>,
 34    toolchain_store: Arc<dyn LanguageToolchainStore>,
 35    _global_task_config_watcher: Task<()>,
 36}
 37
 38enum StoreMode {
 39    Local {
 40        downstream_client: Option<(AnyProtoClient, u64)>,
 41        environment: Model<ProjectEnvironment>,
 42    },
 43    Remote {
 44        upstream_client: AnyProtoClient,
 45        project_id: u64,
 46    },
 47}
 48
 49impl EventEmitter<crate::Event> for TaskStore {}
 50
 51impl TaskStore {
 52    pub fn init(client: Option<&AnyProtoClient>) {
 53        if let Some(client) = client {
 54            client.add_model_request_handler(Self::handle_task_context_for_location);
 55        }
 56    }
 57
 58    async fn handle_task_context_for_location(
 59        store: Model<Self>,
 60        envelope: TypedEnvelope<proto::TaskContextForLocation>,
 61        mut cx: AsyncAppContext,
 62    ) -> anyhow::Result<proto::TaskContext> {
 63        let location = envelope
 64            .payload
 65            .location
 66            .context("no location given for task context handling")?;
 67        let (buffer_store, is_remote) = store.update(&mut cx, |store, _| {
 68            Ok(match store {
 69                TaskStore::Functional(state) => (
 70                    state.buffer_store.clone(),
 71                    match &state.mode {
 72                        StoreMode::Local { .. } => false,
 73                        StoreMode::Remote { .. } => true,
 74                    },
 75                ),
 76                TaskStore::Noop => {
 77                    anyhow::bail!("empty task store cannot handle task context requests")
 78                }
 79            })
 80        })??;
 81        let buffer_store = buffer_store
 82            .upgrade()
 83            .context("no buffer store when handling task context request")?;
 84
 85        let buffer_id = BufferId::new(location.buffer_id).with_context(|| {
 86            format!(
 87                "cannot handle task context request for invalid buffer id: {}",
 88                location.buffer_id
 89            )
 90        })?;
 91
 92        let start = location
 93            .start
 94            .and_then(deserialize_anchor)
 95            .context("missing task context location start")?;
 96        let end = location
 97            .end
 98            .and_then(deserialize_anchor)
 99            .context("missing task context location end")?;
100        let buffer = buffer_store
101            .update(&mut cx, |buffer_store, cx| {
102                if is_remote {
103                    buffer_store.wait_for_remote_buffer(buffer_id, cx)
104                } else {
105                    Task::ready(
106                        buffer_store
107                            .get(buffer_id)
108                            .with_context(|| format!("no local buffer with id {buffer_id}")),
109                    )
110                }
111            })?
112            .await?;
113
114        let location = Location {
115            buffer,
116            range: start..end,
117        };
118        let context_task = store.update(&mut cx, |store, cx| {
119            let captured_variables = {
120                let mut variables = TaskVariables::from_iter(
121                    envelope
122                        .payload
123                        .task_variables
124                        .into_iter()
125                        .filter_map(|(k, v)| Some((k.parse().log_err()?, v))),
126                );
127
128                let snapshot = location.buffer.read(cx).snapshot();
129                let range = location.range.to_offset(&snapshot);
130
131                for range in snapshot.runnable_ranges(range) {
132                    for (capture_name, value) in range.extra_captures {
133                        variables.insert(VariableName::Custom(capture_name.into()), value);
134                    }
135                }
136                variables
137            };
138            store.task_context_for_location(captured_variables, location, cx)
139        })?;
140        let task_context = context_task.await.unwrap_or_default();
141        Ok(proto::TaskContext {
142            project_env: task_context.project_env.into_iter().collect(),
143            cwd: task_context
144                .cwd
145                .map(|cwd| cwd.to_string_lossy().to_string()),
146            task_variables: task_context
147                .task_variables
148                .into_iter()
149                .map(|(variable_name, variable_value)| (variable_name.to_string(), variable_value))
150                .collect(),
151        })
152    }
153
154    pub fn local(
155        fs: Arc<dyn Fs>,
156        buffer_store: WeakModel<BufferStore>,
157        worktree_store: Model<WorktreeStore>,
158        toolchain_store: Arc<dyn LanguageToolchainStore>,
159        environment: Model<ProjectEnvironment>,
160        cx: &mut ModelContext<'_, Self>,
161    ) -> Self {
162        Self::Functional(StoreState {
163            mode: StoreMode::Local {
164                downstream_client: None,
165                environment,
166            },
167            task_inventory: Inventory::new(cx),
168            buffer_store,
169            toolchain_store,
170            worktree_store,
171            _global_task_config_watcher: Self::subscribe_to_global_task_file_changes(fs, cx),
172        })
173    }
174
175    pub fn remote(
176        fs: Arc<dyn Fs>,
177        buffer_store: WeakModel<BufferStore>,
178        worktree_store: Model<WorktreeStore>,
179        toolchain_store: Arc<dyn LanguageToolchainStore>,
180        upstream_client: AnyProtoClient,
181        project_id: u64,
182        cx: &mut ModelContext<'_, Self>,
183    ) -> Self {
184        Self::Functional(StoreState {
185            mode: StoreMode::Remote {
186                upstream_client,
187                project_id,
188            },
189            task_inventory: Inventory::new(cx),
190            buffer_store,
191            toolchain_store,
192            worktree_store,
193            _global_task_config_watcher: Self::subscribe_to_global_task_file_changes(fs, cx),
194        })
195    }
196
197    pub fn task_context_for_location(
198        &self,
199        captured_variables: TaskVariables,
200        location: Location,
201        cx: &mut AppContext,
202    ) -> Task<Option<TaskContext>> {
203        match self {
204            TaskStore::Functional(state) => match &state.mode {
205                StoreMode::Local { environment, .. } => local_task_context_for_location(
206                    state.worktree_store.clone(),
207                    state.toolchain_store.clone(),
208                    environment.clone(),
209                    captured_variables,
210                    location,
211                    cx,
212                ),
213                StoreMode::Remote {
214                    upstream_client,
215                    project_id,
216                } => remote_task_context_for_location(
217                    *project_id,
218                    upstream_client.clone(),
219                    state.worktree_store.clone(),
220                    captured_variables,
221                    location,
222                    state.toolchain_store.clone(),
223                    cx,
224                ),
225            },
226            TaskStore::Noop => Task::ready(None),
227        }
228    }
229
230    pub fn task_inventory(&self) -> Option<&Model<Inventory>> {
231        match self {
232            TaskStore::Functional(state) => Some(&state.task_inventory),
233            TaskStore::Noop => None,
234        }
235    }
236
237    pub fn shared(
238        &mut self,
239        remote_id: u64,
240        new_downstream_client: AnyProtoClient,
241        _cx: &mut AppContext,
242    ) {
243        if let Self::Functional(StoreState {
244            mode: StoreMode::Local {
245                downstream_client, ..
246            },
247            ..
248        }) = self
249        {
250            *downstream_client = Some((new_downstream_client, remote_id));
251        }
252    }
253
254    pub fn unshared(&mut self, _: &mut ModelContext<Self>) {
255        if let Self::Functional(StoreState {
256            mode: StoreMode::Local {
257                downstream_client, ..
258            },
259            ..
260        }) = self
261        {
262            *downstream_client = None;
263        }
264    }
265
266    pub(super) fn update_user_tasks(
267        &self,
268        location: Option<SettingsLocation<'_>>,
269        raw_tasks_json: Option<&str>,
270        cx: &mut ModelContext<'_, Self>,
271    ) -> anyhow::Result<()> {
272        let task_inventory = match self {
273            TaskStore::Functional(state) => &state.task_inventory,
274            TaskStore::Noop => return Ok(()),
275        };
276        let raw_tasks_json = raw_tasks_json
277            .map(|json| json.trim())
278            .filter(|json| !json.is_empty());
279
280        task_inventory.update(cx, |inventory, _| {
281            inventory.update_file_based_tasks(location, raw_tasks_json)
282        })
283    }
284
285    fn subscribe_to_global_task_file_changes(
286        fs: Arc<dyn Fs>,
287        cx: &mut ModelContext<'_, Self>,
288    ) -> Task<()> {
289        let mut user_tasks_file_rx =
290            watch_config_file(&cx.background_executor(), fs, paths::tasks_file().clone());
291        let user_tasks_content = cx.background_executor().block(user_tasks_file_rx.next());
292        cx.spawn(move |task_store, mut cx| async move {
293            if let Some(user_tasks_content) = user_tasks_content {
294                let Ok(_) = task_store.update(&mut cx, |task_store, cx| {
295                    task_store
296                        .update_user_tasks(None, Some(&user_tasks_content), cx)
297                        .log_err();
298                }) else {
299                    return;
300                };
301            }
302            while let Some(user_tasks_content) = user_tasks_file_rx.next().await {
303                let Ok(()) = task_store.update(&mut cx, |task_store, cx| {
304                    let result = task_store.update_user_tasks(None, Some(&user_tasks_content), cx);
305                    if let Err(err) = &result {
306                        log::error!("Failed to load user tasks: {err}");
307                        cx.emit(crate::Event::Toast {
308                            notification_id: "load-user-tasks".into(),
309                            message: format!("Invalid global tasks file\n{err}"),
310                        });
311                    }
312                    cx.refresh();
313                }) else {
314                    break; // App dropped
315                };
316            }
317        })
318    }
319}
320
321fn local_task_context_for_location(
322    worktree_store: Model<WorktreeStore>,
323    toolchain_store: Arc<dyn LanguageToolchainStore>,
324    environment: Model<ProjectEnvironment>,
325    captured_variables: TaskVariables,
326    location: Location,
327    cx: &AppContext,
328) -> Task<Option<TaskContext>> {
329    let worktree_id = location.buffer.read(cx).file().map(|f| f.worktree_id(cx));
330    let worktree_abs_path = worktree_id
331        .and_then(|worktree_id| worktree_store.read(cx).worktree_for_id(worktree_id, cx))
332        .and_then(|worktree| worktree.read(cx).root_dir());
333
334    cx.spawn(|mut cx| async move {
335        let worktree_abs_path = worktree_abs_path.clone();
336        let project_env = environment
337            .update(&mut cx, |environment, cx| {
338                environment.get_environment(worktree_id, worktree_abs_path.clone(), cx)
339            })
340            .ok()?
341            .await;
342
343        let mut task_variables = cx
344            .update(|cx| {
345                combine_task_variables(
346                    captured_variables,
347                    location,
348                    project_env.clone(),
349                    BasicContextProvider::new(worktree_store),
350                    toolchain_store,
351                    cx,
352                )
353            })
354            .ok()?
355            .await
356            .log_err()?;
357        // Remove all custom entries starting with _, as they're not intended for use by the end user.
358        task_variables.sweep();
359
360        Some(TaskContext {
361            project_env: project_env.unwrap_or_default(),
362            cwd: worktree_abs_path.map(|p| p.to_path_buf()),
363            task_variables,
364        })
365    })
366}
367
368fn remote_task_context_for_location(
369    project_id: u64,
370    upstream_client: AnyProtoClient,
371    worktree_store: Model<WorktreeStore>,
372    captured_variables: TaskVariables,
373    location: Location,
374    toolchain_store: Arc<dyn LanguageToolchainStore>,
375    cx: &mut AppContext,
376) -> Task<Option<TaskContext>> {
377    cx.spawn(|cx| async move {
378        // We need to gather a client context, as the headless one may lack certain information (e.g. tree-sitter parsing is disabled there, so symbols are not available).
379        let mut remote_context = cx
380            .update(|cx| {
381                BasicContextProvider::new(worktree_store).build_context(
382                    &TaskVariables::default(),
383                    &location,
384                    None,
385                    toolchain_store,
386                    cx,
387                )
388            })
389            .ok()?
390            .await
391            .log_err()
392            .unwrap_or_default();
393        remote_context.extend(captured_variables);
394
395        let buffer_id = cx
396            .update(|cx| location.buffer.read(cx).remote_id().to_proto())
397            .ok()?;
398        let context_task = upstream_client.request(proto::TaskContextForLocation {
399            project_id,
400            location: Some(proto::Location {
401                buffer_id,
402                start: Some(serialize_anchor(&location.range.start)),
403                end: Some(serialize_anchor(&location.range.end)),
404            }),
405            task_variables: remote_context
406                .into_iter()
407                .map(|(k, v)| (k.to_string(), v))
408                .collect(),
409        });
410        let task_context = context_task.await.log_err()?;
411        Some(TaskContext {
412            cwd: task_context.cwd.map(PathBuf::from),
413            task_variables: task_context
414                .task_variables
415                .into_iter()
416                .filter_map(
417                    |(variable_name, variable_value)| match variable_name.parse() {
418                        Ok(variable_name) => Some((variable_name, variable_value)),
419                        Err(()) => {
420                            log::error!("Unknown variable name: {variable_name}");
421                            None
422                        }
423                    },
424                )
425                .collect(),
426            project_env: task_context.project_env.into_iter().collect(),
427        })
428    })
429}
430
431fn combine_task_variables(
432    mut captured_variables: TaskVariables,
433    location: Location,
434    project_env: Option<HashMap<String, String>>,
435    baseline: BasicContextProvider,
436    toolchain_store: Arc<dyn LanguageToolchainStore>,
437    cx: &mut AppContext,
438) -> Task<anyhow::Result<TaskVariables>> {
439    let language_context_provider = location
440        .buffer
441        .read(cx)
442        .language()
443        .and_then(|language| language.context_provider());
444    cx.spawn(move |cx| async move {
445        let baseline = cx
446            .update(|cx| {
447                baseline.build_context(
448                    &captured_variables,
449                    &location,
450                    project_env.clone(),
451                    toolchain_store.clone(),
452                    cx,
453                )
454            })?
455            .await
456            .context("building basic default context")?;
457        captured_variables.extend(baseline);
458        if let Some(provider) = language_context_provider {
459            captured_variables.extend(
460                cx.update(|cx| {
461                    provider.build_context(
462                        &captured_variables,
463                        &location,
464                        project_env,
465                        toolchain_store,
466                        cx,
467                    )
468                })?
469                .await
470                .context("building provider context")?,
471            );
472        }
473        Ok(captured_variables)
474    })
475}