task_template.rs

  1use std::path::PathBuf;
  2
  3use anyhow::{bail, Context};
  4use collections::{HashMap, HashSet};
  5use schemars::{gen::SchemaSettings, JsonSchema};
  6use serde::{Deserialize, Serialize};
  7use sha2::{Digest, Sha256};
  8use util::{truncate_and_remove_front, ResultExt};
  9
 10use crate::{
 11    ResolvedTask, SpawnInTerminal, TaskContext, TaskId, VariableName, ZED_VARIABLE_NAME_PREFIX,
 12};
 13
 14/// A template definition of a Zed task to run.
 15/// May use the [`VariableName`] to get the corresponding substitutions into its fields.
 16///
 17/// Template itself is not ready to spawn a task, it needs to be resolved with a [`TaskContext`] first, that
 18/// contains all relevant Zed state in task variables.
 19/// A single template may produce different tasks (or none) for different contexts.
 20#[derive(Clone, Default, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
 21#[serde(rename_all = "snake_case")]
 22pub struct TaskTemplate {
 23    /// Human readable name of the task to display in the UI.
 24    pub label: String,
 25    /// Executable command to spawn.
 26    pub command: String,
 27    /// Arguments to the command.
 28    #[serde(default)]
 29    pub args: Vec<String>,
 30    /// Env overrides for the command, will be appended to the terminal's environment from the settings.
 31    #[serde(default)]
 32    pub env: HashMap<String, String>,
 33    /// Current working directory to spawn the command into, defaults to current project root.
 34    #[serde(default)]
 35    pub cwd: Option<String>,
 36    /// Whether to use a new terminal tab or reuse the existing one to spawn the process.
 37    #[serde(default)]
 38    pub use_new_terminal: bool,
 39    /// Whether to allow multiple instances of the same task to be run, or rather wait for the existing ones to finish.
 40    #[serde(default)]
 41    pub allow_concurrent_runs: bool,
 42    // Tasks like "execute the selection" better have the constant labels (to avoid polluting the history with temporary tasks),
 43    // and always use the latest context with the latest selection.
 44    //
 45    // Current impl will always pick previously spawned tasks on full label conflict in the tasks modal and terminal tabs, never
 46    // getting the latest selection for them.
 47    // This flag inverts the behavior, effectively removing all previously spawned tasks from history,
 48    // if their full labels are the same as the labels of the newly resolved tasks.
 49    // Such tasks are still re-runnable, and will use the old context in that case (unless the rerun task forces this).
 50    //
 51    // Current approach is relatively hacky, a better way is understand when the new resolved tasks needs a rerun,
 52    // and replace the historic task accordingly.
 53    #[doc(hidden)]
 54    #[serde(default)]
 55    pub ignore_previously_resolved: bool,
 56    /// What to do with the terminal pane and tab, after the command was started:
 57    /// * `always` — always show the terminal pane, add and focus the corresponding task's tab in it (default)
 58    /// * `never` — avoid changing current terminal pane focus, but still add/reuse the task's tab there
 59    #[serde(default)]
 60    pub reveal: RevealStrategy,
 61}
 62
 63/// What to do with the terminal pane and tab, after the command was started.
 64#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
 65#[serde(rename_all = "snake_case")]
 66pub enum RevealStrategy {
 67    /// Always show the terminal pane, add and focus the corresponding task's tab in it.
 68    #[default]
 69    Always,
 70    /// Do not change terminal pane focus, but still add/reuse the task's tab there.
 71    Never,
 72}
 73
 74/// A group of Tasks defined in a JSON file.
 75#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
 76pub struct TaskTemplates(pub Vec<TaskTemplate>);
 77
 78impl TaskTemplates {
 79    /// Generates JSON schema of Tasks JSON template format.
 80    pub fn generate_json_schema() -> serde_json_lenient::Value {
 81        let schema = SchemaSettings::draft07()
 82            .with(|settings| settings.option_add_null_type = false)
 83            .into_generator()
 84            .into_root_schema_for::<Self>();
 85
 86        serde_json_lenient::to_value(schema).unwrap()
 87    }
 88}
 89
 90impl TaskTemplate {
 91    /// Replaces all `VariableName` task variables in the task template string fields.
 92    /// If any replacement fails or the new string substitutions still have [`ZED_VARIABLE_NAME_PREFIX`],
 93    /// `None` is returned.
 94    ///
 95    /// Every [`ResolvedTask`] gets a [`TaskId`], based on the `id_base` (to avoid collision with various task sources),
 96    /// and hashes of its template and [`TaskContext`], see [`ResolvedTask`] fields' documentation for more details.
 97    pub fn resolve_task(&self, id_base: &str, cx: &TaskContext) -> Option<ResolvedTask> {
 98        if self.label.trim().is_empty() || self.command.trim().is_empty() {
 99            return None;
100        }
101
102        let mut variable_names = HashMap::default();
103        let mut substituted_variables = HashSet::default();
104        let task_variables = cx
105            .task_variables
106            .0
107            .iter()
108            .map(|(key, value)| {
109                let key_string = key.to_string();
110                if !variable_names.contains_key(&key_string) {
111                    variable_names.insert(key_string.clone(), key.clone());
112                }
113                (key_string, value.as_str())
114            })
115            .collect::<HashMap<_, _>>();
116        let truncated_variables = truncate_variables(&task_variables);
117        let cwd = match self.cwd.as_deref() {
118            Some(cwd) => {
119                let substitured_cwd = substitute_all_template_variables_in_str(
120                    cwd,
121                    &task_variables,
122                    &variable_names,
123                    &mut substituted_variables,
124                )?;
125                Some(substitured_cwd)
126            }
127            None => None,
128        }
129        .map(PathBuf::from)
130        .or(cx.cwd.clone());
131        let human_readable_label = substitute_all_template_variables_in_str(
132            &self.label,
133            &truncated_variables,
134            &variable_names,
135            &mut substituted_variables,
136        )?
137        .lines()
138        .fold(String::new(), |mut string, line| {
139            if string.is_empty() {
140                string.push_str(line);
141            } else {
142                string.push_str("\\n");
143                string.push_str(line);
144            }
145            string
146        });
147        let full_label = substitute_all_template_variables_in_str(
148            &self.label,
149            &task_variables,
150            &variable_names,
151            &mut substituted_variables,
152        )?;
153        let command = substitute_all_template_variables_in_str(
154            &self.command,
155            &task_variables,
156            &variable_names,
157            &mut substituted_variables,
158        )?;
159        let args_with_substitutions = substitute_all_template_variables_in_vec(
160            &self.args,
161            &task_variables,
162            &variable_names,
163            &mut substituted_variables,
164        )?;
165
166        let task_hash = to_hex_hash(&self)
167            .context("hashing task template")
168            .log_err()?;
169        let variables_hash = to_hex_hash(&task_variables)
170            .context("hashing task variables")
171            .log_err()?;
172        let id = TaskId(format!("{id_base}_{task_hash}_{variables_hash}"));
173        let mut env = substitute_all_template_variables_in_map(
174            &self.env,
175            &task_variables,
176            &variable_names,
177            &mut substituted_variables,
178        )?;
179        env.extend(task_variables.into_iter().map(|(k, v)| (k, v.to_owned())));
180        Some(ResolvedTask {
181            id: id.clone(),
182            substituted_variables,
183            original_task: self.clone(),
184            resolved_label: full_label.clone(),
185            resolved: Some(SpawnInTerminal {
186                id,
187                cwd,
188                full_label,
189                label: human_readable_label,
190                command_label: args_with_substitutions.iter().fold(
191                    command.clone(),
192                    |mut command_label, arg| {
193                        command_label.push(' ');
194                        command_label.push_str(arg);
195                        command_label
196                    },
197                ),
198                command,
199                args: self.args.clone(),
200                env,
201                use_new_terminal: self.use_new_terminal,
202                allow_concurrent_runs: self.allow_concurrent_runs,
203                reveal: self.reveal,
204            }),
205        })
206    }
207}
208
209const MAX_DISPLAY_VARIABLE_LENGTH: usize = 15;
210
211fn truncate_variables(task_variables: &HashMap<String, &str>) -> HashMap<String, String> {
212    task_variables
213        .iter()
214        .map(|(key, value)| {
215            (
216                key.clone(),
217                truncate_and_remove_front(value, MAX_DISPLAY_VARIABLE_LENGTH),
218            )
219        })
220        .collect()
221}
222
223fn to_hex_hash(object: impl Serialize) -> anyhow::Result<String> {
224    let json = serde_json_lenient::to_string(&object).context("serializing the object")?;
225    let mut hasher = Sha256::new();
226    hasher.update(json.as_bytes());
227    Ok(hex::encode(hasher.finalize()))
228}
229
230fn substitute_all_template_variables_in_str<A: AsRef<str>>(
231    template_str: &str,
232    task_variables: &HashMap<String, A>,
233    variable_names: &HashMap<String, VariableName>,
234    substituted_variables: &mut HashSet<VariableName>,
235) -> Option<String> {
236    let substituted_string = shellexpand::env_with_context(template_str, |var| {
237        // Colons denote a default value in case the variable is not set. We want to preserve that default, as otherwise shellexpand will substitute it for us.
238        let colon_position = var.find(':').unwrap_or(var.len());
239        let (variable_name, default) = var.split_at(colon_position);
240        if let Some(name) = task_variables.get(variable_name) {
241            if let Some(substituted_variable) = variable_names.get(variable_name) {
242                substituted_variables.insert(substituted_variable.clone());
243            }
244
245            let mut name = name.as_ref().to_owned();
246            // Got a task variable hit
247            if !default.is_empty() {
248                name.push_str(default);
249            }
250            return Ok(Some(name));
251        } else if variable_name.starts_with(ZED_VARIABLE_NAME_PREFIX) {
252            bail!("Unknown variable name: {variable_name}");
253        }
254        // This is an unknown variable.
255        // We should not error out, as they may come from user environment (e.g. $PATH). That means that the variable substitution might not be perfect.
256        // If there's a default, we need to return the string verbatim as otherwise shellexpand will apply that default for us.
257        if !default.is_empty() {
258            return Ok(Some(format!("${{{var}}}")));
259        }
260        // Else we can just return None and that variable will be left as is.
261        Ok(None)
262    })
263    .ok()?;
264    Some(substituted_string.into_owned())
265}
266
267fn substitute_all_template_variables_in_vec(
268    template_strs: &[String],
269    task_variables: &HashMap<String, &str>,
270    variable_names: &HashMap<String, VariableName>,
271    substituted_variables: &mut HashSet<VariableName>,
272) -> Option<Vec<String>> {
273    let mut expanded = Vec::with_capacity(template_strs.len());
274    for variable in template_strs {
275        let new_value = substitute_all_template_variables_in_str(
276            variable,
277            task_variables,
278            variable_names,
279            substituted_variables,
280        )?;
281        expanded.push(new_value);
282    }
283    Some(expanded)
284}
285
286fn substitute_all_template_variables_in_map(
287    keys_and_values: &HashMap<String, String>,
288    task_variables: &HashMap<String, &str>,
289    variable_names: &HashMap<String, VariableName>,
290    substituted_variables: &mut HashSet<VariableName>,
291) -> Option<HashMap<String, String>> {
292    let mut new_map: HashMap<String, String> = Default::default();
293    for (key, value) in keys_and_values {
294        let new_value = substitute_all_template_variables_in_str(
295            &value,
296            task_variables,
297            variable_names,
298            substituted_variables,
299        )?;
300        let new_key = substitute_all_template_variables_in_str(
301            &key,
302            task_variables,
303            variable_names,
304            substituted_variables,
305        )?;
306        new_map.insert(new_key, new_value);
307    }
308    Some(new_map)
309}
310
311#[cfg(test)]
312mod tests {
313    use std::{borrow::Cow, path::Path};
314
315    use crate::{TaskVariables, VariableName};
316
317    use super::*;
318
319    const TEST_ID_BASE: &str = "test_base";
320
321    #[test]
322    fn test_resolving_templates_with_blank_command_and_label() {
323        let task_with_all_properties = TaskTemplate {
324            label: "test_label".to_string(),
325            command: "test_command".to_string(),
326            args: vec!["test_arg".to_string()],
327            env: HashMap::from_iter([("test_env_key".to_string(), "test_env_var".to_string())]),
328            ..TaskTemplate::default()
329        };
330
331        for task_with_blank_property in &[
332            TaskTemplate {
333                label: "".to_string(),
334                ..task_with_all_properties.clone()
335            },
336            TaskTemplate {
337                command: "".to_string(),
338                ..task_with_all_properties.clone()
339            },
340            TaskTemplate {
341                label: "".to_string(),
342                command: "".to_string(),
343                ..task_with_all_properties.clone()
344            },
345        ] {
346            assert_eq!(
347                task_with_blank_property.resolve_task(TEST_ID_BASE, &TaskContext::default()),
348                None,
349                "should not resolve task with blank label and/or command: {task_with_blank_property:?}"
350            );
351        }
352    }
353
354    #[test]
355    fn test_template_cwd_resolution() {
356        let task_without_cwd = TaskTemplate {
357            cwd: None,
358            label: "test task".to_string(),
359            command: "echo 4".to_string(),
360            ..TaskTemplate::default()
361        };
362
363        let resolved_task = |task_template: &TaskTemplate, task_cx| {
364            let resolved_task = task_template
365                .resolve_task(TEST_ID_BASE, task_cx)
366                .unwrap_or_else(|| panic!("failed to resolve task {task_without_cwd:?}"));
367            assert_substituted_variables(&resolved_task, Vec::new());
368            resolved_task
369                .resolved
370                .clone()
371                .unwrap_or_else(|| {
372                    panic!("failed to get resolve data for resolved task. Template: {task_without_cwd:?} Resolved: {resolved_task:?}")
373                })
374        };
375
376        let cx = TaskContext {
377            cwd: None,
378            task_variables: TaskVariables::default(),
379        };
380        assert_eq!(
381            resolved_task(&task_without_cwd, &cx).cwd,
382            None,
383            "When neither task nor task context have cwd, it should be None"
384        );
385
386        let context_cwd = Path::new("a").join("b").join("c");
387        let cx = TaskContext {
388            cwd: Some(context_cwd.clone()),
389            task_variables: TaskVariables::default(),
390        };
391        assert_eq!(
392            resolved_task(&task_without_cwd, &cx).cwd.as_deref(),
393            Some(context_cwd.as_path()),
394            "TaskContext's cwd should be taken on resolve if task's cwd is None"
395        );
396
397        let task_cwd = Path::new("d").join("e").join("f");
398        let mut task_with_cwd = task_without_cwd.clone();
399        task_with_cwd.cwd = Some(task_cwd.display().to_string());
400        let task_with_cwd = task_with_cwd;
401
402        let cx = TaskContext {
403            cwd: None,
404            task_variables: TaskVariables::default(),
405        };
406        assert_eq!(
407            resolved_task(&task_with_cwd, &cx).cwd.as_deref(),
408            Some(task_cwd.as_path()),
409            "TaskTemplate's cwd should be taken on resolve if TaskContext's cwd is None"
410        );
411
412        let cx = TaskContext {
413            cwd: Some(context_cwd.clone()),
414            task_variables: TaskVariables::default(),
415        };
416        assert_eq!(
417            resolved_task(&task_with_cwd, &cx).cwd.as_deref(),
418            Some(task_cwd.as_path()),
419            "TaskTemplate's cwd should be taken on resolve if TaskContext's cwd is not None"
420        );
421    }
422
423    #[test]
424    fn test_template_variables_resolution() {
425        let custom_variable_1 = VariableName::Custom(Cow::Borrowed("custom_variable_1"));
426        let custom_variable_2 = VariableName::Custom(Cow::Borrowed("custom_variable_2"));
427        let long_value = "01".repeat(MAX_DISPLAY_VARIABLE_LENGTH * 2);
428        let all_variables = [
429            (VariableName::Row, "1234".to_string()),
430            (VariableName::Column, "5678".to_string()),
431            (VariableName::File, "test_file".to_string()),
432            (VariableName::SelectedText, "test_selected_text".to_string()),
433            (VariableName::Symbol, long_value.clone()),
434            (VariableName::WorktreeRoot, "/test_root/".to_string()),
435            (
436                custom_variable_1.clone(),
437                "test_custom_variable_1".to_string(),
438            ),
439            (
440                custom_variable_2.clone(),
441                "test_custom_variable_2".to_string(),
442            ),
443        ];
444
445        let task_with_all_variables = TaskTemplate {
446            label: format!(
447                "test label for {} and {}",
448                VariableName::Row.template_value(),
449                VariableName::Symbol.template_value(),
450            ),
451            command: format!(
452                "echo {} {}",
453                VariableName::File.template_value(),
454                VariableName::Symbol.template_value(),
455            ),
456            args: vec![
457                format!("arg1 {}", VariableName::SelectedText.template_value()),
458                format!("arg2 {}", VariableName::Column.template_value()),
459                format!("arg3 {}", VariableName::Symbol.template_value()),
460            ],
461            env: HashMap::from_iter([
462                ("test_env_key".to_string(), "test_env_var".to_string()),
463                (
464                    "env_key_1".to_string(),
465                    VariableName::WorktreeRoot.template_value(),
466                ),
467                (
468                    "env_key_2".to_string(),
469                    format!(
470                        "env_var_2_{}_{}",
471                        custom_variable_1.template_value(),
472                        custom_variable_2.template_value()
473                    ),
474                ),
475                (
476                    "env_key_3".to_string(),
477                    format!("env_var_3_{}", VariableName::Symbol.template_value()),
478                ),
479            ]),
480            ..TaskTemplate::default()
481        };
482
483        let mut first_resolved_id = None;
484        for i in 0..15 {
485            let resolved_task = task_with_all_variables.resolve_task(
486                TEST_ID_BASE,
487                &TaskContext {
488                    cwd: None,
489                    task_variables: TaskVariables::from_iter(all_variables.clone()),
490                },
491            ).unwrap_or_else(|| panic!("Should successfully resolve task {task_with_all_variables:?} with variables {all_variables:?}"));
492
493            match &first_resolved_id {
494                None => first_resolved_id = Some(resolved_task.id.clone()),
495                Some(first_id) => assert_eq!(
496                    &resolved_task.id, first_id,
497                    "Step {i}, for the same task template and context, there should be the same resolved task id"
498                ),
499            }
500
501            assert_eq!(
502                resolved_task.original_task, task_with_all_variables,
503                "Resolved task should store its template without changes"
504            );
505            assert_eq!(
506                resolved_task.resolved_label,
507                format!("test label for 1234 and {long_value}"),
508                "Resolved task label should be substituted with variables and those should not be shortened"
509            );
510            assert_substituted_variables(
511                &resolved_task,
512                all_variables.iter().map(|(name, _)| name.clone()).collect(),
513            );
514
515            let spawn_in_terminal = resolved_task
516                .resolved
517                .as_ref()
518                .expect("should have resolved a spawn in terminal task");
519            assert_eq!(
520                spawn_in_terminal.label,
521                format!(
522                    "test label for 1234 and …{}",
523                    &long_value[..=MAX_DISPLAY_VARIABLE_LENGTH]
524                ),
525                "Human-readable label should have long substitutions trimmed"
526            );
527            assert_eq!(
528                spawn_in_terminal.command,
529                format!("echo test_file {long_value}"),
530                "Command should be substituted with variables and those should not be shortened"
531            );
532            assert_eq!(
533                spawn_in_terminal.args,
534                &[
535                    "arg1 $ZED_SELECTED_TEXT",
536                    "arg2 $ZED_COLUMN",
537                    "arg3 $ZED_SYMBOL",
538                ],
539                "Args should not be substituted with variables"
540            );
541            assert_eq!(
542                spawn_in_terminal.command_label,
543                format!("{} arg1 test_selected_text arg2 5678 arg3 {long_value}", spawn_in_terminal.command),
544                "Command label args should be substituted with variables and those should not be shortened"
545            );
546
547            assert_eq!(
548                spawn_in_terminal
549                    .env
550                    .get("test_env_key")
551                    .map(|s| s.as_str()),
552                Some("test_env_var")
553            );
554            assert_eq!(
555                spawn_in_terminal.env.get("env_key_1").map(|s| s.as_str()),
556                Some("/test_root/")
557            );
558            assert_eq!(
559                spawn_in_terminal.env.get("env_key_2").map(|s| s.as_str()),
560                Some("env_var_2_test_custom_variable_1_test_custom_variable_2")
561            );
562            assert_eq!(
563                spawn_in_terminal.env.get("env_key_3"),
564                Some(&format!("env_var_3_{long_value}")),
565                "Env vars should be substituted with variables and those should not be shortened"
566            );
567        }
568
569        for i in 0..all_variables.len() {
570            let mut not_all_variables = all_variables.to_vec();
571            let removed_variable = not_all_variables.remove(i);
572            let resolved_task_attempt = task_with_all_variables.resolve_task(
573                TEST_ID_BASE,
574                &TaskContext {
575                    cwd: None,
576                    task_variables: TaskVariables::from_iter(not_all_variables),
577                },
578            );
579            assert_eq!(resolved_task_attempt, None, "If any of the Zed task variables is not substituted, the task should not be resolved, but got some resolution without the variable {removed_variable:?} (index {i})");
580        }
581    }
582
583    #[test]
584    fn test_can_resolve_free_variables() {
585        let task = TaskTemplate {
586            label: "My task".into(),
587            command: "echo".into(),
588            args: vec!["$PATH".into()],
589            ..Default::default()
590        };
591        let resolved_task = task
592            .resolve_task(TEST_ID_BASE, &TaskContext::default())
593            .unwrap();
594        assert_substituted_variables(&resolved_task, Vec::new());
595        let resolved = resolved_task.resolved.unwrap();
596        assert_eq!(resolved.label, task.label);
597        assert_eq!(resolved.command, task.command);
598        assert_eq!(resolved.args, task.args);
599    }
600
601    #[test]
602    fn test_errors_on_missing_zed_variable() {
603        let task = TaskTemplate {
604            label: "My task".into(),
605            command: "echo".into(),
606            args: vec!["$ZED_VARIABLE".into()],
607            ..Default::default()
608        };
609        assert!(task
610            .resolve_task(TEST_ID_BASE, &TaskContext::default())
611            .is_none());
612    }
613
614    #[test]
615    fn test_symbol_dependent_tasks() {
616        let task_with_all_properties = TaskTemplate {
617            label: "test_label".to_string(),
618            command: "test_command".to_string(),
619            args: vec!["test_arg".to_string()],
620            env: HashMap::from_iter([("test_env_key".to_string(), "test_env_var".to_string())]),
621            ..TaskTemplate::default()
622        };
623        let cx = TaskContext {
624            cwd: None,
625            task_variables: TaskVariables::from_iter(Some((
626                VariableName::Symbol,
627                "test_symbol".to_string(),
628            ))),
629        };
630
631        for (i, symbol_dependent_task) in [
632            TaskTemplate {
633                label: format!("test_label_{}", VariableName::Symbol.template_value()),
634                ..task_with_all_properties.clone()
635            },
636            TaskTemplate {
637                command: format!("test_command_{}", VariableName::Symbol.template_value()),
638                ..task_with_all_properties.clone()
639            },
640            TaskTemplate {
641                args: vec![format!(
642                    "test_arg_{}",
643                    VariableName::Symbol.template_value()
644                )],
645                ..task_with_all_properties.clone()
646            },
647            TaskTemplate {
648                env: HashMap::from_iter([(
649                    "test_env_key".to_string(),
650                    format!("test_env_var_{}", VariableName::Symbol.template_value()),
651                )]),
652                ..task_with_all_properties.clone()
653            },
654        ]
655        .into_iter()
656        .enumerate()
657        {
658            let resolved = symbol_dependent_task
659                .resolve_task(TEST_ID_BASE, &cx)
660                .unwrap_or_else(|| panic!("Failed to resolve task {symbol_dependent_task:?}"));
661            assert_eq!(
662                resolved.substituted_variables,
663                HashSet::from_iter(Some(VariableName::Symbol)),
664                "(index {i}) Expected the task to depend on symbol task variable: {resolved:?}"
665            )
666        }
667    }
668
669    #[track_caller]
670    fn assert_substituted_variables(resolved_task: &ResolvedTask, mut expected: Vec<VariableName>) {
671        let mut resolved_variables = resolved_task
672            .substituted_variables
673            .iter()
674            .cloned()
675            .collect::<Vec<_>>();
676        resolved_variables.sort_by_key(|var| var.to_string());
677        expected.sort_by_key(|var| var.to_string());
678        assert_eq!(resolved_variables, expected)
679    }
680}