chat-with-functions.rs

  1use anyhow::Context as _;
  2use assets::Assets;
  3use assistant2::AssistantPanel;
  4use assistant_tooling::{LanguageModelTool, ToolRegistry};
  5use client::Client;
  6use gpui::{actions, AnyElement, App, AppContext, KeyBinding, Task, View, WindowOptions};
  7use language::LanguageRegistry;
  8use project::Project;
  9use rand::Rng;
 10use schemars::JsonSchema;
 11use serde::{Deserialize, Serialize};
 12use settings::{KeymapFile, DEFAULT_KEYMAP_PATH};
 13use std::sync::Arc;
 14use theme::LoadThemes;
 15use ui::{div, prelude::*, Render};
 16use util::ResultExt as _;
 17
 18actions!(example, [Quit]);
 19
 20struct RollDiceTool {}
 21
 22impl RollDiceTool {
 23    fn new() -> Self {
 24        Self {}
 25    }
 26}
 27
 28#[derive(Serialize, Deserialize, JsonSchema, Clone)]
 29#[serde(rename_all = "snake_case")]
 30enum Die {
 31    D6 = 6,
 32    D20 = 20,
 33}
 34
 35impl Die {
 36    fn into_str(&self) -> &'static str {
 37        match self {
 38            Die::D6 => "d6",
 39            Die::D20 => "d20",
 40        }
 41    }
 42}
 43
 44#[derive(Serialize, Deserialize, JsonSchema, Clone)]
 45struct DiceParams {
 46    /// The number of dice to roll.
 47    num_dice: u8,
 48    /// Which die to roll. Defaults to a d6 if not provided.
 49    die_type: Option<Die>,
 50}
 51
 52#[derive(Serialize, Deserialize)]
 53struct DieRoll {
 54    die: Die,
 55    roll: u8,
 56}
 57
 58impl DieRoll {
 59    fn render(&self) -> AnyElement {
 60        match self.die {
 61            Die::D6 => {
 62                let face = match self.roll {
 63                    6 => div().child(""),
 64                    5 => div().child(""),
 65                    4 => div().child(""),
 66                    3 => div().child(""),
 67                    2 => div().child(""),
 68                    1 => div().child(""),
 69                    _ => div().child("😅"),
 70                };
 71                face.text_3xl().into_any_element()
 72            }
 73            _ => div()
 74                .child(format!("{}", self.roll))
 75                .text_3xl()
 76                .into_any_element(),
 77        }
 78    }
 79}
 80
 81#[derive(Serialize, Deserialize)]
 82struct DiceRoll {
 83    rolls: Vec<DieRoll>,
 84}
 85
 86impl LanguageModelTool for RollDiceTool {
 87    type Input = DiceParams;
 88    type Output = DiceRoll;
 89
 90    fn name(&self) -> String {
 91        "roll_dice".to_string()
 92    }
 93
 94    fn description(&self) -> String {
 95        "Rolls N many dice and returns the results.".to_string()
 96    }
 97
 98    fn execute(&self, input: &Self::Input, _cx: &AppContext) -> Task<gpui::Result<Self::Output>> {
 99        let rolls = (0..input.num_dice)
100            .map(|_| {
101                let die_type = input.die_type.as_ref().unwrap_or(&Die::D6).clone();
102
103                DieRoll {
104                    die: die_type.clone(),
105                    roll: rand::thread_rng().gen_range(1..=die_type as u8),
106                }
107            })
108            .collect();
109
110        return Task::ready(Ok(DiceRoll { rolls }));
111    }
112
113    fn render(
114        _tool_call_id: &str,
115        _input: &Self::Input,
116        output: &Self::Output,
117        _cx: &mut WindowContext,
118    ) -> gpui::AnyElement {
119        h_flex()
120            .children(
121                output
122                    .rolls
123                    .iter()
124                    .map(|roll| div().p_2().child(roll.render())),
125            )
126            .into_any_element()
127    }
128
129    fn format(_input: &Self::Input, output: &Self::Output) -> String {
130        let mut result = String::new();
131        for roll in &output.rolls {
132            let die = &roll.die;
133            result.push_str(&format!("{}: {}\n", die.into_str(), roll.roll));
134        }
135        result
136    }
137}
138
139fn main() {
140    env_logger::init();
141    App::new().with_assets(Assets).run(|cx| {
142        cx.bind_keys(Some(KeyBinding::new("cmd-q", Quit, None)));
143        cx.on_action(|_: &Quit, cx: &mut AppContext| {
144            cx.quit();
145        });
146
147        settings::init(cx);
148        language::init(cx);
149        Project::init_settings(cx);
150        editor::init(cx);
151        theme::init(LoadThemes::JustBase, cx);
152        Assets.load_fonts(cx).unwrap();
153        KeymapFile::load_asset(DEFAULT_KEYMAP_PATH, cx).unwrap();
154        client::init_settings(cx);
155        release_channel::init("0.130.0", cx);
156
157        let client = Client::production(cx);
158        {
159            let client = client.clone();
160            cx.spawn(|cx| async move { client.authenticate_and_connect(false, &cx).await })
161                .detach_and_log_err(cx);
162        }
163        assistant2::init(client.clone(), cx);
164
165        let language_registry = Arc::new(LanguageRegistry::new(
166            Task::ready(()),
167            cx.background_executor().clone(),
168        ));
169        let node_runtime = node_runtime::RealNodeRuntime::new(client.http_client());
170        languages::init(language_registry.clone(), node_runtime, cx);
171
172        cx.spawn(|cx| async move {
173            cx.update(|cx| {
174                let mut tool_registry = ToolRegistry::new();
175                tool_registry
176                    .register(RollDiceTool::new())
177                    .context("failed to register DummyTool")
178                    .log_err();
179
180                let tool_registry = Arc::new(tool_registry);
181
182                println!("Tools registered");
183                for definition in tool_registry.definitions() {
184                    println!("{}", definition);
185                }
186
187                cx.open_window(WindowOptions::default(), |cx| {
188                    cx.new_view(|cx| Example::new(language_registry, tool_registry, cx))
189                });
190                cx.activate(true);
191            })
192        })
193        .detach_and_log_err(cx);
194    })
195}
196
197struct Example {
198    assistant_panel: View<AssistantPanel>,
199}
200
201impl Example {
202    fn new(
203        language_registry: Arc<LanguageRegistry>,
204        tool_registry: Arc<ToolRegistry>,
205        cx: &mut ViewContext<Self>,
206    ) -> Self {
207        Self {
208            assistant_panel: cx
209                .new_view(|cx| AssistantPanel::new(language_registry, tool_registry, cx)),
210        }
211    }
212}
213
214impl Render for Example {
215    fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl ui::prelude::IntoElement {
216        div().size_full().child(self.assistant_panel.clone())
217    }
218}