task_store.rs

  1use std::{path::PathBuf, sync::Arc};
  2
  3use anyhow::Context as _;
  4use collections::HashMap;
  5use fs::Fs;
  6use futures::StreamExt as _;
  7use gpui::{AppContext, AsyncAppContext, EventEmitter, Model, ModelContext, Task, WeakModel};
  8use language::{
  9    proto::{deserialize_anchor, serialize_anchor},
 10    ContextProvider as _, Location,
 11};
 12use rpc::{proto, AnyProtoClient, TypedEnvelope};
 13use settings::{watch_config_file, SettingsLocation};
 14use task::{TaskContext, TaskVariables, VariableName};
 15use text::BufferId;
 16use util::ResultExt;
 17
 18use crate::{
 19    buffer_store::BufferStore, worktree_store::WorktreeStore, BasicContextProvider, Inventory,
 20    ProjectEnvironment,
 21};
 22
 23pub enum TaskStore {
 24    Functional(StoreState),
 25    Noop,
 26}
 27
 28pub struct StoreState {
 29    mode: StoreMode,
 30    task_inventory: Model<Inventory>,
 31    buffer_store: WeakModel<BufferStore>,
 32    worktree_store: Model<WorktreeStore>,
 33    _global_task_config_watcher: Task<()>,
 34}
 35
 36enum StoreMode {
 37    Local {
 38        downstream_client: Option<(AnyProtoClient, u64)>,
 39        environment: Model<ProjectEnvironment>,
 40    },
 41    Remote {
 42        upstream_client: AnyProtoClient,
 43        project_id: u64,
 44    },
 45}
 46
 47impl EventEmitter<crate::Event> for TaskStore {}
 48
 49impl TaskStore {
 50    pub fn init(client: Option<&AnyProtoClient>) {
 51        if let Some(client) = client {
 52            client.add_model_request_handler(Self::handle_task_context_for_location);
 53        }
 54    }
 55
 56    async fn handle_task_context_for_location(
 57        store: Model<Self>,
 58        envelope: TypedEnvelope<proto::TaskContextForLocation>,
 59        mut cx: AsyncAppContext,
 60    ) -> anyhow::Result<proto::TaskContext> {
 61        let location = envelope
 62            .payload
 63            .location
 64            .context("no location given for task context handling")?;
 65        let (buffer_store, is_remote) = store.update(&mut cx, |store, _| {
 66            Ok(match store {
 67                TaskStore::Functional(state) => (
 68                    state.buffer_store.clone(),
 69                    match &state.mode {
 70                        StoreMode::Local { .. } => false,
 71                        StoreMode::Remote { .. } => true,
 72                    },
 73                ),
 74                TaskStore::Noop => {
 75                    anyhow::bail!("empty task store cannot handle task context requests")
 76                }
 77            })
 78        })??;
 79        let buffer_store = buffer_store
 80            .upgrade()
 81            .context("no buffer store when handling task context request")?;
 82
 83        let buffer_id = BufferId::new(location.buffer_id).with_context(|| {
 84            format!(
 85                "cannot handle task context request for invalid buffer id: {}",
 86                location.buffer_id
 87            )
 88        })?;
 89
 90        let start = location
 91            .start
 92            .and_then(deserialize_anchor)
 93            .context("missing task context location start")?;
 94        let end = location
 95            .end
 96            .and_then(deserialize_anchor)
 97            .context("missing task context location end")?;
 98        let buffer = buffer_store
 99            .update(&mut cx, |buffer_store, cx| {
100                if is_remote {
101                    buffer_store.wait_for_remote_buffer(buffer_id, cx)
102                } else {
103                    Task::ready(
104                        buffer_store
105                            .get(buffer_id)
106                            .with_context(|| format!("no local buffer with id {buffer_id}")),
107                    )
108                }
109            })?
110            .await?;
111
112        let location = Location {
113            buffer,
114            range: start..end,
115        };
116        let context_task = store.update(&mut cx, |store, cx| {
117            let captured_variables = {
118                let mut variables = TaskVariables::from_iter(
119                    envelope
120                        .payload
121                        .task_variables
122                        .into_iter()
123                        .filter_map(|(k, v)| Some((k.parse().log_err()?, v))),
124                );
125
126                for range in location
127                    .buffer
128                    .read(cx)
129                    .snapshot()
130                    .runnable_ranges(location.range.clone())
131                {
132                    for (capture_name, value) in range.extra_captures {
133                        variables.insert(VariableName::Custom(capture_name.into()), value);
134                    }
135                }
136                variables
137            };
138            store.task_context_for_location(captured_variables, location, cx)
139        })?;
140        let task_context = context_task.await.unwrap_or_default();
141        Ok(proto::TaskContext {
142            project_env: task_context.project_env.into_iter().collect(),
143            cwd: task_context
144                .cwd
145                .map(|cwd| cwd.to_string_lossy().to_string()),
146            task_variables: task_context
147                .task_variables
148                .into_iter()
149                .map(|(variable_name, variable_value)| (variable_name.to_string(), variable_value))
150                .collect(),
151        })
152    }
153
154    pub fn local(
155        fs: Arc<dyn Fs>,
156        buffer_store: WeakModel<BufferStore>,
157        worktree_store: Model<WorktreeStore>,
158        environment: Model<ProjectEnvironment>,
159        cx: &mut ModelContext<'_, Self>,
160    ) -> Self {
161        Self::Functional(StoreState {
162            mode: StoreMode::Local {
163                downstream_client: None,
164                environment,
165            },
166            task_inventory: Inventory::new(cx),
167            buffer_store,
168            worktree_store,
169            _global_task_config_watcher: Self::subscribe_to_global_task_file_changes(fs, cx),
170        })
171    }
172
173    pub fn remote(
174        fs: Arc<dyn Fs>,
175        buffer_store: WeakModel<BufferStore>,
176        worktree_store: Model<WorktreeStore>,
177        upstream_client: AnyProtoClient,
178        project_id: u64,
179        cx: &mut ModelContext<'_, Self>,
180    ) -> Self {
181        Self::Functional(StoreState {
182            mode: StoreMode::Remote {
183                upstream_client,
184                project_id,
185            },
186            task_inventory: Inventory::new(cx),
187            buffer_store,
188            worktree_store,
189            _global_task_config_watcher: Self::subscribe_to_global_task_file_changes(fs, cx),
190        })
191    }
192
193    pub fn task_context_for_location(
194        &self,
195        captured_variables: TaskVariables,
196        location: Location,
197        cx: &mut AppContext,
198    ) -> Task<Option<TaskContext>> {
199        match self {
200            TaskStore::Functional(state) => match &state.mode {
201                StoreMode::Local { environment, .. } => local_task_context_for_location(
202                    state.worktree_store.clone(),
203                    environment.clone(),
204                    captured_variables,
205                    location,
206                    cx,
207                ),
208                StoreMode::Remote {
209                    upstream_client,
210                    project_id,
211                } => remote_task_context_for_location(
212                    *project_id,
213                    upstream_client,
214                    state.worktree_store.clone(),
215                    captured_variables,
216                    location,
217                    cx,
218                ),
219            },
220            TaskStore::Noop => Task::ready(None),
221        }
222    }
223
224    pub fn task_inventory(&self) -> Option<&Model<Inventory>> {
225        match self {
226            TaskStore::Functional(state) => Some(&state.task_inventory),
227            TaskStore::Noop => None,
228        }
229    }
230
231    pub fn shared(
232        &mut self,
233        remote_id: u64,
234        new_downstream_client: AnyProtoClient,
235        _cx: &mut AppContext,
236    ) {
237        if let Self::Functional(StoreState {
238            mode: StoreMode::Local {
239                downstream_client, ..
240            },
241            ..
242        }) = self
243        {
244            *downstream_client = Some((new_downstream_client, remote_id));
245        }
246    }
247
248    pub fn unshared(&mut self, _: &mut ModelContext<Self>) {
249        if let Self::Functional(StoreState {
250            mode: StoreMode::Local {
251                downstream_client, ..
252            },
253            ..
254        }) = self
255        {
256            *downstream_client = None;
257        }
258    }
259
260    pub(super) fn update_user_tasks(
261        &self,
262        location: Option<SettingsLocation<'_>>,
263        raw_tasks_json: Option<&str>,
264        cx: &mut ModelContext<'_, Self>,
265    ) -> anyhow::Result<()> {
266        let task_inventory = match self {
267            TaskStore::Functional(state) => &state.task_inventory,
268            TaskStore::Noop => return Ok(()),
269        };
270        let raw_tasks_json = raw_tasks_json
271            .map(|json| json.trim())
272            .filter(|json| !json.is_empty());
273
274        task_inventory.update(cx, |inventory, _| {
275            inventory.update_file_based_tasks(location, raw_tasks_json)
276        })
277    }
278
279    fn subscribe_to_global_task_file_changes(
280        fs: Arc<dyn Fs>,
281        cx: &mut ModelContext<'_, Self>,
282    ) -> Task<()> {
283        let mut user_tasks_file_rx =
284            watch_config_file(&cx.background_executor(), fs, paths::tasks_file().clone());
285        let user_tasks_content = cx.background_executor().block(user_tasks_file_rx.next());
286        cx.spawn(move |task_store, mut cx| async move {
287            if let Some(user_tasks_content) = user_tasks_content {
288                let Ok(_) = task_store.update(&mut cx, |task_store, cx| {
289                    task_store
290                        .update_user_tasks(None, Some(&user_tasks_content), cx)
291                        .log_err();
292                }) else {
293                    return;
294                };
295            }
296            while let Some(user_tasks_content) = user_tasks_file_rx.next().await {
297                let Ok(()) = task_store.update(&mut cx, |task_store, cx| {
298                    let result = task_store.update_user_tasks(None, Some(&user_tasks_content), cx);
299                    if let Err(err) = &result {
300                        log::error!("Failed to load user tasks: {err}");
301                        cx.emit(crate::Event::Notification(format!(
302                            "Invalid global tasks file\n{err}"
303                        )));
304                    }
305                    cx.refresh();
306                }) else {
307                    break; // App dropped
308                };
309            }
310        })
311    }
312}
313
314fn local_task_context_for_location(
315    worktree_store: Model<WorktreeStore>,
316    environment: Model<ProjectEnvironment>,
317    captured_variables: TaskVariables,
318    location: Location,
319    cx: &AppContext,
320) -> Task<Option<TaskContext>> {
321    let worktree_id = location.buffer.read(cx).file().map(|f| f.worktree_id(cx));
322    let worktree_abs_path = worktree_id
323        .and_then(|worktree_id| worktree_store.read(cx).worktree_for_id(worktree_id, cx))
324        .map(|worktree| worktree.read(cx).abs_path());
325
326    cx.spawn(|mut cx| async move {
327        let worktree_abs_path = worktree_abs_path.clone();
328        let project_env = environment
329            .update(&mut cx, |environment, cx| {
330                environment.get_environment(worktree_id, worktree_abs_path.clone(), cx)
331            })
332            .ok()?
333            .await;
334
335        let mut task_variables = cx
336            .update(|cx| {
337                combine_task_variables(
338                    captured_variables,
339                    location,
340                    project_env.as_ref(),
341                    BasicContextProvider::new(worktree_store),
342                    cx,
343                )
344                .log_err()
345            })
346            .ok()
347            .flatten()?;
348        // Remove all custom entries starting with _, as they're not intended for use by the end user.
349        task_variables.sweep();
350
351        Some(TaskContext {
352            project_env: project_env.unwrap_or_default(),
353            cwd: worktree_abs_path.map(|p| p.to_path_buf()),
354            task_variables,
355        })
356    })
357}
358
359fn remote_task_context_for_location(
360    project_id: u64,
361    upstream_client: &AnyProtoClient,
362    worktree_store: Model<WorktreeStore>,
363    captured_variables: TaskVariables,
364    location: Location,
365    cx: &mut AppContext,
366) -> Task<Option<TaskContext>> {
367    // We need to gather a client context, as the headless one may lack certain information (e.g. tree-sitter parsing is disabled there, so symbols are not available).
368    let mut remote_context = BasicContextProvider::new(worktree_store)
369        .build_context(&TaskVariables::default(), &location, None, cx)
370        .log_err()
371        .unwrap_or_default();
372    remote_context.extend(captured_variables);
373
374    let context_task = upstream_client.request(proto::TaskContextForLocation {
375        project_id,
376        location: Some(proto::Location {
377            buffer_id: location.buffer.read(cx).remote_id().into(),
378            start: Some(serialize_anchor(&location.range.start)),
379            end: Some(serialize_anchor(&location.range.end)),
380        }),
381        task_variables: remote_context
382            .into_iter()
383            .map(|(k, v)| (k.to_string(), v))
384            .collect(),
385    });
386    cx.spawn(|_| async move {
387        let task_context = context_task.await.log_err()?;
388        Some(TaskContext {
389            cwd: task_context.cwd.map(PathBuf::from),
390            task_variables: task_context
391                .task_variables
392                .into_iter()
393                .filter_map(
394                    |(variable_name, variable_value)| match variable_name.parse() {
395                        Ok(variable_name) => Some((variable_name, variable_value)),
396                        Err(()) => {
397                            log::error!("Unknown variable name: {variable_name}");
398                            None
399                        }
400                    },
401                )
402                .collect(),
403            project_env: task_context.project_env.into_iter().collect(),
404        })
405    })
406}
407
408fn combine_task_variables(
409    mut captured_variables: TaskVariables,
410    location: Location,
411    project_env: Option<&HashMap<String, String>>,
412    baseline: BasicContextProvider,
413    cx: &mut AppContext,
414) -> anyhow::Result<TaskVariables> {
415    let language_context_provider = location
416        .buffer
417        .read(cx)
418        .language()
419        .and_then(|language| language.context_provider());
420    let baseline = baseline
421        .build_context(&captured_variables, &location, project_env, cx)
422        .context("building basic default context")?;
423    captured_variables.extend(baseline);
424    if let Some(provider) = language_context_provider {
425        captured_variables.extend(
426            provider
427                .build_context(&captured_variables, &location, project_env, cx)
428                .context("building provider context")?,
429        );
430    }
431    Ok(captured_variables)
432}