1use anyhow::{anyhow, bail, Result};
2use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
3use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
4use http_client::HttpClient;
5use ollama::{
6 get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
7 ChatResponseDelta, OllamaToolCall,
8};
9use schemars::JsonSchema;
10use serde::{Deserialize, Serialize};
11use settings::{Settings, SettingsStore};
12use std::{collections::BTreeMap, sync::Arc, time::Duration};
13use ui::{prelude::*, ButtonLike, Indicator};
14use util::ResultExt;
15
16use crate::{
17 settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
18 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
19 LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
20};
21
22const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
23const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
24const OLLAMA_SITE: &str = "https://ollama.com/";
25
26const PROVIDER_ID: &str = "ollama";
27const PROVIDER_NAME: &str = "Ollama";
28
29#[derive(Default, Debug, Clone, PartialEq)]
30pub struct OllamaSettings {
31 pub api_url: String,
32 pub low_speed_timeout: Option<Duration>,
33 pub available_models: Vec<AvailableModel>,
34}
35
36#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
37pub struct AvailableModel {
38 /// The model name in the Ollama API (e.g. "llama3.1:latest")
39 pub name: String,
40 /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
41 pub display_name: Option<String>,
42 /// The Context Length parameter to the model (aka num_ctx or n_ctx)
43 pub max_tokens: usize,
44}
45
46pub struct OllamaLanguageModelProvider {
47 http_client: Arc<dyn HttpClient>,
48 state: gpui::Model<State>,
49}
50
51pub struct State {
52 http_client: Arc<dyn HttpClient>,
53 available_models: Vec<ollama::Model>,
54 _subscription: Subscription,
55}
56
57impl State {
58 fn is_authenticated(&self) -> bool {
59 !self.available_models.is_empty()
60 }
61
62 fn fetch_models(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
63 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
64 let http_client = self.http_client.clone();
65 let api_url = settings.api_url.clone();
66
67 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
68 cx.spawn(|this, mut cx| async move {
69 let models = get_models(http_client.as_ref(), &api_url, None).await?;
70
71 let mut models: Vec<ollama::Model> = models
72 .into_iter()
73 // Since there is no metadata from the Ollama API
74 // indicating which models are embedding models,
75 // simply filter out models with "-embed" in their name
76 .filter(|model| !model.name.contains("-embed"))
77 .map(|model| ollama::Model::new(&model.name, None, None))
78 .collect();
79
80 models.sort_by(|a, b| a.name.cmp(&b.name));
81
82 this.update(&mut cx, |this, cx| {
83 this.available_models = models;
84 cx.notify();
85 })
86 })
87 }
88
89 fn authenticate(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
90 if self.is_authenticated() {
91 Task::ready(Ok(()))
92 } else {
93 self.fetch_models(cx)
94 }
95 }
96}
97
98impl OllamaLanguageModelProvider {
99 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
100 let this = Self {
101 http_client: http_client.clone(),
102 state: cx.new_model(|cx| State {
103 http_client,
104 available_models: Default::default(),
105 _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
106 this.fetch_models(cx).detach();
107 cx.notify();
108 }),
109 }),
110 };
111 this.state
112 .update(cx, |state, cx| state.fetch_models(cx).detach());
113 this
114 }
115}
116
117impl LanguageModelProviderState for OllamaLanguageModelProvider {
118 type ObservableEntity = State;
119
120 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
121 Some(self.state.clone())
122 }
123}
124
125impl LanguageModelProvider for OllamaLanguageModelProvider {
126 fn id(&self) -> LanguageModelProviderId {
127 LanguageModelProviderId(PROVIDER_ID.into())
128 }
129
130 fn name(&self) -> LanguageModelProviderName {
131 LanguageModelProviderName(PROVIDER_NAME.into())
132 }
133
134 fn icon(&self) -> IconName {
135 IconName::AiOllama
136 }
137
138 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
139 let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default();
140
141 // Add models from the Ollama API
142 for model in self.state.read(cx).available_models.iter() {
143 models.insert(model.name.clone(), model.clone());
144 }
145
146 // Override with available models from settings
147 for model in AllLanguageModelSettings::get_global(cx)
148 .ollama
149 .available_models
150 .iter()
151 {
152 models.insert(
153 model.name.clone(),
154 ollama::Model {
155 name: model.name.clone(),
156 display_name: model.display_name.clone(),
157 max_tokens: model.max_tokens,
158 keep_alive: None,
159 },
160 );
161 }
162
163 models
164 .into_values()
165 .map(|model| {
166 Arc::new(OllamaLanguageModel {
167 id: LanguageModelId::from(model.name.clone()),
168 model: model.clone(),
169 http_client: self.http_client.clone(),
170 request_limiter: RateLimiter::new(4),
171 }) as Arc<dyn LanguageModel>
172 })
173 .collect()
174 }
175
176 fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &AppContext) {
177 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
178 let http_client = self.http_client.clone();
179 let api_url = settings.api_url.clone();
180 let id = model.id().0.to_string();
181 cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
182 .detach_and_log_err(cx);
183 }
184
185 fn is_authenticated(&self, cx: &AppContext) -> bool {
186 self.state.read(cx).is_authenticated()
187 }
188
189 fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
190 self.state.update(cx, |state, cx| state.authenticate(cx))
191 }
192
193 fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
194 let state = self.state.clone();
195 cx.new_view(|cx| ConfigurationView::new(state, cx)).into()
196 }
197
198 fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
199 self.state.update(cx, |state, cx| state.fetch_models(cx))
200 }
201}
202
203pub struct OllamaLanguageModel {
204 id: LanguageModelId,
205 model: ollama::Model,
206 http_client: Arc<dyn HttpClient>,
207 request_limiter: RateLimiter,
208}
209
210impl OllamaLanguageModel {
211 fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
212 ChatRequest {
213 model: self.model.name.clone(),
214 messages: request
215 .messages
216 .into_iter()
217 .map(|msg| match msg.role {
218 Role::User => ChatMessage::User {
219 content: msg.string_contents(),
220 },
221 Role::Assistant => ChatMessage::Assistant {
222 content: msg.string_contents(),
223 tool_calls: None,
224 },
225 Role::System => ChatMessage::System {
226 content: msg.string_contents(),
227 },
228 })
229 .collect(),
230 keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
231 stream: true,
232 options: Some(ChatOptions {
233 num_ctx: Some(self.model.max_tokens),
234 stop: Some(request.stop),
235 temperature: Some(request.temperature),
236 ..Default::default()
237 }),
238 tools: vec![],
239 }
240 }
241 fn request_completion(
242 &self,
243 request: ChatRequest,
244 cx: &AsyncAppContext,
245 ) -> BoxFuture<'static, Result<ChatResponseDelta>> {
246 let http_client = self.http_client.clone();
247
248 let Ok(api_url) = cx.update(|cx| {
249 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
250 settings.api_url.clone()
251 }) else {
252 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
253 };
254
255 async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed()
256 }
257}
258
259impl LanguageModel for OllamaLanguageModel {
260 fn id(&self) -> LanguageModelId {
261 self.id.clone()
262 }
263
264 fn name(&self) -> LanguageModelName {
265 LanguageModelName::from(self.model.display_name().to_string())
266 }
267
268 fn provider_id(&self) -> LanguageModelProviderId {
269 LanguageModelProviderId(PROVIDER_ID.into())
270 }
271
272 fn provider_name(&self) -> LanguageModelProviderName {
273 LanguageModelProviderName(PROVIDER_NAME.into())
274 }
275
276 fn telemetry_id(&self) -> String {
277 format!("ollama/{}", self.model.id())
278 }
279
280 fn max_token_count(&self) -> usize {
281 self.model.max_token_count()
282 }
283
284 fn count_tokens(
285 &self,
286 request: LanguageModelRequest,
287 _cx: &AppContext,
288 ) -> BoxFuture<'static, Result<usize>> {
289 // There is no endpoint for this _yet_ in Ollama
290 // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
291 let token_count = request
292 .messages
293 .iter()
294 .map(|msg| msg.string_contents().chars().count())
295 .sum::<usize>()
296 / 4;
297
298 async move { Ok(token_count) }.boxed()
299 }
300
301 fn stream_completion(
302 &self,
303 request: LanguageModelRequest,
304 cx: &AsyncAppContext,
305 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
306 let request = self.to_ollama_request(request);
307
308 let http_client = self.http_client.clone();
309 let Ok((api_url, low_speed_timeout)) = cx.update(|cx| {
310 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
311 (settings.api_url.clone(), settings.low_speed_timeout)
312 }) else {
313 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
314 };
315
316 let future = self.request_limiter.stream(async move {
317 let response =
318 stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout)
319 .await?;
320 let stream = response
321 .filter_map(|response| async move {
322 match response {
323 Ok(delta) => {
324 let content = match delta.message {
325 ChatMessage::User { content } => content,
326 ChatMessage::Assistant { content, .. } => content,
327 ChatMessage::System { content } => content,
328 };
329 Some(Ok(content))
330 }
331 Err(error) => Some(Err(error)),
332 }
333 })
334 .boxed();
335 Ok(stream)
336 });
337
338 async move { Ok(future.await?.boxed()) }.boxed()
339 }
340
341 fn use_any_tool(
342 &self,
343 request: LanguageModelRequest,
344 tool_name: String,
345 tool_description: String,
346 schema: serde_json::Value,
347 cx: &AsyncAppContext,
348 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
349 use ollama::{OllamaFunctionTool, OllamaTool};
350 let function = OllamaFunctionTool {
351 name: tool_name.clone(),
352 description: Some(tool_description),
353 parameters: Some(schema),
354 };
355 let tools = vec![OllamaTool::Function { function }];
356 let request = self.to_ollama_request(request).with_tools(tools);
357 let response = self.request_completion(request, cx);
358 self.request_limiter
359 .run(async move {
360 let response = response.await?;
361 let ChatMessage::Assistant { tool_calls, .. } = response.message else {
362 bail!("message does not have an assistant role");
363 };
364 if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
365 for call in tool_calls {
366 let OllamaToolCall::Function(function) = call;
367 if function.name == tool_name {
368 return Ok(futures::stream::once(async move {
369 Ok(function.arguments.to_string())
370 })
371 .boxed());
372 }
373 }
374 } else {
375 bail!("assistant message does not have any tool calls");
376 };
377
378 bail!("tool not used")
379 })
380 .boxed()
381 }
382}
383
384struct ConfigurationView {
385 state: gpui::Model<State>,
386 loading_models_task: Option<Task<()>>,
387}
388
389impl ConfigurationView {
390 pub fn new(state: gpui::Model<State>, cx: &mut ViewContext<Self>) -> Self {
391 let loading_models_task = Some(cx.spawn({
392 let state = state.clone();
393 |this, mut cx| async move {
394 if let Some(task) = state
395 .update(&mut cx, |state, cx| state.authenticate(cx))
396 .log_err()
397 {
398 task.await.log_err();
399 }
400 this.update(&mut cx, |this, cx| {
401 this.loading_models_task = None;
402 cx.notify();
403 })
404 .log_err();
405 }
406 }));
407
408 Self {
409 state,
410 loading_models_task,
411 }
412 }
413
414 fn retry_connection(&self, cx: &mut WindowContext) {
415 self.state
416 .update(cx, |state, cx| state.fetch_models(cx))
417 .detach_and_log_err(cx);
418 }
419}
420
421impl Render for ConfigurationView {
422 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
423 let is_authenticated = self.state.read(cx).is_authenticated();
424
425 let ollama_intro = "Get up and running with Llama 3.1, Mistral, Gemma 2, and other large language models with Ollama.";
426 let ollama_reqs =
427 "Ollama must be running with at least one model installed to use it in the assistant.";
428
429 let mut inline_code_bg = cx.theme().colors().editor_background;
430 inline_code_bg.fade_out(0.5);
431
432 if self.loading_models_task.is_some() {
433 div().child(Label::new("Loading models...")).into_any()
434 } else {
435 v_flex()
436 .size_full()
437 .gap_3()
438 .child(
439 v_flex()
440 .size_full()
441 .gap_2()
442 .p_1()
443 .child(Label::new(ollama_intro))
444 .child(Label::new(ollama_reqs))
445 .child(
446 h_flex()
447 .gap_0p5()
448 .child(Label::new("Once installed, try "))
449 .child(
450 div()
451 .bg(inline_code_bg)
452 .px_1p5()
453 .rounded_md()
454 .child(Label::new("ollama run llama3.1")),
455 ),
456 ),
457 )
458 .child(
459 h_flex()
460 .w_full()
461 .pt_2()
462 .justify_between()
463 .gap_2()
464 .child(
465 h_flex()
466 .w_full()
467 .gap_2()
468 .map(|this| {
469 if is_authenticated {
470 this.child(
471 Button::new("ollama-site", "Ollama")
472 .style(ButtonStyle::Subtle)
473 .icon(IconName::ExternalLink)
474 .icon_size(IconSize::XSmall)
475 .icon_color(Color::Muted)
476 .on_click(move |_, cx| cx.open_url(OLLAMA_SITE))
477 .into_any_element(),
478 )
479 } else {
480 this.child(
481 Button::new(
482 "download_ollama_button",
483 "Download Ollama",
484 )
485 .style(ButtonStyle::Subtle)
486 .icon(IconName::ExternalLink)
487 .icon_size(IconSize::XSmall)
488 .icon_color(Color::Muted)
489 .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
490 .into_any_element(),
491 )
492 }
493 })
494 .child(
495 Button::new("view-models", "All Models")
496 .style(ButtonStyle::Subtle)
497 .icon(IconName::ExternalLink)
498 .icon_size(IconSize::XSmall)
499 .icon_color(Color::Muted)
500 .on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
501 ),
502 )
503 .child(if is_authenticated {
504 // This is only a button to ensure the spacing is correct
505 // it should stay disabled
506 ButtonLike::new("connected")
507 .disabled(true)
508 // Since this won't ever be clickable, we can use the arrow cursor
509 .cursor_style(gpui::CursorStyle::Arrow)
510 .child(
511 h_flex()
512 .gap_2()
513 .child(Indicator::dot().color(Color::Success))
514 .child(Label::new("Connected"))
515 .into_any_element(),
516 )
517 .into_any_element()
518 } else {
519 Button::new("retry_ollama_models", "Connect")
520 .icon_position(IconPosition::Start)
521 .icon(IconName::ArrowCircle)
522 .on_click(cx.listener(move |this, _, cx| this.retry_connection(cx)))
523 .into_any_element()
524 }),
525 )
526 .into_any()
527 }
528 }
529}