From 53bf5b8f9100ae276581da735a13505c0c9b3f03 Mon Sep 17 00:00:00 2001 From: hems Date: Tue, 15 Jul 2025 13:01:39 +0100 Subject: [PATCH] feat: add Ollama integration with automatic model discovery - Add comprehensive Ollama HTTP API client with model discovery - Implement automatic service management with process cleanup - Add model loading and context window detection - Include comprehensive test suite (16 tests across 4 modules) - Add standalone debugging tool for manual testing - Update README with auto-discovery documentation - Integrate with existing provider system for seamless model switching The implementation provides zero-configuration Ollama support with automatic service startup, model discovery, and proper cleanup handling. --- README.md | 49 ++++ cmd/test-ollama/main.go | 111 ++++++++ internal/config/load.go | 42 +++ internal/llm/agent/agent.go | 7 +- internal/ollama/client.go | 175 +++++++++++++ internal/ollama/client_test.go | 239 ++++++++++++++++++ internal/ollama/ollama.go | 11 + internal/ollama/ollama_test.go | 25 ++ internal/ollama/process.go | 60 +++++ internal/ollama/process_test.go | 110 ++++++++ internal/ollama/service.go | 117 +++++++++ internal/ollama/service_test.go | 193 ++++++++++++++ internal/ollama/types.go | 60 +++++ .../tui/components/dialogs/models/list.go | 24 ++ internal/tui/tui.go | 8 + 15 files changed, 1229 insertions(+), 2 deletions(-) create mode 100644 cmd/test-ollama/main.go create mode 100644 internal/ollama/client.go create mode 100644 internal/ollama/client_test.go create mode 100644 internal/ollama/ollama.go create mode 100644 internal/ollama/ollama_test.go create mode 100644 internal/ollama/process.go create mode 100644 internal/ollama/process_test.go create mode 100644 internal/ollama/service.go create mode 100644 internal/ollama/service_test.go create mode 100644 internal/ollama/types.go diff --git a/README.md b/README.md index f69a451eaba21f92da001f075ad630fa43ff3aba..a2d9ee8d5166f4ad98a6e1e47a0e4c8083c743e6 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,55 @@ Crush supports all OpenAI-compatible APIs. Here's an example configuration for D } ``` +### Local AI Model Providers + +Crush is compatible with local AI model applications that implement OpenAI's API standard. This includes popular tools like Ollama, LM Studio, LocalAI, Jan.ai, and many others. Running models locally gives you complete privacy and control over your AI infrastructure. + +#### Ollama + +[Ollama](https://ollama.com) is a popular tool for running AI models locally. It packages models with all dependencies, making deployment simple and reliable. + +**Installation:** +```bash +# Install Ollama +curl -fsSL https://ollama.com/install.sh | sh + +# Download and run a model +ollama run llama3.2:3b +``` + +**Auto-Discovery:** +Crush automatically detects Ollama installations and discovers available models without any configuration needed. Simply install Ollama and pull models - they'll appear in the model switcher automatically. + +**Manual Configuration (Optional):** +For advanced use cases or custom configurations, you can manually define Ollama providers: + +```json +{ + "providers": { + "ollama": { + "type": "openai", + "base_url": "http://localhost:11434/v1", + "api_key": "ollama", + "models": [ + { + "id": "llama3.2:3b", + "model": "Llama 3.2 3B", + "context_window": 131072, + "default_max_tokens": 4096, + "cost_per_1m_in": 0, + "cost_per_1m_out": 0 + } + ] + } + } +} +``` + +#### Other Local AI Tools + +For other local AI applications (LM Studio, LocalAI, Jan.ai, etc.), you can configure them manually using the OpenAI-compatible API format shown above. + ## Whatcha think? Weโ€™d love to hear your thoughts on this project. Feel free to drop us a note! diff --git a/cmd/test-ollama/main.go b/cmd/test-ollama/main.go new file mode 100644 index 0000000000000000000000000000000000000000..0ac0a1bb095e5d47dabd2696feaed21816acbe65 --- /dev/null +++ b/cmd/test-ollama/main.go @@ -0,0 +1,111 @@ +package main + +import ( + "context" + "fmt" + "os" + "time" + + "github.com/charmbracelet/crush/internal/ollama" +) + +func main() { + fmt.Println("๐Ÿงช Ollama Test Suite") + fmt.Println("===================") + + // Test 1: Check if Ollama is installed + fmt.Print("1. Checking if Ollama is installed... ") + if ollama.IsInstalled() { + fmt.Println("โœ… PASS") + } else { + fmt.Println("โŒ FAIL - Ollama is not installed") + fmt.Println(" Please install Ollama from https://ollama.com") + os.Exit(1) + } + + // Test 2: Check if Ollama is running + fmt.Print("2. Checking if Ollama is running... ") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if ollama.IsRunning(ctx) { + fmt.Println("โœ… PASS") + } else { + fmt.Println("โŒ FAIL - Ollama is not running") + + // Test 3: Try to start Ollama service + fmt.Print("3. Attempting to start Ollama service... ") + ctx2, cancel2 := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel2() + + if err := ollama.StartOllamaService(ctx2); err != nil { + fmt.Printf("โŒ FAIL - %v\n", err) + os.Exit(1) + } + fmt.Println("โœ… PASS") + + // Verify it's now running + fmt.Print("4. Verifying Ollama is now running... ") + if ollama.IsRunning(ctx2) { + fmt.Println("โœ… PASS") + } else { + fmt.Println("โŒ FAIL - Service started but not responding") + os.Exit(1) + } + } + + // Test 4: Get available models + fmt.Print("5. Getting available models... ") + ctx3, cancel3 := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel3() + + models, err := ollama.GetModels(ctx3) + if err != nil { + fmt.Printf("โŒ FAIL - %v\n", err) + os.Exit(1) + } + fmt.Printf("โœ… PASS (%d models found)\n", len(models)) + + // Display models + if len(models) > 0 { + fmt.Println("\n๐Ÿ“‹ Available Models:") + for i, model := range models { + fmt.Printf(" %d. %s\n", i+1, model.ID) + fmt.Printf(" Context: %d tokens, Max: %d tokens\n", + model.ContextWindow, model.DefaultMaxTokens) + } + } else { + fmt.Println("\nโš ๏ธ No models found. You may need to download some models first.") + fmt.Println(" Example: ollama pull llama3.2:3b") + } + + // Test 5: Get provider + fmt.Print("\n6. Getting Ollama provider... ") + provider, err := ollama.GetProvider(ctx3) + if err != nil { + fmt.Printf("โŒ FAIL - %v\n", err) + os.Exit(1) + } + fmt.Printf("โœ… PASS (%s with %d models)\n", provider.Name, len(provider.Models)) + + // Test 6: Test model loading check + if len(models) > 0 { + testModel := models[0].ID + fmt.Printf("7. Checking if model '%s' is loaded... ", testModel) + + loaded, err := ollama.IsModelLoaded(ctx3, testModel) + if err != nil { + fmt.Printf("โŒ FAIL - %v\n", err) + } else if loaded { + fmt.Println("โœ… PASS (model is loaded)") + } else { + fmt.Println("โš ๏ธ PASS (model is not loaded)") + } + } + + fmt.Println("\n๐ŸŽ‰ All tests completed successfully!") + fmt.Println("\nTo run individual tests:") + fmt.Println(" go test ./internal/ollama -v") + fmt.Println("\nTo run benchmarks:") + fmt.Println(" go test ./internal/ollama -bench=.") +} diff --git a/internal/config/load.go b/internal/config/load.go index 81cb4398e5b3a7a2147ab5388b37088788ea041b..9d662515af646f8dd37cf2f114f93b23382580ec 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -1,6 +1,7 @@ package config import ( + "context" "encoding/json" "fmt" "io" @@ -9,11 +10,13 @@ import ( "runtime" "slices" "strings" + "time" "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/fur/client" "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/log" + "github.com/charmbracelet/crush/internal/ollama" "golang.org/x/exp/slog" ) @@ -184,6 +187,45 @@ func (cfg *Config) configureProviders(env env.Env, resolver VariableResolver, kn cfg.Providers[string(p.ID)] = prepared } + // Auto-detect Ollama if it's available and not already configured + if _, exists := cfg.Providers["ollama"]; !exists { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // First try to get provider if Ollama is already running + if ollamaProvider, err := ollama.GetProvider(ctx); err == nil { + slog.Debug("Auto-detected running Ollama provider", "models", len(ollamaProvider.Models)) + cfg.Providers["ollama"] = ProviderConfig{ + ID: "ollama", + Name: "Ollama", + BaseURL: "http://localhost:11434/v1", + Type: provider.TypeOpenAI, + APIKey: "ollama", + Models: ollamaProvider.Models, + } + } else { + // If Ollama is not running, try to start it + if err := ollama.EnsureOllamaRunning(ctx); err == nil { + // Now try to get the provider again + if ollamaProvider, err := ollama.GetProvider(ctx); err == nil { + slog.Debug("Started Ollama service and detected provider", "models", len(ollamaProvider.Models)) + cfg.Providers["ollama"] = ProviderConfig{ + ID: "ollama", + Name: "Ollama", + BaseURL: "http://localhost:11434/v1", + Type: provider.TypeOpenAI, + APIKey: "ollama", + Models: ollamaProvider.Models, + } + } else { + slog.Debug("Started Ollama service but failed to get provider", "error", err) + } + } else { + slog.Debug("Failed to start Ollama service", "error", err) + } + } + } + // validate the custom providers for id, providerConfig := range cfg.Providers { if knownProviderNames[id] { diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 56fb431b3b705a656cdfbf9df426b8ce8c7298c4..7f0f4efeb10d6492069a777a1b14080ecdd62ee9 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -168,7 +168,7 @@ func NewAgent( } } smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall) - if smallModel.ID == "" { + if smallModel == nil { return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID) } @@ -817,6 +817,9 @@ func (a *agent) UpdateModel() error { // Get current provider configuration currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model) + if currentProviderCfg == nil { + return fmt.Errorf("provider configuration for agent %s not found in config", a.agentCfg.Name) + } if currentProviderCfg.ID == "" { return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name) } @@ -825,7 +828,7 @@ func (a *agent) UpdateModel() error { if string(currentProviderCfg.ID) != a.providerID { // Provider changed, need to recreate the main provider model := cfg.GetModelByType(a.agentCfg.Model) - if model.ID == "" { + if model == nil { return fmt.Errorf("model not found for agent %s", a.agentCfg.Name) } diff --git a/internal/ollama/client.go b/internal/ollama/client.go new file mode 100644 index 0000000000000000000000000000000000000000..f9aba8effb0232779ad50fd2fed8c45136457d86 --- /dev/null +++ b/internal/ollama/client.go @@ -0,0 +1,175 @@ +package ollama + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/charmbracelet/crush/internal/fur/provider" +) + +const ( + defaultOllamaURL = "http://localhost:11434" + requestTimeout = 2 * time.Second +) + +// IsRunning checks if Ollama is running by attempting to connect to its API +func IsRunning(ctx context.Context) bool { + client := &http.Client{ + Timeout: requestTimeout, + } + + req, err := http.NewRequestWithContext(ctx, "GET", defaultOllamaURL+"/api/tags", nil) + if err != nil { + return false + } + + resp, err := client.Do(req) + if err != nil { + return false + } + defer resp.Body.Close() + + return resp.StatusCode == http.StatusOK +} + +// GetModels retrieves available models from Ollama +func GetModels(ctx context.Context) ([]provider.Model, error) { + client := &http.Client{ + Timeout: requestTimeout, + } + + req, err := http.NewRequestWithContext(ctx, "GET", defaultOllamaURL+"/api/tags", nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to connect to Ollama: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Ollama returned status %d", resp.StatusCode) + } + + var tagsResponse OllamaTagsResponse + if err := json.NewDecoder(resp.Body).Decode(&tagsResponse); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + models := make([]provider.Model, len(tagsResponse.Models)) + for i, ollamaModel := range tagsResponse.Models { + models[i] = provider.Model{ + ID: ollamaModel.Name, + Model: ollamaModel.Name, + CostPer1MIn: 0, // Local models have no cost + CostPer1MOut: 0, + CostPer1MInCached: 0, + CostPer1MOutCached: 0, + ContextWindow: getContextWindow(ollamaModel.Details.Family), + DefaultMaxTokens: 4096, + CanReason: false, + HasReasoningEffort: false, + SupportsImages: supportsImages(ollamaModel.Details.Family), + } + } + + return models, nil +} + +// GetRunningModels returns models that are currently loaded in memory +func GetRunningModels(ctx context.Context) ([]OllamaRunningModel, error) { + client := &http.Client{ + Timeout: requestTimeout, + } + + req, err := http.NewRequestWithContext(ctx, "GET", defaultOllamaURL+"/api/ps", nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to connect to Ollama: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Ollama returned status %d", resp.StatusCode) + } + + var psResponse OllamaRunningModelsResponse + if err := json.NewDecoder(resp.Body).Decode(&psResponse); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + return psResponse.Models, nil +} + +// IsModelLoaded checks if a specific model is currently loaded in memory +func IsModelLoaded(ctx context.Context, modelName string) (bool, error) { + runningModels, err := GetRunningModels(ctx) + if err != nil { + return false, err + } + + for _, model := range runningModels { + if model.Name == modelName { + return true, nil + } + } + + return false, nil +} + +// GetProvider returns a provider.Provider for Ollama if it's running +func GetProvider(ctx context.Context) (*provider.Provider, error) { + if !IsRunning(ctx) { + return nil, fmt.Errorf("Ollama is not running") + } + + models, err := GetModels(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get models: %w", err) + } + + return &provider.Provider{ + Name: "Ollama", + ID: "ollama", + Models: models, + }, nil +} + +// getContextWindow returns an estimated context window based on model family +func getContextWindow(family string) int64 { + switch family { + case "llama": + return 131072 // Llama 3.x context window + case "mistral": + return 32768 + case "gemma": + return 8192 + case "qwen", "qwen2": + return 131072 + case "phi": + return 131072 + case "codellama": + return 16384 + default: + return 8192 // Conservative default + } +} + +// supportsImages returns whether a model family supports image inputs +func supportsImages(family string) bool { + switch family { + case "llama-vision", "llava": + return true + default: + return false + } +} diff --git a/internal/ollama/client_test.go b/internal/ollama/client_test.go new file mode 100644 index 0000000000000000000000000000000000000000..a36fcaaccb40eacfb2bef1c1df28a67b62c1709e --- /dev/null +++ b/internal/ollama/client_test.go @@ -0,0 +1,239 @@ +package ollama + +import ( + "context" + "testing" + "time" +) + +func TestIsRunning(t *testing.T) { + if !IsInstalled() { + t.Skip("Ollama is not installed, skipping IsRunning test") + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + running := IsRunning(ctx) + + if running { + t.Log("โœ“ Ollama is running") + } else { + t.Log("โœ— Ollama is not running") + } + + // This test doesn't fail - it's informational + // The behavior depends on whether Ollama is actually running +} + +func TestGetModels(t *testing.T) { + if !IsInstalled() { + t.Skip("Ollama is not installed, skipping GetModels test") + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Ensure Ollama is running + if !IsRunning(ctx) { + t.Log("Ollama is not running, attempting to start...") + if err := StartOllamaService(ctx); err != nil { + t.Fatalf("Failed to start Ollama service: %v", err) + } + defer cleanupProcesses() + } + + models, err := GetModels(ctx) + if err != nil { + t.Fatalf("Failed to get models: %v", err) + } + + t.Logf("โœ“ Found %d models:", len(models)) + for _, model := range models { + t.Logf(" - %s (context: %d, max_tokens: %d)", + model.ID, model.ContextWindow, model.DefaultMaxTokens) + } +} + +func TestGetRunningModels(t *testing.T) { + if !IsInstalled() { + t.Skip("Ollama is not installed, skipping GetRunningModels test") + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Ensure Ollama is running + if !IsRunning(ctx) { + t.Log("Ollama is not running, attempting to start...") + if err := StartOllamaService(ctx); err != nil { + t.Fatalf("Failed to start Ollama service: %v", err) + } + defer cleanupProcesses() + } + + runningModels, err := GetRunningModels(ctx) + if err != nil { + t.Fatalf("Failed to get running models: %v", err) + } + + t.Logf("โœ“ Found %d running models:", len(runningModels)) + for _, model := range runningModels { + t.Logf(" - %s (size: %d bytes)", model.Name, model.Size) + } +} + +func TestIsModelLoaded(t *testing.T) { + if !IsInstalled() { + t.Skip("Ollama is not installed, skipping IsModelLoaded test") + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Ensure Ollama is running + if !IsRunning(ctx) { + t.Log("Ollama is not running, attempting to start...") + if err := StartOllamaService(ctx); err != nil { + t.Fatalf("Failed to start Ollama service: %v", err) + } + defer cleanupProcesses() + } + + // Get available models first + models, err := GetModels(ctx) + if err != nil { + t.Fatalf("Failed to get models: %v", err) + } + + if len(models) == 0 { + t.Skip("No models available, skipping IsModelLoaded test") + } + + testModel := models[0].ID + t.Logf("Testing model: %s", testModel) + + loaded, err := IsModelLoaded(ctx, testModel) + if err != nil { + t.Fatalf("Failed to check if model is loaded: %v", err) + } + + if loaded { + t.Logf("โœ“ Model %s is loaded", testModel) + } else { + t.Logf("โœ— Model %s is not loaded", testModel) + } +} + +func TestGetProvider(t *testing.T) { + if !IsInstalled() { + t.Skip("Ollama is not installed, skipping GetProvider test") + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Ensure Ollama is running + if !IsRunning(ctx) { + t.Log("Ollama is not running, attempting to start...") + if err := StartOllamaService(ctx); err != nil { + t.Fatalf("Failed to start Ollama service: %v", err) + } + defer cleanupProcesses() + } + + provider, err := GetProvider(ctx) + if err != nil { + t.Fatalf("Failed to get provider: %v", err) + } + + if provider.Name != "Ollama" { + t.Errorf("Expected provider name to be 'Ollama', got '%s'", provider.Name) + } + + if provider.ID != "ollama" { + t.Errorf("Expected provider ID to be 'ollama', got '%s'", provider.ID) + } + + t.Logf("โœ“ Provider: %s (ID: %s) with %d models", + provider.Name, provider.ID, len(provider.Models)) +} + +func TestGetContextWindow(t *testing.T) { + tests := []struct { + family string + expected int64 + }{ + {"llama", 131072}, + {"mistral", 32768}, + {"gemma", 8192}, + {"qwen", 131072}, + {"qwen2", 131072}, + {"phi", 131072}, + {"codellama", 16384}, + {"unknown", 8192}, + } + + for _, tt := range tests { + t.Run(tt.family, func(t *testing.T) { + result := getContextWindow(tt.family) + if result != tt.expected { + t.Errorf("getContextWindow(%s) = %d, expected %d", + tt.family, result, tt.expected) + } + }) + } +} + +func TestSupportsImages(t *testing.T) { + tests := []struct { + family string + expected bool + }{ + {"llama-vision", true}, + {"llava", true}, + {"llama", false}, + {"mistral", false}, + {"unknown", false}, + } + + for _, tt := range tests { + t.Run(tt.family, func(t *testing.T) { + result := supportsImages(tt.family) + if result != tt.expected { + t.Errorf("supportsImages(%s) = %v, expected %v", + tt.family, result, tt.expected) + } + }) + } +} + +// Benchmark tests for client functions +func BenchmarkIsRunning(b *testing.B) { + if !IsInstalled() { + b.Skip("Ollama is not installed") + } + + ctx := context.Background() + + for i := 0; i < b.N; i++ { + IsRunning(ctx) + } +} + +func BenchmarkGetModels(b *testing.B) { + if !IsInstalled() { + b.Skip("Ollama is not installed") + } + + ctx := context.Background() + + // Ensure Ollama is running for benchmark + if !IsRunning(ctx) { + b.Skip("Ollama is not running") + } + + for i := 0; i < b.N; i++ { + GetModels(ctx) + } +} diff --git a/internal/ollama/ollama.go b/internal/ollama/ollama.go new file mode 100644 index 0000000000000000000000000000000000000000..31122c7497e13389c9e4d5d9d0a2426ad2fb471f --- /dev/null +++ b/internal/ollama/ollama.go @@ -0,0 +1,11 @@ +package ollama + +import ( + "os/exec" +) + +// IsInstalled checks if Ollama is installed on the system +func IsInstalled() bool { + _, err := exec.LookPath("ollama") + return err == nil +} diff --git a/internal/ollama/ollama_test.go b/internal/ollama/ollama_test.go new file mode 100644 index 0000000000000000000000000000000000000000..9b247b25416b32ac879ea6b3c0304300a12c7cf0 --- /dev/null +++ b/internal/ollama/ollama_test.go @@ -0,0 +1,25 @@ +package ollama + +import ( + "testing" +) + +func TestIsInstalled(t *testing.T) { + installed := IsInstalled() + + if installed { + t.Log("โœ“ Ollama is installed on this system") + } else { + t.Log("โœ— Ollama is not installed on this system") + } + + // This test doesn't fail - it's informational + // In a real scenario, you might want to skip other tests if Ollama is not installed +} + +// Benchmark test for IsInstalled +func BenchmarkIsInstalled(b *testing.B) { + for i := 0; i < b.N; i++ { + IsInstalled() + } +} diff --git a/internal/ollama/process.go b/internal/ollama/process.go new file mode 100644 index 0000000000000000000000000000000000000000..3067e8d15d2ce83a60d7feea6533d7a67f30dd2c --- /dev/null +++ b/internal/ollama/process.go @@ -0,0 +1,60 @@ +package ollama + +import ( + "os" + "os/exec" + "os/signal" + "syscall" + "time" +) + +var processManager = &ProcessManager{ + processes: make(map[string]*exec.Cmd), +} + +// setupProcessCleanup sets up signal handlers to clean up processes on exit +func setupProcessCleanup() { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + + go func() { + <-c + cleanupProcesses() + os.Exit(0) + }() +} + +// cleanupProcesses terminates all Ollama processes started by Crush +func cleanupProcesses() { + processManager.mu.Lock() + defer processManager.mu.Unlock() + + // Clean up model processes + for modelName, cmd := range processManager.processes { + if cmd.Process != nil { + cmd.Process.Kill() + cmd.Wait() // Wait for the process to actually exit + } + delete(processManager.processes, modelName) + } + + // Clean up Ollama server if Crush started it + if processManager.crushStartedOllama && processManager.ollamaServer != nil { + if processManager.ollamaServer.Process != nil { + // Kill the entire process group to ensure all children are terminated + syscall.Kill(-processManager.ollamaServer.Process.Pid, syscall.SIGTERM) + + // Give it a moment to shut down gracefully + time.Sleep(2 * time.Second) + + // Force kill if still running + if processManager.ollamaServer.ProcessState == nil { + syscall.Kill(-processManager.ollamaServer.Process.Pid, syscall.SIGKILL) + } + + processManager.ollamaServer.Wait() // Wait for the process to actually exit + } + processManager.ollamaServer = nil + processManager.crushStartedOllama = false + } +} diff --git a/internal/ollama/process_test.go b/internal/ollama/process_test.go new file mode 100644 index 0000000000000000000000000000000000000000..b3441ffdf879f0e9ae7f735d45d79854f91440d9 --- /dev/null +++ b/internal/ollama/process_test.go @@ -0,0 +1,110 @@ +package ollama + +import ( + "context" + "testing" + "time" +) + +func TestProcessManager(t *testing.T) { + if !IsInstalled() { + t.Skip("Ollama is not installed, skipping ProcessManager test") + } + + // Test that processManager is initialized + if processManager == nil { + t.Fatal("processManager is nil") + } + + if processManager.processes == nil { + t.Fatal("processManager.processes is nil") + } + + t.Log("โœ“ ProcessManager is properly initialized") +} + +func TestCleanupProcesses(t *testing.T) { + if !IsInstalled() { + t.Skip("Ollama is not installed, skipping cleanup test") + } + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + // Start Ollama service if not running + wasRunning := IsRunning(ctx) + if !wasRunning { + t.Log("Starting Ollama service for cleanup test...") + if err := StartOllamaService(ctx); err != nil { + t.Fatalf("Failed to start Ollama service: %v", err) + } + + // Verify it started + if !IsRunning(ctx) { + t.Fatal("Failed to start Ollama service") + } + + // Test cleanup + t.Log("Testing cleanup...") + cleanupProcesses() + + // Give some time for cleanup + time.Sleep(3 * time.Second) + + // Verify cleanup worked (service should be stopped) + if IsRunning(ctx) { + t.Error("Ollama service is still running after cleanup") + } else { + t.Log("โœ“ Cleanup successfully stopped Ollama service") + } + } else { + t.Log("โœ“ Ollama was already running, skipping cleanup test to avoid disruption") + } +} + +func TestSetupProcessCleanup(t *testing.T) { + // Test that setupProcessCleanup can be called without panicking + // Note: We can't easily test signal handling in unit tests + defer func() { + if r := recover(); r != nil { + t.Fatalf("setupProcessCleanup panicked: %v", r) + } + }() + + // This should not panic and should be safe to call multiple times + setupProcessCleanup() + setupProcessCleanup() // Should be safe due to sync.Once + + t.Log("โœ“ setupProcessCleanup completed without panic") +} + +func TestProcessManagerThreadSafety(t *testing.T) { + if !IsInstalled() { + t.Skip("Ollama is not installed, skipping thread safety test") + } + + // Test concurrent access to processManager + done := make(chan bool) + + // Start multiple goroutines that access processManager + for i := 0; i < 10; i++ { + go func() { + processManager.mu.RLock() + _ = len(processManager.processes) + processManager.mu.RUnlock() + done <- true + }() + } + + // Wait for all goroutines to complete + for i := 0; i < 10; i++ { + select { + case <-done: + // Success + case <-time.After(1 * time.Second): + t.Fatal("Thread safety test timed out") + } + } + + t.Log("โœ“ ProcessManager thread safety test passed") +} diff --git a/internal/ollama/service.go b/internal/ollama/service.go new file mode 100644 index 0000000000000000000000000000000000000000..a603d2e80974c58ec83fa6ace7c532710fb43cac --- /dev/null +++ b/internal/ollama/service.go @@ -0,0 +1,117 @@ +package ollama + +import ( + "context" + "fmt" + "os/exec" + "syscall" + "time" +) + +// StartOllamaService starts the Ollama service if it's not already running +func StartOllamaService(ctx context.Context) error { + if IsRunning(ctx) { + return nil // Already running + } + + // Set up signal handling for cleanup + processManager.setupOnce.Do(func() { + setupProcessCleanup() + }) + + // Start ollama serve + cmd := exec.CommandContext(ctx, "ollama", "serve") + cmd.Stdout = nil // Suppress output + cmd.Stderr = nil // Suppress errors + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: true, // Create new process group so we can kill it and all children + } + + if err := cmd.Start(); err != nil { + return fmt.Errorf("failed to start Ollama service: %w", err) + } + + // Store the process for cleanup + processManager.mu.Lock() + processManager.ollamaServer = cmd + processManager.crushStartedOllama = true + processManager.mu.Unlock() + + // Wait for Ollama to be ready (with timeout) + timeout := time.After(10 * time.Second) + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-timeout: + return fmt.Errorf("timeout waiting for Ollama service to start") + case <-ticker.C: + if IsRunning(ctx) { + return nil // Ollama is now running + } + case <-ctx.Done(): + return ctx.Err() + } + } +} + +// StartModel starts a model using `ollama run` and keeps it loaded +func StartModel(ctx context.Context, modelName string) error { + // Check if model is already running + if loaded, err := IsModelLoaded(ctx, modelName); err != nil { + return fmt.Errorf("failed to check if model is loaded: %w", err) + } else if loaded { + return nil // Model is already running + } + + // Set up signal handling for cleanup + processManager.setupOnce.Do(func() { + setupProcessCleanup() + }) + + // Start the model in the background + cmd := exec.CommandContext(ctx, "ollama", "run", modelName) + cmd.Stdin = nil // No interactive input + cmd.Stdout = nil // Suppress output + cmd.Stderr = nil // Suppress errors + + if err := cmd.Start(); err != nil { + return fmt.Errorf("failed to start model %s: %w", modelName, err) + } + + // Store the process for cleanup + processManager.mu.Lock() + processManager.processes[modelName] = cmd + processManager.mu.Unlock() + + // Wait for the model to be loaded (with timeout) + timeout := time.After(30 * time.Second) + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case <-timeout: + return fmt.Errorf("timeout waiting for model %s to load", modelName) + case <-ticker.C: + if loaded, err := IsModelLoaded(ctx, modelName); err != nil { + return fmt.Errorf("failed to check if model is loaded: %w", err) + } else if loaded { + return nil // Model is now running + } + case <-ctx.Done(): + return ctx.Err() + } + } +} + +// EnsureOllamaRunning ensures Ollama service is running, starting it if necessary +func EnsureOllamaRunning(ctx context.Context) error { + return StartOllamaService(ctx) +} + +// EnsureModelRunning ensures a model is running, starting it if necessary +func EnsureModelRunning(ctx context.Context, modelName string) error { + return StartModel(ctx, modelName) +} diff --git a/internal/ollama/service_test.go b/internal/ollama/service_test.go new file mode 100644 index 0000000000000000000000000000000000000000..5e458b10b65b44bbaa30238655350f91ec81a777 --- /dev/null +++ b/internal/ollama/service_test.go @@ -0,0 +1,193 @@ +package ollama + +import ( + "context" + "testing" + "time" +) + +func TestStartOllamaService(t *testing.T) { + if !IsInstalled() { + t.Skip("Ollama is not installed, skipping StartOllamaService test") + } + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + // First check if it's already running + if IsRunning(ctx) { + t.Log("โœ“ Ollama is already running, skipping start test") + return + } + + t.Log("Starting Ollama service...") + err := StartOllamaService(ctx) + if err != nil { + t.Fatalf("Failed to start Ollama service: %v", err) + } + + // Verify it's now running + if !IsRunning(ctx) { + t.Fatal("Ollama service was started but IsRunning still returns false") + } + + t.Log("โœ“ Ollama service started successfully") + + // Clean up - stop the service we started + cleanupProcesses() +} + +func TestEnsureOllamaRunning(t *testing.T) { + if !IsInstalled() { + t.Skip("Ollama is not installed, skipping EnsureOllamaRunning test") + } + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + // Test that EnsureOllamaRunning works whether Ollama is running or not + err := EnsureOllamaRunning(ctx) + if err != nil { + t.Fatalf("EnsureOllamaRunning failed: %v", err) + } + + // Verify Ollama is running + if !IsRunning(ctx) { + t.Fatal("EnsureOllamaRunning succeeded but Ollama is not running") + } + + t.Log("โœ“ EnsureOllamaRunning succeeded") + + // Test calling it again when already running + err = EnsureOllamaRunning(ctx) + if err != nil { + t.Fatalf("EnsureOllamaRunning failed on second call: %v", err) + } + + t.Log("โœ“ EnsureOllamaRunning works when already running") +} + +func TestStartModel(t *testing.T) { + if !IsInstalled() { + t.Skip("Ollama is not installed, skipping StartModel test") + } + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + // Ensure Ollama is running + if !IsRunning(ctx) { + t.Log("Starting Ollama service...") + if err := StartOllamaService(ctx); err != nil { + t.Fatalf("Failed to start Ollama service: %v", err) + } + defer cleanupProcesses() + } + + // Get available models + models, err := GetModels(ctx) + if err != nil { + t.Fatalf("Failed to get models: %v", err) + } + + if len(models) == 0 { + t.Skip("No models available, skipping StartModel test") + } + + // Pick a smaller model if available, otherwise use the first one + testModel := models[0].ID + for _, model := range models { + if model.ID == "phi3:3.8b" || model.ID == "llama3.2:3b" { + testModel = model.ID + break + } + } + + t.Logf("Testing with model: %s", testModel) + + // Check if model is already loaded + loaded, err := IsModelLoaded(ctx, testModel) + if err != nil { + t.Fatalf("Failed to check if model is loaded: %v", err) + } + + if loaded { + t.Log("โœ“ Model is already loaded, skipping start test") + return + } + + t.Log("Starting model...") + err = StartModel(ctx, testModel) + if err != nil { + t.Fatalf("Failed to start model: %v", err) + } + + // Verify model is now loaded + loaded, err = IsModelLoaded(ctx, testModel) + if err != nil { + t.Fatalf("Failed to check if model is loaded after start: %v", err) + } + + if !loaded { + t.Fatal("StartModel succeeded but model is not loaded") + } + + t.Log("โœ“ Model started successfully") +} + +func TestEnsureModelRunning(t *testing.T) { + if !IsInstalled() { + t.Skip("Ollama is not installed, skipping EnsureModelRunning test") + } + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + // Ensure Ollama is running + if !IsRunning(ctx) { + t.Log("Starting Ollama service...") + if err := StartOllamaService(ctx); err != nil { + t.Fatalf("Failed to start Ollama service: %v", err) + } + defer cleanupProcesses() + } + + // Get available models + models, err := GetModels(ctx) + if err != nil { + t.Fatalf("Failed to get models: %v", err) + } + + if len(models) == 0 { + t.Skip("No models available, skipping EnsureModelRunning test") + } + + testModel := models[0].ID + t.Logf("Testing with model: %s", testModel) + + // Test EnsureModelRunning + err = EnsureModelRunning(ctx, testModel) + if err != nil { + t.Fatalf("EnsureModelRunning failed: %v", err) + } + + // Verify model is running + loaded, err := IsModelLoaded(ctx, testModel) + if err != nil { + t.Fatalf("Failed to check if model is loaded: %v", err) + } + + if !loaded { + t.Fatal("EnsureModelRunning succeeded but model is not loaded") + } + + t.Log("โœ“ EnsureModelRunning succeeded") + + // Test calling it again when already running + err = EnsureModelRunning(ctx, testModel) + if err != nil { + t.Fatalf("EnsureModelRunning failed on second call: %v", err) + } + + t.Log("โœ“ EnsureModelRunning works when model already running") +} diff --git a/internal/ollama/types.go b/internal/ollama/types.go new file mode 100644 index 0000000000000000000000000000000000000000..efd992dd43492dbe950ec0426ad2835925d67b68 --- /dev/null +++ b/internal/ollama/types.go @@ -0,0 +1,60 @@ +package ollama + +import ( + "os/exec" + "sync" +) + +// OllamaModel represents a model returned by Ollama's API +type OllamaModel struct { + Name string `json:"name"` + Model string `json:"model"` + Size int64 `json:"size"` + ModifiedAt string `json:"modified_at"` + Digest string `json:"digest"` + Details struct { + ParentModel string `json:"parent_model"` + Format string `json:"format"` + Family string `json:"family"` + Families []string `json:"families"` + ParameterSize string `json:"parameter_size"` + QuantizationLevel string `json:"quantization_level"` + } `json:"details"` +} + +// OllamaTagsResponse represents the response from Ollama's /api/tags endpoint +type OllamaTagsResponse struct { + Models []OllamaModel `json:"models"` +} + +// OllamaRunningModel represents a model that is currently loaded in memory +type OllamaRunningModel struct { + Name string `json:"name"` + Model string `json:"model"` + Size int64 `json:"size"` + Digest string `json:"digest"` + Details struct { + ParentModel string `json:"parent_model"` + Format string `json:"format"` + Family string `json:"family"` + Families []string `json:"families"` + ParameterSize string `json:"parameter_size"` + QuantizationLevel string `json:"quantization_level"` + } `json:"details"` + ExpiresAt string `json:"expires_at"` + SizeVRAM int64 `json:"size_vram"` +} + +// OllamaRunningModelsResponse represents the response from Ollama's /api/ps endpoint +type OllamaRunningModelsResponse struct { + Models []OllamaRunningModel `json:"models"` +} + +// ProcessManager manages Ollama processes started by Crush +type ProcessManager struct { + mu sync.RWMutex + processes map[string]*exec.Cmd + ollamaServer *exec.Cmd // The main Ollama server process + setupOnce sync.Once + crushStartedOllama bool // Track if Crush started the Ollama service +} diff --git a/internal/tui/components/dialogs/models/list.go b/internal/tui/components/dialogs/models/list.go index 8425b8f2c04569749a33867fb7e14e4b628d019e..b23c09cf61988f705b9572e6df82896fc99d0c8e 100644 --- a/internal/tui/components/dialogs/models/list.go +++ b/internal/tui/components/dialogs/models/list.go @@ -152,6 +152,30 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { } } + // Add Ollama provider if it's configured (auto-detected at config load) + if ollamaConfig, exists := cfg.Providers["ollama"]; exists && !ollamaConfig.Disable { + // Convert to provider.Provider for consistency + ollamaProvider := provider.Provider{ + Name: ollamaConfig.Name, + ID: provider.InferenceProvider(ollamaConfig.ID), + Models: ollamaConfig.Models, + } + + section := commands.NewItemSection("Ollama") + section.SetInfo("Ollama") + modelItems = append(modelItems, section) + + for _, model := range ollamaProvider.Models { + modelItems = append(modelItems, completions.NewCompletionItem(model.Model, ModelOption{ + Provider: ollamaProvider, + Model: model, + })) + if model.ID == currentModel.Model && "ollama" == currentModel.Provider { + selectIndex = len(modelItems) - 1 + } + } + } + // Then add the known providers from the predefined list for _, provider := range m.providers { // Skip if we already added this provider as an unknown provider diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 0b10b74792c5cc6c91dc285d42a7d9a6736c2b90..8fe4d125557fccc8fab7436df1830c6f6498b7d5 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -9,6 +9,7 @@ import ( "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/llm/agent" + "github.com/charmbracelet/crush/internal/ollama" "github.com/charmbracelet/crush/internal/permission" "github.com/charmbracelet/crush/internal/pubsub" cmpChat "github.com/charmbracelet/crush/internal/tui/components/chat" @@ -176,6 +177,13 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case models.ModelSelectedMsg: config.Get().UpdatePreferredModel(msg.ModelType, msg.Model) + // If this is an Ollama model, ensure it's running + if msg.Model.Provider == "ollama" { + if err := ollama.EnsureModelRunning(context.Background(), msg.Model.Model); err != nil { + return a, util.ReportError(fmt.Errorf("failed to start Ollama model %s: %v", msg.Model.Model, err)) + } + } + // Update the agent with the new model/provider configuration if err := a.app.UpdateAgentModel(); err != nil { return a, util.ReportError(fmt.Errorf("model changed to %s but failed to update agent: %v", msg.Model.Model, err))