1pub mod assistant_panel;
2pub mod assistant_settings;
3mod codegen;
4mod completion_provider;
5mod prompts;
6mod saved_conversation;
7mod streaming_diff;
8
9pub use assistant_panel::AssistantPanel;
10use assistant_settings::{AssistantSettings, OpenAiModel, ZedDotDevModel};
11use client::{proto, Client};
12use command_palette_hooks::CommandPaletteFilter;
13pub(crate) use completion_provider::*;
14use gpui::{actions, AppContext, BorrowAppContext, Global, SharedString};
15pub(crate) use saved_conversation::*;
16use serde::{Deserialize, Serialize};
17use settings::{Settings, SettingsStore};
18use std::{
19 fmt::{self, Display},
20 sync::Arc,
21};
22
23actions!(
24 assistant,
25 [
26 Assist,
27 Split,
28 CycleMessageRole,
29 QuoteSelection,
30 ToggleFocus,
31 ResetKey,
32 InlineAssist,
33 ToggleIncludeConversation,
34 ToggleHistory,
35 ]
36);
37
38#[derive(
39 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
40)]
41struct MessageId(usize);
42
43#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
44#[serde(rename_all = "lowercase")]
45pub enum Role {
46 User,
47 Assistant,
48 System,
49}
50
51impl Role {
52 pub fn cycle(&mut self) {
53 *self = match self {
54 Role::User => Role::Assistant,
55 Role::Assistant => Role::System,
56 Role::System => Role::User,
57 }
58 }
59}
60
61impl Display for Role {
62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
63 match self {
64 Role::User => write!(f, "user"),
65 Role::Assistant => write!(f, "assistant"),
66 Role::System => write!(f, "system"),
67 }
68 }
69}
70
71#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
72pub enum LanguageModel {
73 ZedDotDev(ZedDotDevModel),
74 OpenAi(OpenAiModel),
75}
76
77impl Default for LanguageModel {
78 fn default() -> Self {
79 LanguageModel::ZedDotDev(ZedDotDevModel::default())
80 }
81}
82
83impl LanguageModel {
84 pub fn telemetry_id(&self) -> String {
85 match self {
86 LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
87 LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
88 }
89 }
90
91 pub fn display_name(&self) -> String {
92 match self {
93 LanguageModel::OpenAi(model) => model.display_name().into(),
94 LanguageModel::ZedDotDev(model) => model.display_name().into(),
95 }
96 }
97
98 pub fn max_token_count(&self) -> usize {
99 match self {
100 LanguageModel::OpenAi(model) => model.max_token_count(),
101 LanguageModel::ZedDotDev(model) => model.max_token_count(),
102 }
103 }
104
105 pub fn id(&self) -> &str {
106 match self {
107 LanguageModel::OpenAi(model) => model.id(),
108 LanguageModel::ZedDotDev(model) => model.id(),
109 }
110 }
111}
112
113#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
114pub struct LanguageModelRequestMessage {
115 pub role: Role,
116 pub content: String,
117}
118
119impl LanguageModelRequestMessage {
120 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
121 proto::LanguageModelRequestMessage {
122 role: match self.role {
123 Role::User => proto::LanguageModelRole::LanguageModelUser,
124 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
125 Role::System => proto::LanguageModelRole::LanguageModelSystem,
126 } as i32,
127 content: self.content.clone(),
128 tool_calls: Vec::new(),
129 tool_call_id: None,
130 }
131 }
132}
133
134#[derive(Debug, Default, Serialize)]
135pub struct LanguageModelRequest {
136 pub model: LanguageModel,
137 pub messages: Vec<LanguageModelRequestMessage>,
138 pub stop: Vec<String>,
139 pub temperature: f32,
140}
141
142impl LanguageModelRequest {
143 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
144 proto::CompleteWithLanguageModel {
145 model: self.model.id().to_string(),
146 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
147 stop: self.stop.clone(),
148 temperature: self.temperature,
149 tool_choice: None,
150 tools: Vec::new(),
151 }
152 }
153}
154
155#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
156pub struct LanguageModelResponseMessage {
157 pub role: Option<Role>,
158 pub content: Option<String>,
159}
160
161#[derive(Deserialize, Debug)]
162pub struct LanguageModelUsage {
163 pub prompt_tokens: u32,
164 pub completion_tokens: u32,
165 pub total_tokens: u32,
166}
167
168#[derive(Deserialize, Debug)]
169pub struct LanguageModelChoiceDelta {
170 pub index: u32,
171 pub delta: LanguageModelResponseMessage,
172 pub finish_reason: Option<String>,
173}
174
175#[derive(Clone, Debug, Serialize, Deserialize)]
176struct MessageMetadata {
177 role: Role,
178 status: MessageStatus,
179}
180
181#[derive(Clone, Debug, Serialize, Deserialize)]
182enum MessageStatus {
183 Pending,
184 Done,
185 Error(SharedString),
186}
187
188/// The state pertaining to the Assistant.
189#[derive(Default)]
190struct Assistant {
191 /// Whether the Assistant is enabled.
192 enabled: bool,
193}
194
195impl Global for Assistant {}
196
197impl Assistant {
198 const NAMESPACE: &'static str = "assistant";
199
200 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
201 if self.enabled == enabled {
202 return;
203 }
204
205 self.enabled = enabled;
206
207 if !enabled {
208 CommandPaletteFilter::update_global(cx, |filter, _cx| {
209 filter.hide_namespace(Self::NAMESPACE);
210 });
211
212 return;
213 }
214
215 CommandPaletteFilter::update_global(cx, |filter, _cx| {
216 filter.show_namespace(Self::NAMESPACE);
217 });
218 }
219}
220
221pub fn init(client: Arc<Client>, cx: &mut AppContext) {
222 cx.set_global(Assistant::default());
223 AssistantSettings::register(cx);
224 completion_provider::init(client, cx);
225 assistant_panel::init(cx);
226
227 CommandPaletteFilter::update_global(cx, |filter, _cx| {
228 filter.hide_namespace(Assistant::NAMESPACE);
229 });
230 cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
231 let settings = AssistantSettings::get_global(cx);
232
233 assistant.set_enabled(settings.enabled, cx);
234 });
235 cx.observe_global::<SettingsStore>(|cx| {
236 cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
237 let settings = AssistantSettings::get_global(cx);
238
239 assistant.set_enabled(settings.enabled, cx);
240 });
241 })
242 .detach();
243}
244
245#[cfg(test)]
246#[ctor::ctor]
247fn init_logger() {
248 if std::env::var("RUST_LOG").is_ok() {
249 env_logger::init();
250 }
251}