1use std::{path::PathBuf, sync::Arc};
2
3use anyhow::Context as _;
4use collections::HashMap;
5use fs::Fs;
6use futures::StreamExt as _;
7use gpui::{AppContext, AsyncAppContext, EventEmitter, Model, ModelContext, Task, WeakModel};
8use language::{
9 proto::{deserialize_anchor, serialize_anchor},
10 ContextProvider as _, LanguageToolchainStore, Location,
11};
12use rpc::{proto, AnyProtoClient, TypedEnvelope};
13use settings::{watch_config_file, SettingsLocation};
14use task::{TaskContext, TaskVariables, VariableName};
15use text::BufferId;
16use util::ResultExt;
17
18use crate::{
19 buffer_store::BufferStore, worktree_store::WorktreeStore, BasicContextProvider, Inventory,
20 ProjectEnvironment,
21};
22
23#[allow(clippy::large_enum_variant)] // platform-dependent warning
24pub enum TaskStore {
25 Functional(StoreState),
26 Noop,
27}
28
29pub struct StoreState {
30 mode: StoreMode,
31 task_inventory: Model<Inventory>,
32 buffer_store: WeakModel<BufferStore>,
33 worktree_store: Model<WorktreeStore>,
34 toolchain_store: Arc<dyn LanguageToolchainStore>,
35 _global_task_config_watcher: Task<()>,
36}
37
38enum StoreMode {
39 Local {
40 downstream_client: Option<(AnyProtoClient, u64)>,
41 environment: Model<ProjectEnvironment>,
42 },
43 Remote {
44 upstream_client: AnyProtoClient,
45 project_id: u64,
46 },
47}
48
49impl EventEmitter<crate::Event> for TaskStore {}
50
51impl TaskStore {
52 pub fn init(client: Option<&AnyProtoClient>) {
53 if let Some(client) = client {
54 client.add_model_request_handler(Self::handle_task_context_for_location);
55 }
56 }
57
58 async fn handle_task_context_for_location(
59 store: Model<Self>,
60 envelope: TypedEnvelope<proto::TaskContextForLocation>,
61 mut cx: AsyncAppContext,
62 ) -> anyhow::Result<proto::TaskContext> {
63 let location = envelope
64 .payload
65 .location
66 .context("no location given for task context handling")?;
67 let (buffer_store, is_remote) = store.update(&mut cx, |store, _| {
68 Ok(match store {
69 TaskStore::Functional(state) => (
70 state.buffer_store.clone(),
71 match &state.mode {
72 StoreMode::Local { .. } => false,
73 StoreMode::Remote { .. } => true,
74 },
75 ),
76 TaskStore::Noop => {
77 anyhow::bail!("empty task store cannot handle task context requests")
78 }
79 })
80 })??;
81 let buffer_store = buffer_store
82 .upgrade()
83 .context("no buffer store when handling task context request")?;
84
85 let buffer_id = BufferId::new(location.buffer_id).with_context(|| {
86 format!(
87 "cannot handle task context request for invalid buffer id: {}",
88 location.buffer_id
89 )
90 })?;
91
92 let start = location
93 .start
94 .and_then(deserialize_anchor)
95 .context("missing task context location start")?;
96 let end = location
97 .end
98 .and_then(deserialize_anchor)
99 .context("missing task context location end")?;
100 let buffer = buffer_store
101 .update(&mut cx, |buffer_store, cx| {
102 if is_remote {
103 buffer_store.wait_for_remote_buffer(buffer_id, cx)
104 } else {
105 Task::ready(
106 buffer_store
107 .get(buffer_id)
108 .with_context(|| format!("no local buffer with id {buffer_id}")),
109 )
110 }
111 })?
112 .await?;
113
114 let location = Location {
115 buffer,
116 range: start..end,
117 };
118 let context_task = store.update(&mut cx, |store, cx| {
119 let captured_variables = {
120 let mut variables = TaskVariables::from_iter(
121 envelope
122 .payload
123 .task_variables
124 .into_iter()
125 .filter_map(|(k, v)| Some((k.parse().log_err()?, v))),
126 );
127
128 for range in location
129 .buffer
130 .read(cx)
131 .snapshot()
132 .runnable_ranges(location.range.clone())
133 {
134 for (capture_name, value) in range.extra_captures {
135 variables.insert(VariableName::Custom(capture_name.into()), value);
136 }
137 }
138 variables
139 };
140 store.task_context_for_location(captured_variables, location, cx)
141 })?;
142 let task_context = context_task.await.unwrap_or_default();
143 Ok(proto::TaskContext {
144 project_env: task_context.project_env.into_iter().collect(),
145 cwd: task_context
146 .cwd
147 .map(|cwd| cwd.to_string_lossy().to_string()),
148 task_variables: task_context
149 .task_variables
150 .into_iter()
151 .map(|(variable_name, variable_value)| (variable_name.to_string(), variable_value))
152 .collect(),
153 })
154 }
155
156 pub fn local(
157 fs: Arc<dyn Fs>,
158 buffer_store: WeakModel<BufferStore>,
159 worktree_store: Model<WorktreeStore>,
160 toolchain_store: Arc<dyn LanguageToolchainStore>,
161 environment: Model<ProjectEnvironment>,
162 cx: &mut ModelContext<'_, Self>,
163 ) -> Self {
164 Self::Functional(StoreState {
165 mode: StoreMode::Local {
166 downstream_client: None,
167 environment,
168 },
169 task_inventory: Inventory::new(cx),
170 buffer_store,
171 toolchain_store,
172 worktree_store,
173 _global_task_config_watcher: Self::subscribe_to_global_task_file_changes(fs, cx),
174 })
175 }
176
177 pub fn remote(
178 fs: Arc<dyn Fs>,
179 buffer_store: WeakModel<BufferStore>,
180 worktree_store: Model<WorktreeStore>,
181 toolchain_store: Arc<dyn LanguageToolchainStore>,
182 upstream_client: AnyProtoClient,
183 project_id: u64,
184 cx: &mut ModelContext<'_, Self>,
185 ) -> Self {
186 Self::Functional(StoreState {
187 mode: StoreMode::Remote {
188 upstream_client,
189 project_id,
190 },
191 task_inventory: Inventory::new(cx),
192 buffer_store,
193 toolchain_store,
194 worktree_store,
195 _global_task_config_watcher: Self::subscribe_to_global_task_file_changes(fs, cx),
196 })
197 }
198
199 pub fn task_context_for_location(
200 &self,
201 captured_variables: TaskVariables,
202 location: Location,
203 cx: &mut AppContext,
204 ) -> Task<Option<TaskContext>> {
205 match self {
206 TaskStore::Functional(state) => match &state.mode {
207 StoreMode::Local { environment, .. } => local_task_context_for_location(
208 state.worktree_store.clone(),
209 state.toolchain_store.clone(),
210 environment.clone(),
211 captured_variables,
212 location,
213 cx,
214 ),
215 StoreMode::Remote {
216 upstream_client,
217 project_id,
218 } => remote_task_context_for_location(
219 *project_id,
220 upstream_client.clone(),
221 state.worktree_store.clone(),
222 captured_variables,
223 location,
224 state.toolchain_store.clone(),
225 cx,
226 ),
227 },
228 TaskStore::Noop => Task::ready(None),
229 }
230 }
231
232 pub fn task_inventory(&self) -> Option<&Model<Inventory>> {
233 match self {
234 TaskStore::Functional(state) => Some(&state.task_inventory),
235 TaskStore::Noop => None,
236 }
237 }
238
239 pub fn shared(
240 &mut self,
241 remote_id: u64,
242 new_downstream_client: AnyProtoClient,
243 _cx: &mut AppContext,
244 ) {
245 if let Self::Functional(StoreState {
246 mode: StoreMode::Local {
247 downstream_client, ..
248 },
249 ..
250 }) = self
251 {
252 *downstream_client = Some((new_downstream_client, remote_id));
253 }
254 }
255
256 pub fn unshared(&mut self, _: &mut ModelContext<Self>) {
257 if let Self::Functional(StoreState {
258 mode: StoreMode::Local {
259 downstream_client, ..
260 },
261 ..
262 }) = self
263 {
264 *downstream_client = None;
265 }
266 }
267
268 pub(super) fn update_user_tasks(
269 &self,
270 location: Option<SettingsLocation<'_>>,
271 raw_tasks_json: Option<&str>,
272 cx: &mut ModelContext<'_, Self>,
273 ) -> anyhow::Result<()> {
274 let task_inventory = match self {
275 TaskStore::Functional(state) => &state.task_inventory,
276 TaskStore::Noop => return Ok(()),
277 };
278 let raw_tasks_json = raw_tasks_json
279 .map(|json| json.trim())
280 .filter(|json| !json.is_empty());
281
282 task_inventory.update(cx, |inventory, _| {
283 inventory.update_file_based_tasks(location, raw_tasks_json)
284 })
285 }
286
287 fn subscribe_to_global_task_file_changes(
288 fs: Arc<dyn Fs>,
289 cx: &mut ModelContext<'_, Self>,
290 ) -> Task<()> {
291 let mut user_tasks_file_rx =
292 watch_config_file(&cx.background_executor(), fs, paths::tasks_file().clone());
293 let user_tasks_content = cx.background_executor().block(user_tasks_file_rx.next());
294 cx.spawn(move |task_store, mut cx| async move {
295 if let Some(user_tasks_content) = user_tasks_content {
296 let Ok(_) = task_store.update(&mut cx, |task_store, cx| {
297 task_store
298 .update_user_tasks(None, Some(&user_tasks_content), cx)
299 .log_err();
300 }) else {
301 return;
302 };
303 }
304 while let Some(user_tasks_content) = user_tasks_file_rx.next().await {
305 let Ok(()) = task_store.update(&mut cx, |task_store, cx| {
306 let result = task_store.update_user_tasks(None, Some(&user_tasks_content), cx);
307 if let Err(err) = &result {
308 log::error!("Failed to load user tasks: {err}");
309 cx.emit(crate::Event::Toast {
310 notification_id: "load-user-tasks".into(),
311 message: format!("Invalid global tasks file\n{err}"),
312 });
313 }
314 cx.refresh();
315 }) else {
316 break; // App dropped
317 };
318 }
319 })
320 }
321}
322
323fn local_task_context_for_location(
324 worktree_store: Model<WorktreeStore>,
325 toolchain_store: Arc<dyn LanguageToolchainStore>,
326 environment: Model<ProjectEnvironment>,
327 captured_variables: TaskVariables,
328 location: Location,
329 cx: &AppContext,
330) -> Task<Option<TaskContext>> {
331 let worktree_id = location.buffer.read(cx).file().map(|f| f.worktree_id(cx));
332 let worktree_abs_path = worktree_id
333 .and_then(|worktree_id| worktree_store.read(cx).worktree_for_id(worktree_id, cx))
334 .and_then(|worktree| worktree.read(cx).root_dir());
335
336 cx.spawn(|mut cx| async move {
337 let worktree_abs_path = worktree_abs_path.clone();
338 let project_env = environment
339 .update(&mut cx, |environment, cx| {
340 environment.get_environment(worktree_id, worktree_abs_path.clone(), cx)
341 })
342 .ok()?
343 .await;
344
345 let mut task_variables = cx
346 .update(|cx| {
347 combine_task_variables(
348 captured_variables,
349 location,
350 project_env.clone(),
351 BasicContextProvider::new(worktree_store),
352 toolchain_store,
353 cx,
354 )
355 })
356 .ok()?
357 .await
358 .log_err()?;
359 // Remove all custom entries starting with _, as they're not intended for use by the end user.
360 task_variables.sweep();
361
362 Some(TaskContext {
363 project_env: project_env.unwrap_or_default(),
364 cwd: worktree_abs_path.map(|p| p.to_path_buf()),
365 task_variables,
366 })
367 })
368}
369
370fn remote_task_context_for_location(
371 project_id: u64,
372 upstream_client: AnyProtoClient,
373 worktree_store: Model<WorktreeStore>,
374 captured_variables: TaskVariables,
375 location: Location,
376 toolchain_store: Arc<dyn LanguageToolchainStore>,
377 cx: &mut AppContext,
378) -> Task<Option<TaskContext>> {
379 cx.spawn(|cx| async move {
380 // We need to gather a client context, as the headless one may lack certain information (e.g. tree-sitter parsing is disabled there, so symbols are not available).
381 let mut remote_context = cx
382 .update(|cx| {
383 BasicContextProvider::new(worktree_store).build_context(
384 &TaskVariables::default(),
385 &location,
386 None,
387 toolchain_store,
388 cx,
389 )
390 })
391 .ok()?
392 .await
393 .log_err()
394 .unwrap_or_default();
395 remote_context.extend(captured_variables);
396
397 let buffer_id = cx
398 .update(|cx| location.buffer.read(cx).remote_id().to_proto())
399 .ok()?;
400 let context_task = upstream_client.request(proto::TaskContextForLocation {
401 project_id,
402 location: Some(proto::Location {
403 buffer_id,
404 start: Some(serialize_anchor(&location.range.start)),
405 end: Some(serialize_anchor(&location.range.end)),
406 }),
407 task_variables: remote_context
408 .into_iter()
409 .map(|(k, v)| (k.to_string(), v))
410 .collect(),
411 });
412 let task_context = context_task.await.log_err()?;
413 Some(TaskContext {
414 cwd: task_context.cwd.map(PathBuf::from),
415 task_variables: task_context
416 .task_variables
417 .into_iter()
418 .filter_map(
419 |(variable_name, variable_value)| match variable_name.parse() {
420 Ok(variable_name) => Some((variable_name, variable_value)),
421 Err(()) => {
422 log::error!("Unknown variable name: {variable_name}");
423 None
424 }
425 },
426 )
427 .collect(),
428 project_env: task_context.project_env.into_iter().collect(),
429 })
430 })
431}
432
433fn combine_task_variables(
434 mut captured_variables: TaskVariables,
435 location: Location,
436 project_env: Option<HashMap<String, String>>,
437 baseline: BasicContextProvider,
438 toolchain_store: Arc<dyn LanguageToolchainStore>,
439 cx: &mut AppContext,
440) -> Task<anyhow::Result<TaskVariables>> {
441 let language_context_provider = location
442 .buffer
443 .read(cx)
444 .language()
445 .and_then(|language| language.context_provider());
446 cx.spawn(move |cx| async move {
447 let baseline = cx
448 .update(|cx| {
449 baseline.build_context(
450 &captured_variables,
451 &location,
452 project_env.clone(),
453 toolchain_store.clone(),
454 cx,
455 )
456 })?
457 .await
458 .context("building basic default context")?;
459 captured_variables.extend(baseline);
460 if let Some(provider) = language_context_provider {
461 captured_variables.extend(
462 cx.update(|cx| {
463 provider.build_context(
464 &captured_variables,
465 &location,
466 project_env,
467 toolchain_store,
468 cx,
469 )
470 })?
471 .await
472 .context("building provider context")?,
473 );
474 }
475 Ok(captured_variables)
476 })
477}