task_store.rs

  1use std::{path::PathBuf, sync::Arc};
  2
  3use anyhow::Context as _;
  4use collections::HashMap;
  5use fs::Fs;
  6use futures::StreamExt as _;
  7use gpui::{AppContext, AsyncAppContext, EventEmitter, Model, ModelContext, Task, WeakModel};
  8use language::{
  9    proto::{deserialize_anchor, serialize_anchor},
 10    ContextProvider as _, Location,
 11};
 12use rpc::{proto, AnyProtoClient, TypedEnvelope};
 13use settings::{watch_config_file, SettingsLocation};
 14use task::{TaskContext, TaskVariables, VariableName};
 15use text::BufferId;
 16use util::ResultExt;
 17
 18use crate::{
 19    buffer_store::BufferStore, worktree_store::WorktreeStore, BasicContextProvider, Inventory,
 20    ProjectEnvironment,
 21};
 22
 23pub enum TaskStore {
 24    Functional(StoreState),
 25    Noop,
 26}
 27
 28pub struct StoreState {
 29    mode: StoreMode,
 30    task_inventory: Model<Inventory>,
 31    buffer_store: WeakModel<BufferStore>,
 32    worktree_store: Model<WorktreeStore>,
 33    _global_task_config_watcher: Task<()>,
 34}
 35
 36enum StoreMode {
 37    Local {
 38        downstream_client: Option<(AnyProtoClient, u64)>,
 39        environment: Model<ProjectEnvironment>,
 40    },
 41    Remote {
 42        upstream_client: AnyProtoClient,
 43        project_id: u64,
 44    },
 45}
 46
 47impl EventEmitter<crate::Event> for TaskStore {}
 48
 49impl TaskStore {
 50    pub fn init(client: Option<&AnyProtoClient>) {
 51        if let Some(client) = client {
 52            client.add_model_request_handler(Self::handle_task_context_for_location);
 53        }
 54    }
 55
 56    async fn handle_task_context_for_location(
 57        store: Model<Self>,
 58        envelope: TypedEnvelope<proto::TaskContextForLocation>,
 59        mut cx: AsyncAppContext,
 60    ) -> anyhow::Result<proto::TaskContext> {
 61        let location = envelope
 62            .payload
 63            .location
 64            .context("no location given for task context handling")?;
 65        let (buffer_store, is_remote) = store.update(&mut cx, |store, _| {
 66            Ok(match store {
 67                TaskStore::Functional(state) => (
 68                    state.buffer_store.clone(),
 69                    match &state.mode {
 70                        StoreMode::Local { .. } => false,
 71                        StoreMode::Remote { .. } => true,
 72                    },
 73                ),
 74                TaskStore::Noop => {
 75                    anyhow::bail!("empty task store cannot handle task context requests")
 76                }
 77            })
 78        })??;
 79        let buffer_store = buffer_store
 80            .upgrade()
 81            .context("no buffer store when handling task context request")?;
 82
 83        let buffer_id = BufferId::new(location.buffer_id).with_context(|| {
 84            format!(
 85                "cannot handle task context request for invalid buffer id: {}",
 86                location.buffer_id
 87            )
 88        })?;
 89
 90        let start = location
 91            .start
 92            .and_then(deserialize_anchor)
 93            .context("missing task context location start")?;
 94        let end = location
 95            .end
 96            .and_then(deserialize_anchor)
 97            .context("missing task context location end")?;
 98        let buffer = buffer_store
 99            .update(&mut cx, |buffer_store, cx| {
100                if is_remote {
101                    buffer_store.wait_for_remote_buffer(buffer_id, cx)
102                } else {
103                    Task::ready(
104                        buffer_store
105                            .get(buffer_id)
106                            .with_context(|| format!("no local buffer with id {buffer_id}")),
107                    )
108                }
109            })?
110            .await?;
111
112        let location = Location {
113            buffer,
114            range: start..end,
115        };
116        let context_task = store.update(&mut cx, |store, cx| {
117            let captured_variables = {
118                let mut variables = TaskVariables::from_iter(
119                    envelope
120                        .payload
121                        .task_variables
122                        .into_iter()
123                        .filter_map(|(k, v)| Some((k.parse().log_err()?, v))),
124                );
125
126                for range in location
127                    .buffer
128                    .read(cx)
129                    .snapshot()
130                    .runnable_ranges(location.range.clone())
131                {
132                    for (capture_name, value) in range.extra_captures {
133                        variables.insert(VariableName::Custom(capture_name.into()), value);
134                    }
135                }
136                variables
137            };
138            store.task_context_for_location(captured_variables, location, cx)
139        })?;
140        let task_context = context_task.await.unwrap_or_default();
141        Ok(proto::TaskContext {
142            project_env: task_context.project_env.into_iter().collect(),
143            cwd: task_context
144                .cwd
145                .map(|cwd| cwd.to_string_lossy().to_string()),
146            task_variables: task_context
147                .task_variables
148                .into_iter()
149                .map(|(variable_name, variable_value)| (variable_name.to_string(), variable_value))
150                .collect(),
151        })
152    }
153
154    pub fn local(
155        fs: Arc<dyn Fs>,
156        buffer_store: WeakModel<BufferStore>,
157        worktree_store: Model<WorktreeStore>,
158        environment: Model<ProjectEnvironment>,
159        cx: &mut ModelContext<'_, Self>,
160    ) -> Self {
161        Self::Functional(StoreState {
162            mode: StoreMode::Local {
163                downstream_client: None,
164                environment,
165            },
166            task_inventory: Inventory::new(cx),
167            buffer_store,
168            worktree_store,
169            _global_task_config_watcher: Self::subscribe_to_global_task_file_changes(fs, cx),
170        })
171    }
172
173    pub fn remote(
174        fs: Arc<dyn Fs>,
175        buffer_store: WeakModel<BufferStore>,
176        worktree_store: Model<WorktreeStore>,
177        upstream_client: AnyProtoClient,
178        project_id: u64,
179        cx: &mut ModelContext<'_, Self>,
180    ) -> Self {
181        Self::Functional(StoreState {
182            mode: StoreMode::Remote {
183                upstream_client,
184                project_id,
185            },
186            task_inventory: Inventory::new(cx),
187            buffer_store,
188            worktree_store,
189            _global_task_config_watcher: Self::subscribe_to_global_task_file_changes(fs, cx),
190        })
191    }
192
193    pub fn task_context_for_location(
194        &self,
195        captured_variables: TaskVariables,
196        location: Location,
197        cx: &mut AppContext,
198    ) -> Task<Option<TaskContext>> {
199        match self {
200            TaskStore::Functional(state) => match &state.mode {
201                StoreMode::Local { environment, .. } => local_task_context_for_location(
202                    state.worktree_store.clone(),
203                    environment.clone(),
204                    captured_variables,
205                    location,
206                    cx,
207                ),
208                StoreMode::Remote {
209                    upstream_client,
210                    project_id,
211                } => remote_task_context_for_location(
212                    *project_id,
213                    upstream_client,
214                    state.worktree_store.clone(),
215                    captured_variables,
216                    location,
217                    cx,
218                ),
219            },
220            TaskStore::Noop => Task::ready(None),
221        }
222    }
223
224    pub fn task_inventory(&self) -> Option<&Model<Inventory>> {
225        match self {
226            TaskStore::Functional(state) => Some(&state.task_inventory),
227            TaskStore::Noop => None,
228        }
229    }
230
231    pub fn shared(
232        &mut self,
233        remote_id: u64,
234        new_downstream_client: AnyProtoClient,
235        _cx: &mut AppContext,
236    ) {
237        if let Self::Functional(StoreState {
238            mode: StoreMode::Local {
239                downstream_client, ..
240            },
241            ..
242        }) = self
243        {
244            *downstream_client = Some((new_downstream_client, remote_id));
245        }
246    }
247
248    pub fn unshared(&mut self, _: &mut ModelContext<Self>) {
249        if let Self::Functional(StoreState {
250            mode: StoreMode::Local {
251                downstream_client, ..
252            },
253            ..
254        }) = self
255        {
256            *downstream_client = None;
257        }
258    }
259
260    pub(super) fn update_user_tasks(
261        &self,
262        location: Option<SettingsLocation<'_>>,
263        raw_tasks_json: Option<&str>,
264        cx: &mut ModelContext<'_, Self>,
265    ) -> anyhow::Result<()> {
266        let task_inventory = match self {
267            TaskStore::Functional(state) => &state.task_inventory,
268            TaskStore::Noop => return Ok(()),
269        };
270        let raw_tasks_json = raw_tasks_json
271            .map(|json| json.trim())
272            .filter(|json| !json.is_empty());
273
274        task_inventory.update(cx, |inventory, _| {
275            inventory.update_file_based_tasks(location, raw_tasks_json)
276        })
277    }
278
279    fn subscribe_to_global_task_file_changes(
280        fs: Arc<dyn Fs>,
281        cx: &mut ModelContext<'_, Self>,
282    ) -> Task<()> {
283        let mut user_tasks_file_rx =
284            watch_config_file(&cx.background_executor(), fs, paths::tasks_file().clone());
285        let user_tasks_content = cx.background_executor().block(user_tasks_file_rx.next());
286        cx.spawn(move |task_store, mut cx| async move {
287            if let Some(user_tasks_content) = user_tasks_content {
288                let Ok(_) = task_store.update(&mut cx, |task_store, cx| {
289                    task_store
290                        .update_user_tasks(None, Some(&user_tasks_content), cx)
291                        .log_err();
292                }) else {
293                    return;
294                };
295            }
296            while let Some(user_tasks_content) = user_tasks_file_rx.next().await {
297                let Ok(()) = task_store.update(&mut cx, |task_store, cx| {
298                    let result = task_store.update_user_tasks(None, Some(&user_tasks_content), cx);
299                    if let Err(err) = &result {
300                        log::error!("Failed to load user tasks: {err}");
301                        cx.emit(crate::Event::Toast {
302                            notification_id: "load-user-tasks".into(),
303                            message: format!("Invalid global tasks file\n{err}"),
304                        });
305                    }
306                    cx.refresh();
307                }) else {
308                    break; // App dropped
309                };
310            }
311        })
312    }
313}
314
315fn local_task_context_for_location(
316    worktree_store: Model<WorktreeStore>,
317    environment: Model<ProjectEnvironment>,
318    captured_variables: TaskVariables,
319    location: Location,
320    cx: &AppContext,
321) -> Task<Option<TaskContext>> {
322    let worktree_id = location.buffer.read(cx).file().map(|f| f.worktree_id(cx));
323    let worktree_abs_path = worktree_id
324        .and_then(|worktree_id| worktree_store.read(cx).worktree_for_id(worktree_id, cx))
325        .map(|worktree| worktree.read(cx).abs_path());
326
327    cx.spawn(|mut cx| async move {
328        let worktree_abs_path = worktree_abs_path.clone();
329        let project_env = environment
330            .update(&mut cx, |environment, cx| {
331                environment.get_environment(worktree_id, worktree_abs_path.clone(), cx)
332            })
333            .ok()?
334            .await;
335
336        let mut task_variables = cx
337            .update(|cx| {
338                combine_task_variables(
339                    captured_variables,
340                    location,
341                    project_env.as_ref(),
342                    BasicContextProvider::new(worktree_store),
343                    cx,
344                )
345                .log_err()
346            })
347            .ok()
348            .flatten()?;
349        // Remove all custom entries starting with _, as they're not intended for use by the end user.
350        task_variables.sweep();
351
352        Some(TaskContext {
353            project_env: project_env.unwrap_or_default(),
354            cwd: worktree_abs_path.map(|p| p.to_path_buf()),
355            task_variables,
356        })
357    })
358}
359
360fn remote_task_context_for_location(
361    project_id: u64,
362    upstream_client: &AnyProtoClient,
363    worktree_store: Model<WorktreeStore>,
364    captured_variables: TaskVariables,
365    location: Location,
366    cx: &mut AppContext,
367) -> Task<Option<TaskContext>> {
368    // We need to gather a client context, as the headless one may lack certain information (e.g. tree-sitter parsing is disabled there, so symbols are not available).
369    let mut remote_context = BasicContextProvider::new(worktree_store)
370        .build_context(&TaskVariables::default(), &location, None, cx)
371        .log_err()
372        .unwrap_or_default();
373    remote_context.extend(captured_variables);
374
375    let context_task = upstream_client.request(proto::TaskContextForLocation {
376        project_id,
377        location: Some(proto::Location {
378            buffer_id: location.buffer.read(cx).remote_id().into(),
379            start: Some(serialize_anchor(&location.range.start)),
380            end: Some(serialize_anchor(&location.range.end)),
381        }),
382        task_variables: remote_context
383            .into_iter()
384            .map(|(k, v)| (k.to_string(), v))
385            .collect(),
386    });
387    cx.spawn(|_| async move {
388        let task_context = context_task.await.log_err()?;
389        Some(TaskContext {
390            cwd: task_context.cwd.map(PathBuf::from),
391            task_variables: task_context
392                .task_variables
393                .into_iter()
394                .filter_map(
395                    |(variable_name, variable_value)| match variable_name.parse() {
396                        Ok(variable_name) => Some((variable_name, variable_value)),
397                        Err(()) => {
398                            log::error!("Unknown variable name: {variable_name}");
399                            None
400                        }
401                    },
402                )
403                .collect(),
404            project_env: task_context.project_env.into_iter().collect(),
405        })
406    })
407}
408
409fn combine_task_variables(
410    mut captured_variables: TaskVariables,
411    location: Location,
412    project_env: Option<&HashMap<String, String>>,
413    baseline: BasicContextProvider,
414    cx: &mut AppContext,
415) -> anyhow::Result<TaskVariables> {
416    let language_context_provider = location
417        .buffer
418        .read(cx)
419        .language()
420        .and_then(|language| language.context_provider());
421    let baseline = baseline
422        .build_context(&captured_variables, &location, project_env, cx)
423        .context("building basic default context")?;
424    captured_variables.extend(baseline);
425    if let Some(provider) = language_context_provider {
426        captured_variables.extend(
427            provider
428                .build_context(&captured_variables, &location, project_env, cx)
429                .context("building provider context")?,
430        );
431    }
432    Ok(captured_variables)
433}