chat_with_functions.rs

  1//! This example creates a basic Chat UI with a function for rolling a die.
  2
  3use anyhow::{Context as _, Result};
  4use assets::Assets;
  5use assistant2::AssistantPanel;
  6use assistant_tooling::{LanguageModelTool, ToolRegistry};
  7use client::Client;
  8use gpui::{actions, AnyElement, App, AppContext, KeyBinding, Task, View, WindowOptions};
  9use language::LanguageRegistry;
 10use project::Project;
 11use rand::Rng;
 12use schemars::JsonSchema;
 13use serde::{Deserialize, Serialize};
 14use settings::{KeymapFile, DEFAULT_KEYMAP_PATH};
 15use std::sync::Arc;
 16use theme::LoadThemes;
 17use ui::{div, prelude::*, Render};
 18use util::ResultExt as _;
 19
 20actions!(example, [Quit]);
 21
 22struct RollDiceTool {}
 23
 24impl RollDiceTool {
 25    fn new() -> Self {
 26        Self {}
 27    }
 28}
 29
 30#[derive(Serialize, Deserialize, JsonSchema, Clone)]
 31#[serde(rename_all = "snake_case")]
 32enum Die {
 33    D6 = 6,
 34    D20 = 20,
 35}
 36
 37impl Die {
 38    fn into_str(&self) -> &'static str {
 39        match self {
 40            Die::D6 => "d6",
 41            Die::D20 => "d20",
 42        }
 43    }
 44}
 45
 46#[derive(Serialize, Deserialize, JsonSchema, Clone)]
 47struct DiceParams {
 48    /// The number of dice to roll.
 49    num_dice: u8,
 50    /// Which die to roll. Defaults to a d6 if not provided.
 51    die_type: Option<Die>,
 52}
 53
 54#[derive(Serialize, Deserialize)]
 55struct DieRoll {
 56    die: Die,
 57    roll: u8,
 58}
 59
 60impl DieRoll {
 61    fn render(&self) -> AnyElement {
 62        match self.die {
 63            Die::D6 => {
 64                let face = match self.roll {
 65                    6 => div().child(""),
 66                    5 => div().child(""),
 67                    4 => div().child(""),
 68                    3 => div().child(""),
 69                    2 => div().child(""),
 70                    1 => div().child(""),
 71                    _ => div().child("😅"),
 72                };
 73                face.text_3xl().into_any_element()
 74            }
 75            _ => div()
 76                .child(format!("{}", self.roll))
 77                .text_3xl()
 78                .into_any_element(),
 79        }
 80    }
 81}
 82
 83#[derive(Serialize, Deserialize)]
 84struct DiceRoll {
 85    rolls: Vec<DieRoll>,
 86}
 87
 88pub struct DiceView {
 89    result: Result<DiceRoll>,
 90}
 91
 92impl Render for DiceView {
 93    fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
 94        let output = match &self.result {
 95            Ok(output) => output,
 96            Err(_) => return "Somehow dice failed 🎲".into_any_element(),
 97        };
 98
 99        h_flex()
100            .children(
101                output
102                    .rolls
103                    .iter()
104                    .map(|roll| div().p_2().child(roll.render())),
105            )
106            .into_any_element()
107    }
108}
109
110impl LanguageModelTool for RollDiceTool {
111    type Input = DiceParams;
112    type Output = DiceRoll;
113    type View = DiceView;
114
115    fn name(&self) -> String {
116        "roll_dice".to_string()
117    }
118
119    fn description(&self) -> String {
120        "Rolls N many dice and returns the results.".to_string()
121    }
122
123    fn execute(&self, input: &Self::Input, _cx: &AppContext) -> Task<gpui::Result<Self::Output>> {
124        let rolls = (0..input.num_dice)
125            .map(|_| {
126                let die_type = input.die_type.as_ref().unwrap_or(&Die::D6).clone();
127
128                DieRoll {
129                    die: die_type.clone(),
130                    roll: rand::thread_rng().gen_range(1..=die_type as u8),
131                }
132            })
133            .collect();
134
135        return Task::ready(Ok(DiceRoll { rolls }));
136    }
137
138    fn output_view(
139        _tool_call_id: String,
140        _input: Self::Input,
141        result: Result<Self::Output>,
142        cx: &mut WindowContext,
143    ) -> gpui::View<Self::View> {
144        cx.new_view(|_cx| DiceView { result })
145    }
146
147    fn format(_: &Self::Input, output: &Result<Self::Output>) -> String {
148        let output = match output {
149            Ok(output) => output,
150            Err(_) => return "Somehow dice failed 🎲".to_string(),
151        };
152
153        let mut result = String::new();
154        for roll in &output.rolls {
155            let die = &roll.die;
156            result.push_str(&format!("{}: {}\n", die.into_str(), roll.roll));
157        }
158        result
159    }
160}
161
162fn main() {
163    env_logger::init();
164    App::new().with_assets(Assets).run(|cx| {
165        cx.bind_keys(Some(KeyBinding::new("cmd-q", Quit, None)));
166        cx.on_action(|_: &Quit, cx: &mut AppContext| {
167            cx.quit();
168        });
169
170        settings::init(cx);
171        language::init(cx);
172        Project::init_settings(cx);
173        editor::init(cx);
174        theme::init(LoadThemes::JustBase, cx);
175        Assets.load_fonts(cx).unwrap();
176        KeymapFile::load_asset(DEFAULT_KEYMAP_PATH, cx).unwrap();
177        client::init_settings(cx);
178        release_channel::init("0.130.0", cx);
179
180        let client = Client::production(cx);
181        {
182            let client = client.clone();
183            cx.spawn(|cx| async move { client.authenticate_and_connect(false, &cx).await })
184                .detach_and_log_err(cx);
185        }
186        assistant2::init(client.clone(), cx);
187
188        let language_registry = Arc::new(LanguageRegistry::new(
189            Task::ready(()),
190            cx.background_executor().clone(),
191        ));
192        let node_runtime = node_runtime::RealNodeRuntime::new(client.http_client());
193        languages::init(language_registry.clone(), node_runtime, cx);
194
195        cx.spawn(|cx| async move {
196            cx.update(|cx| {
197                cx.open_window(WindowOptions::default(), |cx| {
198                    let mut tool_registry = ToolRegistry::new();
199                    tool_registry
200                        .register(RollDiceTool::new(), cx)
201                        .context("failed to register DummyTool")
202                        .log_err();
203
204                    let tool_registry = Arc::new(tool_registry);
205
206                    println!("Tools registered");
207                    for definition in tool_registry.definitions() {
208                        println!("{}", definition);
209                    }
210
211                    cx.new_view(|cx| Example::new(language_registry, tool_registry, cx))
212                });
213                cx.activate(true);
214            })
215        })
216        .detach_and_log_err(cx);
217    })
218}
219
220struct Example {
221    assistant_panel: View<AssistantPanel>,
222}
223
224impl Example {
225    fn new(
226        language_registry: Arc<LanguageRegistry>,
227        tool_registry: Arc<ToolRegistry>,
228        cx: &mut ViewContext<Self>,
229    ) -> Self {
230        Self {
231            assistant_panel: cx
232                .new_view(|cx| AssistantPanel::new(language_registry, tool_registry, cx)),
233        }
234    }
235}
236
237impl Render for Example {
238    fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl ui::prelude::IntoElement {
239        div().size_full().child(self.assistant_panel.clone())
240    }
241}