.opencode.json 🔗
@@ -1,8 +1,4 @@
{
- "model": {
- "coder": "claude-3.7-sonnet",
- "coderMaxTokens": 20000
- },
"lsp": {
"gopls": {
"command": "gopls"
Kujtim Hoxha created
.opencode.json | 4
README.md | 34
cmd/root.go | 24
go.mod | 8
go.sum | 14
internal/app/app.go | 17
internal/app/lsp.go | 19
internal/config/config.go | 108 +
internal/db/files.sql.go | 4
internal/db/sql/files.sql | 4
internal/diff/diff.go | 99 ++
internal/llm/agent/agent-tool.go | 18
internal/llm/agent/agent.go | 861 +++++++--------------
internal/llm/agent/coder.go | 63 -
internal/llm/agent/mcp-tools.go | 4
internal/llm/agent/task.go | 47 -
internal/llm/agent/tools.go | 50 +
internal/llm/models/anthropic.go | 71 +
internal/llm/models/models.go | 190 ++--
internal/llm/prompt/coder.go | 28
internal/llm/prompt/prompt.go | 19
internal/llm/prompt/task.go | 5
internal/llm/prompt/title.go | 4
internal/llm/provider/anthropic.go | 531 +++++++------
internal/llm/provider/bedrock.go | 101 +-
internal/llm/provider/gemini.go | 533 ++++++++----
internal/llm/provider/openai.go | 401 ++++++----
internal/llm/provider/provider.go | 169 +++
internal/llm/tools/bash.go | 7
internal/llm/tools/bash_test.go | 31
internal/llm/tools/edit.go | 75 +
internal/llm/tools/edit_test.go | 30
internal/llm/tools/file.go | 10
internal/llm/tools/glob.go | 4
internal/llm/tools/grep.go | 4
internal/llm/tools/ls.go | 4
internal/llm/tools/mocks_test.go | 246 ++++++
internal/llm/tools/shell/shell.go | 12
internal/llm/tools/sourcegraph.go | 2
internal/llm/tools/tools.go | 9
internal/llm/tools/write.go | 27
internal/llm/tools/write_test.go | 22
internal/logging/logger.go | 41 +
internal/lsp/client.go | 13
internal/lsp/handlers.go | 2
internal/lsp/transport.go | 28
internal/lsp/watcher/watcher.go | 18
internal/message/content.go | 30
internal/pubsub/broker.go | 2
internal/session/session.go | 15
internal/tui/components/chat/chat.go | 2
internal/tui/components/chat/editor.go | 22
internal/tui/components/chat/messages.go | 205 ++++
internal/tui/components/chat/sidebar.go | 176 ++++
internal/tui/components/core/dialog.go | 117 --
internal/tui/components/core/help.go | 119 ---
internal/tui/components/core/status.go | 90 +
internal/tui/components/dialog/help.go | 182 ++++
internal/tui/components/dialog/permission.go | 682 ++++++++--------
internal/tui/components/dialog/quit.go | 156 ++-
internal/tui/components/logs/details.go | 2
internal/tui/components/logs/table.go | 22
internal/tui/components/repl/editor.go | 201 -----
internal/tui/components/repl/messages.go | 513 -------------
internal/tui/components/repl/sessions.go | 249 ------
internal/tui/layout/overlay.go | 11
internal/tui/layout/split.go | 1
internal/tui/page/chat.go | 32
internal/tui/page/init.go | 308 -------
internal/tui/page/logs.go | 17
internal/tui/page/repl.go | 21
internal/tui/tui.go | 277 +++---
main.go | 7
73 files changed, 3,595 insertions(+), 3,879 deletions(-)
@@ -1,8 +1,4 @@
{
- "model": {
- "coder": "claude-3.7-sonnet",
- "coderMaxTokens": 20000
- },
"lsp": {
"gopls": {
"command": "gopls"
@@ -1,14 +1,14 @@
-# TermAI
+# OpenCode
> **⚠️ Early Development Notice:** This project is in early development and is not yet ready for production use. Features may change, break, or be incomplete. Use at your own risk.
A powerful terminal-based AI assistant for developers, providing intelligent coding assistance directly in your terminal.
-[](https://asciinema.org/a/dtc4nJyGSZX79HRUmFLY3gmoy)
+[](https://asciinema.org/a/dtc4nJyGSZX79HRUmFLY3gmoy)
## Overview
-TermAI is a Go-based CLI application that brings AI assistance to your terminal. It provides a TUI (Terminal User Interface) for interacting with various AI models to help with coding tasks, debugging, and more.
+OpenCode is a Go-based CLI application that brings AI assistance to your terminal. It provides a TUI (Terminal User Interface) for interacting with various AI models to help with coding tasks, debugging, and more.
## Features
@@ -23,16 +23,16 @@ TermAI is a Go-based CLI application that brings AI assistance to your terminal.
```bash
# Coming soon
-go install github.com/kujtimiihoxha/termai@latest
+go install github.com/kujtimiihoxha/opencode@latest
```
## Configuration
-TermAI looks for configuration in the following locations:
+OpenCode looks for configuration in the following locations:
-- `$HOME/.termai.json`
-- `$XDG_CONFIG_HOME/termai/.termai.json`
-- `./.termai.json` (local directory)
+- `$HOME/.opencode.json`
+- `$XDG_CONFIG_HOME/opencode/.opencode.json`
+- `./.opencode.json` (local directory)
You can also use environment variables:
@@ -43,11 +43,11 @@ You can also use environment variables:
## Usage
```bash
-# Start TermAI
-termai
+# Start OpenCode
+opencode
# Start with debug logging
-termai -d
+opencode -d
```
### Keyboard Shortcuts
@@ -81,7 +81,7 @@ termai -d
## Architecture
-TermAI is built with a modular architecture:
+OpenCode is built with a modular architecture:
- **cmd**: Command-line interface using Cobra
- **internal/app**: Core application services
@@ -103,22 +103,22 @@ TermAI is built with a modular architecture:
```bash
# Clone the repository
-git clone https://github.com/kujtimiihoxha/termai.git
-cd termai
+git clone https://github.com/kujtimiihoxha/opencode.git
+cd opencode
# Build the diff script first
go run cmd/diff/main.go
# Build
-go build -o termai
+go build -o opencode
# Run
-./termai
+./opencode
```
## Acknowledgments
-TermAI builds upon the work of several open source projects and developers:
+OpenCode builds upon the work of several open source projects and developers:
- [@isaacphi](https://github.com/isaacphi) - LSP client implementation
@@ -20,7 +20,7 @@ import (
)
var rootCmd = &cobra.Command{
- Use: "termai",
+ Use: "OpenCode",
Short: "A terminal ai assistant",
Long: `A terminal ai assistant`,
RunE: func(cmd *cobra.Command, args []string) error {
@@ -89,12 +89,9 @@ var rootCmd = &cobra.Command{
// Set up message handling for the TUI
go func() {
defer tuiWg.Done()
- defer func() {
- if r := recover(); r != nil {
- logging.Error("Panic in TUI message handling: %v", r)
- attemptTUIRecovery(program)
- }
- }()
+ defer logging.RecoverPanic("TUI-message-handler", func() {
+ attemptTUIRecovery(program)
+ })
for {
select {
@@ -153,11 +150,7 @@ func attemptTUIRecovery(program *tea.Program) {
func initMCPTools(ctx context.Context, app *app.App) {
go func() {
- defer func() {
- if r := recover(); r != nil {
- logging.Error("Panic in MCP goroutine: %v", r)
- }
- }()
+ defer logging.RecoverPanic("MCP-goroutine", nil)
// Create a context with timeout for the initial MCP tools fetch
ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second)
@@ -179,11 +172,7 @@ func setupSubscriber[T any](
wg.Add(1)
go func() {
defer wg.Done()
- defer func() {
- if r := recover(); r != nil {
- logging.Error("Panic in %s subscription goroutine: %v", name, r)
- }
- }()
+ defer logging.RecoverPanic(fmt.Sprintf("subscription-%s", name), nil)
for {
select {
@@ -232,6 +221,7 @@ func setupSubscriptions(app *app.App) (chan tea.Msg, func()) {
// Wait with a timeout for all goroutines to complete
waitCh := make(chan struct{})
go func() {
+ defer logging.RecoverPanic("subscription-cleanup", nil)
wg.Wait()
close(waitCh)
}()
@@ -23,7 +23,6 @@ require (
github.com/golang-migrate/migrate/v4 v4.18.2
github.com/google/generative-ai-go v0.19.0
github.com/google/uuid v1.6.0
- github.com/kujtimiihoxha/vimtea v0.0.3-0.20250329221256-a250e98498f9
github.com/lrstanley/bubblezone v0.0.0-20250315020633-c249a3fe1231
github.com/mark3labs/mcp-go v0.17.0
github.com/mattn/go-runewidth v0.0.16
@@ -36,7 +35,6 @@ require (
github.com/spf13/cobra v1.9.1
github.com/spf13/viper v1.20.0
github.com/stretchr/testify v1.10.0
- golang.org/x/net v0.39.0
google.golang.org/api v0.215.0
)
@@ -106,7 +104,6 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/sagikazarmark/locafero v0.7.0 // indirect
- github.com/sahilm/fuzzy v0.1.1 // indirect
github.com/skeema/knownhosts v1.3.1 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.12.0 // indirect
@@ -129,11 +126,8 @@ require (
go.opentelemetry.io/otel/trace v1.29.0 // indirect
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
- golang.design/x/clipboard v0.7.0 // indirect
golang.org/x/crypto v0.37.0 // indirect
- golang.org/x/exp/shiny v0.0.0-20250305212735-054e65f0b394 // indirect
- golang.org/x/image v0.14.0 // indirect
- golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a // indirect
+ golang.org/x/net v0.39.0 // indirect
golang.org/x/oauth2 v0.25.0 // indirect
golang.org/x/sync v0.13.0 // indirect
golang.org/x/sys v0.32.0 // indirect
@@ -180,10 +180,6 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
-github.com/kujtimiihoxha/vimtea v0.0.3-0.20250329221256-a250e98498f9 h1:xYfCLI8KUwmXDFp1pOpNX+XsQczQw9VbEuju1pQF5/A=
-github.com/kujtimiihoxha/vimtea v0.0.3-0.20250329221256-a250e98498f9/go.mod h1:Ye+kIkTmPO5xuqCQ+PPHDTGIViRRoSpSIlcYgma8YlA=
-github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
-github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lrstanley/bubblezone v0.0.0-20250315020633-c249a3fe1231 h1:9rjt7AfnrXKNSZhp36A3/4QAZAwGGCGD/p8Bse26zms=
@@ -235,8 +231,6 @@ github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sagikazarmark/locafero v0.7.0 h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo=
github.com/sagikazarmark/locafero v0.7.0/go.mod h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k=
-github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA=
-github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y=
github.com/sebdah/goldie/v2 v2.5.3 h1:9ES/mNN+HNUbNWpVAlrzuZ7jE+Nrczbj8uFRjM7624Y=
github.com/sebdah/goldie/v2 v2.5.3/go.mod h1:oZ9fp0+se1eapSRjfYbsV/0Hqhbuu3bJVvKI/NNtssI=
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
@@ -302,8 +296,6 @@ go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
-golang.design/x/clipboard v0.7.0 h1:4Je8M/ys9AJumVnl8m+rZnIvstSnYj1fvzqYrU3TXvo=
-golang.design/x/clipboard v0.7.0/go.mod h1:PQIvqYO9GP29yINEfsEn5zSQKAz3UgXmZKzDA6dnq2E=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
@@ -314,12 +306,6 @@ golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8=
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY=
-golang.org/x/exp/shiny v0.0.0-20250305212735-054e65f0b394 h1:bFYqOIMdeiCEdzPJkLiOoMDzW/v3tjW4AA/RmUZYsL8=
-golang.org/x/exp/shiny v0.0.0-20250305212735-054e65f0b394/go.mod h1:ygj7T6vSGhhm/9yTpOQQNvuAUFziTH7RUiH74EoE2C8=
-golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
-golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
-golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a h1:sYbmY3FwUWCBTodZL1S3JUuOvaW6kM2o+clDzzDNBWg=
-golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a/go.mod h1:Ede7gF0KGoHlj822RtphAHK1jLdrcuRBZg0sF1Q+SPc=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
@@ -7,6 +7,7 @@ import (
"sync"
"time"
+ "github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/db"
"github.com/kujtimiihoxha/termai/internal/history"
"github.com/kujtimiihoxha/termai/internal/llm/agent"
@@ -20,7 +21,7 @@ import (
type App struct {
Sessions session.Service
Messages message.Service
- Files history.Service
+ History history.Service
Permissions permission.Service
CoderAgent agent.Service
@@ -43,7 +44,7 @@ func New(ctx context.Context, conn *sql.DB) (*App, error) {
app := &App{
Sessions: sessions,
Messages: messages,
- Files: files,
+ History: files,
Permissions: permission.NewPermissionService(),
LSPClients: make(map[string]*lsp.Client),
}
@@ -51,11 +52,17 @@ func New(ctx context.Context, conn *sql.DB) (*App, error) {
app.initLSPClients(ctx)
var err error
- app.CoderAgent, err = agent.NewCoderAgent(
- app.Permissions,
+ app.CoderAgent, err = agent.NewAgent(
+ config.AgentCoder,
app.Sessions,
app.Messages,
- app.LSPClients,
+ agent.CoderAgentTools(
+ app.Permissions,
+ app.Sessions,
+ app.Messages,
+ app.History,
+ app.LSPClients,
+ ),
)
if err != nil {
logging.Error("Failed to create coder agent", err)
@@ -22,16 +22,17 @@ func (app *App) initLSPClients(ctx context.Context) {
// createAndStartLSPClient creates a new LSP client, initializes it, and starts its workspace watcher
func (app *App) createAndStartLSPClient(ctx context.Context, name string, command string, args ...string) {
// Create a specific context for initialization with a timeout
- initCtx, initCancel := context.WithTimeout(context.Background(), 30*time.Second)
- defer initCancel()
// Create the LSP client
- lspClient, err := lsp.NewClient(initCtx, command, args...)
+ lspClient, err := lsp.NewClient(ctx, command, args...)
if err != nil {
logging.Error("Failed to create LSP client for", name, err)
return
+
}
+ initCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
+ defer cancel()
// Initialize with the initialization context
_, err = lspClient.InitializeLSPClient(initCtx, config.WorkingDirectory())
if err != nil {
@@ -64,14 +65,10 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, comman
// runWorkspaceWatcher executes the workspace watcher for an LSP client
func (app *App) runWorkspaceWatcher(ctx context.Context, name string, workspaceWatcher *watcher.WorkspaceWatcher) {
defer app.watcherWG.Done()
- defer func() {
- if r := recover(); r != nil {
- logging.Error("LSP client crashed", "client", name, "panic", r)
-
- // Try to restart the client
- app.restartLSPClient(ctx, name)
- }
- }()
+ defer logging.RecoverPanic("LSP-"+name, func() {
+ // Try to restart the client
+ app.restartLSPClient(ctx, name)
+ })
workspaceWatcher.WatchWorkspace(ctx, config.WorkingDirectory())
logging.Info("Workspace watcher stopped", "client", name)
@@ -31,12 +31,18 @@ type MCPServer struct {
Headers map[string]string `json:"headers"`
}
-// Model defines configuration for different LLM models and their token limits.
-type Model struct {
- Coder models.ModelID `json:"coder"`
- CoderMaxTokens int64 `json:"coderMaxTokens"`
- Task models.ModelID `json:"task"`
- TaskMaxTokens int64 `json:"taskMaxTokens"`
+type AgentName string
+
+const (
+ AgentCoder AgentName = "coder"
+ AgentTask AgentName = "task"
+ AgentTitle AgentName = "title"
+)
+
+// Agent defines configuration for different LLM models and their token limits.
+type Agent struct {
+ Model models.ModelID `json:"model"`
+ MaxTokens int64 `json:"maxTokens"`
}
// Provider defines configuration for an LLM provider.
@@ -65,8 +71,9 @@ type Config struct {
MCPServers map[string]MCPServer `json:"mcpServers,omitempty"`
Providers map[models.ModelProvider]Provider `json:"providers,omitempty"`
LSP map[string]LSPConfig `json:"lsp,omitempty"`
- Model Model `json:"model"`
+ Agents map[AgentName]Agent `json:"agents"`
Debug bool `json:"debug,omitempty"`
+ DebugLSP bool `json:"debugLSP,omitempty"`
}
// Application constants
@@ -118,11 +125,42 @@ func Load(workingDir string, debug bool) (*Config, error) {
if cfg.Debug {
defaultLevel = slog.LevelDebug
}
- // Configure logger
- logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{
- Level: defaultLevel,
- }))
- slog.SetDefault(logger)
+ // if we are in debug mode make the writer a file
+ if cfg.Debug {
+ loggingFile := fmt.Sprintf("%s/%s", cfg.Data.Directory, "debug.log")
+
+ // if file does not exist create it
+ if _, err := os.Stat(loggingFile); os.IsNotExist(err) {
+ if err := os.MkdirAll(cfg.Data.Directory, 0o755); err != nil {
+ return cfg, fmt.Errorf("failed to create directory: %w", err)
+ }
+ if _, err := os.Create(loggingFile); err != nil {
+ return cfg, fmt.Errorf("failed to create log file: %w", err)
+ }
+ }
+
+ sloggingFileWriter, err := os.OpenFile(loggingFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666)
+ if err != nil {
+ return cfg, fmt.Errorf("failed to open log file: %w", err)
+ }
+ // Configure logger
+ logger := slog.New(slog.NewTextHandler(sloggingFileWriter, &slog.HandlerOptions{
+ Level: defaultLevel,
+ }))
+ slog.SetDefault(logger)
+ } else {
+ // Configure logger
+ logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{
+ Level: defaultLevel,
+ }))
+ slog.SetDefault(logger)
+ }
+
+ // Override the max tokens for title agent
+ cfg.Agents[AgentTitle] = Agent{
+ Model: cfg.Agents[AgentTitle].Model,
+ MaxTokens: 80,
+ }
return cfg, nil
}
@@ -159,44 +197,50 @@ func setProviderDefaults() {
// Groq configuration
if apiKey := os.Getenv("GROQ_API_KEY"); apiKey != "" {
viper.SetDefault("providers.groq.apiKey", apiKey)
- viper.SetDefault("model.coder", models.QWENQwq)
- viper.SetDefault("model.coderMaxTokens", defaultMaxTokens)
- viper.SetDefault("model.task", models.QWENQwq)
- viper.SetDefault("model.taskMaxTokens", defaultMaxTokens)
+ viper.SetDefault("agents.coder.model", models.QWENQwq)
+ viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens)
+ viper.SetDefault("agents.task.model", models.QWENQwq)
+ viper.SetDefault("agents.task.maxTokens", defaultMaxTokens)
+ viper.SetDefault("agents.title.model", models.QWENQwq)
}
// Google Gemini configuration
if apiKey := os.Getenv("GEMINI_API_KEY"); apiKey != "" {
viper.SetDefault("providers.gemini.apiKey", apiKey)
- viper.SetDefault("model.coder", models.GRMINI20Flash)
- viper.SetDefault("model.coderMaxTokens", defaultMaxTokens)
- viper.SetDefault("model.task", models.GRMINI20Flash)
- viper.SetDefault("model.taskMaxTokens", defaultMaxTokens)
+ viper.SetDefault("agents.coder.model", models.GRMINI20Flash)
+ viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens)
+ viper.SetDefault("agents.task.model", models.GRMINI20Flash)
+ viper.SetDefault("agents.task.maxTokens", defaultMaxTokens)
+ viper.SetDefault("agents.title.model", models.GRMINI20Flash)
}
// OpenAI configuration
if apiKey := os.Getenv("OPENAI_API_KEY"); apiKey != "" {
viper.SetDefault("providers.openai.apiKey", apiKey)
- viper.SetDefault("model.coder", models.GPT4o)
- viper.SetDefault("model.coderMaxTokens", defaultMaxTokens)
- viper.SetDefault("model.task", models.GPT4o)
- viper.SetDefault("model.taskMaxTokens", defaultMaxTokens)
+ viper.SetDefault("agents.coder.model", models.GPT4o)
+ viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens)
+ viper.SetDefault("agents.task.model", models.GPT4o)
+ viper.SetDefault("agents.task.maxTokens", defaultMaxTokens)
+ viper.SetDefault("agents.title.model", models.GPT4o)
+
}
// Anthropic configuration
if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" {
viper.SetDefault("providers.anthropic.apiKey", apiKey)
- viper.SetDefault("model.coder", models.Claude37Sonnet)
- viper.SetDefault("model.coderMaxTokens", defaultMaxTokens)
- viper.SetDefault("model.task", models.Claude37Sonnet)
- viper.SetDefault("model.taskMaxTokens", defaultMaxTokens)
+ viper.SetDefault("agents.coder.model", models.Claude37Sonnet)
+ viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens)
+ viper.SetDefault("agents.task.model", models.Claude37Sonnet)
+ viper.SetDefault("agents.task.maxTokens", defaultMaxTokens)
+ viper.SetDefault("agents.title.model", models.Claude37Sonnet)
}
if hasAWSCredentials() {
- viper.SetDefault("model.coder", models.BedrockClaude37Sonnet)
- viper.SetDefault("model.coderMaxTokens", defaultMaxTokens)
- viper.SetDefault("model.task", models.BedrockClaude37Sonnet)
- viper.SetDefault("model.taskMaxTokens", defaultMaxTokens)
+ viper.SetDefault("agents.coder.model", models.BedrockClaude37Sonnet)
+ viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens)
+ viper.SetDefault("agents.task.model", models.BedrockClaude37Sonnet)
+ viper.SetDefault("agents.task.maxTokens", defaultMaxTokens)
+ viper.SetDefault("agents.title.model", models.BedrockClaude37Sonnet)
}
}
@@ -97,7 +97,9 @@ func (q *Queries) GetFile(ctx context.Context, id string) (File, error) {
const getFileByPathAndSession = `-- name: GetFileByPathAndSession :one
SELECT id, session_id, path, content, version, created_at, updated_at
FROM files
-WHERE path = ? AND session_id = ? LIMIT 1
+WHERE path = ? AND session_id = ?
+ORDER BY created_at DESC
+LIMIT 1
`
type GetFileByPathAndSessionParams struct {
@@ -6,7 +6,9 @@ WHERE id = ? LIMIT 1;
-- name: GetFileByPathAndSession :one
SELECT *
FROM files
-WHERE path = ? AND session_id = ? LIMIT 1;
+WHERE path = ? AND session_id = ?
+ORDER BY created_at DESC
+LIMIT 1;
-- name: ListFilesBySession :many
SELECT *
@@ -19,6 +19,8 @@ import (
"github.com/charmbracelet/x/ansi"
"github.com/go-git/go-git/v5"
"github.com/go-git/go-git/v5/plumbing/object"
+ "github.com/kujtimiihoxha/termai/internal/config"
+ "github.com/kujtimiihoxha/termai/internal/logging"
"github.com/sergi/go-diff/diffmatchpatch"
)
@@ -77,6 +79,8 @@ type linePair struct {
// StyleConfig defines styling for diff rendering
type StyleConfig struct {
+ ShowHeader bool
+ FileNameFg lipgloss.Color
// Background colors
RemovedLineBg lipgloss.Color
AddedLineBg lipgloss.Color
@@ -106,11 +110,13 @@ type StyleOption func(*StyleConfig)
func NewStyleConfig(opts ...StyleOption) StyleConfig {
// Default color scheme
config := StyleConfig{
+ ShowHeader: true,
+ FileNameFg: lipgloss.Color("#fab283"),
RemovedLineBg: lipgloss.Color("#3A3030"),
AddedLineBg: lipgloss.Color("#303A30"),
ContextLineBg: lipgloss.Color("#212121"),
- HunkLineBg: lipgloss.Color("#23252D"),
- HunkLineFg: lipgloss.Color("#8CA3B4"),
+ HunkLineBg: lipgloss.Color("#212121"),
+ HunkLineFg: lipgloss.Color("#a0a0a0"),
RemovedFg: lipgloss.Color("#7C4444"),
AddedFg: lipgloss.Color("#478247"),
LineNumberFg: lipgloss.Color("#888888"),
@@ -132,6 +138,10 @@ func NewStyleConfig(opts ...StyleOption) StyleConfig {
}
// Style option functions
+func WithFileNameFg(color lipgloss.Color) StyleOption {
+ return func(s *StyleConfig) { s.FileNameFg = color }
+}
+
func WithRemovedLineBg(color lipgloss.Color) StyleOption {
return func(s *StyleConfig) { s.RemovedLineBg = color }
}
@@ -190,6 +200,10 @@ func WithHunkLineFg(color lipgloss.Color) StyleOption {
return func(s *StyleConfig) { s.HunkLineFg = color }
}
+func WithShowHeader(show bool) StyleOption {
+ return func(s *StyleConfig) { s.ShowHeader = show }
+}
+
// -------------------------------------------------------------------------
// Parse Configuration
// -------------------------------------------------------------------------
@@ -841,10 +855,12 @@ func RenderSideBySideHunk(fileName string, h Hunk, opts ...SideBySideOption) str
// Calculate column width
colWidth := config.TotalWidth / 2
+ leftWidth := colWidth
+ rightWidth := config.TotalWidth - colWidth
var sb strings.Builder
for _, p := range pairs {
- leftStr := renderLeftColumn(fileName, p.left, colWidth, config.Style)
- rightStr := renderRightColumn(fileName, p.right, colWidth, config.Style)
+ leftStr := renderLeftColumn(fileName, p.left, leftWidth, config.Style)
+ rightStr := renderRightColumn(fileName, p.right, rightWidth, config.Style)
sb.WriteString(leftStr + rightStr + "\n")
}
@@ -861,17 +877,50 @@ func FormatDiff(diffText string, opts ...SideBySideOption) (string, error) {
var sb strings.Builder
config := NewSideBySideConfig(opts...)
- for i, h := range diffResult.Hunks {
- if i > 0 {
- // Render hunk header
- sb.WriteString(
- lipgloss.NewStyle().
- Background(config.Style.HunkLineBg).
- Foreground(config.Style.HunkLineFg).
- Width(config.TotalWidth).
- Render(h.Header) + "\n",
- )
- }
+ if config.Style.ShowHeader {
+ removeIcon := lipgloss.NewStyle().
+ Background(config.Style.RemovedLineBg).
+ Foreground(config.Style.RemovedFg).
+ Render("⏹")
+ addIcon := lipgloss.NewStyle().
+ Background(config.Style.AddedLineBg).
+ Foreground(config.Style.AddedFg).
+ Render("⏹")
+
+ fileName := lipgloss.NewStyle().
+ Background(config.Style.ContextLineBg).
+ Foreground(config.Style.FileNameFg).
+ Render(" " + diffResult.OldFile)
+ sb.WriteString(
+ lipgloss.NewStyle().
+ Background(config.Style.ContextLineBg).
+ Padding(0, 1, 0, 1).
+ Foreground(config.Style.FileNameFg).
+ BorderStyle(lipgloss.NormalBorder()).
+ BorderTop(true).
+ BorderBottom(true).
+ BorderForeground(config.Style.FileNameFg).
+ BorderBackground(config.Style.ContextLineBg).
+ Width(config.TotalWidth).
+ Render(
+ lipgloss.JoinHorizontal(lipgloss.Top,
+ removeIcon,
+ addIcon,
+ fileName,
+ ),
+ ) + "\n",
+ )
+ }
+
+ for _, h := range diffResult.Hunks {
+ // Render hunk header
+ sb.WriteString(
+ lipgloss.NewStyle().
+ Background(config.Style.HunkLineBg).
+ Foreground(config.Style.HunkLineFg).
+ Width(config.TotalWidth).
+ Render(h.Header) + "\n",
+ )
sb.WriteString(RenderSideBySideHunk(diffResult.OldFile, h, opts...))
}
@@ -880,9 +929,15 @@ func FormatDiff(diffText string, opts ...SideBySideOption) (string, error) {
// GenerateDiff creates a unified diff from two file contents
func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, int) {
+ // remove the cwd prefix and ensure consistent path format
+ // this prevents issues with absolute paths in different environments
+ cwd := config.WorkingDirectory()
+ fileName = strings.TrimPrefix(fileName, cwd)
+ fileName = strings.TrimPrefix(fileName, "/")
// Create temporary directory for git operations
- tempDir, err := os.MkdirTemp("", "git-diff-temp")
+ tempDir, err := os.MkdirTemp("", fmt.Sprintf("git-diff-%d", time.Now().UnixNano()))
if err != nil {
+ logging.Error("Failed to create temp directory for git diff", "error", err)
return "", 0, 0
}
defer os.RemoveAll(tempDir)
@@ -890,25 +945,30 @@ func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, in
// Initialize git repo
repo, err := git.PlainInit(tempDir, false)
if err != nil {
+ logging.Error("Failed to initialize git repository", "error", err)
return "", 0, 0
}
wt, err := repo.Worktree()
if err != nil {
+ logging.Error("Failed to get git worktree", "error", err)
return "", 0, 0
}
// Write the "before" content and commit it
fullPath := filepath.Join(tempDir, fileName)
if err = os.MkdirAll(filepath.Dir(fullPath), 0o755); err != nil {
+ logging.Error("Failed to create directory for file", "error", err)
return "", 0, 0
}
if err = os.WriteFile(fullPath, []byte(beforeContent), 0o644); err != nil {
+ logging.Error("Failed to write before content to file", "error", err)
return "", 0, 0
}
_, err = wt.Add(fileName)
if err != nil {
+ logging.Error("Failed to add file to git", "error", err)
return "", 0, 0
}
@@ -920,16 +980,19 @@ func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, in
},
})
if err != nil {
+ logging.Error("Failed to commit before content", "error", err)
return "", 0, 0
}
// Write the "after" content and commit it
if err = os.WriteFile(fullPath, []byte(afterContent), 0o644); err != nil {
+ logging.Error("Failed to write after content to file", "error", err)
return "", 0, 0
}
_, err = wt.Add(fileName)
if err != nil {
+ logging.Error("Failed to add file to git", "error", err)
return "", 0, 0
}
@@ -941,22 +1004,26 @@ func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, in
},
})
if err != nil {
+ logging.Error("Failed to commit after content", "error", err)
return "", 0, 0
}
// Get the diff between the two commits
beforeCommitObj, err := repo.CommitObject(beforeCommit)
if err != nil {
+ logging.Error("Failed to get before commit object", "error", err)
return "", 0, 0
}
afterCommitObj, err := repo.CommitObject(afterCommit)
if err != nil {
+ logging.Error("Failed to get after commit object", "error", err)
return "", 0, 0
}
patch, err := beforeCommitObj.Patch(afterCommitObj)
if err != nil {
+ logging.Error("Failed to create git diff patch", "error", err)
return "", 0, 0
}
@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
+ "github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/kujtimiihoxha/termai/internal/lsp"
"github.com/kujtimiihoxha/termai/internal/message"
@@ -53,7 +54,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
return tools.ToolResponse{}, fmt.Errorf("session_id and message_id are required")
}
- agent, err := NewTaskAgent(b.messages, b.sessions, b.lspClients)
+ agent, err := NewAgent(config.AgentTask, b.sessions, b.messages, TaskAgentTools(b.lspClients))
if err != nil {
return tools.ToolResponse{}, fmt.Errorf("error creating agent: %s", err)
}
@@ -63,21 +64,16 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
return tools.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
}
- err = agent.Generate(ctx, session.ID, params.Prompt)
+ done, err := agent.Run(ctx, session.ID, params.Prompt)
if err != nil {
return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", err)
}
-
- messages, err := b.messages.List(ctx, session.ID)
- if err != nil {
- return tools.ToolResponse{}, fmt.Errorf("error listing messages: %s", err)
- }
-
- if len(messages) == 0 {
- return tools.NewTextErrorResponse("no response"), nil
+ result := <-done
+ if result.Err() != nil {
+ return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", result.Err())
}
- response := messages[len(messages)-1]
+ response := result.Response()
if response.Role != message.Assistant {
return tools.NewTextErrorResponse("no response"), nil
}
@@ -4,8 +4,6 @@ import (
"context"
"errors"
"fmt"
- "os"
- "runtime/debug"
"strings"
"sync"
@@ -16,133 +14,101 @@ import (
"github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/kujtimiihoxha/termai/internal/logging"
"github.com/kujtimiihoxha/termai/internal/message"
+ "github.com/kujtimiihoxha/termai/internal/permission"
"github.com/kujtimiihoxha/termai/internal/session"
)
// Common errors
var (
- ErrProviderNotEnabled = errors.New("provider is not enabled")
- ErrRequestCancelled = errors.New("request cancelled by user")
- ErrSessionBusy = errors.New("session is currently processing another request")
+ ErrRequestCancelled = errors.New("request cancelled by user")
+ ErrSessionBusy = errors.New("session is currently processing another request")
)
-// Service defines the interface for generating responses
+type AgentEvent struct {
+ message message.Message
+ err error
+}
+
+func (e *AgentEvent) Err() error {
+ return e.err
+}
+
+func (e *AgentEvent) Response() message.Message {
+ return e.message
+}
+
type Service interface {
- Generate(ctx context.Context, sessionID string, content string) error
- Cancel(sessionID string) error
+ Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error)
+ Cancel(sessionID string)
+ IsSessionBusy(sessionID string) bool
}
type agent struct {
- sessions session.Service
- messages message.Service
- model models.Model
- tools []tools.BaseTool
- agent provider.Provider
- titleGenerator provider.Provider
- activeRequests sync.Map // map[sessionID]context.CancelFunc
+ sessions session.Service
+ messages message.Service
+
+ tools []tools.BaseTool
+ provider provider.Provider
+
+ titleProvider provider.Provider
+
+ activeRequests sync.Map
}
-// NewAgent creates a new agent instance with the given model and tools
-func NewAgent(ctx context.Context, sessions session.Service, messages message.Service, model models.Model, tools []tools.BaseTool) (Service, error) {
- agentProvider, titleGenerator, err := getAgentProviders(ctx, model)
+func NewAgent(
+ agentName config.AgentName,
+ sessions session.Service,
+ messages message.Service,
+ agentTools []tools.BaseTool,
+) (Service, error) {
+ agentProvider, err := createAgentProvider(agentName)
if err != nil {
- return nil, fmt.Errorf("failed to initialize providers: %w", err)
+ return nil, err
+ }
+ var titleProvider provider.Provider
+ // Only generate titles for the coder agent
+ if agentName == config.AgentCoder {
+ titleProvider, err = createAgentProvider(config.AgentTitle)
+ if err != nil {
+ return nil, err
+ }
}
- return &agent{
- model: model,
- tools: tools,
- sessions: sessions,
+ agent := &agent{
+ provider: agentProvider,
messages: messages,
- agent: agentProvider,
- titleGenerator: titleGenerator,
+ sessions: sessions,
+ tools: agentTools,
+ titleProvider: titleProvider,
activeRequests: sync.Map{},
- }, nil
+ }
+
+ return agent, nil
}
-// Cancel cancels an active request by session ID
-func (a *agent) Cancel(sessionID string) error {
+func (a *agent) Cancel(sessionID string) {
if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
if cancel, ok := cancelFunc.(context.CancelFunc); ok {
logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
cancel()
- return nil
}
}
- return errors.New("no active request found for this session")
}
-// Generate starts the generation process
-func (a *agent) Generate(ctx context.Context, sessionID string, content string) error {
- // Check if this session already has an active request
- if _, busy := a.activeRequests.Load(sessionID); busy {
- return ErrSessionBusy
- }
-
- // Create a cancellable context
- genCtx, cancel := context.WithCancel(ctx)
-
- // Store cancel function to allow user cancellation
- a.activeRequests.Store(sessionID, cancel)
-
- // Launch the generation in a goroutine
- go func() {
- defer func() {
- if r := recover(); r != nil {
- logging.ErrorPersist(fmt.Sprintf("Panic in Generate: %v", r))
-
- // dump stack trace into a file
- file, err := os.Create("panic.log")
- if err != nil {
- logging.ErrorPersist(fmt.Sprintf("Failed to create panic log: %v", err))
- return
- }
-
- defer file.Close()
-
- stackTrace := debug.Stack()
- if _, err := file.Write(stackTrace); err != nil {
- logging.ErrorPersist(fmt.Sprintf("Failed to write panic log: %v", err))
- }
-
- }
- }()
- defer a.activeRequests.Delete(sessionID)
- defer cancel()
-
- if err := a.generate(genCtx, sessionID, content); err != nil {
- if !errors.Is(err, ErrRequestCancelled) && !errors.Is(err, context.Canceled) {
- // Log the error (avoid logging cancellations as they're expected)
- logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, err))
-
- // You may want to create an error message in the chat
- bgCtx := context.Background()
- errorMsg := fmt.Sprintf("Sorry, an error occurred: %v", err)
- _, createErr := a.messages.Create(bgCtx, sessionID, message.CreateMessageParams{
- Role: message.System,
- Parts: []message.ContentPart{
- message.TextContent{
- Text: errorMsg,
- },
- },
- })
- if createErr != nil {
- logging.ErrorPersist(fmt.Sprintf("Failed to create error message: %v", createErr))
- }
- }
- }
- }()
-
- return nil
-}
-
-// IsSessionBusy checks if a session currently has an active request
func (a *agent) IsSessionBusy(sessionID string) bool {
_, busy := a.activeRequests.Load(sessionID)
return busy
-} // handleTitleGeneration asynchronously generates a title for new sessions
-func (a *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) {
- response, err := a.titleGenerator.SendMessages(
+}
+
+func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
+ if a.titleProvider == nil {
+ return nil
+ }
+ session, err := a.sessions.Get(ctx, sessionID)
+ if err != nil {
+ return err
+ }
+ response, err := a.titleProvider.SendMessages(
ctx,
[]message.Message{
{
@@ -154,121 +120,152 @@ func (a *agent) handleTitleGeneration(ctx context.Context, sessionID, content st
},
},
},
- nil,
+ make([]tools.BaseTool, 0),
)
if err != nil {
- logging.ErrorPersist(fmt.Sprintf("Failed to generate title: %v", err))
- return
+ return err
}
- session, err := a.sessions.Get(ctx, sessionID)
- if err != nil {
- logging.ErrorPersist(fmt.Sprintf("Failed to get session: %v", err))
- return
+ title := strings.TrimSpace(strings.ReplaceAll(response.Content, "\n", " "))
+ if title == "" {
+ return nil
}
- if response.Content != "" {
- session.Title = strings.TrimSpace(response.Content)
- session.Title = strings.ReplaceAll(session.Title, "\n", " ")
- if _, err := a.sessions.Save(ctx, session); err != nil {
- logging.ErrorPersist(fmt.Sprintf("Failed to save session title: %v", err))
- }
+ session.Title = title
+ _, err = a.sessions.Save(ctx, session)
+ return err
+}
+
+func (a *agent) err(err error) AgentEvent {
+ return AgentEvent{
+ err: err,
}
}
-// TrackUsage updates token usage statistics for the session
-func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
- session, err := a.sessions.Get(ctx, sessionID)
- if err != nil {
- return fmt.Errorf("failed to get session: %w", err)
+func (a *agent) Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error) {
+ events := make(chan AgentEvent)
+ if a.IsSessionBusy(sessionID) {
+ return nil, ErrSessionBusy
}
- cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
- model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
- model.CostPer1MIn/1e6*float64(usage.InputTokens) +
- model.CostPer1MOut/1e6*float64(usage.OutputTokens)
+ genCtx, cancel := context.WithCancel(ctx)
+
+ a.activeRequests.Store(sessionID, cancel)
+ go func() {
+ logging.Debug("Request started", "sessionID", sessionID)
+ defer logging.RecoverPanic("agent.Run", func() {
+ events <- a.err(fmt.Errorf("panic while running the agent"))
+ })
- session.Cost += cost
- session.CompletionTokens += usage.OutputTokens
- session.PromptTokens += usage.InputTokens
+ result := a.processGeneration(genCtx, sessionID, content)
+ if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) {
+ logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, result))
+ }
+ logging.Debug("Request completed", "sessionID", sessionID)
+ a.activeRequests.Delete(sessionID)
+ cancel()
+ events <- result
+ close(events)
+ }()
+ return events, nil
+}
- _, err = a.sessions.Save(ctx, session)
+func (a *agent) processGeneration(ctx context.Context, sessionID, content string) AgentEvent {
+ // List existing messages; if none, start title generation asynchronously.
+ msgs, err := a.messages.List(ctx, sessionID)
if err != nil {
- return fmt.Errorf("failed to save session: %w", err)
+ return a.err(fmt.Errorf("failed to list messages: %w", err))
+ }
+ if len(msgs) == 0 {
+ go func() {
+ defer logging.RecoverPanic("agent.Run", func() {
+ logging.ErrorPersist("panic while generating title")
+ })
+ titleErr := a.generateTitle(context.Background(), sessionID, content)
+ if titleErr != nil {
+ logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
+ }
+ }()
}
- return nil
-}
-// processEvent handles different types of events during generation
-func (a *agent) processEvent(
- ctx context.Context,
- sessionID string,
- assistantMsg *message.Message,
- event provider.ProviderEvent,
-) error {
- select {
- case <-ctx.Done():
- return ctx.Err()
- default:
- // Continue processing
+ userMsg, err := a.createUserMessage(ctx, sessionID, content)
+ if err != nil {
+ return a.err(fmt.Errorf("failed to create user message: %w", err))
}
- switch event.Type {
- case provider.EventThinkingDelta:
- assistantMsg.AppendReasoningContent(event.Content)
- return a.messages.Update(ctx, *assistantMsg)
- case provider.EventContentDelta:
- assistantMsg.AppendContent(event.Content)
- return a.messages.Update(ctx, *assistantMsg)
- case provider.EventError:
- if errors.Is(event.Error, context.Canceled) {
- logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
- return context.Canceled
+ // Append the new user message to the conversation history.
+ msgHistory := append(msgs, userMsg)
+ for {
+ // Check for cancellation before each iteration
+ select {
+ case <-ctx.Done():
+ return a.err(ctx.Err())
+ default:
+ // Continue processing
}
- logging.ErrorPersist(event.Error.Error())
- return event.Error
- case provider.EventWarning:
- logging.WarnPersist(event.Info)
- case provider.EventInfo:
- logging.InfoPersist(event.Info)
- case provider.EventComplete:
- assistantMsg.SetToolCalls(event.Response.ToolCalls)
- assistantMsg.AddFinish(event.Response.FinishReason)
- if err := a.messages.Update(ctx, *assistantMsg); err != nil {
- return fmt.Errorf("failed to update message: %w", err)
+ agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
+ if err != nil {
+ if errors.Is(err, context.Canceled) {
+ return a.err(ErrRequestCancelled)
+ }
+ return a.err(fmt.Errorf("failed to process events: %w", err))
+ }
+ logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
+ if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
+ // We are not done, we need to respond with the tool response
+ msgHistory = append(msgHistory, agentMessage, *toolResults)
+ continue
+ }
+ return AgentEvent{
+ message: agentMessage,
}
- return a.TrackUsage(ctx, sessionID, a.model, event.Response.Usage)
}
+}
- return nil
+func (a *agent) createUserMessage(ctx context.Context, sessionID, content string) (message.Message, error) {
+ return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
+ Role: message.User,
+ Parts: []message.ContentPart{
+ message.TextContent{Text: content},
+ },
+ })
}
-// ExecuteTools runs all tool calls sequentially and returns the results
-func (a *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
- toolResults := make([]message.ToolResult, len(toolCalls))
+func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
+ eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
+
+ assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
+ Role: message.Assistant,
+ Parts: []message.ContentPart{},
+ Model: a.provider.Model().ID,
+ })
+ if err != nil {
+ return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
+ }
- // Create a child context that can be canceled
- ctx, cancel := context.WithCancel(ctx)
- defer cancel()
+ // Add the session and message ID into the context if needed by tools.
+ ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
+ ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
- // Check if already canceled before starting any execution
- if ctx.Err() != nil {
- // Mark all tools as canceled
- for i, toolCall := range toolCalls {
- toolResults[i] = message.ToolResult{
- ToolCallID: toolCall.ID,
- Content: "Tool execution canceled by user",
- IsError: true,
- }
+ // Process each event in the stream.
+ for event := range eventChan {
+ if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
+ a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
+ return assistantMsg, nil, processErr
+ }
+ if ctx.Err() != nil {
+ a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
+ return assistantMsg, nil, ctx.Err()
}
- return toolResults, ctx.Err()
}
+ toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
+ toolCalls := assistantMsg.ToolCalls()
for i, toolCall := range toolCalls {
- // Check for cancellation before executing each tool
select {
case <-ctx.Done():
- // Mark this and all remaining tools as canceled
+ a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
+ // Make all future tool calls cancelled
for j := i; j < len(toolCalls); j++ {
toolResults[j] = message.ToolResult{
ToolCallID: toolCalls[j].ID,
@@ -276,412 +273,180 @@ func (a *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall,
IsError: true,
}
}
- return toolResults, ctx.Err()
+ goto out
default:
// Continue processing
- }
-
- response := ""
- isError := false
- found := false
-
- // Find and execute the appropriate tool
- for _, tool := range tls {
- if tool.Info().Name == toolCall.Name {
- found = true
- toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
- ID: toolCall.ID,
- Name: toolCall.Name,
- Input: toolCall.Input,
- })
-
- if toolErr != nil {
- if errors.Is(toolErr, context.Canceled) {
- response = "Tool execution canceled by user"
- } else {
- response = fmt.Sprintf("Error running tool: %s", toolErr)
- }
- isError = true
- } else {
- response = toolResult.Content
- isError = toolResult.IsError
+ var tool tools.BaseTool
+ for _, availableTools := range a.tools {
+ if availableTools.Info().Name == toolCall.Name {
+ tool = availableTools
}
- break
}
- }
-
- if !found {
- response = fmt.Sprintf("Tool not found: %s", toolCall.Name)
- isError = true
- }
-
- toolResults[i] = message.ToolResult{
- ToolCallID: toolCall.ID,
- Content: response,
- IsError: isError,
- }
- }
- return toolResults, nil
-}
-
-// handleToolExecution processes tool calls and creates tool result messages
-func (a *agent) handleToolExecution(
- ctx context.Context,
- assistantMsg message.Message,
-) (*message.Message, error) {
- select {
- case <-ctx.Done():
- // If cancelled, create tool results that indicate cancellation
- if len(assistantMsg.ToolCalls()) > 0 {
- toolResults := make([]message.ToolResult, 0, len(assistantMsg.ToolCalls()))
- for _, tc := range assistantMsg.ToolCalls() {
- toolResults = append(toolResults, message.ToolResult{
- ToolCallID: tc.ID,
- Content: "Tool execution canceled by user",
+ // Tool not found
+ if tool == nil {
+ toolResults[i] = message.ToolResult{
+ ToolCallID: toolCall.ID,
+ Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
IsError: true,
- })
+ }
+ continue
}
- // Use background context to ensure the message is created even if original context is cancelled
- bgCtx := context.Background()
- parts := make([]message.ContentPart, 0)
- for _, toolResult := range toolResults {
- parts = append(parts, toolResult)
- }
- msg, err := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{
- Role: message.Tool,
- Parts: parts,
+ toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
+ ID: toolCall.ID,
+ Name: toolCall.Name,
+ Input: toolCall.Input,
})
- if err != nil {
- return nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
- }
- return &msg, ctx.Err()
- }
- return nil, ctx.Err()
- default:
- // Continue processing
- }
-
- if len(assistantMsg.ToolCalls()) == 0 {
- return nil, nil
- }
-
- toolResults, err := a.ExecuteTools(ctx, assistantMsg.ToolCalls(), a.tools)
- if err != nil {
- // If error is from cancellation, still return the partial results we have
- if errors.Is(err, context.Canceled) {
- // Use background context to ensure the message is created even if original context is cancelled
- bgCtx := context.Background()
- parts := make([]message.ContentPart, 0)
- for _, toolResult := range toolResults {
- parts = append(parts, toolResult)
+ if toolErr != nil {
+ if errors.Is(toolErr, permission.ErrorPermissionDenied) {
+ toolResults[i] = message.ToolResult{
+ ToolCallID: toolCall.ID,
+ Content: "Permission denied",
+ IsError: true,
+ }
+ for j := i + 1; j < len(toolCalls); j++ {
+ toolResults[j] = message.ToolResult{
+ ToolCallID: toolCalls[j].ID,
+ Content: "Tool execution canceled by user",
+ IsError: true,
+ }
+ }
+ a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
+ } else {
+ toolResults[i] = message.ToolResult{
+ ToolCallID: toolCall.ID,
+ Content: toolErr.Error(),
+ IsError: true,
+ }
+ for j := i; j < len(toolCalls); j++ {
+ toolResults[j] = message.ToolResult{
+ ToolCallID: toolCalls[j].ID,
+ Content: "Previous tool failed",
+ IsError: true,
+ }
+ }
+ a.finishMessage(ctx, &assistantMsg, message.FinishReasonError)
+ }
+ // If permission is denied or an error happens we cancel all the following tools
+ break
}
-
- msg, createErr := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{
- Role: message.Tool,
- Parts: parts,
- })
- if createErr != nil {
- logging.ErrorPersist(fmt.Sprintf("Failed to create tool message after cancellation: %v", createErr))
- return nil, err
+ toolResults[i] = message.ToolResult{
+ ToolCallID: toolCall.ID,
+ Content: toolResult.Content,
+ Metadata: toolResult.Metadata,
+ IsError: toolResult.IsError,
}
- return &msg, err
}
- return nil, err
}
-
- parts := make([]message.ContentPart, 0, len(toolResults))
- for _, toolResult := range toolResults {
- parts = append(parts, toolResult)
+out:
+ if len(toolResults) == 0 {
+ return assistantMsg, nil, nil
}
-
- msg, err := a.messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{
+ parts := make([]message.ContentPart, 0)
+ for _, tr := range toolResults {
+ parts = append(parts, tr)
+ }
+ msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
Role: message.Tool,
Parts: parts,
})
if err != nil {
- return nil, fmt.Errorf("failed to create tool message: %w", err)
+ return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
}
- return &msg, nil
+ return assistantMsg, &msg, err
}
-// generate handles the main generation workflow
-func (a *agent) generate(ctx context.Context, sessionID string, content string) error {
- ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
+func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
+ msg.AddFinish(finishReson)
+ _ = a.messages.Update(ctx, *msg)
+}
- // Handle context cancellation at any point
- if err := ctx.Err(); err != nil {
- return ErrRequestCancelled
+func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ // Continue processing.
}
- messages, err := a.messages.List(ctx, sessionID)
- if err != nil {
- return fmt.Errorf("failed to list messages: %w", err)
+ switch event.Type {
+ case provider.EventThinkingDelta:
+ assistantMsg.AppendReasoningContent(event.Content)
+ return a.messages.Update(ctx, *assistantMsg)
+ case provider.EventContentDelta:
+ assistantMsg.AppendContent(event.Content)
+ return a.messages.Update(ctx, *assistantMsg)
+ case provider.EventError:
+ if errors.Is(event.Error, context.Canceled) {
+ logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
+ return context.Canceled
+ }
+ logging.ErrorPersist(event.Error.Error())
+ return event.Error
+ case provider.EventComplete:
+ assistantMsg.SetToolCalls(event.Response.ToolCalls)
+ assistantMsg.AddFinish(event.Response.FinishReason)
+ if err := a.messages.Update(ctx, *assistantMsg); err != nil {
+ return fmt.Errorf("failed to update message: %w", err)
+ }
+ return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
}
- if len(messages) == 0 {
- titleCtx := context.Background()
- go a.handleTitleGeneration(titleCtx, sessionID, content)
- }
+ return nil
+}
- userMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
- Role: message.User,
- Parts: []message.ContentPart{
- message.TextContent{
- Text: content,
- },
- },
- })
+func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
+ sess, err := a.sessions.Get(ctx, sessionID)
if err != nil {
- return fmt.Errorf("failed to create user message: %w", err)
+ return fmt.Errorf("failed to get session: %w", err)
}
- messages = append(messages, userMsg)
-
- for {
- // Check for cancellation before each iteration
- select {
- case <-ctx.Done():
- return ErrRequestCancelled
- default:
- // Continue processing
- }
-
- eventChan, err := a.agent.StreamResponse(ctx, messages, a.tools)
- if err != nil {
- if errors.Is(err, context.Canceled) {
- return ErrRequestCancelled
- }
- return fmt.Errorf("failed to stream response: %w", err)
- }
-
- assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
- Role: message.Assistant,
- Parts: []message.ContentPart{},
- Model: a.model.ID,
- })
- if err != nil {
- return fmt.Errorf("failed to create assistant message: %w", err)
- }
-
- ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
-
- // Process events from the LLM provider
- for event := range eventChan {
- if err := a.processEvent(ctx, sessionID, &assistantMsg, event); err != nil {
- if errors.Is(err, context.Canceled) {
- // Mark as canceled but don't create separate message
- assistantMsg.AddFinish("canceled")
- _ = a.messages.Update(context.Background(), assistantMsg)
- return ErrRequestCancelled
- }
- assistantMsg.AddFinish("error:" + err.Error())
- _ = a.messages.Update(ctx, assistantMsg)
- return fmt.Errorf("event processing error: %w", err)
- }
-
- // Check for cancellation during event processing
- select {
- case <-ctx.Done():
- // Mark as canceled
- assistantMsg.AddFinish("canceled")
- _ = a.messages.Update(context.Background(), assistantMsg)
- return ErrRequestCancelled
- default:
- }
- }
-
- // Check for cancellation before tool execution
- select {
- case <-ctx.Done():
- assistantMsg.AddFinish("canceled_by_user")
- _ = a.messages.Update(context.Background(), assistantMsg)
- return ErrRequestCancelled
- default:
- }
-
- // Execute any tool calls
- toolMsg, err := a.handleToolExecution(ctx, assistantMsg)
- if err != nil {
- if errors.Is(err, context.Canceled) {
- assistantMsg.AddFinish("canceled_by_user")
- _ = a.messages.Update(context.Background(), assistantMsg)
- return ErrRequestCancelled
- }
- return fmt.Errorf("tool execution error: %w", err)
- }
-
- if err := a.messages.Update(ctx, assistantMsg); err != nil {
- return fmt.Errorf("failed to update assistant message: %w", err)
- }
-
- // If no tool calls, we're done
- if len(assistantMsg.ToolCalls()) == 0 {
- break
- }
+ cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
+ model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
+ model.CostPer1MIn/1e6*float64(usage.InputTokens) +
+ model.CostPer1MOut/1e6*float64(usage.OutputTokens)
- // Add messages for next iteration
- messages = append(messages, assistantMsg)
- if toolMsg != nil {
- messages = append(messages, *toolMsg)
- }
+ sess.Cost += cost
+ sess.CompletionTokens += usage.OutputTokens
+ sess.PromptTokens += usage.InputTokens
- // Check for cancellation after tool execution
- select {
- case <-ctx.Done():
- return ErrRequestCancelled
- default:
- }
+ _, err = a.sessions.Save(ctx, sess)
+ if err != nil {
+ return fmt.Errorf("failed to save session: %w", err)
}
-
return nil
}
-// getAgentProviders initializes the LLM providers based on the chosen model
-func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
- maxTokens := config.Get().Model.CoderMaxTokens
-
- providerConfig, ok := config.Get().Providers[model.Provider]
- if !ok || providerConfig.Disabled {
- return nil, nil, ErrProviderNotEnabled
+func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
+ cfg := config.Get()
+ agentConfig, ok := cfg.Agents[agentName]
+ if !ok {
+ return nil, fmt.Errorf("agent %s not found", agentName)
+ }
+ model, ok := models.SupportedModels[agentConfig.Model]
+ if !ok {
+ return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
}
- var agentProvider provider.Provider
- var titleGenerator provider.Provider
- var err error
-
- switch model.Provider {
- case models.ProviderOpenAI:
- agentProvider, err = provider.NewOpenAIProvider(
- provider.WithOpenAISystemMessage(
- prompt.CoderOpenAISystemPrompt(),
- ),
- provider.WithOpenAIMaxTokens(maxTokens),
- provider.WithOpenAIModel(model),
- provider.WithOpenAIKey(providerConfig.APIKey),
- )
- if err != nil {
- return nil, nil, fmt.Errorf("failed to create OpenAI agent provider: %w", err)
- }
-
- titleGenerator, err = provider.NewOpenAIProvider(
- provider.WithOpenAISystemMessage(
- prompt.TitlePrompt(),
- ),
- provider.WithOpenAIMaxTokens(80),
- provider.WithOpenAIModel(model),
- provider.WithOpenAIKey(providerConfig.APIKey),
- )
- if err != nil {
- return nil, nil, fmt.Errorf("failed to create OpenAI title generator: %w", err)
- }
-
- case models.ProviderAnthropic:
- agentProvider, err = provider.NewAnthropicProvider(
- provider.WithAnthropicSystemMessage(
- prompt.CoderAnthropicSystemPrompt(),
- ),
- provider.WithAnthropicMaxTokens(maxTokens),
- provider.WithAnthropicKey(providerConfig.APIKey),
- provider.WithAnthropicModel(model),
- )
- if err != nil {
- return nil, nil, fmt.Errorf("failed to create Anthropic agent provider: %w", err)
- }
-
- titleGenerator, err = provider.NewAnthropicProvider(
- provider.WithAnthropicSystemMessage(
- prompt.TitlePrompt(),
- ),
- provider.WithAnthropicMaxTokens(80),
- provider.WithAnthropicKey(providerConfig.APIKey),
- provider.WithAnthropicModel(model),
- )
- if err != nil {
- return nil, nil, fmt.Errorf("failed to create Anthropic title generator: %w", err)
- }
-
- case models.ProviderGemini:
- agentProvider, err = provider.NewGeminiProvider(
- ctx,
- provider.WithGeminiSystemMessage(
- prompt.CoderOpenAISystemPrompt(),
- ),
- provider.WithGeminiMaxTokens(int32(maxTokens)),
- provider.WithGeminiKey(providerConfig.APIKey),
- provider.WithGeminiModel(model),
- )
- if err != nil {
- return nil, nil, fmt.Errorf("failed to create Gemini agent provider: %w", err)
- }
-
- titleGenerator, err = provider.NewGeminiProvider(
- ctx,
- provider.WithGeminiSystemMessage(
- prompt.TitlePrompt(),
- ),
- provider.WithGeminiMaxTokens(80),
- provider.WithGeminiKey(providerConfig.APIKey),
- provider.WithGeminiModel(model),
- )
- if err != nil {
- return nil, nil, fmt.Errorf("failed to create Gemini title generator: %w", err)
- }
-
- case models.ProviderGROQ:
- agentProvider, err = provider.NewOpenAIProvider(
- provider.WithOpenAISystemMessage(
- prompt.CoderAnthropicSystemPrompt(),
- ),
- provider.WithOpenAIMaxTokens(maxTokens),
- provider.WithOpenAIModel(model),
- provider.WithOpenAIKey(providerConfig.APIKey),
- provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
- )
- if err != nil {
- return nil, nil, fmt.Errorf("failed to create GROQ agent provider: %w", err)
- }
-
- titleGenerator, err = provider.NewOpenAIProvider(
- provider.WithOpenAISystemMessage(
- prompt.TitlePrompt(),
- ),
- provider.WithOpenAIMaxTokens(80),
- provider.WithOpenAIModel(model),
- provider.WithOpenAIKey(providerConfig.APIKey),
- provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
- )
- if err != nil {
- return nil, nil, fmt.Errorf("failed to create GROQ title generator: %w", err)
- }
-
- case models.ProviderBedrock:
- agentProvider, err = provider.NewBedrockProvider(
- provider.WithBedrockSystemMessage(
- prompt.CoderAnthropicSystemPrompt(),
- ),
- provider.WithBedrockMaxTokens(maxTokens),
- provider.WithBedrockModel(model),
- )
- if err != nil {
- return nil, nil, fmt.Errorf("failed to create Bedrock agent provider: %w", err)
- }
-
- titleGenerator, err = provider.NewBedrockProvider(
- provider.WithBedrockSystemMessage(
- prompt.TitlePrompt(),
- ),
- provider.WithBedrockMaxTokens(80),
- provider.WithBedrockModel(model),
- )
- if err != nil {
- return nil, nil, fmt.Errorf("failed to create Bedrock title generator: %w", err)
- }
- default:
- return nil, nil, fmt.Errorf("unsupported provider: %s", model.Provider)
+ providerCfg, ok := cfg.Providers[model.Provider]
+ if !ok {
+ return nil, fmt.Errorf("provider %s not supported", model.Provider)
+ }
+ if providerCfg.Disabled {
+ return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
+ }
+ agentProvider, err := provider.NewProvider(
+ model.Provider,
+ provider.WithAPIKey(providerCfg.APIKey),
+ provider.WithModel(model),
+ provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
+ provider.WithMaxTokens(agentConfig.MaxTokens),
+ )
+ if err != nil {
+ return nil, fmt.Errorf("could not create provider: %v", err)
}
- return agentProvider, titleGenerator, nil
+ return agentProvider, nil
}
@@ -1,63 +0,0 @@
-package agent
-
-import (
- "context"
- "errors"
-
- "github.com/kujtimiihoxha/termai/internal/config"
- "github.com/kujtimiihoxha/termai/internal/llm/models"
- "github.com/kujtimiihoxha/termai/internal/llm/tools"
- "github.com/kujtimiihoxha/termai/internal/lsp"
- "github.com/kujtimiihoxha/termai/internal/message"
- "github.com/kujtimiihoxha/termai/internal/permission"
- "github.com/kujtimiihoxha/termai/internal/session"
-)
-
-type coderAgent struct {
- Service
-}
-
-func NewCoderAgent(
- permissions permission.Service,
- sessions session.Service,
- messages message.Service,
- lspClients map[string]*lsp.Client,
-) (Service, error) {
- model, ok := models.SupportedModels[config.Get().Model.Coder]
- if !ok {
- return nil, errors.New("model not supported")
- }
-
- ctx := context.Background()
- otherTools := GetMcpTools(ctx, permissions)
- if len(lspClients) > 0 {
- otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
- }
- agent, err := NewAgent(
- ctx,
- sessions,
- messages,
- model,
- append(
- []tools.BaseTool{
- tools.NewBashTool(permissions),
- tools.NewEditTool(lspClients, permissions),
- tools.NewFetchTool(permissions),
- tools.NewGlobTool(),
- tools.NewGrepTool(),
- tools.NewLsTool(),
- tools.NewSourcegraphTool(),
- tools.NewViewTool(lspClients),
- tools.NewWriteTool(lspClients, permissions),
- NewAgentTool(sessions, messages, lspClients),
- }, otherTools...,
- ),
- )
- if err != nil {
- return nil, err
- }
-
- return &coderAgent{
- agent,
- }, nil
-}
@@ -46,7 +46,7 @@ func runTool(ctx context.Context, c MCPClient, toolName string, input string) (t
initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initRequest.Params.ClientInfo = mcp.Implementation{
- Name: "termai",
+ Name: "OpenCode",
Version: version.Version,
}
@@ -135,7 +135,7 @@ func getTools(ctx context.Context, name string, m config.MCPServer, permissions
initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initRequest.Params.ClientInfo = mcp.Implementation{
- Name: "termai",
+ Name: "OpenCode",
Version: version.Version,
}
@@ -1,47 +0,0 @@
-package agent
-
-import (
- "context"
- "errors"
-
- "github.com/kujtimiihoxha/termai/internal/config"
- "github.com/kujtimiihoxha/termai/internal/llm/models"
- "github.com/kujtimiihoxha/termai/internal/llm/tools"
- "github.com/kujtimiihoxha/termai/internal/lsp"
- "github.com/kujtimiihoxha/termai/internal/message"
- "github.com/kujtimiihoxha/termai/internal/session"
-)
-
-type taskAgent struct {
- Service
-}
-
-func NewTaskAgent(messages message.Service, sessions session.Service, lspClients map[string]*lsp.Client) (Service, error) {
- model, ok := models.SupportedModels[config.Get().Model.Coder]
- if !ok {
- return nil, errors.New("model not supported")
- }
-
- ctx := context.Background()
-
- agent, err := NewAgent(
- ctx,
- sessions,
- messages,
- model,
- []tools.BaseTool{
- tools.NewGlobTool(),
- tools.NewGrepTool(),
- tools.NewLsTool(),
- tools.NewSourcegraphTool(),
- tools.NewViewTool(lspClients),
- },
- )
- if err != nil {
- return nil, err
- }
-
- return &taskAgent{
- agent,
- }, nil
-}
@@ -0,0 +1,50 @@
+package agent
+
+import (
+ "context"
+
+ "github.com/kujtimiihoxha/termai/internal/history"
+ "github.com/kujtimiihoxha/termai/internal/llm/tools"
+ "github.com/kujtimiihoxha/termai/internal/lsp"
+ "github.com/kujtimiihoxha/termai/internal/message"
+ "github.com/kujtimiihoxha/termai/internal/permission"
+ "github.com/kujtimiihoxha/termai/internal/session"
+)
+
+func CoderAgentTools(
+ permissions permission.Service,
+ sessions session.Service,
+ messages message.Service,
+ history history.Service,
+ lspClients map[string]*lsp.Client,
+) []tools.BaseTool {
+ ctx := context.Background()
+ otherTools := GetMcpTools(ctx, permissions)
+ if len(lspClients) > 0 {
+ otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
+ }
+ return append(
+ []tools.BaseTool{
+ tools.NewBashTool(permissions),
+ tools.NewEditTool(lspClients, permissions, history),
+ tools.NewFetchTool(permissions),
+ tools.NewGlobTool(),
+ tools.NewGrepTool(),
+ tools.NewLsTool(),
+ tools.NewSourcegraphTool(),
+ tools.NewViewTool(lspClients),
+ tools.NewWriteTool(lspClients, permissions, history),
+ NewAgentTool(sessions, messages, lspClients),
+ }, otherTools...,
+ )
+}
+
+func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool {
+ return []tools.BaseTool{
+ tools.NewGlobTool(),
+ tools.NewGrepTool(),
+ tools.NewLsTool(),
+ tools.NewSourcegraphTool(),
+ tools.NewViewTool(lspClients),
+ }
+}
@@ -0,0 +1,71 @@
+package models
+
+const (
+ ProviderAnthropic ModelProvider = "anthropic"
+
+ // Models
+ Claude35Sonnet ModelID = "claude-3.5-sonnet"
+ Claude3Haiku ModelID = "claude-3-haiku"
+ Claude37Sonnet ModelID = "claude-3.7-sonnet"
+ Claude35Haiku ModelID = "claude-3.5-haiku"
+ Claude3Opus ModelID = "claude-3-opus"
+)
+
+var AnthropicModels = map[ModelID]Model{
+ // Anthropic
+ Claude35Sonnet: {
+ ID: Claude35Sonnet,
+ Name: "Claude 3.5 Sonnet",
+ Provider: ProviderAnthropic,
+ APIModel: "claude-3-5-sonnet-latest",
+ CostPer1MIn: 3.0,
+ CostPer1MInCached: 3.75,
+ CostPer1MOutCached: 0.30,
+ CostPer1MOut: 15.0,
+ ContextWindow: 200000,
+ },
+ Claude3Haiku: {
+ ID: Claude3Haiku,
+ Name: "Claude 3 Haiku",
+ Provider: ProviderAnthropic,
+ APIModel: "claude-3-haiku-latest",
+ CostPer1MIn: 0.25,
+ CostPer1MInCached: 0.30,
+ CostPer1MOutCached: 0.03,
+ CostPer1MOut: 1.25,
+ ContextWindow: 200000,
+ },
+ Claude37Sonnet: {
+ ID: Claude37Sonnet,
+ Name: "Claude 3.7 Sonnet",
+ Provider: ProviderAnthropic,
+ APIModel: "claude-3-7-sonnet-latest",
+ CostPer1MIn: 3.0,
+ CostPer1MInCached: 3.75,
+ CostPer1MOutCached: 0.30,
+ CostPer1MOut: 15.0,
+ ContextWindow: 200000,
+ },
+ Claude35Haiku: {
+ ID: Claude35Haiku,
+ Name: "Claude 3.5 Haiku",
+ Provider: ProviderAnthropic,
+ APIModel: "claude-3-5-haiku-latest",
+ CostPer1MIn: 0.80,
+ CostPer1MInCached: 1.0,
+ CostPer1MOutCached: 0.08,
+ CostPer1MOut: 4.0,
+ ContextWindow: 200000,
+ },
+ Claude3Opus: {
+ ID: Claude3Opus,
+ Name: "Claude 3 Opus",
+ Provider: ProviderAnthropic,
+ APIModel: "claude-3-opus-latest",
+ CostPer1MIn: 15.0,
+ CostPer1MInCached: 18.75,
+ CostPer1MOutCached: 1.50,
+ CostPer1MOut: 75.0,
+ ContextWindow: 200000,
+ },
+}
@@ -1,5 +1,7 @@
package models
+import "maps"
+
type (
ModelID string
ModelProvider string
@@ -14,15 +16,13 @@ type Model struct {
CostPer1MOut float64 `json:"cost_per_1m_out"`
CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
+ ContextWindow int64 `json:"context_window"`
}
// Model IDs
const (
- // Anthropic
- Claude35Sonnet ModelID = "claude-3.5-sonnet"
- Claude3Haiku ModelID = "claude-3-haiku"
- Claude37Sonnet ModelID = "claude-3.7-sonnet"
// OpenAI
+ GPT4o ModelID = "gpt-4o"
GPT41 ModelID = "gpt-4.1"
// GEMINI
@@ -37,47 +37,59 @@ const (
)
const (
- ProviderOpenAI ModelProvider = "openai"
- ProviderAnthropic ModelProvider = "anthropic"
- ProviderBedrock ModelProvider = "bedrock"
- ProviderGemini ModelProvider = "gemini"
- ProviderGROQ ModelProvider = "groq"
+ ProviderOpenAI ModelProvider = "openai"
+ ProviderBedrock ModelProvider = "bedrock"
+ ProviderGemini ModelProvider = "gemini"
+ ProviderGROQ ModelProvider = "groq"
+
+ // ForTests
+ ProviderMock ModelProvider = "__mock"
)
var SupportedModels = map[ModelID]Model{
- // Anthropic
- Claude35Sonnet: {
- ID: Claude35Sonnet,
- Name: "Claude 3.5 Sonnet",
- Provider: ProviderAnthropic,
- APIModel: "claude-3-5-sonnet-latest",
- CostPer1MIn: 3.0,
- CostPer1MInCached: 3.75,
- CostPer1MOutCached: 0.30,
- CostPer1MOut: 15.0,
- },
- Claude3Haiku: {
- ID: Claude3Haiku,
- Name: "Claude 3 Haiku",
- Provider: ProviderAnthropic,
- APIModel: "claude-3-haiku-latest",
- CostPer1MIn: 0.80,
- CostPer1MInCached: 1,
- CostPer1MOutCached: 0.08,
- CostPer1MOut: 4,
- },
- Claude37Sonnet: {
- ID: Claude37Sonnet,
- Name: "Claude 3.7 Sonnet",
- Provider: ProviderAnthropic,
- APIModel: "claude-3-7-sonnet-latest",
- CostPer1MIn: 3.0,
- CostPer1MInCached: 3.75,
- CostPer1MOutCached: 0.30,
- CostPer1MOut: 15.0,
+ // // Anthropic
+ // Claude35Sonnet: {
+ // ID: Claude35Sonnet,
+ // Name: "Claude 3.5 Sonnet",
+ // Provider: ProviderAnthropic,
+ // APIModel: "claude-3-5-sonnet-latest",
+ // CostPer1MIn: 3.0,
+ // CostPer1MInCached: 3.75,
+ // CostPer1MOutCached: 0.30,
+ // CostPer1MOut: 15.0,
+ // },
+ // Claude3Haiku: {
+ // ID: Claude3Haiku,
+ // Name: "Claude 3 Haiku",
+ // Provider: ProviderAnthropic,
+ // APIModel: "claude-3-haiku-latest",
+ // CostPer1MIn: 0.80,
+ // CostPer1MInCached: 1,
+ // CostPer1MOutCached: 0.08,
+ // CostPer1MOut: 4,
+ // },
+ // Claude37Sonnet: {
+ // ID: Claude37Sonnet,
+ // Name: "Claude 3.7 Sonnet",
+ // Provider: ProviderAnthropic,
+ // APIModel: "claude-3-7-sonnet-latest",
+ // CostPer1MIn: 3.0,
+ // CostPer1MInCached: 3.75,
+ // CostPer1MOutCached: 0.30,
+ // CostPer1MOut: 15.0,
+ // },
+ //
+ // // OpenAI
+ GPT4o: {
+ ID: GPT4o,
+ Name: "GPT-4o",
+ Provider: ProviderOpenAI,
+ APIModel: "gpt-4.1",
+ CostPer1MIn: 2.00,
+ CostPer1MInCached: 0.50,
+ CostPer1MOutCached: 0,
+ CostPer1MOut: 8.00,
},
-
- // OpenAI
GPT41: {
ID: GPT41,
Name: "GPT-4.1",
@@ -88,51 +100,55 @@ var SupportedModels = map[ModelID]Model{
CostPer1MOutCached: 0,
CostPer1MOut: 8.00,
},
+ //
+ // // GEMINI
+ // GEMINI25: {
+ // ID: GEMINI25,
+ // Name: "Gemini 2.5 Pro",
+ // Provider: ProviderGemini,
+ // APIModel: "gemini-2.5-pro-exp-03-25",
+ // CostPer1MIn: 0,
+ // CostPer1MInCached: 0,
+ // CostPer1MOutCached: 0,
+ // CostPer1MOut: 0,
+ // },
+ //
+ // GRMINI20Flash: {
+ // ID: GRMINI20Flash,
+ // Name: "Gemini 2.0 Flash",
+ // Provider: ProviderGemini,
+ // APIModel: "gemini-2.0-flash",
+ // CostPer1MIn: 0.1,
+ // CostPer1MInCached: 0,
+ // CostPer1MOutCached: 0.025,
+ // CostPer1MOut: 0.4,
+ // },
+ //
+ // // GROQ
+ // QWENQwq: {
+ // ID: QWENQwq,
+ // Name: "Qwen Qwq",
+ // Provider: ProviderGROQ,
+ // APIModel: "qwen-qwq-32b",
+ // CostPer1MIn: 0,
+ // CostPer1MInCached: 0,
+ // CostPer1MOutCached: 0,
+ // CostPer1MOut: 0,
+ // },
+ //
+ // // Bedrock
+ // BedrockClaude37Sonnet: {
+ // ID: BedrockClaude37Sonnet,
+ // Name: "Bedrock: Claude 3.7 Sonnet",
+ // Provider: ProviderBedrock,
+ // APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0",
+ // CostPer1MIn: 3.0,
+ // CostPer1MInCached: 3.75,
+ // CostPer1MOutCached: 0.30,
+ // CostPer1MOut: 15.0,
+ // },
+}
- // GEMINI
- GEMINI25: {
- ID: GEMINI25,
- Name: "Gemini 2.5 Pro",
- Provider: ProviderGemini,
- APIModel: "gemini-2.5-pro-exp-03-25",
- CostPer1MIn: 0,
- CostPer1MInCached: 0,
- CostPer1MOutCached: 0,
- CostPer1MOut: 0,
- },
-
- GRMINI20Flash: {
- ID: GRMINI20Flash,
- Name: "Gemini 2.0 Flash",
- Provider: ProviderGemini,
- APIModel: "gemini-2.0-flash",
- CostPer1MIn: 0.1,
- CostPer1MInCached: 0,
- CostPer1MOutCached: 0.025,
- CostPer1MOut: 0.4,
- },
-
- // GROQ
- QWENQwq: {
- ID: QWENQwq,
- Name: "Qwen Qwq",
- Provider: ProviderGROQ,
- APIModel: "qwen-qwq-32b",
- CostPer1MIn: 0,
- CostPer1MInCached: 0,
- CostPer1MOutCached: 0,
- CostPer1MOut: 0,
- },
-
- // Bedrock
- BedrockClaude37Sonnet: {
- ID: BedrockClaude37Sonnet,
- Name: "Bedrock: Claude 3.7 Sonnet",
- Provider: ProviderBedrock,
- APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0",
- CostPer1MIn: 3.0,
- CostPer1MInCached: 3.75,
- CostPer1MOutCached: 0.30,
- CostPer1MOut: 15.0,
- },
+func init() {
+ maps.Copy(SupportedModels, AnthropicModels)
}
@@ -9,11 +9,22 @@ import (
"time"
"github.com/kujtimiihoxha/termai/internal/config"
+ "github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
)
-func CoderOpenAISystemPrompt() string {
- basePrompt := `You are termAI, an autonomous CLI-based software engineer. Your job is to reduce user effort by proactively reasoning, inferring context, and solving software engineering tasks end-to-end with minimal prompting.
+func CoderPrompt(provider models.ModelProvider) string {
+ basePrompt := baseAnthropicCoderPrompt
+ switch provider {
+ case models.ProviderOpenAI:
+ basePrompt = baseOpenAICoderPrompt
+ }
+ envInfo := getEnvironmentInfo()
+
+ return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation())
+}
+
+const baseOpenAICoderPrompt = `You are termAI, an autonomous CLI-based software engineer. Your job is to reduce user effort by proactively reasoning, inferring context, and solving software engineering tasks end-to-end with minimal prompting.
# Your mindset
Act like a competent, efficient software engineer who is familiar with large codebases. You should:
@@ -65,13 +76,7 @@ assistant: [searches repo for references, returns file paths and lines]
Never commit changes unless the user explicitly asks you to.`
- envInfo := getEnvironmentInfo()
-
- return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation())
-}
-
-func CoderAnthropicSystemPrompt() string {
- basePrompt := `You are termAI, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user.
+const baseAnthropicCoderPrompt = `You are termAI, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user.
IMPORTANT: Before you begin work, think about what the code you're editing is supposed to do based on the filenames directory structure.
@@ -166,11 +171,6 @@ NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTAN
You MUST answer concisely with fewer than 4 lines of text (not including tool use or code generation), unless user asks for detail.`
- envInfo := getEnvironmentInfo()
-
- return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation())
-}
-
func getEnvironmentInfo() string {
cwd := config.WorkingDirectory()
isGit := isGitRepo(cwd)
@@ -0,0 +1,19 @@
+package prompt
+
+import (
+ "github.com/kujtimiihoxha/termai/internal/config"
+ "github.com/kujtimiihoxha/termai/internal/llm/models"
+)
+
+func GetAgentPrompt(agentName config.AgentName, provider models.ModelProvider) string {
+ switch agentName {
+ case config.AgentCoder:
+ return CoderPrompt(provider)
+ case config.AgentTitle:
+ return TitlePrompt(provider)
+ case config.AgentTask:
+ return TaskPrompt(provider)
+ default:
+ return "You are a helpful assistant"
+ }
+}
@@ -2,11 +2,12 @@ package prompt
import (
"fmt"
+
+ "github.com/kujtimiihoxha/termai/internal/llm/models"
)
-func TaskAgentSystemPrompt() string {
+func TaskPrompt(_ models.ModelProvider) string {
agentPrompt := `You are an agent for termAI. Given the user's prompt, you should use the tools available to you to answer the user's question.
-
Notes:
1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
2. When relevant, share file names and code snippets relevant to the query
@@ -1,6 +1,8 @@
package prompt
-func TitlePrompt() string {
+import "github.com/kujtimiihoxha/termai/internal/llm/models"
+
+func TitlePrompt(_ models.ModelProvider) string {
return `you will generate a short title based on the first message a user begins a conversation with
- ensure it is not more than 50 characters long
- the title should be a summary of the user's message
@@ -12,187 +12,257 @@ import (
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/bedrock"
"github.com/anthropics/anthropic-sdk-go/option"
- "github.com/kujtimiihoxha/termai/internal/llm/models"
+ "github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
+ "github.com/kujtimiihoxha/termai/internal/logging"
"github.com/kujtimiihoxha/termai/internal/message"
)
-type anthropicProvider struct {
- client anthropic.Client
- model models.Model
- maxTokens int64
- apiKey string
- systemMessage string
- useBedrock bool
- disableCache bool
+type anthropicOptions struct {
+ useBedrock bool
+ disableCache bool
+ shouldThink func(userMessage string) bool
}
-type AnthropicOption func(*anthropicProvider)
+type AnthropicOption func(*anthropicOptions)
-func WithAnthropicSystemMessage(message string) AnthropicOption {
- return func(a *anthropicProvider) {
- a.systemMessage = message
- }
+type anthropicClient struct {
+ providerOptions providerClientOptions
+ options anthropicOptions
+ client anthropic.Client
}
-func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption {
- return func(a *anthropicProvider) {
- a.maxTokens = maxTokens
- }
-}
+type AnthropicClient ProviderClient
-func WithAnthropicModel(model models.Model) AnthropicOption {
- return func(a *anthropicProvider) {
- a.model = model
+func newAnthropicClient(opts providerClientOptions) AnthropicClient {
+ anthropicOpts := anthropicOptions{}
+ for _, o := range opts.anthropicOptions {
+ o(&anthropicOpts)
}
-}
-func WithAnthropicKey(apiKey string) AnthropicOption {
- return func(a *anthropicProvider) {
- a.apiKey = apiKey
+ anthropicClientOptions := []option.RequestOption{}
+ if opts.apiKey != "" {
+ anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
}
-}
-
-func WithAnthropicBedrock() AnthropicOption {
- return func(a *anthropicProvider) {
- a.useBedrock = true
+ if anthropicOpts.useBedrock {
+ anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background()))
}
-}
-func WithAnthropicDisableCache() AnthropicOption {
- return func(a *anthropicProvider) {
- a.disableCache = true
+ client := anthropic.NewClient(anthropicClientOptions...)
+ return &anthropicClient{
+ providerOptions: opts,
+ options: anthropicOpts,
+ client: client,
}
}
-func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
- provider := &anthropicProvider{
- maxTokens: 1024,
- }
+func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) {
+ cachedBlocks := 0
+ for _, msg := range messages {
+ switch msg.Role {
+ case message.User:
+ content := anthropic.NewTextBlock(msg.Content().String())
+ if cachedBlocks < 2 && !a.options.disableCache {
+ content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
+ Type: "ephemeral",
+ }
+ cachedBlocks++
+ }
+ anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content))
- for _, opt := range opts {
- opt(provider)
- }
+ case message.Assistant:
+ blocks := []anthropic.ContentBlockParamUnion{}
+ if msg.Content().String() != "" {
+ content := anthropic.NewTextBlock(msg.Content().String())
+ if cachedBlocks < 2 && !a.options.disableCache {
+ content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
+ Type: "ephemeral",
+ }
+ cachedBlocks++
+ }
+ blocks = append(blocks, content)
+ }
- if provider.systemMessage == "" {
- return nil, errors.New("system message is required")
- }
+ for _, toolCall := range msg.ToolCalls() {
+ var inputMap map[string]any
+ err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
+ if err != nil {
+ continue
+ }
+ blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
+ }
- anthropicOptions := []option.RequestOption{}
+ if len(blocks) == 0 {
+ logging.Warn("There is a message without content, investigate")
+ // This should never happend but we log this because we might have a bug in our cleanup method
+ continue
+ }
+ anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
- if provider.apiKey != "" {
- anthropicOptions = append(anthropicOptions, option.WithAPIKey(provider.apiKey))
- }
- if provider.useBedrock {
- anthropicOptions = append(anthropicOptions, bedrock.WithLoadDefaultConfig(context.Background()))
+ case message.Tool:
+ results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
+ for i, toolResult := range msg.ToolResults() {
+ results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
+ }
+ anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
+ }
}
-
- provider.client = anthropic.NewClient(anthropicOptions...)
- return provider, nil
+ return
}
-func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
- messages = cleanupMessages(messages)
- anthropicMessages := a.convertToAnthropicMessages(messages)
- anthropicTools := a.convertToAnthropicTools(tools)
-
- response, err := a.client.Messages.New(
- ctx,
- anthropic.MessageNewParams{
- Model: anthropic.Model(a.model.APIModel),
- MaxTokens: a.maxTokens,
- Temperature: anthropic.Float(0),
- Messages: anthropicMessages,
- Tools: anthropicTools,
- System: []anthropic.TextBlockParam{
- {
- Text: a.systemMessage,
- CacheControl: anthropic.CacheControlEphemeralParam{
- Type: "ephemeral",
- },
- },
+func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
+ anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
+
+ for i, tool := range tools {
+ info := tool.Info()
+ toolParam := anthropic.ToolParam{
+ Name: info.Name,
+ Description: anthropic.String(info.Description),
+ InputSchema: anthropic.ToolInputSchemaParam{
+ Properties: info.Parameters,
+ // TODO: figure out how we can tell claude the required fields?
},
- },
- )
- if err != nil {
- return nil, err
- }
+ }
- content := ""
- for _, block := range response.Content {
- if text, ok := block.AsAny().(anthropic.TextBlock); ok {
- content += text.Text
+ if i == len(tools)-1 && !a.options.disableCache {
+ toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
+ Type: "ephemeral",
+ }
}
- }
- toolCalls := a.extractToolCalls(response.Content)
- tokenUsage := a.extractTokenUsage(response.Usage)
+ anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
+ }
- return &ProviderResponse{
- Content: content,
- ToolCalls: toolCalls,
- Usage: tokenUsage,
- }, nil
+ return anthropicTools
}
-func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
- messages = cleanupMessages(messages)
- anthropicMessages := a.convertToAnthropicMessages(messages)
- anthropicTools := a.convertToAnthropicTools(tools)
+func (a *anthropicClient) finishReason(reason string) message.FinishReason {
+ switch reason {
+ case "end_turn":
+ return message.FinishReasonEndTurn
+ case "max_tokens":
+ return message.FinishReasonMaxTokens
+ case "tool_use":
+ return message.FinishReasonToolUse
+ case "stop_sequence":
+ return message.FinishReasonEndTurn
+ default:
+ return message.FinishReasonUnknown
+ }
+}
+func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams {
var thinkingParam anthropic.ThinkingConfigParamUnion
lastMessage := messages[len(messages)-1]
+ isUser := lastMessage.Role == anthropic.MessageParamRoleUser
+ messageContent := ""
temperature := anthropic.Float(0)
- if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") {
- thinkingParam = anthropic.ThinkingConfigParamUnion{
- OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
- BudgetTokens: int64(float64(a.maxTokens) * 0.8),
- Type: "enabled",
- },
+ if isUser {
+ for _, m := range lastMessage.Content {
+ if m.OfRequestTextBlock != nil && m.OfRequestTextBlock.Text != "" {
+ messageContent = m.OfRequestTextBlock.Text
+ }
+ }
+ if messageContent != "" && a.options.shouldThink != nil && a.options.shouldThink(messageContent) {
+ thinkingParam = anthropic.ThinkingConfigParamUnion{
+ OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
+ BudgetTokens: int64(float64(a.providerOptions.maxTokens) * 0.8),
+ Type: "enabled",
+ },
+ }
+ temperature = anthropic.Float(1)
}
- temperature = anthropic.Float(1)
}
- eventChan := make(chan ProviderEvent)
+ return anthropic.MessageNewParams{
+ Model: anthropic.Model(a.providerOptions.model.APIModel),
+ MaxTokens: a.providerOptions.maxTokens,
+ Temperature: temperature,
+ Messages: messages,
+ Tools: tools,
+ Thinking: thinkingParam,
+ System: []anthropic.TextBlockParam{
+ {
+ Text: a.providerOptions.systemMessage,
+ CacheControl: anthropic.CacheControlEphemeralParam{
+ Type: "ephemeral",
+ },
+ },
+ },
+ }
+}
- go func() {
- defer close(eventChan)
+func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (resposne *ProviderResponse, err error) {
+ preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
+ cfg := config.Get()
+ if cfg.Debug {
+ jsonData, _ := json.Marshal(preparedMessages)
+ logging.Debug("Prepared messages", "messages", string(jsonData))
+ }
+ attempts := 0
+ for {
+ attempts++
+ anthropicResponse, err := a.client.Messages.New(
+ ctx,
+ preparedMessages,
+ )
+ // If there is an error we are going to see if we can retry the call
+ if err != nil {
+ retry, after, retryErr := a.shouldRetry(attempts, err)
+ if retryErr != nil {
+ return nil, retryErr
+ }
+ if retry {
+ logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-time.After(time.Duration(after) * time.Millisecond):
+ continue
+ }
+ }
+ return nil, retryErr
+ }
- const maxRetries = 8
- attempts := 0
+ content := ""
+ for _, block := range anthropicResponse.Content {
+ if text, ok := block.AsAny().(anthropic.TextBlock); ok {
+ content += text.Text
+ }
+ }
- for {
+ return &ProviderResponse{
+ Content: content,
+ ToolCalls: a.toolCalls(*anthropicResponse),
+ Usage: a.usage(*anthropicResponse),
+ }, nil
+ }
+}
+func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
+ preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
+ cfg := config.Get()
+ if cfg.Debug {
+ jsonData, _ := json.Marshal(preparedMessages)
+ logging.Debug("Prepared messages", "messages", string(jsonData))
+ }
+ attempts := 0
+ eventChan := make(chan ProviderEvent)
+ go func() {
+ for {
attempts++
-
- stream := a.client.Messages.NewStreaming(
+ anthropicStream := a.client.Messages.NewStreaming(
ctx,
- anthropic.MessageNewParams{
- Model: anthropic.Model(a.model.APIModel),
- MaxTokens: a.maxTokens,
- Temperature: temperature,
- Messages: anthropicMessages,
- Tools: anthropicTools,
- Thinking: thinkingParam,
- System: []anthropic.TextBlockParam{
- {
- Text: a.systemMessage,
- CacheControl: anthropic.CacheControlEphemeralParam{
- Type: "ephemeral",
- },
- },
- },
- },
+ preparedMessages,
)
-
accumulatedMessage := anthropic.Message{}
- for stream.Next() {
- event := stream.Current()
+ for anthropicStream.Next() {
+ event := anthropicStream.Current()
err := accumulatedMessage.Accumulate(event)
if err != nil {
eventChan <- ProviderEvent{Type: EventError, Error: err}
- return // Don't retry on accumulation errors
+ continue
}
switch event := event.AsAny().(type) {
@@ -211,6 +281,7 @@ func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []messa
Content: event.Delta.Text,
}
}
+ // TODO: check if we can somehow stream tool calls
case anthropic.ContentBlockStopEvent:
eventChan <- ProviderEvent{Type: EventContentStop}
@@ -223,84 +294,87 @@ func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []messa
}
}
- toolCalls := a.extractToolCalls(accumulatedMessage.Content)
- tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
-
eventChan <- ProviderEvent{
Type: EventComplete,
Response: &ProviderResponse{
Content: content,
- ToolCalls: toolCalls,
- Usage: tokenUsage,
- FinishReason: string(accumulatedMessage.StopReason),
+ ToolCalls: a.toolCalls(accumulatedMessage),
+ Usage: a.usage(accumulatedMessage),
+ FinishReason: a.finishReason(string(accumulatedMessage.StopReason)),
},
}
}
}
- err := stream.Err()
+ err := anthropicStream.Err()
if err == nil || errors.Is(err, io.EOF) {
+ close(eventChan)
return
}
-
- var apierr *anthropic.Error
- if !errors.As(err, &apierr) {
- eventChan <- ProviderEvent{Type: EventError, Error: err}
- return
- }
-
- if apierr.StatusCode != 429 && apierr.StatusCode != 529 {
- eventChan <- ProviderEvent{Type: EventError, Error: err}
+ // If there is an error we are going to see if we can retry the call
+ retry, after, retryErr := a.shouldRetry(attempts, err)
+ if retryErr != nil {
+ eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
+ close(eventChan)
return
}
-
- if attempts > maxRetries {
- eventChan <- ProviderEvent{
- Type: EventError,
- Error: errors.New("maximum retry attempts reached for rate limit (429)"),
- }
- return
- }
-
- retryMs := 0
- retryAfterValues := apierr.Response.Header.Values("Retry-After")
- if len(retryAfterValues) > 0 {
- var retryAfterSec int
- if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryAfterSec); err == nil {
- retryMs = retryAfterSec * 1000
- eventChan <- ProviderEvent{
- Type: EventWarning,
- Info: fmt.Sprintf("[Rate limited: waiting %d seconds as specified by API]", retryAfterSec),
+ if retry {
+ logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
+ select {
+ case <-ctx.Done():
+ // context cancelled
+ if ctx.Err() != nil {
+ eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
}
+ close(eventChan)
+ return
+ case <-time.After(time.Duration(after) * time.Millisecond):
+ continue
}
- } else {
- eventChan <- ProviderEvent{
- Type: EventWarning,
- Info: fmt.Sprintf("[Retrying due to rate limit... attempt %d of %d]", attempts, maxRetries),
- }
-
- backoffMs := 2000 * (1 << (attempts - 1))
- jitterMs := int(float64(backoffMs) * 0.2)
- retryMs = backoffMs + jitterMs
}
- select {
- case <-ctx.Done():
+ if ctx.Err() != nil {
eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
- return
- case <-time.After(time.Duration(retryMs) * time.Millisecond):
- continue
}
+ close(eventChan)
+ return
}
}()
+ return eventChan
+}
- return eventChan, nil
+func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, error) {
+ var apierr *anthropic.Error
+ if !errors.As(err, &apierr) {
+ return false, 0, err
+ }
+
+ if apierr.StatusCode != 429 && apierr.StatusCode != 529 {
+ return false, 0, err
+ }
+
+ if attempts > maxRetries {
+ return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
+ }
+
+ retryMs := 0
+ retryAfterValues := apierr.Response.Header.Values("Retry-After")
+
+ backoffMs := 2000 * (1 << (attempts - 1))
+ jitterMs := int(float64(backoffMs) * 0.2)
+ retryMs = backoffMs + jitterMs
+ if len(retryAfterValues) > 0 {
+ if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
+ retryMs = retryMs * 1000
+ }
+ }
+ return true, int64(retryMs), nil
}
-func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall {
+func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall {
var toolCalls []message.ToolCall
- for _, block := range content {
+ for _, block := range msg.Content {
switch variant := block.AsAny().(type) {
case anthropic.ToolUseBlock:
toolCall := message.ToolCall{
@@ -316,90 +390,33 @@ func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUni
return toolCalls
}
-func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage {
+func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage {
return TokenUsage{
- InputTokens: usage.InputTokens,
- OutputTokens: usage.OutputTokens,
- CacheCreationTokens: usage.CacheCreationInputTokens,
- CacheReadTokens: usage.CacheReadInputTokens,
+ InputTokens: msg.Usage.InputTokens,
+ OutputTokens: msg.Usage.OutputTokens,
+ CacheCreationTokens: msg.Usage.CacheCreationInputTokens,
+ CacheReadTokens: msg.Usage.CacheReadInputTokens,
}
}
-func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
- anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
-
- for i, tool := range tools {
- info := tool.Info()
- toolParam := anthropic.ToolParam{
- Name: info.Name,
- Description: anthropic.String(info.Description),
- InputSchema: anthropic.ToolInputSchemaParam{
- Properties: info.Parameters,
- },
- }
-
- if i == len(tools)-1 && !a.disableCache {
- toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
- Type: "ephemeral",
- }
- }
-
- anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
+func WithAnthropicBedrock(useBedrock bool) AnthropicOption {
+ return func(options *anthropicOptions) {
+ options.useBedrock = useBedrock
}
-
- return anthropicTools
}
-func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam {
- anthropicMessages := make([]anthropic.MessageParam, 0, len(messages))
- cachedBlocks := 0
-
- for _, msg := range messages {
- switch msg.Role {
- case message.User:
- content := anthropic.NewTextBlock(msg.Content().String())
- if cachedBlocks < 2 && !a.disableCache {
- content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
- Type: "ephemeral",
- }
- cachedBlocks++
- }
- anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content))
-
- case message.Assistant:
- blocks := []anthropic.ContentBlockParamUnion{}
- if msg.Content().String() != "" {
- content := anthropic.NewTextBlock(msg.Content().String())
- if cachedBlocks < 2 && !a.disableCache {
- content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
- Type: "ephemeral",
- }
- cachedBlocks++
- }
- blocks = append(blocks, content)
- }
-
- for _, toolCall := range msg.ToolCalls() {
- var inputMap map[string]any
- err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
- if err != nil {
- continue
- }
- blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
- }
+func WithAnthropicDisableCache() AnthropicOption {
+ return func(options *anthropicOptions) {
+ options.disableCache = true
+ }
+}
- if len(blocks) > 0 {
- anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
- }
+func DefaultShouldThinkFn(s string) bool {
+ return strings.Contains(strings.ToLower(s), "think")
+}
- case message.Tool:
- results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
- for i, toolResult := range msg.ToolResults() {
- results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
- }
- anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
- }
+func WithAnthropicShouldThinkFn(fn func(string) bool) AnthropicOption {
+ return func(options *anthropicOptions) {
+ options.shouldThink = fn
}
-
- return anthropicMessages
}
@@ -7,33 +7,29 @@ import (
"os"
"strings"
- "github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/kujtimiihoxha/termai/internal/message"
)
-type bedrockProvider struct {
- childProvider Provider
- model models.Model
- maxTokens int64
- systemMessage string
+type bedrockOptions struct {
+ // Bedrock specific options can be added here
}
-func (b *bedrockProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
- return b.childProvider.SendMessages(ctx, messages, tools)
-}
+type BedrockOption func(*bedrockOptions)
-func (b *bedrockProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
- return b.childProvider.StreamResponse(ctx, messages, tools)
+type bedrockClient struct {
+ providerOptions providerClientOptions
+ options bedrockOptions
+ childProvider ProviderClient
}
-func NewBedrockProvider(opts ...BedrockOption) (Provider, error) {
- provider := &bedrockProvider{}
- for _, opt := range opts {
- opt(provider)
- }
+type BedrockClient ProviderClient
+
+func newBedrockClient(opts providerClientOptions) BedrockClient {
+ bedrockOpts := bedrockOptions{}
+ // Apply bedrock specific options if they are added in the future
- // based on the AWS region prefix the model name with, us, eu, ap, sa, etc.
+ // Get AWS region from environment
region := os.Getenv("AWS_REGION")
if region == "" {
region = os.Getenv("AWS_DEFAULT_REGION")
@@ -43,45 +39,62 @@ func NewBedrockProvider(opts ...BedrockOption) (Provider, error) {
region = "us-east-1" // default region
}
if len(region) < 2 {
- return nil, errors.New("AWS_REGION or AWS_DEFAULT_REGION environment variable is invalid")
+ return &bedrockClient{
+ providerOptions: opts,
+ options: bedrockOpts,
+ childProvider: nil, // Will cause an error when used
+ }
}
+
+ // Prefix the model name with region
regionPrefix := region[:2]
- provider.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, provider.model.APIModel)
+ modelName := opts.model.APIModel
+ opts.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, modelName)
- if strings.Contains(string(provider.model.APIModel), "anthropic") {
- anthropic, err := NewAnthropicProvider(
- WithAnthropicModel(provider.model),
- WithAnthropicMaxTokens(provider.maxTokens),
- WithAnthropicSystemMessage(provider.systemMessage),
- WithAnthropicBedrock(),
+ // Determine which provider to use based on the model
+ if strings.Contains(string(opts.model.APIModel), "anthropic") {
+ // Create Anthropic client with Bedrock configuration
+ anthropicOpts := opts
+ anthropicOpts.anthropicOptions = append(anthropicOpts.anthropicOptions,
+ WithAnthropicBedrock(true),
WithAnthropicDisableCache(),
)
- provider.childProvider = anthropic
- if err != nil {
- return nil, err
+ return &bedrockClient{
+ providerOptions: opts,
+ options: bedrockOpts,
+ childProvider: newAnthropicClient(anthropicOpts),
}
- } else {
- return nil, errors.New("unsupported model for bedrock provider")
}
- return provider, nil
-}
-
-type BedrockOption func(*bedrockProvider)
-func WithBedrockSystemMessage(message string) BedrockOption {
- return func(a *bedrockProvider) {
- a.systemMessage = message
+ // Return client with nil childProvider if model is not supported
+ // This will cause an error when used
+ return &bedrockClient{
+ providerOptions: opts,
+ options: bedrockOpts,
+ childProvider: nil,
}
}
-func WithBedrockMaxTokens(maxTokens int64) BedrockOption {
- return func(a *bedrockProvider) {
- a.maxTokens = maxTokens
+func (b *bedrockClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
+ if b.childProvider == nil {
+ return nil, errors.New("unsupported model for bedrock provider")
}
+ return b.childProvider.send(ctx, messages, tools)
}
-func WithBedrockModel(model models.Model) BedrockOption {
- return func(a *bedrockProvider) {
- a.model = model
+func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
+ eventChan := make(chan ProviderEvent)
+
+ if b.childProvider == nil {
+ go func() {
+ eventChan <- ProviderEvent{
+ Type: EventError,
+ Error: errors.New("unsupported model for bedrock provider"),
+ }
+ close(eventChan)
+ }()
+ return eventChan
}
-}
+
+ return b.childProvider.stream(ctx, messages, tools)
+}
@@ -4,80 +4,68 @@ import (
"context"
"encoding/json"
"errors"
+ "fmt"
+ "io"
+ "strings"
+ "time"
"github.com/google/generative-ai-go/genai"
"github.com/google/uuid"
- "github.com/kujtimiihoxha/termai/internal/llm/models"
+ "github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
+ "github.com/kujtimiihoxha/termai/internal/logging"
"github.com/kujtimiihoxha/termai/internal/message"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
)
-type geminiProvider struct {
- client *genai.Client
- model models.Model
- maxTokens int32
- apiKey string
- systemMessage string
+type geminiOptions struct {
+ disableCache bool
}
-type GeminiOption func(*geminiProvider)
+type GeminiOption func(*geminiOptions)
-func NewGeminiProvider(ctx context.Context, opts ...GeminiOption) (Provider, error) {
- provider := &geminiProvider{
- maxTokens: 5000,
- }
+type geminiClient struct {
+ providerOptions providerClientOptions
+ options geminiOptions
+ client *genai.Client
+}
- for _, opt := range opts {
- opt(provider)
- }
+type GeminiClient ProviderClient
- if provider.systemMessage == "" {
- return nil, errors.New("system message is required")
+func newGeminiClient(opts providerClientOptions) GeminiClient {
+ geminiOpts := geminiOptions{}
+ for _, o := range opts.geminiOptions {
+ o(&geminiOpts)
}
- client, err := genai.NewClient(ctx, option.WithAPIKey(provider.apiKey))
+ client, err := genai.NewClient(context.Background(), option.WithAPIKey(opts.apiKey))
if err != nil {
- return nil, err
- }
- provider.client = client
-
- return provider, nil
-}
-
-func WithGeminiSystemMessage(message string) GeminiOption {
- return func(p *geminiProvider) {
- p.systemMessage = message
+ logging.Error("Failed to create Gemini client", "error", err)
+ return nil
}
-}
-func WithGeminiMaxTokens(maxTokens int32) GeminiOption {
- return func(p *geminiProvider) {
- p.maxTokens = maxTokens
+ return &geminiClient{
+ providerOptions: opts,
+ options: geminiOpts,
+ client: client,
}
}
-func WithGeminiModel(model models.Model) GeminiOption {
- return func(p *geminiProvider) {
- p.model = model
- }
-}
-
-func WithGeminiKey(apiKey string) GeminiOption {
- return func(p *geminiProvider) {
- p.apiKey = apiKey
- }
-}
+func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
+ var history []*genai.Content
-func (p *geminiProvider) Close() {
- if p.client != nil {
- p.client.Close()
- }
-}
+ // Add system message first
+ history = append(history, &genai.Content{
+ Parts: []genai.Part{genai.Text(g.providerOptions.systemMessage)},
+ Role: "user",
+ })
-func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*genai.Content {
- var history []*genai.Content
+ // Add a system response to acknowledge the system message
+ history = append(history, &genai.Content{
+ Parts: []genai.Part{genai.Text("I'll help you with that.")},
+ Role: "model",
+ })
for _, msg := range messages {
switch msg.Role {
@@ -86,6 +74,7 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g
Parts: []genai.Part{genai.Text(msg.Content().String())},
Role: "user",
})
+
case message.Assistant:
content := &genai.Content{
Role: "model",
@@ -107,6 +96,7 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g
}
history = append(history, content)
+
case message.Tool:
for _, result := range msg.ToolResults() {
response := map[string]interface{}{"result": result.Content}
@@ -114,10 +104,11 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g
if err == nil {
response = parsed
}
+
var toolCall message.ToolCall
- for _, msg := range messages {
- if msg.Role == message.Assistant {
- for _, call := range msg.ToolCalls() {
+ for _, m := range messages {
+ if m.Role == message.Assistant {
+ for _, call := range m.ToolCalls() {
if call.ID == result.ToolCallID {
toolCall = call
break
@@ -140,186 +131,358 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g
return history
}
-func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse) TokenUsage {
- if resp == nil || resp.UsageMetadata == nil {
- return TokenUsage{}
- }
+func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
+ geminiTools := make([]*genai.Tool, 0, len(tools))
- return TokenUsage{
- InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
- OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
- CacheCreationTokens: 0, // Not directly provided by Gemini
- CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
+ for _, tool := range tools {
+ info := tool.Info()
+ declaration := &genai.FunctionDeclaration{
+ Name: info.Name,
+ Description: info.Description,
+ Parameters: &genai.Schema{
+ Type: genai.TypeObject,
+ Properties: convertSchemaProperties(info.Parameters),
+ Required: info.Required,
+ },
+ }
+
+ geminiTools = append(geminiTools, &genai.Tool{
+ FunctionDeclarations: []*genai.FunctionDeclaration{declaration},
+ })
}
+
+ return geminiTools
}
-func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
- messages = cleanupMessages(messages)
- model := p.client.GenerativeModel(p.model.APIModel)
- model.SetMaxOutputTokens(p.maxTokens)
+func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
+ reasonStr := reason.String()
+ switch {
+ case reasonStr == "STOP":
+ return message.FinishReasonEndTurn
+ case reasonStr == "MAX_TOKENS":
+ return message.FinishReasonMaxTokens
+ case strings.Contains(reasonStr, "FUNCTION") || strings.Contains(reasonStr, "TOOL"):
+ return message.FinishReasonToolUse
+ default:
+ return message.FinishReasonUnknown
+ }
+}
- model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
+func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
+ model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
+ model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
+ // Convert tools
if len(tools) > 0 {
- declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
- for _, declaration := range declarations {
- model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
- }
+ model.Tools = g.convertTools(tools)
}
- chat := model.StartChat()
- chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
+ // Convert messages
+ geminiMessages := g.convertMessages(messages)
- lastUserMsg := messages[len(messages)-1]
- resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content().String()))
- if err != nil {
- return nil, err
+ cfg := config.Get()
+ if cfg.Debug {
+ jsonData, _ := json.Marshal(geminiMessages)
+ logging.Debug("Prepared messages", "messages", string(jsonData))
}
- var content string
- var toolCalls []message.ToolCall
+ attempts := 0
+ for {
+ attempts++
+ chat := model.StartChat()
+ chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
+
+ lastMsg := geminiMessages[len(geminiMessages)-1]
+ var lastText string
+ for _, part := range lastMsg.Parts {
+ if text, ok := part.(genai.Text); ok {
+ lastText = string(text)
+ break
+ }
+ }
- if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
- for _, part := range resp.Candidates[0].Content.Parts {
- switch p := part.(type) {
- case genai.Text:
- content = string(p)
- case genai.FunctionCall:
- id := "call_" + uuid.New().String()
- args, _ := json.Marshal(p.Args)
- toolCalls = append(toolCalls, message.ToolCall{
- ID: id,
- Name: p.Name,
- Input: string(args),
- Type: "function",
- })
+ resp, err := chat.SendMessage(ctx, genai.Text(lastText))
+ // If there is an error we are going to see if we can retry the call
+ if err != nil {
+ retry, after, retryErr := g.shouldRetry(attempts, err)
+ if retryErr != nil {
+ return nil, retryErr
}
+ if retry {
+ logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-time.After(time.Duration(after) * time.Millisecond):
+ continue
+ }
+ }
+ return nil, retryErr
}
- }
- tokenUsage := p.extractTokenUsage(resp)
+ content := ""
+ var toolCalls []message.ToolCall
+
+ if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
+ for _, part := range resp.Candidates[0].Content.Parts {
+ switch p := part.(type) {
+ case genai.Text:
+ content = string(p)
+ case genai.FunctionCall:
+ id := "call_" + uuid.New().String()
+ args, _ := json.Marshal(p.Args)
+ toolCalls = append(toolCalls, message.ToolCall{
+ ID: id,
+ Name: p.Name,
+ Input: string(args),
+ Type: "function",
+ })
+ }
+ }
+ }
- return &ProviderResponse{
- Content: content,
- ToolCalls: toolCalls,
- Usage: tokenUsage,
- }, nil
+ return &ProviderResponse{
+ Content: content,
+ ToolCalls: toolCalls,
+ Usage: g.usage(resp),
+ FinishReason: g.finishReason(resp.Candidates[0].FinishReason),
+ }, nil
+ }
}
-func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
- messages = cleanupMessages(messages)
- model := p.client.GenerativeModel(p.model.APIModel)
- model.SetMaxOutputTokens(p.maxTokens)
-
- model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
+func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
+ model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
+ model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
+ // Convert tools
if len(tools) > 0 {
- declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
- for _, declaration := range declarations {
- model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
- }
+ model.Tools = g.convertTools(tools)
}
- chat := model.StartChat()
- chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
+ // Convert messages
+ geminiMessages := g.convertMessages(messages)
- lastUserMsg := messages[len(messages)-1]
-
- iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content().String()))
+ cfg := config.Get()
+ if cfg.Debug {
+ jsonData, _ := json.Marshal(geminiMessages)
+ logging.Debug("Prepared messages", "messages", string(jsonData))
+ }
+ attempts := 0
eventChan := make(chan ProviderEvent)
go func() {
defer close(eventChan)
- var finalResp *genai.GenerateContentResponse
- currentContent := ""
- toolCalls := []message.ToolCall{}
-
for {
- resp, err := iter.Next()
- if err == iterator.Done {
- break
- }
- if err != nil {
- eventChan <- ProviderEvent{
- Type: EventError,
- Error: err,
+ attempts++
+ chat := model.StartChat()
+ chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
+
+ lastMsg := geminiMessages[len(geminiMessages)-1]
+ var lastText string
+ for _, part := range lastMsg.Parts {
+ if text, ok := part.(genai.Text); ok {
+ lastText = string(text)
+ break
}
- return
}
- finalResp = resp
+ iter := chat.SendMessageStream(ctx, genai.Text(lastText))
- if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
- for _, part := range resp.Candidates[0].Content.Parts {
- switch p := part.(type) {
- case genai.Text:
- newText := string(p)
- eventChan <- ProviderEvent{
- Type: EventContentDelta,
- Content: newText,
- }
- currentContent += newText
- case genai.FunctionCall:
- id := "call_" + uuid.New().String()
- args, _ := json.Marshal(p.Args)
- newCall := message.ToolCall{
- ID: id,
- Name: p.Name,
- Input: string(args),
- Type: "function",
- }
+ currentContent := ""
+ toolCalls := []message.ToolCall{}
+ var finalResp *genai.GenerateContentResponse
- isNew := true
- for _, existing := range toolCalls {
- if existing.Name == newCall.Name && existing.Input == newCall.Input {
- isNew = false
- break
+ eventChan <- ProviderEvent{Type: EventContentStart}
+
+ for {
+ resp, err := iter.Next()
+ if err == iterator.Done {
+ break
+ }
+ if err != nil {
+ retry, after, retryErr := g.shouldRetry(attempts, err)
+ if retryErr != nil {
+ eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
+ return
+ }
+ if retry {
+ logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
+ select {
+ case <-ctx.Done():
+ if ctx.Err() != nil {
+ eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
}
+
+ return
+ case <-time.After(time.Duration(after) * time.Millisecond):
+ break
}
+ } else {
+ eventChan <- ProviderEvent{Type: EventError, Error: err}
+ return
+ }
+ }
+
+ finalResp = resp
+
+ if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
+ for _, part := range resp.Candidates[0].Content.Parts {
+ switch p := part.(type) {
+ case genai.Text:
+ newText := string(p)
+ delta := newText[len(currentContent):]
+ if delta != "" {
+ eventChan <- ProviderEvent{
+ Type: EventContentDelta,
+ Content: delta,
+ }
+ currentContent = newText
+ }
+ case genai.FunctionCall:
+ id := "call_" + uuid.New().String()
+ args, _ := json.Marshal(p.Args)
+ newCall := message.ToolCall{
+ ID: id,
+ Name: p.Name,
+ Input: string(args),
+ Type: "function",
+ }
- if isNew {
- toolCalls = append(toolCalls, newCall)
+ isNew := true
+ for _, existing := range toolCalls {
+ if existing.Name == newCall.Name && existing.Input == newCall.Input {
+ isNew = false
+ break
+ }
+ }
+
+ if isNew {
+ toolCalls = append(toolCalls, newCall)
+ }
}
}
}
}
- }
- tokenUsage := p.extractTokenUsage(finalResp)
+ eventChan <- ProviderEvent{Type: EventContentStop}
- eventChan <- ProviderEvent{
- Type: EventComplete,
- Response: &ProviderResponse{
- Content: currentContent,
- ToolCalls: toolCalls,
- Usage: tokenUsage,
- FinishReason: string(finalResp.Candidates[0].FinishReason.String()),
- },
+ if finalResp != nil {
+ eventChan <- ProviderEvent{
+ Type: EventComplete,
+ Response: &ProviderResponse{
+ Content: currentContent,
+ ToolCalls: toolCalls,
+ Usage: g.usage(finalResp),
+ FinishReason: g.finishReason(finalResp.Candidates[0].FinishReason),
+ },
+ }
+ return
+ }
+
+ // If we get here, we need to retry
+ if attempts > maxRetries {
+ eventChan <- ProviderEvent{
+ Type: EventError,
+ Error: fmt.Errorf("maximum retry attempts reached: %d retries", maxRetries),
+ }
+ return
+ }
+
+ // Wait before retrying
+ select {
+ case <-ctx.Done():
+ if ctx.Err() != nil {
+ eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
+ }
+ return
+ case <-time.After(time.Duration(2000*(1<<(attempts-1))) * time.Millisecond):
+ continue
+ }
}
}()
- return eventChan, nil
+ return eventChan
}
-func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration {
- declarations := make([]*genai.FunctionDeclaration, len(tools))
+func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
+ // Check if error is a rate limit error
+ if attempts > maxRetries {
+ return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
+ }
- for i, tool := range tools {
- info := tool.Info()
- declarations[i] = &genai.FunctionDeclaration{
- Name: info.Name,
- Description: info.Description,
- Parameters: &genai.Schema{
- Type: genai.TypeObject,
- Properties: convertSchemaProperties(info.Parameters),
- Required: info.Required,
- },
+ // Gemini doesn't have a standard error type we can check against
+ // So we'll check the error message for rate limit indicators
+ if errors.Is(err, io.EOF) {
+ return false, 0, err
+ }
+
+ errMsg := err.Error()
+ isRateLimit := false
+
+ // Check for common rate limit error messages
+ if contains(errMsg, "rate limit", "quota exceeded", "too many requests") {
+ isRateLimit = true
+ }
+
+ if !isRateLimit {
+ return false, 0, err
+ }
+
+ // Calculate backoff with jitter
+ backoffMs := 2000 * (1 << (attempts - 1))
+ jitterMs := int(float64(backoffMs) * 0.2)
+ retryMs := backoffMs + jitterMs
+
+ return true, int64(retryMs), nil
+}
+
+func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall {
+ var toolCalls []message.ToolCall
+
+ if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
+ for _, part := range resp.Candidates[0].Content.Parts {
+ if funcCall, ok := part.(genai.FunctionCall); ok {
+ id := "call_" + uuid.New().String()
+ args, _ := json.Marshal(funcCall.Args)
+ toolCalls = append(toolCalls, message.ToolCall{
+ ID: id,
+ Name: funcCall.Name,
+ Input: string(args),
+ Type: "function",
+ })
+ }
}
}
- return declarations
+ return toolCalls
+}
+
+func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
+ if resp == nil || resp.UsageMetadata == nil {
+ return TokenUsage{}
+ }
+
+ return TokenUsage{
+ InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
+ OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
+ CacheCreationTokens: 0, // Not directly provided by Gemini
+ CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
+ }
+}
+
+func WithGeminiDisableCache() GeminiOption {
+ return func(options *geminiOptions) {
+ options.disableCache = true
+ }
+}
+
+// Helper functions
+func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
+ var result map[string]interface{}
+ err := json.Unmarshal([]byte(jsonStr), &result)
+ return result, err
}
func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
@@ -396,8 +559,12 @@ func mapJSONTypeToGenAI(jsonType string) genai.Type {
}
}
-func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
- var result map[string]interface{}
- err := json.Unmarshal([]byte(jsonStr), &result)
- return result, err
+func contains(s string, substrs ...string) bool {
+ for _, substr := range substrs {
+ if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
+ return true
+ }
+ }
+ return false
}
+
@@ -2,89 +2,65 @@ package provider
import (
"context"
+ "encoding/json"
"errors"
+ "fmt"
+ "io"
+ "time"
- "github.com/kujtimiihoxha/termai/internal/llm/models"
+ "github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
+ "github.com/kujtimiihoxha/termai/internal/logging"
"github.com/kujtimiihoxha/termai/internal/message"
"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
)
-type openaiProvider struct {
- client openai.Client
- model models.Model
- maxTokens int64
- baseURL string
- apiKey string
- systemMessage string
+type openaiOptions struct {
+ baseURL string
+ disableCache bool
}
-type OpenAIOption func(*openaiProvider)
+type OpenAIOption func(*openaiOptions)
-func NewOpenAIProvider(opts ...OpenAIOption) (Provider, error) {
- provider := &openaiProvider{
- maxTokens: 5000,
- }
-
- for _, opt := range opts {
- opt(provider)
- }
-
- clientOpts := []option.RequestOption{
- option.WithAPIKey(provider.apiKey),
- }
- if provider.baseURL != "" {
- clientOpts = append(clientOpts, option.WithBaseURL(provider.baseURL))
- }
-
- provider.client = openai.NewClient(clientOpts...)
- if provider.systemMessage == "" {
- return nil, errors.New("system message is required")
- }
-
- return provider, nil
+type openaiClient struct {
+ providerOptions providerClientOptions
+ options openaiOptions
+ client openai.Client
}
-func WithOpenAISystemMessage(message string) OpenAIOption {
- return func(p *openaiProvider) {
- p.systemMessage = message
- }
-}
+type OpenAIClient ProviderClient
-func WithOpenAIMaxTokens(maxTokens int64) OpenAIOption {
- return func(p *openaiProvider) {
- p.maxTokens = maxTokens
+func newOpenAIClient(opts providerClientOptions) OpenAIClient {
+ openaiOpts := openaiOptions{}
+ for _, o := range opts.openaiOptions {
+ o(&openaiOpts)
}
-}
-func WithOpenAIModel(model models.Model) OpenAIOption {
- return func(p *openaiProvider) {
- p.model = model
+ openaiClientOptions := []option.RequestOption{}
+ if opts.apiKey != "" {
+ openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
}
-}
-
-func WithOpenAIBaseURL(baseURL string) OpenAIOption {
- return func(p *openaiProvider) {
- p.baseURL = baseURL
+ if openaiOpts.baseURL != "" {
+ openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(openaiOpts.baseURL))
}
-}
-func WithOpenAIKey(apiKey string) OpenAIOption {
- return func(p *openaiProvider) {
- p.apiKey = apiKey
+ client := openai.NewClient(openaiClientOptions...)
+ return &openaiClient{
+ providerOptions: opts,
+ options: openaiOpts,
+ client: client,
}
}
-func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []openai.ChatCompletionMessageParamUnion {
- var chatMessages []openai.ChatCompletionMessageParamUnion
-
- chatMessages = append(chatMessages, openai.SystemMessage(p.systemMessage))
+func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
+ // Add system message first
+ openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage))
for _, msg := range messages {
switch msg.Role {
case message.User:
- chatMessages = append(chatMessages, openai.UserMessage(msg.Content().String()))
+ openaiMessages = append(openaiMessages, openai.UserMessage(msg.Content().String()))
case message.Assistant:
assistantMsg := openai.ChatCompletionAssistantMessageParam{
@@ -111,23 +87,23 @@ func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []o
}
}
- chatMessages = append(chatMessages, openai.ChatCompletionMessageParamUnion{
+ openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
OfAssistant: &assistantMsg,
})
case message.Tool:
for _, result := range msg.ToolResults() {
- chatMessages = append(chatMessages,
+ openaiMessages = append(openaiMessages,
openai.ToolMessage(result.Content, result.ToolCallID),
)
}
}
}
- return chatMessages
+ return
}
-func (p *openaiProvider) convertToOpenAITools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
+func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
for i, tool := range tools {
@@ -148,133 +124,238 @@ func (p *openaiProvider) convertToOpenAITools(tools []tools.BaseTool) []openai.C
return openaiTools
}
-func (p *openaiProvider) extractTokenUsage(usage openai.CompletionUsage) TokenUsage {
- cachedTokens := int64(0)
-
- cachedTokens = usage.PromptTokensDetails.CachedTokens
- inputTokens := usage.PromptTokens - cachedTokens
-
- return TokenUsage{
- InputTokens: inputTokens,
- OutputTokens: usage.CompletionTokens,
- CacheCreationTokens: 0, // OpenAI doesn't provide this directly
- CacheReadTokens: cachedTokens,
+func (o *openaiClient) finishReason(reason string) message.FinishReason {
+ switch reason {
+ case "stop":
+ return message.FinishReasonEndTurn
+ case "length":
+ return message.FinishReasonMaxTokens
+ case "tool_calls":
+ return message.FinishReasonToolUse
+ default:
+ return message.FinishReasonUnknown
}
}
-func (p *openaiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
- messages = cleanupMessages(messages)
- chatMessages := p.convertToOpenAIMessages(messages)
- openaiTools := p.convertToOpenAITools(tools)
-
- params := openai.ChatCompletionNewParams{
- Model: openai.ChatModel(p.model.APIModel),
- Messages: chatMessages,
- MaxTokens: openai.Int(p.maxTokens),
- Tools: openaiTools,
- }
-
- response, err := p.client.Chat.Completions.New(ctx, params)
- if err != nil {
- return nil, err
+func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
+ return openai.ChatCompletionNewParams{
+ Model: openai.ChatModel(o.providerOptions.model.APIModel),
+ Messages: messages,
+ MaxTokens: openai.Int(o.providerOptions.maxTokens),
+ Tools: tools,
}
+}
- content := ""
- if response.Choices[0].Message.Content != "" {
- content = response.Choices[0].Message.Content
+func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
+ params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
+ cfg := config.Get()
+ if cfg.Debug {
+ jsonData, _ := json.Marshal(params)
+ logging.Debug("Prepared messages", "messages", string(jsonData))
}
-
- var toolCalls []message.ToolCall
- if len(response.Choices[0].Message.ToolCalls) > 0 {
- toolCalls = make([]message.ToolCall, len(response.Choices[0].Message.ToolCalls))
- for i, call := range response.Choices[0].Message.ToolCalls {
- toolCalls[i] = message.ToolCall{
- ID: call.ID,
- Name: call.Function.Name,
- Input: call.Function.Arguments,
- Type: "function",
+ attempts := 0
+ for {
+ attempts++
+ openaiResponse, err := o.client.Chat.Completions.New(
+ ctx,
+ params,
+ )
+ // If there is an error we are going to see if we can retry the call
+ if err != nil {
+ retry, after, retryErr := o.shouldRetry(attempts, err)
+ if retryErr != nil {
+ return nil, retryErr
}
+ if retry {
+ logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-time.After(time.Duration(after) * time.Millisecond):
+ continue
+ }
+ }
+ return nil, retryErr
}
- }
- tokenUsage := p.extractTokenUsage(response.Usage)
+ content := ""
+ if openaiResponse.Choices[0].Message.Content != "" {
+ content = openaiResponse.Choices[0].Message.Content
+ }
- return &ProviderResponse{
- Content: content,
- ToolCalls: toolCalls,
- Usage: tokenUsage,
- }, nil
+ return &ProviderResponse{
+ Content: content,
+ ToolCalls: o.toolCalls(*openaiResponse),
+ Usage: o.usage(*openaiResponse),
+ FinishReason: o.finishReason(string(openaiResponse.Choices[0].FinishReason)),
+ }, nil
+ }
}
-func (p *openaiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
- messages = cleanupMessages(messages)
- chatMessages := p.convertToOpenAIMessages(messages)
- openaiTools := p.convertToOpenAITools(tools)
-
- params := openai.ChatCompletionNewParams{
- Model: openai.ChatModel(p.model.APIModel),
- Messages: chatMessages,
- MaxTokens: openai.Int(p.maxTokens),
- Tools: openaiTools,
- StreamOptions: openai.ChatCompletionStreamOptionsParam{
- IncludeUsage: openai.Bool(true),
- },
+func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
+ params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
+ params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
+ IncludeUsage: openai.Bool(true),
}
- stream := p.client.Chat.Completions.NewStreaming(ctx, params)
+ cfg := config.Get()
+ if cfg.Debug {
+ jsonData, _ := json.Marshal(params)
+ logging.Debug("Prepared messages", "messages", string(jsonData))
+ }
+ attempts := 0
eventChan := make(chan ProviderEvent)
- toolCalls := make([]message.ToolCall, 0)
go func() {
- defer close(eventChan)
-
- acc := openai.ChatCompletionAccumulator{}
- currentContent := ""
-
- for stream.Next() {
- chunk := stream.Current()
- acc.AddChunk(chunk)
-
- if tool, ok := acc.JustFinishedToolCall(); ok {
- toolCalls = append(toolCalls, message.ToolCall{
- ID: tool.Id,
- Name: tool.Name,
- Input: tool.Arguments,
- Type: "function",
- })
- }
+ for {
+ attempts++
+ openaiStream := o.client.Chat.Completions.NewStreaming(
+ ctx,
+ params,
+ )
+
+ acc := openai.ChatCompletionAccumulator{}
+ currentContent := ""
+ toolCalls := make([]message.ToolCall, 0)
+
+ for openaiStream.Next() {
+ chunk := openaiStream.Current()
+ acc.AddChunk(chunk)
+
+ if tool, ok := acc.JustFinishedToolCall(); ok {
+ toolCalls = append(toolCalls, message.ToolCall{
+ ID: tool.Id,
+ Name: tool.Name,
+ Input: tool.Arguments,
+ Type: "function",
+ })
+ }
- for _, choice := range chunk.Choices {
- if choice.Delta.Content != "" {
- eventChan <- ProviderEvent{
- Type: EventContentDelta,
- Content: choice.Delta.Content,
+ for _, choice := range chunk.Choices {
+ if choice.Delta.Content != "" {
+ eventChan <- ProviderEvent{
+ Type: EventContentDelta,
+ Content: choice.Delta.Content,
+ }
+ currentContent += choice.Delta.Content
}
- currentContent += choice.Delta.Content
}
}
- }
- if err := stream.Err(); err != nil {
- eventChan <- ProviderEvent{
- Type: EventError,
- Error: err,
+ err := openaiStream.Err()
+ if err == nil || errors.Is(err, io.EOF) {
+ // Stream completed successfully
+ eventChan <- ProviderEvent{
+ Type: EventComplete,
+ Response: &ProviderResponse{
+ Content: currentContent,
+ ToolCalls: toolCalls,
+ Usage: o.usage(acc.ChatCompletion),
+ FinishReason: o.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason)),
+ },
+ }
+ close(eventChan)
+ return
}
+
+ // If there is an error we are going to see if we can retry the call
+ retry, after, retryErr := o.shouldRetry(attempts, err)
+ if retryErr != nil {
+ eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
+ close(eventChan)
+ return
+ }
+ if retry {
+ logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
+ select {
+ case <-ctx.Done():
+ // context cancelled
+ if ctx.Err() == nil {
+ eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
+ }
+ close(eventChan)
+ return
+ case <-time.After(time.Duration(after) * time.Millisecond):
+ continue
+ }
+ }
+ eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
+ close(eventChan)
return
}
+ }()
- tokenUsage := p.extractTokenUsage(acc.Usage)
+ return eventChan
+}
- eventChan <- ProviderEvent{
- Type: EventComplete,
- Response: &ProviderResponse{
- Content: currentContent,
- ToolCalls: toolCalls,
- Usage: tokenUsage,
- },
+func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
+ var apierr *openai.Error
+ if !errors.As(err, &apierr) {
+ return false, 0, err
+ }
+
+ if apierr.StatusCode != 429 && apierr.StatusCode != 500 {
+ return false, 0, err
+ }
+
+ if attempts > maxRetries {
+ return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
+ }
+
+ retryMs := 0
+ retryAfterValues := apierr.Response.Header.Values("Retry-After")
+
+ backoffMs := 2000 * (1 << (attempts - 1))
+ jitterMs := int(float64(backoffMs) * 0.2)
+ retryMs = backoffMs + jitterMs
+ if len(retryAfterValues) > 0 {
+ if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
+ retryMs = retryMs * 1000
}
- }()
+ }
+ return true, int64(retryMs), nil
+}
- return eventChan, nil
+func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
+ var toolCalls []message.ToolCall
+
+ if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
+ for _, call := range completion.Choices[0].Message.ToolCalls {
+ toolCall := message.ToolCall{
+ ID: call.ID,
+ Name: call.Function.Name,
+ Input: call.Function.Arguments,
+ Type: "function",
+ }
+ toolCalls = append(toolCalls, toolCall)
+ }
+ }
+
+ return toolCalls
}
+
+func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
+ cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
+ inputTokens := completion.Usage.PromptTokens - cachedTokens
+
+ return TokenUsage{
+ InputTokens: inputTokens,
+ OutputTokens: completion.Usage.CompletionTokens,
+ CacheCreationTokens: 0, // OpenAI doesn't provide this directly
+ CacheReadTokens: cachedTokens,
+ }
+}
+
+func WithOpenAIBaseURL(baseURL string) OpenAIOption {
+ return func(options *openaiOptions) {
+ options.baseURL = baseURL
+ }
+}
+
+func WithOpenAIDisableCache() OpenAIOption {
+ return func(options *openaiOptions) {
+ options.disableCache = true
+ }
+}
+
@@ -2,14 +2,17 @@ package provider
import (
"context"
+ "fmt"
+ "github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/kujtimiihoxha/termai/internal/message"
)
-// EventType represents the type of streaming event
type EventType string
+const maxRetries = 8
+
const (
EventContentStart EventType = "content_start"
EventContentDelta EventType = "content_delta"
@@ -18,7 +21,6 @@ const (
EventComplete EventType = "complete"
EventError EventType = "error"
EventWarning EventType = "warning"
- EventInfo EventType = "info"
)
type TokenUsage struct {
@@ -32,61 +34,152 @@ type ProviderResponse struct {
Content string
ToolCalls []message.ToolCall
Usage TokenUsage
- FinishReason string
+ FinishReason message.FinishReason
}
type ProviderEvent struct {
- Type EventType
+ Type EventType
+
Content string
Thinking string
- ToolCall *message.ToolCall
- Error error
Response *ProviderResponse
- // Used for giving users info on e.x retry
- Info string
+ Error error
}
-
type Provider interface {
SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
- StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error)
+ StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
+
+ Model() models.Model
+}
+
+type providerClientOptions struct {
+ apiKey string
+ model models.Model
+ maxTokens int64
+ systemMessage string
+
+ anthropicOptions []AnthropicOption
+ openaiOptions []OpenAIOption
+ geminiOptions []GeminiOption
+ bedrockOptions []BedrockOption
+}
+
+type ProviderClientOption func(*providerClientOptions)
+
+type ProviderClient interface {
+ send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
+ stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
+}
+
+type baseProvider[C ProviderClient] struct {
+ options providerClientOptions
+ client C
+}
+
+func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption) (Provider, error) {
+ clientOptions := providerClientOptions{}
+ for _, o := range opts {
+ o(&clientOptions)
+ }
+ switch providerName {
+ case models.ProviderAnthropic:
+ return &baseProvider[AnthropicClient]{
+ options: clientOptions,
+ client: newAnthropicClient(clientOptions),
+ }, nil
+ case models.ProviderOpenAI:
+ return &baseProvider[OpenAIClient]{
+ options: clientOptions,
+ client: newOpenAIClient(clientOptions),
+ }, nil
+ case models.ProviderGemini:
+ return &baseProvider[GeminiClient]{
+ options: clientOptions,
+ client: newGeminiClient(clientOptions),
+ }, nil
+ case models.ProviderBedrock:
+ return &baseProvider[BedrockClient]{
+ options: clientOptions,
+ client: newBedrockClient(clientOptions),
+ }, nil
+ case models.ProviderMock:
+ // TODO: implement mock client for test
+ panic("not implemented")
+ }
+ return nil, fmt.Errorf("provider not supported: %s", providerName)
}
-func cleanupMessages(messages []message.Message) []message.Message {
- // First pass: filter out canceled messages
- var cleanedMessages []message.Message
+func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
for _, msg := range messages {
- if msg.FinishReason() != "canceled" || len(msg.ToolCalls()) > 0 {
- // if there are toolCalls this means we want to return it to the LLM telling it that those tools have been
- // cancelled
- cleanedMessages = append(cleanedMessages, msg)
+ // The message has no content
+ if len(msg.Parts) == 0 {
+ continue
}
+ cleaned = append(cleaned, msg)
}
+ return
+}
- // Second pass: filter out tool messages without a corresponding tool call
- var result []message.Message
- toolMessageIDs := make(map[string]bool)
+func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
+ messages = p.cleanMessages(messages)
+ return p.client.send(ctx, messages, tools)
+}
- for _, msg := range cleanedMessages {
- if msg.Role == message.Assistant {
- for _, toolCall := range msg.ToolCalls() {
- toolMessageIDs[toolCall.ID] = true // Mark as referenced
- }
- }
+func (p *baseProvider[C]) Model() models.Model {
+ return p.options.model
+}
+
+func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
+ messages = p.cleanMessages(messages)
+ return p.client.stream(ctx, messages, tools)
+}
+
+func WithAPIKey(apiKey string) ProviderClientOption {
+ return func(options *providerClientOptions) {
+ options.apiKey = apiKey
}
+}
- // Keep only messages that aren't unreferenced tool messages
- for _, msg := range cleanedMessages {
- if msg.Role == message.Tool {
- for _, toolCall := range msg.ToolResults() {
- if referenced, exists := toolMessageIDs[toolCall.ToolCallID]; exists && referenced {
- result = append(result, msg)
- }
- }
- } else {
- result = append(result, msg)
- }
+func WithModel(model models.Model) ProviderClientOption {
+ return func(options *providerClientOptions) {
+ options.model = model
+ }
+}
+
+func WithMaxTokens(maxTokens int64) ProviderClientOption {
+ return func(options *providerClientOptions) {
+ options.maxTokens = maxTokens
+ }
+}
+
+func WithSystemMessage(systemMessage string) ProviderClientOption {
+ return func(options *providerClientOptions) {
+ options.systemMessage = systemMessage
+ }
+}
+
+func WithAnthropicOptions(anthropicOptions ...AnthropicOption) ProviderClientOption {
+ return func(options *providerClientOptions) {
+ options.anthropicOptions = anthropicOptions
+ }
+}
+
+func WithOpenAIOptions(openaiOptions ...OpenAIOption) ProviderClientOption {
+ return func(options *providerClientOptions) {
+ options.openaiOptions = openaiOptions
+ }
+}
+
+func WithGeminiOptions(geminiOptions ...GeminiOption) ProviderClientOption {
+ return func(options *providerClientOptions) {
+ options.geminiOptions = geminiOptions
+ }
+}
+
+func WithBedrockOptions(bedrockOptions ...BedrockOption) ProviderClientOption {
+ return func(options *providerClientOptions) {
+ options.bedrockOptions = bedrockOptions
}
- return result
}
@@ -23,7 +23,8 @@ type BashPermissionsParams struct {
}
type BashResponseMetadata struct {
- Took int64 `json:"took"`
+ StartTime int64 `json:"start_time"`
+ EndTime int64 `json:"end_time"`
}
type bashTool struct {
permissions permission.Service
@@ -282,7 +283,6 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
if err != nil {
return ToolResponse{}, fmt.Errorf("error executing command: %w", err)
}
- took := time.Since(startTime).Milliseconds()
stdout = truncateOutput(stdout)
stderr = truncateOutput(stderr)
@@ -311,7 +311,8 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
}
metadata := BashResponseMetadata{
- Took: took,
+ StartTime: startTime.UnixMilli(),
+ EndTime: time.Now().UnixMilli(),
}
if stdout == "" {
return WithResponseMetadata(NewTextResponse("no output"), metadata), nil
@@ -8,8 +8,6 @@ import (
"testing"
"time"
- "github.com/kujtimiihoxha/termai/internal/permission"
- "github.com/kujtimiihoxha/termai/internal/pubsub"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -340,32 +338,3 @@ func TestCountLines(t *testing.T) {
})
}
}
-
-// Mock permission service for testing
-type mockPermissionService struct {
- *pubsub.Broker[permission.PermissionRequest]
- allow bool
-}
-
-func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) {
- // Not needed for tests
-}
-
-func (m *mockPermissionService) Grant(permission permission.PermissionRequest) {
- // Not needed for tests
-}
-
-func (m *mockPermissionService) Deny(permission permission.PermissionRequest) {
- // Not needed for tests
-}
-
-func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool {
- return m.allow
-}
-
-func newMockPermissionService(allow bool) permission.Service {
- return &mockPermissionService{
- Broker: pubsub.NewBroker[permission.PermissionRequest](),
- allow: allow,
- }
-}
@@ -11,6 +11,7 @@ import (
"github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/diff"
+ "github.com/kujtimiihoxha/termai/internal/history"
"github.com/kujtimiihoxha/termai/internal/lsp"
"github.com/kujtimiihoxha/termai/internal/permission"
)
@@ -35,6 +36,7 @@ type EditResponseMetadata struct {
type editTool struct {
lspClients map[string]*lsp.Client
permissions permission.Service
+ files history.Service
}
const (
@@ -88,10 +90,11 @@ When making edits:
Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.`
)
-func NewEditTool(lspClients map[string]*lsp.Client, permissions permission.Service) BaseTool {
+func NewEditTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service) BaseTool {
return &editTool{
lspClients: lspClients,
permissions: permissions,
+ files: files,
}
}
@@ -153,6 +156,11 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
if err != nil {
return response, nil
}
+ if response.IsError {
+ // Return early if there was an error during content replacement
+ // This prevents unnecessary LSP diagnostics processing
+ return response, nil
+ }
waitForLspDiagnostics(ctx, params.FilePath, e.lspClients)
text := fmt.Sprintf("<result>\n%s\n</result>\n", response.Content)
@@ -208,6 +216,20 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string)
return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
}
+ // File can't be in the history so we create a new file history
+ _, err = e.files.Create(ctx, sessionID, filePath, "")
+ if err != nil {
+ // Log error but don't fail the operation
+ return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
+ }
+
+ // Add the new content to the file history
+ _, err = e.files.CreateVersion(ctx, sessionID, filePath, content)
+ if err != nil {
+ // Log error but don't fail the operation
+ fmt.Printf("Error creating file history version: %v\n", err)
+ }
+
recordFileWrite(filePath)
recordFileRead(filePath)
@@ -298,6 +320,29 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string
if err != nil {
return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
}
+
+ // Check if file exists in history
+ file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID)
+ if err != nil {
+ _, err = e.files.Create(ctx, sessionID, filePath, oldContent)
+ if err != nil {
+ // Log error but don't fail the operation
+ return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
+ }
+ }
+ if file.Content != oldContent {
+ // User Manually changed the content store an intermediate version
+ _, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent)
+ if err != nil {
+ fmt.Printf("Error creating file history version: %v\n", err)
+ }
+ }
+ // Store the new version
+ _, err = e.files.CreateVersion(ctx, sessionID, filePath, "")
+ if err != nil {
+ fmt.Printf("Error creating file history version: %v\n", err)
+ }
+
recordFileWrite(filePath)
recordFileRead(filePath)
@@ -356,6 +401,9 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
newContent := oldContent[:index] + newString + oldContent[index+len(oldString):]
+ if oldContent == newContent {
+ return NewTextErrorResponse("new content is the same as old content. No changes made."), nil
+ }
sessionID, messageID := GetContextValues(ctx)
if sessionID == "" || messageID == "" {
@@ -374,8 +422,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
Description: fmt.Sprintf("Replace content in file %s", filePath),
Params: EditPermissionsParams{
FilePath: filePath,
-
- Diff: diff,
+ Diff: diff,
},
},
)
@@ -388,6 +435,28 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
}
+ // Check if file exists in history
+ file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID)
+ if err != nil {
+ _, err = e.files.Create(ctx, sessionID, filePath, oldContent)
+ if err != nil {
+ // Log error but don't fail the operation
+ return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
+ }
+ }
+ if file.Content != oldContent {
+ // User Manually changed the content store an intermediate version
+ _, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent)
+ if err != nil {
+ fmt.Printf("Error creating file history version: %v\n", err)
+ }
+ }
+ // Store the new version
+ _, err = e.files.CreateVersion(ctx, sessionID, filePath, newContent)
+ if err != nil {
+ fmt.Printf("Error creating file history version: %v\n", err)
+ }
+
recordFileWrite(filePath)
recordFileRead(filePath)
@@ -14,7 +14,7 @@ import (
)
func TestEditTool_Info(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
info := tool.Info()
assert.Equal(t, EditToolName, info.Name)
@@ -34,7 +34,7 @@ func TestEditTool_Run(t *testing.T) {
defer os.RemoveAll(tempDir)
t.Run("creates a new file successfully", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
filePath := filepath.Join(tempDir, "new_file.txt")
content := "This is a test content"
@@ -64,7 +64,7 @@ func TestEditTool_Run(t *testing.T) {
})
t.Run("creates file with nested directories", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt")
content := "Content in nested directory"
@@ -94,7 +94,7 @@ func TestEditTool_Run(t *testing.T) {
})
t.Run("fails to create file that already exists", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
// Create a file first
filePath := filepath.Join(tempDir, "existing_file.txt")
@@ -123,7 +123,7 @@ func TestEditTool_Run(t *testing.T) {
})
t.Run("fails to create file when path is a directory", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
// Create a directory
dirPath := filepath.Join(tempDir, "test_dir")
@@ -151,7 +151,7 @@ func TestEditTool_Run(t *testing.T) {
})
t.Run("replaces content successfully", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
// Create a file first
filePath := filepath.Join(tempDir, "replace_content.txt")
@@ -191,7 +191,7 @@ func TestEditTool_Run(t *testing.T) {
})
t.Run("deletes content successfully", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
// Create a file first
filePath := filepath.Join(tempDir, "delete_content.txt")
@@ -230,7 +230,7 @@ func TestEditTool_Run(t *testing.T) {
})
t.Run("handles invalid parameters", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
call := ToolCall{
Name: EditToolName,
@@ -243,7 +243,7 @@ func TestEditTool_Run(t *testing.T) {
})
t.Run("handles missing file_path", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
params := EditParams{
FilePath: "",
@@ -265,7 +265,7 @@ func TestEditTool_Run(t *testing.T) {
})
t.Run("handles file not found", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
filePath := filepath.Join(tempDir, "non_existent_file.txt")
params := EditParams{
@@ -288,7 +288,7 @@ func TestEditTool_Run(t *testing.T) {
})
t.Run("handles old_string not found in file", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
// Create a file first
filePath := filepath.Join(tempDir, "content_not_found.txt")
@@ -320,7 +320,7 @@ func TestEditTool_Run(t *testing.T) {
})
t.Run("handles multiple occurrences of old_string", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
// Create a file with duplicate content
filePath := filepath.Join(tempDir, "duplicate_content.txt")
@@ -352,7 +352,7 @@ func TestEditTool_Run(t *testing.T) {
})
t.Run("handles file modified since last read", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
// Create a file
filePath := filepath.Join(tempDir, "modified_file.txt")
@@ -394,7 +394,7 @@ func TestEditTool_Run(t *testing.T) {
})
t.Run("handles file not read before editing", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
// Create a file
filePath := filepath.Join(tempDir, "not_read_file.txt")
@@ -423,7 +423,7 @@ func TestEditTool_Run(t *testing.T) {
})
t.Run("handles permission denied", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(false))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(false), newMockFileHistoryService())
// Create a file
filePath := filepath.Join(tempDir, "permission_denied.txt")
@@ -3,8 +3,6 @@ package tools
import (
"sync"
"time"
-
- "github.com/kujtimiihoxha/termai/internal/config"
)
// File record to track when files were read/written
@@ -19,14 +17,6 @@ var (
fileRecordMutex sync.RWMutex
)
-func removeWorkingDirectoryPrefix(path string) string {
- wd := config.WorkingDirectory()
- if len(path) > len(wd) && path[:len(wd)] == wd {
- return path[len(wd)+1:]
- }
- return path
-}
-
func recordFileRead(path string) {
fileRecordMutex.Lock()
defer fileRecordMutex.Unlock()
@@ -63,7 +63,7 @@ type GlobParams struct {
Path string `json:"path"`
}
-type GlobMetadata struct {
+type GlobResponseMetadata struct {
NumberOfFiles int `json:"number_of_files"`
Truncated bool `json:"truncated"`
}
@@ -124,7 +124,7 @@ func (g *globTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
return WithResponseMetadata(
NewTextResponse(output),
- GlobMetadata{
+ GlobResponseMetadata{
NumberOfFiles: len(files),
Truncated: truncated,
},
@@ -27,7 +27,7 @@ type grepMatch struct {
modTime time.Time
}
-type GrepMetadata struct {
+type GrepResponseMetadata struct {
NumberOfMatches int `json:"number_of_matches"`
Truncated bool `json:"truncated"`
}
@@ -134,7 +134,7 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
return WithResponseMetadata(
NewTextResponse(output),
- GrepMetadata{
+ GrepResponseMetadata{
NumberOfMatches: len(matches),
Truncated: truncated,
},
@@ -23,7 +23,7 @@ type TreeNode struct {
Children []*TreeNode `json:"children,omitempty"`
}
-type LSMetadata struct {
+type LSResponseMetadata struct {
NumberOfFiles int `json:"number_of_files"`
Truncated bool `json:"truncated"`
}
@@ -121,7 +121,7 @@ func (l *lsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
return WithResponseMetadata(
NewTextResponse(output),
- LSMetadata{
+ LSResponseMetadata{
NumberOfFiles: len(files),
Truncated: truncated,
},
@@ -0,0 +1,246 @@
+package tools
+
+import (
+ "context"
+ "fmt"
+ "sort"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/kujtimiihoxha/termai/internal/history"
+ "github.com/kujtimiihoxha/termai/internal/permission"
+ "github.com/kujtimiihoxha/termai/internal/pubsub"
+)
+
+// Mock permission service for testing
+type mockPermissionService struct {
+ *pubsub.Broker[permission.PermissionRequest]
+ allow bool
+}
+
+func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) {
+ // Not needed for tests
+}
+
+func (m *mockPermissionService) Grant(permission permission.PermissionRequest) {
+ // Not needed for tests
+}
+
+func (m *mockPermissionService) Deny(permission permission.PermissionRequest) {
+ // Not needed for tests
+}
+
+func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool {
+ return m.allow
+}
+
+func newMockPermissionService(allow bool) permission.Service {
+ return &mockPermissionService{
+ Broker: pubsub.NewBroker[permission.PermissionRequest](),
+ allow: allow,
+ }
+}
+
+type mockFileHistoryService struct {
+ *pubsub.Broker[history.File]
+ files map[string]history.File // ID -> File
+ timeNow func() int64
+}
+
+// Create implements history.Service.
+func (m *mockFileHistoryService) Create(ctx context.Context, sessionID string, path string, content string) (history.File, error) {
+ return m.createWithVersion(ctx, sessionID, path, content, history.InitialVersion)
+}
+
+// CreateVersion implements history.Service.
+func (m *mockFileHistoryService) CreateVersion(ctx context.Context, sessionID string, path string, content string) (history.File, error) {
+ var files []history.File
+ for _, file := range m.files {
+ if file.Path == path {
+ files = append(files, file)
+ }
+ }
+
+ if len(files) == 0 {
+ // No previous versions, create initial
+ return m.Create(ctx, sessionID, path, content)
+ }
+
+ // Sort files by CreatedAt in descending order
+ sort.Slice(files, func(i, j int) bool {
+ return files[i].CreatedAt > files[j].CreatedAt
+ })
+
+ // Get the latest version
+ latestFile := files[0]
+ latestVersion := latestFile.Version
+
+ // Generate the next version
+ var nextVersion string
+ if latestVersion == history.InitialVersion {
+ nextVersion = "v1"
+ } else if strings.HasPrefix(latestVersion, "v") {
+ versionNum, err := strconv.Atoi(latestVersion[1:])
+ if err != nil {
+ // If we can't parse the version, just use a timestamp-based version
+ nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt)
+ } else {
+ nextVersion = fmt.Sprintf("v%d", versionNum+1)
+ }
+ } else {
+ // If the version format is unexpected, use a timestamp-based version
+ nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt)
+ }
+
+ return m.createWithVersion(ctx, sessionID, path, content, nextVersion)
+}
+
+func (m *mockFileHistoryService) createWithVersion(_ context.Context, sessionID, path, content, version string) (history.File, error) {
+ now := m.timeNow()
+ file := history.File{
+ ID: uuid.New().String(),
+ SessionID: sessionID,
+ Path: path,
+ Content: content,
+ Version: version,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+
+ m.files[file.ID] = file
+ m.Publish(pubsub.CreatedEvent, file)
+ return file, nil
+}
+
+// Delete implements history.Service.
+func (m *mockFileHistoryService) Delete(ctx context.Context, id string) error {
+ file, ok := m.files[id]
+ if !ok {
+ return fmt.Errorf("file not found: %s", id)
+ }
+
+ delete(m.files, id)
+ m.Publish(pubsub.DeletedEvent, file)
+ return nil
+}
+
+// DeleteSessionFiles implements history.Service.
+func (m *mockFileHistoryService) DeleteSessionFiles(ctx context.Context, sessionID string) error {
+ files, err := m.ListBySession(ctx, sessionID)
+ if err != nil {
+ return err
+ }
+
+ for _, file := range files {
+ err = m.Delete(ctx, file.ID)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// Get implements history.Service.
+func (m *mockFileHistoryService) Get(ctx context.Context, id string) (history.File, error) {
+ file, ok := m.files[id]
+ if !ok {
+ return history.File{}, fmt.Errorf("file not found: %s", id)
+ }
+ return file, nil
+}
+
+// GetByPathAndSession implements history.Service.
+func (m *mockFileHistoryService) GetByPathAndSession(ctx context.Context, path string, sessionID string) (history.File, error) {
+ var latestFile history.File
+ var found bool
+ var latestTime int64
+
+ for _, file := range m.files {
+ if file.Path == path && file.SessionID == sessionID {
+ if !found || file.CreatedAt > latestTime {
+ latestFile = file
+ latestTime = file.CreatedAt
+ found = true
+ }
+ }
+ }
+
+ if !found {
+ return history.File{}, fmt.Errorf("file not found: %s for session %s", path, sessionID)
+ }
+ return latestFile, nil
+}
+
+// ListBySession implements history.Service.
+func (m *mockFileHistoryService) ListBySession(ctx context.Context, sessionID string) ([]history.File, error) {
+ var files []history.File
+ for _, file := range m.files {
+ if file.SessionID == sessionID {
+ files = append(files, file)
+ }
+ }
+
+ // Sort by CreatedAt in descending order
+ sort.Slice(files, func(i, j int) bool {
+ return files[i].CreatedAt > files[j].CreatedAt
+ })
+
+ return files, nil
+}
+
+// ListLatestSessionFiles implements history.Service.
+func (m *mockFileHistoryService) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]history.File, error) {
+ // Map to track the latest file for each path
+ latestFiles := make(map[string]history.File)
+
+ for _, file := range m.files {
+ if file.SessionID == sessionID {
+ existing, ok := latestFiles[file.Path]
+ if !ok || file.CreatedAt > existing.CreatedAt {
+ latestFiles[file.Path] = file
+ }
+ }
+ }
+
+ // Convert map to slice
+ var result []history.File
+ for _, file := range latestFiles {
+ result = append(result, file)
+ }
+
+ // Sort by CreatedAt in descending order
+ sort.Slice(result, func(i, j int) bool {
+ return result[i].CreatedAt > result[j].CreatedAt
+ })
+
+ return result, nil
+}
+
+// Subscribe implements history.Service.
+func (m *mockFileHistoryService) Subscribe(ctx context.Context) <-chan pubsub.Event[history.File] {
+ return m.Broker.Subscribe(ctx)
+}
+
+// Update implements history.Service.
+func (m *mockFileHistoryService) Update(ctx context.Context, file history.File) (history.File, error) {
+ _, ok := m.files[file.ID]
+ if !ok {
+ return history.File{}, fmt.Errorf("file not found: %s", file.ID)
+ }
+
+ file.UpdatedAt = m.timeNow()
+ m.files[file.ID] = file
+ m.Publish(pubsub.UpdatedEvent, file)
+ return file, nil
+}
+
+func newMockFileHistoryService() history.Service {
+ return &mockFileHistoryService{
+ Broker: pubsub.NewBroker[history.File](),
+ files: make(map[string]history.File),
+ timeNow: func() int64 { return time.Now().Unix() },
+ }
+}
@@ -83,11 +83,21 @@ func newPersistentShell(cwd string) *PersistentShell {
commandQueue: make(chan *commandExecution, 10),
}
- go shell.processCommands()
+ go func() {
+ defer func() {
+ if r := recover(); r != nil {
+ fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r)
+ shell.isAlive = false
+ close(shell.commandQueue)
+ }
+ }()
+ shell.processCommands()
+ }()
go func() {
err := cmd.Wait()
if err != nil {
+ // Log the error if needed
}
shell.isAlive = false
close(shell.commandQueue)
@@ -18,7 +18,7 @@ type SourcegraphParams struct {
Timeout int `json:"timeout,omitempty"`
}
-type SourcegraphMetadata struct {
+type SourcegraphResponseMetadata struct {
NumberOfMatches int `json:"number_of_matches"`
Truncated bool `json:"truncated"`
}
@@ -14,12 +14,17 @@ type ToolInfo struct {
type toolResponseType string
+type (
+ sessionIDContextKey string
+ messageIDContextKey string
+)
+
const (
ToolResponseTypeText toolResponseType = "text"
ToolResponseTypeImage toolResponseType = "image"
- SessionIDContextKey = "session_id"
- MessageIDContextKey = "message_id"
+ SessionIDContextKey sessionIDContextKey = "session_id"
+ MessageIDContextKey messageIDContextKey = "message_id"
)
type ToolResponse struct {
@@ -10,6 +10,7 @@ import (
"github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/diff"
+ "github.com/kujtimiihoxha/termai/internal/history"
"github.com/kujtimiihoxha/termai/internal/lsp"
"github.com/kujtimiihoxha/termai/internal/permission"
)
@@ -27,6 +28,7 @@ type WritePermissionsParams struct {
type writeTool struct {
lspClients map[string]*lsp.Client
permissions permission.Service
+ files history.Service
}
type WriteResponseMetadata struct {
@@ -67,10 +69,11 @@ TIPS:
- Always include descriptive comments when making changes to existing code`
)
-func NewWriteTool(lspClients map[string]*lsp.Client, permissions permission.Service) BaseTool {
+func NewWriteTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service) BaseTool {
return &writeTool{
lspClients: lspClients,
permissions: permissions,
+ files: files,
}
}
@@ -176,6 +179,28 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
return ToolResponse{}, fmt.Errorf("error writing file: %w", err)
}
+ // Check if file exists in history
+ file, err := w.files.GetByPathAndSession(ctx, filePath, sessionID)
+ if err != nil {
+ _, err = w.files.Create(ctx, sessionID, filePath, oldContent)
+ if err != nil {
+ // Log error but don't fail the operation
+ return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
+ }
+ }
+ if file.Content != oldContent {
+ // User Manually changed the content store an intermediate version
+ _, err = w.files.CreateVersion(ctx, sessionID, filePath, oldContent)
+ if err != nil {
+ fmt.Printf("Error creating file history version: %v\n", err)
+ }
+ }
+ // Store the new version
+ _, err = w.files.CreateVersion(ctx, sessionID, filePath, params.Content)
+ if err != nil {
+ fmt.Printf("Error creating file history version: %v\n", err)
+ }
+
recordFileWrite(filePath)
recordFileRead(filePath)
waitForLspDiagnostics(ctx, filePath, w.lspClients)
@@ -14,7 +14,7 @@ import (
)
func TestWriteTool_Info(t *testing.T) {
- tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
info := tool.Info()
assert.Equal(t, WriteToolName, info.Name)
@@ -32,7 +32,7 @@ func TestWriteTool_Run(t *testing.T) {
defer os.RemoveAll(tempDir)
t.Run("creates a new file successfully", func(t *testing.T) {
- tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
filePath := filepath.Join(tempDir, "new_file.txt")
content := "This is a test content"
@@ -61,7 +61,7 @@ func TestWriteTool_Run(t *testing.T) {
})
t.Run("creates file with nested directories", func(t *testing.T) {
- tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt")
content := "Content in nested directory"
@@ -90,7 +90,7 @@ func TestWriteTool_Run(t *testing.T) {
})
t.Run("updates existing file", func(t *testing.T) {
- tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
// Create a file first
filePath := filepath.Join(tempDir, "existing_file.txt")
@@ -127,7 +127,7 @@ func TestWriteTool_Run(t *testing.T) {
})
t.Run("handles invalid parameters", func(t *testing.T) {
- tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
call := ToolCall{
Name: WriteToolName,
@@ -140,7 +140,7 @@ func TestWriteTool_Run(t *testing.T) {
})
t.Run("handles missing file_path", func(t *testing.T) {
- tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
params := WriteParams{
FilePath: "",
@@ -161,7 +161,7 @@ func TestWriteTool_Run(t *testing.T) {
})
t.Run("handles missing content", func(t *testing.T) {
- tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
params := WriteParams{
FilePath: filepath.Join(tempDir, "file.txt"),
@@ -182,7 +182,7 @@ func TestWriteTool_Run(t *testing.T) {
})
t.Run("handles writing to a directory path", func(t *testing.T) {
- tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
// Create a directory
dirPath := filepath.Join(tempDir, "test_dir")
@@ -208,7 +208,7 @@ func TestWriteTool_Run(t *testing.T) {
})
t.Run("handles permission denied", func(t *testing.T) {
- tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(false))
+ tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(false), newMockFileHistoryService())
filePath := filepath.Join(tempDir, "permission_denied.txt")
params := WriteParams{
@@ -234,7 +234,7 @@ func TestWriteTool_Run(t *testing.T) {
})
t.Run("detects file modified since last read", func(t *testing.T) {
- tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
// Create a file
filePath := filepath.Join(tempDir, "modified_file.txt")
@@ -275,7 +275,7 @@ func TestWriteTool_Run(t *testing.T) {
})
t.Run("skips writing when content is identical", func(t *testing.T) {
- tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
// Create a file
filePath := filepath.Join(tempDir, "identical_content.txt")
@@ -1,6 +1,12 @@
package logging
-import "log/slog"
+import (
+ "fmt"
+ "log/slog"
+ "os"
+ "runtime/debug"
+ "time"
+)
func Info(msg string, args ...any) {
slog.Info(msg, args...)
@@ -37,3 +43,36 @@ func ErrorPersist(msg string, args ...any) {
args = append(args, persistKeyArg, true)
slog.Error(msg, args...)
}
+
+// RecoverPanic is a common function to handle panics gracefully.
+// It logs the error, creates a panic log file with stack trace,
+// and executes an optional cleanup function before returning.
+func RecoverPanic(name string, cleanup func()) {
+ if r := recover(); r != nil {
+ // Log the panic
+ ErrorPersist(fmt.Sprintf("Panic in %s: %v", name, r))
+
+ // Create a timestamped panic log file
+ timestamp := time.Now().Format("20060102-150405")
+ filename := fmt.Sprintf("opencode-panic-%s-%s.log", name, timestamp)
+
+ file, err := os.Create(filename)
+ if err != nil {
+ ErrorPersist(fmt.Sprintf("Failed to create panic log: %v", err))
+ } else {
+ defer file.Close()
+
+ // Write panic information and stack trace
+ fmt.Fprintf(file, "Panic in %s: %v\n\n", name, r)
+ fmt.Fprintf(file, "Time: %s\n\n", time.Now().Format(time.RFC3339))
+ fmt.Fprintf(file, "Stack Trace:\n%s\n", debug.Stack())
+
+ InfoPersist(fmt.Sprintf("Panic details written to %s", filename))
+ }
+
+ // Execute cleanup function if provided
+ if cleanup != nil {
+ cleanup()
+ }
+ }
+}
@@ -97,7 +97,12 @@ func NewClient(ctx context.Context, command string, args ...string) (*Client, er
}()
// Start message handling loop
- go client.handleMessages()
+ go func() {
+ defer logging.RecoverPanic("LSP-message-handler", func() {
+ logging.ErrorPersist("LSP message handler crashed, LSP functionality may be impaired")
+ })
+ client.handleMessages()
+ }()
return client, nil
}
@@ -374,7 +379,7 @@ func (c *Client) CloseFile(ctx context.Context, filepath string) error {
},
}
- if cnf.Debug {
+ if cnf.DebugLSP {
logging.Debug("Closing file", "file", filepath)
}
if err := c.Notify(ctx, "textDocument/didClose", params); err != nil {
@@ -413,12 +418,12 @@ func (c *Client) CloseAllFiles(ctx context.Context) {
// Then close them all
for _, filePath := range filesToClose {
err := c.CloseFile(ctx, filePath)
- if err != nil && cnf.Debug {
+ if err != nil && cnf.DebugLSP {
logging.Warn("Error closing file", "file", filePath, "error", err)
}
}
- if cnf.Debug {
+ if cnf.DebugLSP {
logging.Debug("Closed all files", "files", filesToClose)
}
}
@@ -88,7 +88,7 @@ func HandleServerMessage(params json.RawMessage) {
Message string `json:"message"`
}
if err := json.Unmarshal(params, &msg); err == nil {
- if cnf.Debug {
+ if cnf.DebugLSP {
logging.Debug("Server message", "type", msg.Type, "message", msg.Message)
}
}
@@ -20,7 +20,7 @@ func WriteMessage(w io.Writer, msg *Message) error {
}
cnf := config.Get()
- if cnf.Debug {
+ if cnf.DebugLSP {
logging.Debug("Sending message to server", "method", msg.Method, "id", msg.ID)
}
@@ -49,7 +49,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
}
line = strings.TrimSpace(line)
- if cnf.Debug {
+ if cnf.DebugLSP {
logging.Debug("Received header", "line", line)
}
@@ -65,7 +65,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
}
}
- if cnf.Debug {
+ if cnf.DebugLSP {
logging.Debug("Content-Length", "length", contentLength)
}
@@ -76,7 +76,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
return nil, fmt.Errorf("failed to read content: %w", err)
}
- if cnf.Debug {
+ if cnf.DebugLSP {
logging.Debug("Received content", "content", string(content))
}
@@ -95,7 +95,7 @@ func (c *Client) handleMessages() {
for {
msg, err := ReadMessage(c.stdout)
if err != nil {
- if cnf.Debug {
+ if cnf.DebugLSP {
logging.Error("Error reading message", "error", err)
}
return
@@ -103,7 +103,7 @@ func (c *Client) handleMessages() {
// Handle server->client request (has both Method and ID)
if msg.Method != "" && msg.ID != 0 {
- if cnf.Debug {
+ if cnf.DebugLSP {
logging.Debug("Received request from server", "method", msg.Method, "id", msg.ID)
}
@@ -157,11 +157,11 @@ func (c *Client) handleMessages() {
c.notificationMu.RUnlock()
if ok {
- if cnf.Debug {
+ if cnf.DebugLSP {
logging.Debug("Handling notification", "method", msg.Method)
}
go handler(msg.Params)
- } else if cnf.Debug {
+ } else if cnf.DebugLSP {
logging.Debug("No handler for notification", "method", msg.Method)
}
continue
@@ -174,12 +174,12 @@ func (c *Client) handleMessages() {
c.handlersMu.RUnlock()
if ok {
- if cnf.Debug {
+ if cnf.DebugLSP {
logging.Debug("Received response for request", "id", msg.ID)
}
ch <- msg
close(ch)
- } else if cnf.Debug {
+ } else if cnf.DebugLSP {
logging.Debug("No handler for response", "id", msg.ID)
}
}
@@ -191,7 +191,7 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
cnf := config.Get()
id := c.nextID.Add(1)
- if cnf.Debug {
+ if cnf.DebugLSP {
logging.Debug("Making call", "method", method, "id", id)
}
@@ -217,14 +217,14 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
return fmt.Errorf("failed to send request: %w", err)
}
- if cnf.Debug {
+ if cnf.DebugLSP {
logging.Debug("Request sent", "method", method, "id", id)
}
// Wait for response
resp := <-ch
- if cnf.Debug {
+ if cnf.DebugLSP {
logging.Debug("Received response", "id", id)
}
@@ -250,7 +250,7 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
// Notify sends a notification (a request without an ID that doesn't expect a response)
func (c *Client) Notify(ctx context.Context, method string, params any) error {
cnf := config.Get()
- if cnf.Debug {
+ if cnf.DebugLSP {
logging.Debug("Sending notification", "method", method)
}
@@ -50,7 +50,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
w.registrations = append(w.registrations, watchers...)
// Print detailed registration information for debugging
- if cnf.Debug {
+ if cnf.DebugLSP {
logging.Debug("Adding file watcher registrations",
"id", id,
"watchers", len(watchers),
@@ -116,7 +116,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
// Skip directories that should be excluded
if d.IsDir() {
if path != w.workspacePath && shouldExcludeDir(path) {
- if cnf.Debug {
+ if cnf.DebugLSP {
logging.Debug("Skipping excluded directory", "path", path)
}
return filepath.SkipDir
@@ -136,7 +136,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
})
elapsedTime := time.Since(startTime)
- if cnf.Debug {
+ if cnf.DebugLSP {
logging.Debug("Workspace scan complete",
"filesOpened", filesOpened,
"elapsedTime", elapsedTime.Seconds(),
@@ -144,7 +144,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
)
}
- if err != nil && cnf.Debug {
+ if err != nil && cnf.DebugLSP {
logging.Debug("Error scanning workspace for files to open", "error", err)
}
}()
@@ -175,7 +175,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
// Skip excluded directories (except workspace root)
if d.IsDir() && path != workspacePath {
if shouldExcludeDir(path) {
- if cnf.Debug {
+ if cnf.DebugLSP {
logging.Debug("Skipping excluded directory", "path", path)
}
return filepath.SkipDir
@@ -228,7 +228,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
}
// Debug logging
- if cnf.Debug {
+ if cnf.DebugLSP {
matched, kind := w.isPathWatched(event.Name)
logging.Debug("File event",
"path", event.Name,
@@ -491,7 +491,7 @@ func (w *WorkspaceWatcher) handleFileEvent(ctx context.Context, uri string, chan
// notifyFileEvent sends a didChangeWatchedFiles notification for a file event
func (w *WorkspaceWatcher) notifyFileEvent(ctx context.Context, uri string, changeType protocol.FileChangeType) error {
cnf := config.Get()
- if cnf.Debug {
+ if cnf.DebugLSP {
logging.Debug("Notifying file event",
"uri", uri,
"changeType", changeType,
@@ -615,7 +615,7 @@ func shouldExcludeFile(filePath string) bool {
// Skip large files
if info.Size() > maxFileSize {
- if cnf.Debug {
+ if cnf.DebugLSP {
logging.Debug("Skipping large file",
"path", filePath,
"size", info.Size(),
@@ -648,7 +648,7 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
// Check if this path should be watched according to server registrations
if watched, _ := w.isPathWatched(path); watched {
// Don't need to check if it's already open - the client.OpenFile handles that
- if err := w.client.OpenFile(ctx, path); err != nil && cnf.Debug {
+ if err := w.client.OpenFile(ctx, path); err != nil && cnf.DebugLSP {
logging.Error("Error opening file", "path", path, "error", err)
}
}
@@ -2,6 +2,7 @@ package message
import (
"encoding/base64"
+ "slices"
"time"
"github.com/kujtimiihoxha/termai/internal/llm/models"
@@ -16,6 +17,20 @@ const (
Tool MessageRole = "tool"
)
+type FinishReason string
+
+const (
+ FinishReasonEndTurn FinishReason = "end_turn"
+ FinishReasonMaxTokens FinishReason = "max_tokens"
+ FinishReasonToolUse FinishReason = "tool_use"
+ FinishReasonCanceled FinishReason = "canceled"
+ FinishReasonError FinishReason = "error"
+ FinishReasonPermissionDenied FinishReason = "permission_denied"
+
+ // Should never happen
+ FinishReasonUnknown FinishReason = "unknown"
+)
+
type ContentPart interface {
isPart()
}
@@ -83,8 +98,8 @@ type ToolResult struct {
func (ToolResult) isPart() {}
type Finish struct {
- Reason string `json:"reason"`
- Time int64 `json:"time"`
+ Reason FinishReason `json:"reason"`
+ Time int64 `json:"time"`
}
func (Finish) isPart() {}
@@ -176,7 +191,7 @@ func (m *Message) FinishPart() *Finish {
return nil
}
-func (m *Message) FinishReason() string {
+func (m *Message) FinishReason() FinishReason {
for _, part := range m.Parts {
if c, ok := part.(Finish); ok {
return c.Reason
@@ -246,7 +261,14 @@ func (m *Message) SetToolResults(tr []ToolResult) {
}
}
-func (m *Message) AddFinish(reason string) {
+func (m *Message) AddFinish(reason FinishReason) {
+ // remove any existing finish part
+ for i, part := range m.Parts {
+ if _, ok := part.(Finish); ok {
+ m.Parts = slices.Delete(m.Parts, i, i+1)
+ break
+ }
+ }
m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix()})
}
@@ -5,7 +5,7 @@ import (
"sync"
)
-const bufferSize = 1024 * 1024
+const bufferSize = 1024
type Logger interface {
Debug(msg string, args ...any)
@@ -24,6 +24,7 @@ type Session struct {
type Service interface {
pubsub.Suscriber[Session]
Create(ctx context.Context, title string) (Session, error)
+ CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
Get(ctx context.Context, id string) (Session, error)
List(ctx context.Context) ([]Session, error)
@@ -63,6 +64,20 @@ func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessi
return session, nil
}
+func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
+ dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
+ ID: "title-" + parentSessionID,
+ ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
+ Title: "Generate a title",
+ })
+ if err != nil {
+ return Session{}, err
+ }
+ session := s.fromDBItem(dbSession)
+ s.Publish(pubsub.CreatedEvent, session)
+ return session, nil
+}
+
func (s *service) Delete(ctx context.Context, id string) error {
session, err := s.Get(ctx, id)
if err != nil {
@@ -19,8 +19,6 @@ type SessionSelectedMsg = session.Session
type SessionClearedMsg struct{}
-type AgentWorkingMsg bool
-
type EditorFocusMsg bool
func lspsConfigured(width int) string {
@@ -5,14 +5,17 @@ import (
"github.com/charmbracelet/bubbles/textarea"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
+ "github.com/kujtimiihoxha/termai/internal/app"
+ "github.com/kujtimiihoxha/termai/internal/session"
"github.com/kujtimiihoxha/termai/internal/tui/layout"
"github.com/kujtimiihoxha/termai/internal/tui/styles"
"github.com/kujtimiihoxha/termai/internal/tui/util"
)
type editorCmp struct {
- textarea textarea.Model
- agentWorking bool
+ app *app.App
+ session session.Session
+ textarea textarea.Model
}
type focusedEditorKeyMaps struct {
@@ -32,7 +35,7 @@ var focusedKeyMaps = focusedEditorKeyMaps{
),
Blur: key.NewBinding(
key.WithKeys("esc"),
- key.WithHelp("esc", "blur editor"),
+ key.WithHelp("esc", "focus messages"),
),
}
@@ -52,7 +55,7 @@ func (m *editorCmp) Init() tea.Cmd {
}
func (m *editorCmp) send() tea.Cmd {
- if m.agentWorking {
+ if m.app.CoderAgent.IsSessionBusy(m.session.ID) {
return util.ReportWarn("Agent is working, please wait...")
}
@@ -66,7 +69,6 @@ func (m *editorCmp) send() tea.Cmd {
util.CmdHandler(SendMsg{
Text: value,
}),
- util.CmdHandler(AgentWorkingMsg(true)),
util.CmdHandler(EditorFocusMsg(false)),
)
}
@@ -74,8 +76,11 @@ func (m *editorCmp) send() tea.Cmd {
func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
switch msg := msg.(type) {
- case AgentWorkingMsg:
- m.agentWorking = bool(msg)
+ case SessionSelectedMsg:
+ if msg.ID != m.session.ID {
+ m.session = msg
+ }
+ return m, nil
case tea.KeyMsg:
// if the key does not match any binding, return
if m.textarea.Focused() && key.Matches(msg, focusedKeyMaps.Send) {
@@ -122,7 +127,7 @@ func (m *editorCmp) BindingKeys() []key.Binding {
return bindings
}
-func NewEditorCmp() tea.Model {
+func NewEditorCmp(app *app.App) tea.Model {
ti := textarea.New()
ti.Prompt = " "
ti.ShowLineNumbers = false
@@ -138,6 +143,7 @@ func NewEditorCmp() tea.Model {
ti.CharLimit = -1
ti.Focus()
return &editorCmp{
+ app: app,
textarea: ti,
}
}
@@ -6,7 +6,9 @@ import (
"fmt"
"math"
"strings"
+ "time"
+ "github.com/charmbracelet/bubbles/key"
"github.com/charmbracelet/bubbles/spinner"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
@@ -17,9 +19,11 @@ import (
"github.com/kujtimiihoxha/termai/internal/llm/agent"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
+ "github.com/kujtimiihoxha/termai/internal/logging"
"github.com/kujtimiihoxha/termai/internal/message"
"github.com/kujtimiihoxha/termai/internal/pubsub"
"github.com/kujtimiihoxha/termai/internal/session"
+ "github.com/kujtimiihoxha/termai/internal/tui/layout"
"github.com/kujtimiihoxha/termai/internal/tui/styles"
"github.com/kujtimiihoxha/termai/internal/tui/util"
)
@@ -32,6 +36,9 @@ const (
toolMessageType
)
+// messagesTickMsg is a message sent by the timer to refresh messages
+type messagesTickMsg time.Time
+
type uiMessage struct {
ID string
messageType uiMessageType
@@ -52,24 +59,34 @@ type messagesCmp struct {
renderer *glamour.TermRenderer
focusRenderer *glamour.TermRenderer
cachedContent map[string]string
- agentWorking bool
spinner spinner.Model
needsRerender bool
- lastViewport string
}
func (m *messagesCmp) Init() tea.Cmd {
- return tea.Batch(m.viewport.Init())
+ return tea.Batch(m.viewport.Init(), m.spinner.Tick, m.tickMessages())
+}
+
+func (m *messagesCmp) tickMessages() tea.Cmd {
+ return tea.Tick(time.Second, func(t time.Time) tea.Msg {
+ return messagesTickMsg(t)
+ })
}
func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmds []tea.Cmd
switch msg := msg.(type) {
- case AgentWorkingMsg:
- m.agentWorking = bool(msg)
- if m.agentWorking {
- cmds = append(cmds, m.spinner.Tick)
+ case messagesTickMsg:
+ // Refresh messages if we have an active session
+ if m.session.ID != "" {
+ messages, err := m.app.Messages.List(context.Background(), m.session.ID)
+ if err == nil {
+ m.messages = messages
+ m.needsRerender = true
+ }
}
+ // Continue ticking
+ cmds = append(cmds, m.tickMessages())
case EditorFocusMsg:
m.writingMode = bool(msg)
case SessionSelectedMsg:
@@ -84,6 +101,7 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.messages = make([]message.Message, 0)
m.currentMsgID = ""
m.needsRerender = true
+ m.cachedContent = make(map[string]string)
return m, nil
case tea.KeyMsg:
@@ -104,6 +122,12 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
if !messageExists {
+ // If we have messages, ensure the previous last message is not cached
+ if len(m.messages) > 0 {
+ lastMsgID := m.messages[len(m.messages)-1].ID
+ delete(m.cachedContent, lastMsgID)
+ }
+
m.messages = append(m.messages, msg.Payload)
delete(m.cachedContent, m.currentMsgID)
m.currentMsgID = msg.Payload.ID
@@ -112,36 +136,40 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
for _, v := range m.messages {
for _, c := range v.ToolCalls() {
- // the message is being added to the session of a tool called
if c.ID == msg.Payload.SessionID {
m.needsRerender = true
}
}
}
} else if msg.Type == pubsub.UpdatedEvent && msg.Payload.SessionID == m.session.ID {
+ logging.Debug("Message", "finish", msg.Payload.FinishReason())
for i, v := range m.messages {
if v.ID == msg.Payload.ID {
- if !m.messages[i].IsFinished() && msg.Payload.IsFinished() && msg.Payload.FinishReason() == "end_turn" || msg.Payload.FinishReason() == "canceled" {
- cmds = append(cmds, util.CmdHandler(AgentWorkingMsg(false)))
- }
m.messages[i] = msg.Payload
delete(m.cachedContent, msg.Payload.ID)
+
+ // If this is the last message, ensure it's not cached
+ if i == len(m.messages)-1 {
+ delete(m.cachedContent, msg.Payload.ID)
+ }
+
m.needsRerender = true
break
}
}
}
}
- if m.agentWorking {
- u, cmd := m.spinner.Update(msg)
- m.spinner = u
- cmds = append(cmds, cmd)
- }
+
oldPos := m.viewport.YPosition
u, cmd := m.viewport.Update(msg)
m.viewport = u
m.needsRerender = m.needsRerender || m.viewport.YPosition != oldPos
cmds = append(cmds, cmd)
+
+ spinner, cmd := m.spinner.Update(msg)
+ m.spinner = spinner
+ cmds = append(cmds, cmd)
+
if m.needsRerender {
m.renderView()
if len(m.messages) > 0 {
@@ -157,10 +185,21 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, tea.Batch(cmds...)
}
+func (m *messagesCmp) IsAgentWorking() bool {
+ return m.app.CoderAgent.IsSessionBusy(m.session.ID)
+}
+
func (m *messagesCmp) renderSimpleMessage(msg message.Message, info ...string) string {
- if v, ok := m.cachedContent[msg.ID]; ok {
- return v
+ // Check if this is the last message in the list
+ isLastMessage := len(m.messages) > 0 && m.messages[len(m.messages)-1].ID == msg.ID
+
+ // Only use cache for non-last messages
+ if !isLastMessage {
+ if v, ok := m.cachedContent[msg.ID]; ok {
+ return v
+ }
}
+
style := styles.BaseStyle.
Width(m.width).
BorderLeft(true).
@@ -191,7 +230,12 @@ func (m *messagesCmp) renderSimpleMessage(msg message.Message, info ...string) s
parts...,
),
)
- m.cachedContent[msg.ID] = rendered
+
+ // Only cache if it's not the last message
+ if !isLastMessage {
+ m.cachedContent[msg.ID] = rendered
+ }
+
return rendered
}
@@ -207,32 +251,71 @@ func formatTimeDifference(unixTime1, unixTime2 int64) string {
return fmt.Sprintf("%dm%ds", minutes, seconds)
}
+func (m *messagesCmp) findToolResponse(callID string) *message.ToolResult {
+ for _, v := range m.messages {
+ for _, c := range v.ToolResults() {
+ if c.ToolCallID == callID {
+ return &c
+ }
+ }
+ }
+ return nil
+}
+
func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) string {
key := ""
value := ""
+ result := styles.BaseStyle.Foreground(styles.PrimaryColor).Render(m.spinner.View() + " waiting for response...")
+
+ response := m.findToolResponse(toolCall.ID)
+ if response != nil && response.IsError {
+ // Clean up error message for display by removing newlines
+ // This ensures error messages display properly in the UI
+ errMsg := strings.ReplaceAll(response.Content, "\n", " ")
+ result = styles.BaseStyle.Foreground(styles.Error).Render(ansi.Truncate(errMsg, 40, "..."))
+ } else if response != nil {
+ result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render("Done")
+ }
switch toolCall.Name {
// TODO: add result data to the tools
case agent.AgentToolName:
key = "Task"
var params agent.AgentParams
json.Unmarshal([]byte(toolCall.Input), ¶ms)
- value = params.Prompt
- // TODO: handle nested calls
+ value = strings.ReplaceAll(params.Prompt, "\n", " ")
+ if response != nil && !response.IsError {
+ firstRow := strings.ReplaceAll(response.Content, "\n", " ")
+ result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(ansi.Truncate(firstRow, 40, "..."))
+ }
case tools.BashToolName:
key = "Bash"
var params tools.BashParams
json.Unmarshal([]byte(toolCall.Input), ¶ms)
value = params.Command
+ if response != nil && !response.IsError {
+ metadata := tools.BashResponseMetadata{}
+ json.Unmarshal([]byte(response.Metadata), &metadata)
+ result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("Took %s", formatTimeDifference(metadata.StartTime, metadata.EndTime)))
+ }
+
case tools.EditToolName:
key = "Edit"
var params tools.EditParams
json.Unmarshal([]byte(toolCall.Input), ¶ms)
value = params.FilePath
+ if response != nil && !response.IsError {
+ metadata := tools.EditResponseMetadata{}
+ json.Unmarshal([]byte(response.Metadata), &metadata)
+ result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d Additions %d Removals", metadata.Additions, metadata.Removals))
+ }
case tools.FetchToolName:
key = "Fetch"
var params tools.FetchParams
json.Unmarshal([]byte(toolCall.Input), ¶ms)
value = params.URL
+ if response != nil && !response.IsError {
+ result = styles.BaseStyle.Foreground(styles.Error).Render(response.Content)
+ }
case tools.GlobToolName:
key = "Glob"
var params tools.GlobParams
@@ -241,6 +324,15 @@ func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) s
params.Path = "."
}
value = fmt.Sprintf("%s (%s)", params.Pattern, params.Path)
+ if response != nil && !response.IsError {
+ metadata := tools.GlobResponseMetadata{}
+ json.Unmarshal([]byte(response.Metadata), &metadata)
+ if metadata.Truncated {
+ result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found (truncated)", metadata.NumberOfFiles))
+ } else {
+ result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found", metadata.NumberOfFiles))
+ }
+ }
case tools.GrepToolName:
key = "Grep"
var params tools.GrepParams
@@ -249,19 +341,46 @@ func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) s
params.Path = "."
}
value = fmt.Sprintf("%s (%s)", params.Pattern, params.Path)
+ if response != nil && !response.IsError {
+ metadata := tools.GrepResponseMetadata{}
+ json.Unmarshal([]byte(response.Metadata), &metadata)
+ if metadata.Truncated {
+ result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found (truncated)", metadata.NumberOfMatches))
+ } else {
+ result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found", metadata.NumberOfMatches))
+ }
+ }
case tools.LSToolName:
- key = "Ls"
+ key = "ls"
var params tools.LSParams
json.Unmarshal([]byte(toolCall.Input), ¶ms)
if params.Path == "" {
params.Path = "."
}
value = params.Path
+ if response != nil && !response.IsError {
+ metadata := tools.LSResponseMetadata{}
+ json.Unmarshal([]byte(response.Metadata), &metadata)
+ if metadata.Truncated {
+ result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found (truncated)", metadata.NumberOfFiles))
+ } else {
+ result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found", metadata.NumberOfFiles))
+ }
+ }
case tools.SourcegraphToolName:
key = "Sourcegraph"
var params tools.SourcegraphParams
json.Unmarshal([]byte(toolCall.Input), ¶ms)
value = params.Query
+ if response != nil && !response.IsError {
+ metadata := tools.SourcegraphResponseMetadata{}
+ json.Unmarshal([]byte(response.Metadata), &metadata)
+ if metadata.Truncated {
+ result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d matches found (truncated)", metadata.NumberOfMatches))
+ } else {
+ result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d matches found", metadata.NumberOfMatches))
+ }
+ }
case tools.ViewToolName:
key = "View"
var params tools.ViewParams
@@ -272,6 +391,12 @@ func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) s
var params tools.WriteParams
json.Unmarshal([]byte(toolCall.Input), ¶ms)
value = params.FilePath
+ if response != nil && !response.IsError {
+ metadata := tools.WriteResponseMetadata{}
+ json.Unmarshal([]byte(response.Metadata), &metadata)
+
+ result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d Additions %d Removals", metadata.Additions, metadata.Removals))
+ }
default:
key = toolCall.Name
var params map[string]any
@@ -300,14 +425,15 @@ func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) s
)
if !isNested {
value = valyeStyle.
- Width(m.width - lipgloss.Width(keyValye) - 2).
Render(
ansi.Truncate(
- value,
- m.width-lipgloss.Width(keyValye)-2,
+ value+" ",
+ m.width-lipgloss.Width(keyValye)-2-lipgloss.Width(result),
"...",
),
)
+ value += result
+
} else {
keyValye = keyStyle.Render(
fmt.Sprintf(" └ %s: ", key),
@@ -409,6 +535,27 @@ func (m *messagesCmp) renderView() {
m.uiMessages = make([]uiMessage, 0)
pos := 0
+ // If we have messages, ensure the last message is not cached
+ // This ensures we always render the latest content for the most recent message
+ // which may be actively updating (e.g., during generation)
+ if len(m.messages) > 0 {
+ lastMsgID := m.messages[len(m.messages)-1].ID
+ delete(m.cachedContent, lastMsgID)
+ }
+
+ // Limit cache to 10 messages
+ if len(m.cachedContent) > 15 {
+ // Create a list of keys to delete (oldest messages first)
+ keys := make([]string, 0, len(m.cachedContent))
+ for k := range m.cachedContent {
+ keys = append(keys, k)
+ }
+ // Delete oldest messages until we have 10 or fewer
+ for i := 0; i < len(keys)-15; i++ {
+ delete(m.cachedContent, keys[i])
+ }
+ }
+
for _, v := range m.messages {
switch v.Role {
case message.User:
@@ -487,7 +634,7 @@ func (m *messagesCmp) View() string {
func (m *messagesCmp) help() string {
text := ""
- if m.agentWorking {
+ if m.IsAgentWorking() {
text += styles.BaseStyle.Foreground(styles.PrimaryColor).Bold(true).Render(
fmt.Sprintf("%s %s ", m.spinner.View(), "Generating..."),
)
@@ -562,9 +709,15 @@ func (m *messagesCmp) SetSession(session session.Session) tea.Cmd {
m.messages = messages
m.currentMsgID = m.messages[len(m.messages)-1].ID
m.needsRerender = true
+ m.cachedContent = make(map[string]string)
return nil
}
+func (m *messagesCmp) BindingKeys() []key.Binding {
+ bindings := layout.KeyMapToSlice(m.viewport.KeyMap)
+ return bindings
+}
+
func NewMessagesCmp(app *app.App) tea.Model {
focusRenderer, _ := glamour.NewTermRenderer(
glamour.WithStyles(styles.MarkdownTheme(true)),
@@ -1,10 +1,15 @@
package chat
import (
+ "context"
"fmt"
+ "strings"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
+ "github.com/kujtimiihoxha/termai/internal/config"
+ "github.com/kujtimiihoxha/termai/internal/diff"
+ "github.com/kujtimiihoxha/termai/internal/history"
"github.com/kujtimiihoxha/termai/internal/pubsub"
"github.com/kujtimiihoxha/termai/internal/session"
"github.com/kujtimiihoxha/termai/internal/tui/styles"
@@ -13,9 +18,33 @@ import (
type sidebarCmp struct {
width, height int
session session.Session
+ history history.Service
+ modFiles map[string]struct {
+ additions int
+ removals int
+ }
}
func (m *sidebarCmp) Init() tea.Cmd {
+ if m.history != nil {
+ ctx := context.Background()
+ // Subscribe to file events
+ filesCh := m.history.Subscribe(ctx)
+
+ // Initialize the modified files map
+ m.modFiles = make(map[string]struct {
+ additions int
+ removals int
+ })
+
+ // Load initial files and calculate diffs
+ m.loadModifiedFiles(ctx)
+
+ // Return a command that will send file events to the Update method
+ return func() tea.Msg {
+ return <-filesCh
+ }
+ }
return nil
}
@@ -27,6 +56,13 @@ func (m *sidebarCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.session = msg.Payload
}
}
+ case pubsub.Event[history.File]:
+ if msg.Payload.SessionID == m.session.ID {
+ // When a file changes, reload all modified files
+ // This ensures we have the complete and accurate list
+ ctx := context.Background()
+ m.loadModifiedFiles(ctx)
+ }
}
return m, nil
}
@@ -86,18 +122,28 @@ func (m *sidebarCmp) modifiedFile(filePath string, additions, removals int) stri
func (m *sidebarCmp) modifiedFiles() string {
modifiedFiles := styles.BaseStyle.Width(m.width).Foreground(styles.PrimaryColor).Bold(true).Render("Modified Files:")
- files := []struct {
- path string
- additions int
- removals int
- }{
- {"file1.txt", 10, 5},
- {"file2.txt", 20, 0},
- {"file3.txt", 0, 15},
+
+ // If no modified files, show a placeholder message
+ if m.modFiles == nil || len(m.modFiles) == 0 {
+ message := "No modified files"
+ remainingWidth := m.width - lipgloss.Width(modifiedFiles)
+ if remainingWidth > 0 {
+ message += strings.Repeat(" ", remainingWidth)
+ }
+ return styles.BaseStyle.
+ Width(m.width).
+ Render(
+ lipgloss.JoinVertical(
+ lipgloss.Top,
+ modifiedFiles,
+ styles.BaseStyle.Foreground(styles.ForgroundDim).Render(message),
+ ),
+ )
}
+
var fileViews []string
- for _, file := range files {
- fileViews = append(fileViews, m.modifiedFile(file.path, file.additions, file.removals))
+ for path, stats := range m.modFiles {
+ fileViews = append(fileViews, m.modifiedFile(path, stats.additions, stats.removals))
}
return styles.BaseStyle.
@@ -123,8 +169,116 @@ func (m *sidebarCmp) GetSize() (int, int) {
return m.width, m.height
}
-func NewSidebarCmp(session session.Session) tea.Model {
+func NewSidebarCmp(session session.Session, history history.Service) tea.Model {
return &sidebarCmp{
session: session,
+ history: history,
+ }
+}
+
+func (m *sidebarCmp) loadModifiedFiles(ctx context.Context) {
+ if m.history == nil || m.session.ID == "" {
+ return
+ }
+
+ // Get all latest files for this session
+ latestFiles, err := m.history.ListLatestSessionFiles(ctx, m.session.ID)
+ if err != nil {
+ return
+ }
+
+ // Get all files for this session (to find initial versions)
+ allFiles, err := m.history.ListBySession(ctx, m.session.ID)
+ if err != nil {
+ return
+ }
+
+ // Process each latest file
+ for _, file := range latestFiles {
+ // Skip if this is the initial version (no changes to show)
+ if file.Version == history.InitialVersion {
+ continue
+ }
+
+ // Find the initial version for this specific file
+ var initialVersion history.File
+ for _, v := range allFiles {
+ if v.Path == file.Path && v.Version == history.InitialVersion {
+ initialVersion = v
+ break
+ }
+ }
+
+ // Skip if we can't find the initial version
+ if initialVersion.ID == "" {
+ continue
+ }
+
+ // Calculate diff between initial and latest version
+ _, additions, removals := diff.GenerateDiff(initialVersion.Content, file.Content, file.Path)
+
+ // Only add to modified files if there are changes
+ if additions > 0 || removals > 0 {
+ // Remove working directory prefix from file path
+ displayPath := file.Path
+ workingDir := config.WorkingDirectory()
+ displayPath = strings.TrimPrefix(displayPath, workingDir)
+ displayPath = strings.TrimPrefix(displayPath, "/")
+
+ m.modFiles[displayPath] = struct {
+ additions int
+ removals int
+ }{
+ additions: additions,
+ removals: removals,
+ }
+ }
+ }
+}
+
+func (m *sidebarCmp) processFileChanges(ctx context.Context, file history.File) {
+ // Skip if not the latest version
+ if file.Version == history.InitialVersion {
+ return
+ }
+
+ // Get all versions of this file
+ fileVersions, err := m.history.ListBySession(ctx, m.session.ID)
+ if err != nil {
+ return
+ }
+
+ // Find the initial version
+ var initialVersion history.File
+ for _, v := range fileVersions {
+ if v.Path == file.Path && v.Version == history.InitialVersion {
+ initialVersion = v
+ break
+ }
+ }
+
+ // Skip if we can't find the initial version
+ if initialVersion.ID == "" {
+ return
+ }
+
+ // Calculate diff between initial and latest version
+ _, additions, removals := diff.GenerateDiff(initialVersion.Content, file.Content, file.Path)
+
+ // Only add to modified files if there are changes
+ if additions > 0 || removals > 0 {
+ // Remove working directory prefix from file path
+ displayPath := file.Path
+ workingDir := config.WorkingDirectory()
+ displayPath = strings.TrimPrefix(displayPath, workingDir)
+ displayPath = strings.TrimPrefix(displayPath, "/")
+
+ m.modFiles[displayPath] = struct {
+ additions int
+ removals int
+ }{
+ additions: additions,
+ removals: removals,
+ }
}
}
@@ -1,117 +0,0 @@
-package core
-
-import (
- "github.com/charmbracelet/bubbles/key"
- tea "github.com/charmbracelet/bubbletea"
- "github.com/charmbracelet/lipgloss"
- "github.com/kujtimiihoxha/termai/internal/tui/layout"
- "github.com/kujtimiihoxha/termai/internal/tui/util"
-)
-
-type SizeableModel interface {
- tea.Model
- layout.Sizeable
-}
-
-type DialogMsg struct {
- Content SizeableModel
- WidthRatio float64
- HeightRatio float64
-
- MinWidth int
- MinHeight int
-}
-
-type DialogCloseMsg struct{}
-
-type KeyBindings struct {
- Return key.Binding
-}
-
-var keys = KeyBindings{
- Return: key.NewBinding(
- key.WithKeys("esc"),
- key.WithHelp("esc", "close"),
- ),
-}
-
-type DialogCmp interface {
- tea.Model
- layout.Bindings
-}
-
-type dialogCmp struct {
- content SizeableModel
- screenWidth int
- screenHeight int
-
- widthRatio float64
- heightRatio float64
-
- minWidth int
- minHeight int
-
- width int
- height int
-}
-
-func (d *dialogCmp) Init() tea.Cmd {
- return nil
-}
-
-func (d *dialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
- switch msg := msg.(type) {
- case tea.WindowSizeMsg:
- d.screenWidth = msg.Width
- d.screenHeight = msg.Height
- d.width = max(int(float64(d.screenWidth)*d.widthRatio), d.minWidth)
- d.height = max(int(float64(d.screenHeight)*d.heightRatio), d.minHeight)
- if d.content != nil {
- d.content.SetSize(d.width, d.height)
- }
- return d, nil
- case DialogMsg:
- d.content = msg.Content
- d.widthRatio = msg.WidthRatio
- d.heightRatio = msg.HeightRatio
- d.minWidth = msg.MinWidth
- d.minHeight = msg.MinHeight
- d.width = max(int(float64(d.screenWidth)*d.widthRatio), d.minWidth)
- d.height = max(int(float64(d.screenHeight)*d.heightRatio), d.minHeight)
- if d.content != nil {
- d.content.SetSize(d.width, d.height)
- }
- case DialogCloseMsg:
- d.content = nil
- return d, nil
- case tea.KeyMsg:
- if key.Matches(msg, keys.Return) {
- return d, util.CmdHandler(DialogCloseMsg{})
- }
- }
- if d.content != nil {
- u, cmd := d.content.Update(msg)
- d.content = u.(SizeableModel)
- return d, cmd
- }
- return d, nil
-}
-
-func (d *dialogCmp) BindingKeys() []key.Binding {
- bindings := []key.Binding{keys.Return}
- if d.content == nil {
- return bindings
- }
- if c, ok := d.content.(layout.Bindings); ok {
- return append(bindings, c.BindingKeys()...)
- }
- return bindings
-}
-
-func (d *dialogCmp) View() string {
- return lipgloss.NewStyle().Width(d.width).Height(d.height).Render(d.content.View())
-}
-
-func NewDialogCmp() DialogCmp {
- return &dialogCmp{}
-}
@@ -1,119 +0,0 @@
-package core
-
-import (
- "strings"
-
- "github.com/charmbracelet/bubbles/key"
- tea "github.com/charmbracelet/bubbletea"
- "github.com/charmbracelet/lipgloss"
- "github.com/kujtimiihoxha/termai/internal/tui/styles"
-)
-
-type HelpCmp interface {
- tea.Model
- SetBindings(bindings []key.Binding)
- Height() int
-}
-
-const (
- helpWidgetHeight = 12
-)
-
-type helpCmp struct {
- width int
- bindings []key.Binding
-}
-
-func (h *helpCmp) Init() tea.Cmd {
- return nil
-}
-
-func (h *helpCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
- switch msg := msg.(type) {
- case tea.WindowSizeMsg:
- h.width = msg.Width
- }
- return h, nil
-}
-
-func (h *helpCmp) View() string {
- helpKeyStyle := styles.Bold.Foreground(styles.Rosewater).Margin(0, 1, 0, 0)
- helpDescStyle := styles.Regular.Foreground(styles.Flamingo)
- // Compile list of bindings to render
- bindings := removeDuplicateBindings(h.bindings)
- // Enumerate through each group of bindings, populating a series of
- // pairs of columns, one for keys, one for descriptions
- var (
- pairs []string
- width int
- rows = helpWidgetHeight - 2
- )
- for i := 0; i < len(bindings); i += rows {
- var (
- keys []string
- descs []string
- )
- for j := i; j < min(i+rows, len(bindings)); j++ {
- keys = append(keys, helpKeyStyle.Render(bindings[j].Help().Key))
- descs = append(descs, helpDescStyle.Render(bindings[j].Help().Desc))
- }
- // Render pair of columns; beyond the first pair, render a three space
- // left margin, in order to visually separate the pairs.
- var cols []string
- if len(pairs) > 0 {
- cols = []string{" "}
- }
- cols = append(cols,
- strings.Join(keys, "\n"),
- strings.Join(descs, "\n"),
- )
-
- pair := lipgloss.JoinHorizontal(lipgloss.Top, cols...)
- // check whether it exceeds the maximum width avail (the width of the
- // terminal, subtracting 2 for the borders).
- width += lipgloss.Width(pair)
- if width > h.width-2 {
- break
- }
- pairs = append(pairs, pair)
- }
-
- // Join pairs of columns and enclose in a border
- content := lipgloss.JoinHorizontal(lipgloss.Top, pairs...)
- return styles.DoubleBorder.Height(rows).PaddingLeft(1).Width(h.width - 2).Render(content)
-}
-
-func removeDuplicateBindings(bindings []key.Binding) []key.Binding {
- seen := make(map[string]struct{})
- result := make([]key.Binding, 0, len(bindings))
-
- // Process bindings in reverse order
- for i := len(bindings) - 1; i >= 0; i-- {
- b := bindings[i]
- k := strings.Join(b.Keys(), " ")
- if _, ok := seen[k]; ok {
- // duplicate, skip
- continue
- }
- seen[k] = struct{}{}
- // Add to the beginning of result to maintain original order
- result = append([]key.Binding{b}, result...)
- }
-
- return result
-}
-
-func (h *helpCmp) SetBindings(bindings []key.Binding) {
- h.bindings = bindings
-}
-
-func (h helpCmp) Height() int {
- return helpWidgetHeight
-}
-
-func NewHelpCmp() HelpCmp {
- return &helpCmp{
- width: 0,
- bindings: make([]key.Binding, 0),
- }
-}
@@ -1,21 +1,25 @@
package core
import (
+ "fmt"
+ "strings"
"time"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/llm/models"
+ "github.com/kujtimiihoxha/termai/internal/lsp"
+ "github.com/kujtimiihoxha/termai/internal/lsp/protocol"
"github.com/kujtimiihoxha/termai/internal/tui/styles"
"github.com/kujtimiihoxha/termai/internal/tui/util"
- "github.com/kujtimiihoxha/termai/internal/version"
)
type statusCmp struct {
info util.InfoMsg
width int
messageTTL time.Duration
+ lspClients map[string]*lsp.Client
}
// clearMessageCmd is a command that clears status messages after a timeout
@@ -47,20 +51,18 @@ func (m statusCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, nil
}
-var (
- versionWidget = styles.Padded.Background(styles.DarkGrey).Foreground(styles.Text).Render(version.Version)
- helpWidget = styles.Padded.Background(styles.Grey).Foreground(styles.Text).Render("? help")
-)
+var helpWidget = styles.Padded.Background(styles.ForgroundMid).Foreground(styles.BackgroundDarker).Bold(true).Render("ctrl+? help")
func (m statusCmp) View() string {
- status := styles.Padded.Background(styles.Grey).Foreground(styles.Text).Render("? help")
+ status := helpWidget
+ diagnostics := styles.Padded.Background(styles.BackgroundDarker).Render(m.projectDiagnostics())
if m.info.Msg != "" {
infoStyle := styles.Padded.
Foreground(styles.Base).
- Width(m.availableFooterMsgWidth())
+ Width(m.availableFooterMsgWidth(diagnostics))
switch m.info.Type {
case util.InfoTypeInfo:
- infoStyle = infoStyle.Background(styles.Blue)
+ infoStyle = infoStyle.Background(styles.BorderColor)
case util.InfoTypeWarn:
infoStyle = infoStyle.Background(styles.Peach)
case util.InfoTypeError:
@@ -68,7 +70,7 @@ func (m statusCmp) View() string {
}
// Truncate message if it's longer than available width
msg := m.info.Msg
- availWidth := m.availableFooterMsgWidth() - 10
+ availWidth := m.availableFooterMsgWidth(diagnostics) - 10
if len(msg) > availWidth && availWidth > 0 {
msg = msg[:availWidth] + "..."
}
@@ -76,27 +78,81 @@ func (m statusCmp) View() string {
} else {
status += styles.Padded.
Foreground(styles.Base).
- Background(styles.LightGrey).
- Width(m.availableFooterMsgWidth()).
+ Background(styles.BackgroundDim).
+ Width(m.availableFooterMsgWidth(diagnostics)).
Render("")
}
+ status += diagnostics
status += m.model()
- status += versionWidget
return status
}
-func (m statusCmp) availableFooterMsgWidth() int {
- // -2 to accommodate padding
- return max(0, m.width-lipgloss.Width(helpWidget)-lipgloss.Width(versionWidget)-lipgloss.Width(m.model()))
+func (m *statusCmp) projectDiagnostics() string {
+ errorDiagnostics := []protocol.Diagnostic{}
+ warnDiagnostics := []protocol.Diagnostic{}
+ hintDiagnostics := []protocol.Diagnostic{}
+ infoDiagnostics := []protocol.Diagnostic{}
+ for _, client := range m.lspClients {
+ for _, d := range client.GetDiagnostics() {
+ for _, diag := range d {
+ switch diag.Severity {
+ case protocol.SeverityError:
+ errorDiagnostics = append(errorDiagnostics, diag)
+ case protocol.SeverityWarning:
+ warnDiagnostics = append(warnDiagnostics, diag)
+ case protocol.SeverityHint:
+ hintDiagnostics = append(hintDiagnostics, diag)
+ case protocol.SeverityInformation:
+ infoDiagnostics = append(infoDiagnostics, diag)
+ }
+ }
+ }
+ }
+
+ if len(errorDiagnostics) == 0 && len(warnDiagnostics) == 0 && len(hintDiagnostics) == 0 && len(infoDiagnostics) == 0 {
+ return "No diagnostics"
+ }
+
+ diagnostics := []string{}
+
+ if len(errorDiagnostics) > 0 {
+ errStr := lipgloss.NewStyle().Foreground(styles.Error).Render(fmt.Sprintf("%s %d", styles.ErrorIcon, len(errorDiagnostics)))
+ diagnostics = append(diagnostics, errStr)
+ }
+ if len(warnDiagnostics) > 0 {
+ warnStr := lipgloss.NewStyle().Foreground(styles.Warning).Render(fmt.Sprintf("%s %d", styles.WarningIcon, len(warnDiagnostics)))
+ diagnostics = append(diagnostics, warnStr)
+ }
+ if len(hintDiagnostics) > 0 {
+ hintStr := lipgloss.NewStyle().Foreground(styles.Text).Render(fmt.Sprintf("%s %d", styles.HintIcon, len(hintDiagnostics)))
+ diagnostics = append(diagnostics, hintStr)
+ }
+ if len(infoDiagnostics) > 0 {
+ infoStr := lipgloss.NewStyle().Foreground(styles.Peach).Render(fmt.Sprintf("%s %d", styles.InfoIcon, len(infoDiagnostics)))
+ diagnostics = append(diagnostics, infoStr)
+ }
+
+ return strings.Join(diagnostics, " ")
+}
+
+func (m statusCmp) availableFooterMsgWidth(diagnostics string) int {
+ return max(0, m.width-lipgloss.Width(helpWidget)-lipgloss.Width(m.model())-lipgloss.Width(diagnostics))
}
func (m statusCmp) model() string {
- model := models.SupportedModels[config.Get().Model.Coder]
+ cfg := config.Get()
+
+ coder, ok := cfg.Agents[config.AgentCoder]
+ if !ok {
+ return "Unknown"
+ }
+ model := models.SupportedModels[coder.Model]
return styles.Padded.Background(styles.Grey).Foreground(styles.Text).Render(model.Name)
}
-func NewStatusCmp() tea.Model {
+func NewStatusCmp(lspClients map[string]*lsp.Client) tea.Model {
return &statusCmp{
messageTTL: 10 * time.Second,
+ lspClients: lspClients,
}
}
@@ -0,0 +1,182 @@
+package dialog
+
+import (
+ "strings"
+
+ "github.com/charmbracelet/bubbles/key"
+ tea "github.com/charmbracelet/bubbletea"
+ "github.com/charmbracelet/lipgloss"
+ "github.com/kujtimiihoxha/termai/internal/tui/styles"
+)
+
+type helpCmp struct {
+ width int
+ height int
+ keys []key.Binding
+}
+
+func (h *helpCmp) Init() tea.Cmd {
+ return nil
+}
+
+func (h *helpCmp) SetBindings(k []key.Binding) {
+ h.keys = k
+}
+
+func (h *helpCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
+ switch msg := msg.(type) {
+ case tea.WindowSizeMsg:
+ h.width = 80
+ h.height = msg.Height
+ }
+ return h, nil
+}
+
+func removeDuplicateBindings(bindings []key.Binding) []key.Binding {
+ seen := make(map[string]struct{})
+ result := make([]key.Binding, 0, len(bindings))
+
+ // Process bindings in reverse order
+ for i := len(bindings) - 1; i >= 0; i-- {
+ b := bindings[i]
+ k := strings.Join(b.Keys(), " ")
+ if _, ok := seen[k]; ok {
+ // duplicate, skip
+ continue
+ }
+ seen[k] = struct{}{}
+ // Add to the beginning of result to maintain original order
+ result = append([]key.Binding{b}, result...)
+ }
+
+ return result
+}
+
+func (h *helpCmp) render() string {
+ helpKeyStyle := styles.Bold.Background(styles.Background).Foreground(styles.Forground).Padding(0, 1, 0, 0)
+ helpDescStyle := styles.Regular.Background(styles.Background).Foreground(styles.ForgroundMid)
+ // Compile list of bindings to render
+ bindings := removeDuplicateBindings(h.keys)
+ // Enumerate through each group of bindings, populating a series of
+ // pairs of columns, one for keys, one for descriptions
+ var (
+ pairs []string
+ width int
+ rows = 12 - 2
+ )
+ for i := 0; i < len(bindings); i += rows {
+ var (
+ keys []string
+ descs []string
+ )
+ for j := i; j < min(i+rows, len(bindings)); j++ {
+ keys = append(keys, helpKeyStyle.Render(bindings[j].Help().Key))
+ descs = append(descs, helpDescStyle.Render(bindings[j].Help().Desc))
+ }
+ // Render pair of columns; beyond the first pair, render a three space
+ // left margin, in order to visually separate the pairs.
+ var cols []string
+ if len(pairs) > 0 {
+ cols = []string{styles.BaseStyle.Render(" ")}
+ }
+
+ maxDescWidth := 0
+ for _, desc := range descs {
+ if maxDescWidth < lipgloss.Width(desc) {
+ maxDescWidth = lipgloss.Width(desc)
+ }
+ }
+ for i := range descs {
+ remainingWidth := maxDescWidth - lipgloss.Width(descs[i])
+ if remainingWidth > 0 {
+ descs[i] = descs[i] + styles.BaseStyle.Render(strings.Repeat(" ", remainingWidth))
+ }
+ }
+ maxKeyWidth := 0
+ for _, key := range keys {
+ if maxKeyWidth < lipgloss.Width(key) {
+ maxKeyWidth = lipgloss.Width(key)
+ }
+ }
+ for i := range keys {
+ remainingWidth := maxKeyWidth - lipgloss.Width(keys[i])
+ if remainingWidth > 0 {
+ keys[i] = keys[i] + styles.BaseStyle.Render(strings.Repeat(" ", remainingWidth))
+ }
+ }
+
+ cols = append(cols,
+ strings.Join(keys, "\n"),
+ strings.Join(descs, "\n"),
+ )
+
+ pair := styles.BaseStyle.Render(lipgloss.JoinHorizontal(lipgloss.Top, cols...))
+ // check whether it exceeds the maximum width avail (the width of the
+ // terminal, subtracting 2 for the borders).
+ width += lipgloss.Width(pair)
+ if width > h.width-2 {
+ break
+ }
+ pairs = append(pairs, pair)
+ }
+
+ // https://github.com/charmbracelet/lipgloss/issues/209
+ if len(pairs) > 1 {
+ prefix := pairs[:len(pairs)-1]
+ lastPair := pairs[len(pairs)-1]
+ prefix = append(prefix, lipgloss.Place(
+ lipgloss.Width(lastPair), // width
+ lipgloss.Height(prefix[0]), // height
+ lipgloss.Left, // x
+ lipgloss.Top, // y
+ lastPair, // content
+ lipgloss.WithWhitespaceBackground(styles.Background), // background
+ ))
+ content := styles.BaseStyle.Width(h.width).Render(
+ lipgloss.JoinHorizontal(
+ lipgloss.Top,
+ prefix...,
+ ),
+ )
+ return content
+ }
+ // Join pairs of columns and enclose in a border
+ content := styles.BaseStyle.Width(h.width).Render(
+ lipgloss.JoinHorizontal(
+ lipgloss.Top,
+ pairs...,
+ ),
+ )
+ return content
+}
+
+func (h *helpCmp) View() string {
+ content := h.render()
+ header := styles.BaseStyle.
+ Bold(true).
+ Width(lipgloss.Width(content)).
+ Foreground(styles.PrimaryColor).
+ Render("Keyboard Shortcuts")
+
+ return styles.BaseStyle.Padding(1).
+ Border(lipgloss.RoundedBorder()).
+ BorderForeground(styles.ForgroundDim).
+ Width(h.width).
+ BorderBackground(styles.Background).
+ Render(
+ lipgloss.JoinVertical(lipgloss.Center,
+ header,
+ styles.BaseStyle.Render(strings.Repeat(" ", lipgloss.Width(header))),
+ content,
+ ),
+ )
+}
+
+type HelpCmp interface {
+ tea.Model
+ SetBindings([]key.Binding)
+}
+
+func NewHelpCmp() HelpCmp {
+ return &helpCmp{}
+}
@@ -12,12 +12,9 @@ import (
"github.com/kujtimiihoxha/termai/internal/diff"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/kujtimiihoxha/termai/internal/permission"
- "github.com/kujtimiihoxha/termai/internal/tui/components/core"
"github.com/kujtimiihoxha/termai/internal/tui/layout"
"github.com/kujtimiihoxha/termai/internal/tui/styles"
"github.com/kujtimiihoxha/termai/internal/tui/util"
-
- "github.com/charmbracelet/huh"
)
type PermissionAction string
@@ -35,69 +32,64 @@ type PermissionResponseMsg struct {
Action PermissionAction
}
-// PermissionDialog interface for permission dialog component
-type PermissionDialog interface {
+// PermissionDialogCmp interface for permission dialog component
+type PermissionDialogCmp interface {
tea.Model
- layout.Sizeable
layout.Bindings
+ SetPermissions(permission permission.PermissionRequest)
}
-type keyMap struct {
- ChangeFocus key.Binding
+type permissionsMapping struct {
+ LeftRight key.Binding
+ EnterSpace key.Binding
+ Allow key.Binding
+ AllowSession key.Binding
+ Deny key.Binding
+ Tab key.Binding
}
-var keyMapValue = keyMap{
- ChangeFocus: key.NewBinding(
+var permissionsKeys = permissionsMapping{
+ LeftRight: key.NewBinding(
+ key.WithKeys("left", "right"),
+ key.WithHelp("←/→", "switch options"),
+ ),
+ EnterSpace: key.NewBinding(
+ key.WithKeys("enter", " "),
+ key.WithHelp("enter/space", "confirm"),
+ ),
+ Allow: key.NewBinding(
+ key.WithKeys("a"),
+ key.WithHelp("a", "allow"),
+ ),
+ AllowSession: key.NewBinding(
+ key.WithKeys("A"),
+ key.WithHelp("A", "allow for session"),
+ ),
+ Deny: key.NewBinding(
+ key.WithKeys("d"),
+ key.WithHelp("d", "deny"),
+ ),
+ Tab: key.NewBinding(
key.WithKeys("tab"),
- key.WithHelp("tab", "change focus"),
+ key.WithHelp("tab", "switch options"),
),
}
// permissionDialogCmp is the implementation of PermissionDialog
type permissionDialogCmp struct {
- form *huh.Form
width int
height int
permission permission.PermissionRequest
windowSize tea.WindowSizeMsg
- r *glamour.TermRenderer
contentViewPort viewport.Model
- isViewportFocus bool
- selectOption *huh.Select[string]
-}
+ selectedOption int // 0: Allow, 1: Allow for session, 2: Deny
-// formatDiff formats a diff string with colors for additions and deletions
-func formatDiff(diffText string) string {
- lines := strings.Split(diffText, "\n")
- var formattedLines []string
-
- // Define styles for different line types
- addStyle := lipgloss.NewStyle().Foreground(styles.Green)
- removeStyle := lipgloss.NewStyle().Foreground(styles.Red)
- headerStyle := lipgloss.NewStyle().Bold(true).Foreground(styles.Blue)
- contextStyle := lipgloss.NewStyle().Foreground(styles.SubText0)
-
- // Process each line
- for _, line := range lines {
- if strings.HasPrefix(line, "+") {
- formattedLines = append(formattedLines, addStyle.Render(line))
- } else if strings.HasPrefix(line, "-") {
- formattedLines = append(formattedLines, removeStyle.Render(line))
- } else if strings.HasPrefix(line, "Changes:") || strings.HasPrefix(line, " ...") {
- formattedLines = append(formattedLines, headerStyle.Render(line))
- } else if strings.HasPrefix(line, " ") {
- formattedLines = append(formattedLines, contextStyle.Render(line))
- } else {
- formattedLines = append(formattedLines, line)
- }
- }
-
- // Join all formatted lines
- return strings.Join(formattedLines, "\n")
+ diffCache map[string]string
+ markdownCache map[string]string
}
func (p *permissionDialogCmp) Init() tea.Cmd {
- return nil
+ return p.contentViewPort.Init()
}
func (p *permissionDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
@@ -106,373 +98,363 @@ func (p *permissionDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
p.windowSize = msg
+ p.SetSize()
+ p.markdownCache = make(map[string]string)
+ p.diffCache = make(map[string]string)
case tea.KeyMsg:
- if key.Matches(msg, keyMapValue.ChangeFocus) {
- p.isViewportFocus = !p.isViewportFocus
- if p.isViewportFocus {
- p.selectOption.Blur()
- // Add a visual indicator for focus change
- cmds = append(cmds, tea.Batch(
- util.ReportInfo("Viewing content - use arrow keys to scroll"),
- ))
- } else {
- p.selectOption.Focus()
- // Add a visual indicator for focus change
- cmds = append(cmds, tea.Batch(
- util.CmdHandler(util.ReportInfo("Select an action")),
- ))
- }
- return p, tea.Batch(cmds...)
- }
- }
-
- if p.isViewportFocus {
- viewPort, cmd := p.contentViewPort.Update(msg)
- p.contentViewPort = viewPort
- cmds = append(cmds, cmd)
- } else {
- form, cmd := p.form.Update(msg)
- if f, ok := form.(*huh.Form); ok {
- p.form = f
+ switch {
+ case key.Matches(msg, permissionsKeys.LeftRight) || key.Matches(msg, permissionsKeys.Tab):
+ // Change selected option
+ p.selectedOption = (p.selectedOption + 1) % 3
+ return p, nil
+ case key.Matches(msg, permissionsKeys.EnterSpace):
+ // Select current option
+ return p, p.selectCurrentOption()
+ case key.Matches(msg, permissionsKeys.Allow):
+ // Select Allow
+ return p, util.CmdHandler(PermissionResponseMsg{Action: PermissionAllow, Permission: p.permission})
+ case key.Matches(msg, permissionsKeys.AllowSession):
+ // Select Allow for session
+ return p, util.CmdHandler(PermissionResponseMsg{Action: PermissionAllowForSession, Permission: p.permission})
+ case key.Matches(msg, permissionsKeys.Deny):
+ // Select Deny
+ return p, util.CmdHandler(PermissionResponseMsg{Action: PermissionDeny, Permission: p.permission})
+ default:
+ // Pass other keys to viewport
+ viewPort, cmd := p.contentViewPort.Update(msg)
+ p.contentViewPort = viewPort
cmds = append(cmds, cmd)
}
-
- if p.form.State == huh.StateCompleted {
- // Get the selected action
- action := p.form.GetString("action")
-
- // Close the dialog and return the response
- return p, tea.Batch(
- util.CmdHandler(core.DialogCloseMsg{}),
- util.CmdHandler(PermissionResponseMsg{Action: PermissionAction(action), Permission: p.permission}),
- )
- }
}
+
return p, tea.Batch(cmds...)
}
-func (p *permissionDialogCmp) render() string {
- keyStyle := lipgloss.NewStyle().Bold(true).Foreground(styles.Rosewater)
- valueStyle := lipgloss.NewStyle().Foreground(styles.Peach)
+func (p *permissionDialogCmp) selectCurrentOption() tea.Cmd {
+ var action PermissionAction
- form := p.form.View()
-
- headerParts := []string{
- lipgloss.JoinHorizontal(lipgloss.Left, keyStyle.Render("Tool:"), " ", valueStyle.Render(p.permission.ToolName)),
- " ",
- lipgloss.JoinHorizontal(lipgloss.Left, keyStyle.Render("Path:"), " ", valueStyle.Render(p.permission.Path)),
- " ",
+ switch p.selectedOption {
+ case 0:
+ action = PermissionAllow
+ case 1:
+ action = PermissionAllowForSession
+ case 2:
+ action = PermissionDeny
}
- // Create the header content first so it can be used in all cases
- headerContent := lipgloss.NewStyle().Padding(0, 1).Render(lipgloss.JoinVertical(lipgloss.Left, headerParts...))
-
- r, _ := glamour.NewTermRenderer(
- glamour.WithStyles(styles.CatppuccinMarkdownStyle()),
- glamour.WithWordWrap(p.width-10),
- glamour.WithEmoji(),
- )
-
- // Handle different tool types
- switch p.permission.ToolName {
- case tools.BashToolName:
- pr := p.permission.Params.(tools.BashPermissionsParams)
- headerParts = append(headerParts, keyStyle.Render("Command:"))
- content := fmt.Sprintf("```bash\n%s\n```", pr.Command)
-
- renderedContent, _ := r.Render(content)
- p.contentViewPort.Width = p.width - 2 - 2
-
- // Calculate content height dynamically based on content
- contentLines := len(strings.Split(renderedContent, "\n"))
- // Set a reasonable min/max for the viewport height
- minContentHeight := 3
- maxContentHeight := p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1
-
- // Add some padding to the content lines
- contentHeight := contentLines + 2
- contentHeight = max(contentHeight, minContentHeight)
- contentHeight = min(contentHeight, maxContentHeight)
- p.contentViewPort.Height = contentHeight
-
- p.contentViewPort.SetContent(renderedContent)
+ return util.CmdHandler(PermissionResponseMsg{Action: action, Permission: p.permission})
+}
- // Style the viewport
- var contentBorder lipgloss.Border
- var borderColor lipgloss.TerminalColor
+func (p *permissionDialogCmp) renderButtons() string {
+ allowStyle := styles.BaseStyle
+ allowSessionStyle := styles.BaseStyle
+ denyStyle := styles.BaseStyle
+ spacerStyle := styles.BaseStyle.Background(styles.Background)
+
+ // Style the selected button
+ switch p.selectedOption {
+ case 0:
+ allowStyle = allowStyle.Background(styles.PrimaryColor).Foreground(styles.Background)
+ allowSessionStyle = allowSessionStyle.Background(styles.Background).Foreground(styles.PrimaryColor)
+ denyStyle = denyStyle.Background(styles.Background).Foreground(styles.PrimaryColor)
+ case 1:
+ allowStyle = allowStyle.Background(styles.Background).Foreground(styles.PrimaryColor)
+ allowSessionStyle = allowSessionStyle.Background(styles.PrimaryColor).Foreground(styles.Background)
+ denyStyle = denyStyle.Background(styles.Background).Foreground(styles.PrimaryColor)
+ case 2:
+ allowStyle = allowStyle.Background(styles.Background).Foreground(styles.PrimaryColor)
+ allowSessionStyle = allowSessionStyle.Background(styles.Background).Foreground(styles.PrimaryColor)
+ denyStyle = denyStyle.Background(styles.PrimaryColor).Foreground(styles.Background)
+ }
- if p.isViewportFocus {
- contentBorder = lipgloss.DoubleBorder()
- borderColor = styles.Blue
- } else {
- contentBorder = lipgloss.RoundedBorder()
- borderColor = styles.Flamingo
- }
+ allowButton := allowStyle.Padding(0, 1).Render("Allow (a)")
+ allowSessionButton := allowSessionStyle.Padding(0, 1).Render("Allow for session (A)")
+ denyButton := denyStyle.Padding(0, 1).Render("Deny (d)")
+
+ content := lipgloss.JoinHorizontal(
+ lipgloss.Left,
+ allowButton,
+ spacerStyle.Render(" "),
+ allowSessionButton,
+ spacerStyle.Render(" "),
+ denyButton,
+ spacerStyle.Render(" "),
+ )
- contentStyle := lipgloss.NewStyle().
- MarginTop(1).
- Padding(0, 1).
- Border(contentBorder).
- BorderForeground(borderColor)
+ remainingWidth := p.width - lipgloss.Width(content)
+ if remainingWidth > 0 {
+ content = spacerStyle.Render(strings.Repeat(" ", remainingWidth)) + content
+ }
+ return content
+}
- if p.isViewportFocus {
- contentStyle = contentStyle.BorderBackground(styles.Surface0)
- }
+func (p *permissionDialogCmp) renderHeader() string {
+ toolKey := styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("Tool")
+ toolValue := styles.BaseStyle.
+ Foreground(styles.Forground).
+ Width(p.width - lipgloss.Width(toolKey)).
+ Render(fmt.Sprintf(": %s", p.permission.ToolName))
- contentFinal := contentStyle.Render(p.contentViewPort.View())
+ pathKey := styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("Path")
+ pathValue := styles.BaseStyle.
+ Foreground(styles.Forground).
+ Width(p.width - lipgloss.Width(pathKey)).
+ Render(fmt.Sprintf(": %s", p.permission.Path))
- return lipgloss.JoinVertical(
- lipgloss.Top,
- headerContent,
- contentFinal,
- form,
- )
+ headerParts := []string{
+ lipgloss.JoinHorizontal(
+ lipgloss.Left,
+ toolKey,
+ toolValue,
+ ),
+ styles.BaseStyle.Render(strings.Repeat(" ", p.width)),
+ lipgloss.JoinHorizontal(
+ lipgloss.Left,
+ pathKey,
+ pathValue,
+ ),
+ styles.BaseStyle.Render(strings.Repeat(" ", p.width)),
+ }
+ // Add tool-specific header information
+ switch p.permission.ToolName {
+ case tools.BashToolName:
+ headerParts = append(headerParts, styles.BaseStyle.Foreground(styles.ForgroundDim).Width(p.width).Bold(true).Render("Command"))
case tools.EditToolName:
- pr := p.permission.Params.(tools.EditPermissionsParams)
- headerParts = append(headerParts, keyStyle.Render("Update"))
- // Recreate header content with the updated headerParts
- headerContent = lipgloss.NewStyle().Padding(0, 1).Render(lipgloss.JoinVertical(lipgloss.Left, headerParts...))
-
- // Format the diff with colors
-
- // Set up viewport for the diff content
- p.contentViewPort.Width = p.width - 2 - 2
-
- // Calculate content height dynamically based on window size
- maxContentHeight := p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1
- p.contentViewPort.Height = maxContentHeight
- diff, err := diff.FormatDiff(pr.Diff, diff.WithTotalWidth(p.contentViewPort.Width))
- if err != nil {
- diff = fmt.Sprintf("Error formatting diff: %v", err)
- }
- p.contentViewPort.SetContent(diff)
+ headerParts = append(headerParts, styles.BaseStyle.Foreground(styles.ForgroundDim).Width(p.width).Bold(true).Render("Diff"))
+ case tools.WriteToolName:
+ headerParts = append(headerParts, styles.BaseStyle.Foreground(styles.ForgroundDim).Width(p.width).Bold(true).Render("Diff"))
+ case tools.FetchToolName:
+ headerParts = append(headerParts, styles.BaseStyle.Foreground(styles.ForgroundDim).Width(p.width).Bold(true).Render("URL"))
+ }
- // Style the viewport
- var contentBorder lipgloss.Border
- var borderColor lipgloss.TerminalColor
+ return lipgloss.NewStyle().Render(lipgloss.JoinVertical(lipgloss.Left, headerParts...))
+}
- if p.isViewportFocus {
- contentBorder = lipgloss.DoubleBorder()
- borderColor = styles.Blue
- } else {
- contentBorder = lipgloss.RoundedBorder()
- borderColor = styles.Flamingo
- }
+func (p *permissionDialogCmp) renderBashContent() string {
+ if pr, ok := p.permission.Params.(tools.BashPermissionsParams); ok {
+ content := fmt.Sprintf("```bash\n%s\n```", pr.Command)
- contentStyle := lipgloss.NewStyle().
- MarginTop(1).
- Padding(0, 1).
- Border(contentBorder).
- BorderForeground(borderColor)
+ // Use the cache for markdown rendering
+ renderedContent := p.GetOrSetMarkdown(p.permission.ID, func() (string, error) {
+ r, _ := glamour.NewTermRenderer(
+ glamour.WithStyles(styles.MarkdownTheme(true)),
+ glamour.WithWordWrap(p.width-10),
+ )
+ s, err := r.Render(content)
+ return styles.ForceReplaceBackgroundWithLipgloss(s, styles.Background), err
+ })
+
+ finalContent := styles.BaseStyle.
+ Width(p.contentViewPort.Width).
+ Render(renderedContent)
+ p.contentViewPort.SetContent(finalContent)
+ return p.styleViewport()
+ }
+ return ""
+}
- if p.isViewportFocus {
- contentStyle = contentStyle.BorderBackground(styles.Surface0)
- }
+func (p *permissionDialogCmp) renderEditContent() string {
+ if pr, ok := p.permission.Params.(tools.EditPermissionsParams); ok {
+ diff := p.GetOrSetDiff(p.permission.ID, func() (string, error) {
+ return diff.FormatDiff(pr.Diff, diff.WithTotalWidth(p.contentViewPort.Width))
+ })
- contentFinal := contentStyle.Render(p.contentViewPort.View())
+ p.contentViewPort.SetContent(diff)
+ return p.styleViewport()
+ }
+ return ""
+}
- return lipgloss.JoinVertical(
- lipgloss.Top,
- headerContent,
- contentFinal,
- form,
- )
+func (p *permissionDialogCmp) renderWriteContent() string {
+ if pr, ok := p.permission.Params.(tools.WritePermissionsParams); ok {
+ // Use the cache for diff rendering
+ diff := p.GetOrSetDiff(p.permission.ID, func() (string, error) {
+ return diff.FormatDiff(pr.Diff, diff.WithTotalWidth(p.contentViewPort.Width))
+ })
- case tools.WriteToolName:
- pr := p.permission.Params.(tools.WritePermissionsParams)
- headerParts = append(headerParts, keyStyle.Render("Content"))
- // Recreate header content with the updated headerParts
- headerContent = lipgloss.NewStyle().Padding(0, 1).Render(lipgloss.JoinVertical(lipgloss.Left, headerParts...))
-
- // Set up viewport for the content
- p.contentViewPort.Width = p.width - 2 - 2
-
- // Calculate content height dynamically based on window size
- maxContentHeight := p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1
- p.contentViewPort.Height = maxContentHeight
- diff, err := diff.FormatDiff(pr.Diff, diff.WithTotalWidth(p.contentViewPort.Width))
- if err != nil {
- diff = fmt.Sprintf("Error formatting diff: %v", err)
- }
p.contentViewPort.SetContent(diff)
+ return p.styleViewport()
+ }
+ return ""
+}
- // Style the viewport
- var contentBorder lipgloss.Border
- var borderColor lipgloss.TerminalColor
+func (p *permissionDialogCmp) renderFetchContent() string {
+ if pr, ok := p.permission.Params.(tools.FetchPermissionsParams); ok {
+ content := fmt.Sprintf("```bash\n%s\n```", pr.URL)
- if p.isViewportFocus {
- contentBorder = lipgloss.DoubleBorder()
- borderColor = styles.Blue
- } else {
- contentBorder = lipgloss.RoundedBorder()
- borderColor = styles.Flamingo
- }
-
- contentStyle := lipgloss.NewStyle().
- MarginTop(1).
- Padding(0, 1).
- Border(contentBorder).
- BorderForeground(borderColor)
+ // Use the cache for markdown rendering
+ renderedContent := p.GetOrSetMarkdown(p.permission.ID, func() (string, error) {
+ r, _ := glamour.NewTermRenderer(
+ glamour.WithStyles(styles.MarkdownTheme(true)),
+ glamour.WithWordWrap(p.width-10),
+ )
+ s, err := r.Render(content)
+ return styles.ForceReplaceBackgroundWithLipgloss(s, styles.Background), err
+ })
- if p.isViewportFocus {
- contentStyle = contentStyle.BorderBackground(styles.Surface0)
- }
+ p.contentViewPort.SetContent(renderedContent)
+ return p.styleViewport()
+ }
+ return ""
+}
- contentFinal := contentStyle.Render(p.contentViewPort.View())
+func (p *permissionDialogCmp) renderDefaultContent() string {
+ content := p.permission.Description
- return lipgloss.JoinVertical(
- lipgloss.Top,
- headerContent,
- contentFinal,
- form,
+ // Use the cache for markdown rendering
+ renderedContent := p.GetOrSetMarkdown(p.permission.ID, func() (string, error) {
+ r, _ := glamour.NewTermRenderer(
+ glamour.WithStyles(styles.CatppuccinMarkdownStyle()),
+ glamour.WithWordWrap(p.width-10),
)
+ s, err := r.Render(content)
+ return styles.ForceReplaceBackgroundWithLipgloss(s, styles.Background), err
+ })
- case tools.FetchToolName:
- pr := p.permission.Params.(tools.FetchPermissionsParams)
- headerParts = append(headerParts, keyStyle.Render("URL: "+pr.URL))
- content := p.permission.Description
+ p.contentViewPort.SetContent(renderedContent)
- renderedContent, _ := r.Render(content)
- p.contentViewPort.Width = p.width - 2 - 2
- p.contentViewPort.Height = p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1
- p.contentViewPort.SetContent(renderedContent)
+ if renderedContent == "" {
+ return ""
+ }
- // Style the viewport
- contentStyle := lipgloss.NewStyle().
- MarginTop(1).
- Padding(0, 1).
- Border(lipgloss.RoundedBorder()).
- BorderForeground(styles.Flamingo)
+ return p.styleViewport()
+}
- contentFinal := contentStyle.Render(p.contentViewPort.View())
- if renderedContent == "" {
- contentFinal = ""
- }
+func (p *permissionDialogCmp) styleViewport() string {
+ contentStyle := lipgloss.NewStyle().
+ Background(styles.Background)
- return lipgloss.JoinVertical(
- lipgloss.Top,
- headerContent,
- contentFinal,
- form,
- )
+ return contentStyle.Render(p.contentViewPort.View())
+}
+func (p *permissionDialogCmp) render() string {
+ title := styles.BaseStyle.
+ Bold(true).
+ Width(p.width - 4).
+ Foreground(styles.PrimaryColor).
+ Render("Permission Required")
+ // Render header
+ headerContent := p.renderHeader()
+ // Render buttons
+ buttons := p.renderButtons()
+
+ // Calculate content height dynamically based on window size
+ p.contentViewPort.Height = p.height - lipgloss.Height(headerContent) - lipgloss.Height(buttons) - 2 - lipgloss.Height(title)
+ p.contentViewPort.Width = p.width - 4
+
+ // Render content based on tool type
+ var contentFinal string
+ switch p.permission.ToolName {
+ case tools.BashToolName:
+ contentFinal = p.renderBashContent()
+ case tools.EditToolName:
+ contentFinal = p.renderEditContent()
+ case tools.WriteToolName:
+ contentFinal = p.renderWriteContent()
+ case tools.FetchToolName:
+ contentFinal = p.renderFetchContent()
default:
- content := p.permission.Description
-
- renderedContent, _ := r.Render(content)
- p.contentViewPort.Width = p.width - 2 - 2
- p.contentViewPort.Height = p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1
- p.contentViewPort.SetContent(renderedContent)
-
- // Style the viewport
- contentStyle := lipgloss.NewStyle().
- MarginTop(1).
- Padding(0, 1).
- Border(lipgloss.RoundedBorder()).
- BorderForeground(styles.Flamingo)
+ contentFinal = p.renderDefaultContent()
+ }
- contentFinal := contentStyle.Render(p.contentViewPort.View())
- if renderedContent == "" {
- contentFinal = ""
- }
+ content := lipgloss.JoinVertical(
+ lipgloss.Top,
+ title,
+ styles.BaseStyle.Render(strings.Repeat(" ", lipgloss.Width(title))),
+ headerContent,
+ contentFinal,
+ buttons,
+ )
- return lipgloss.JoinVertical(
- lipgloss.Top,
- headerContent,
- contentFinal,
- form,
+ return styles.BaseStyle.
+ Padding(1, 0, 0, 1).
+ Border(lipgloss.RoundedBorder()).
+ BorderBackground(styles.Background).
+ BorderForeground(styles.ForgroundDim).
+ Width(p.width).
+ Height(p.height).
+ Render(
+ content,
)
- }
}
func (p *permissionDialogCmp) View() string {
return p.render()
}
-func (p *permissionDialogCmp) GetSize() (int, int) {
- return p.width, p.height
+func (p *permissionDialogCmp) BindingKeys() []key.Binding {
+ return layout.KeyMapToSlice(helpKeys)
}
-func (p *permissionDialogCmp) SetSize(width int, height int) {
- p.width = width
- p.height = height
- p.form = p.form.WithWidth(width)
+func (p *permissionDialogCmp) SetSize() {
+ if p.permission.ID == "" {
+ return
+ }
+ switch p.permission.ToolName {
+ case tools.BashToolName:
+ p.width = int(float64(p.windowSize.Width) * 0.4)
+ p.height = int(float64(p.windowSize.Height) * 0.3)
+ case tools.EditToolName:
+ p.width = int(float64(p.windowSize.Width) * 0.8)
+ p.height = int(float64(p.windowSize.Height) * 0.8)
+ case tools.WriteToolName:
+ p.width = int(float64(p.windowSize.Width) * 0.8)
+ p.height = int(float64(p.windowSize.Height) * 0.8)
+ case tools.FetchToolName:
+ p.width = int(float64(p.windowSize.Width) * 0.4)
+ p.height = int(float64(p.windowSize.Height) * 0.3)
+ default:
+ p.width = int(float64(p.windowSize.Width) * 0.7)
+ p.height = int(float64(p.windowSize.Height) * 0.5)
+ }
}
-func (p *permissionDialogCmp) BindingKeys() []key.Binding {
- return p.form.KeyBinds()
+func (p *permissionDialogCmp) SetPermissions(permission permission.PermissionRequest) {
+ p.permission = permission
+ p.SetSize()
}
-func newPermissionDialogCmp(permission permission.PermissionRequest) PermissionDialog {
- // Create a note field for displaying the content
+// Helper to get or set cached diff content
+func (c *permissionDialogCmp) GetOrSetDiff(key string, generator func() (string, error)) string {
+ if cached, ok := c.diffCache[key]; ok {
+ return cached
+ }
- // Create select field for the permission options
- selectOption := huh.NewSelect[string]().
- Key("action").
- Options(
- huh.NewOption("Allow", string(PermissionAllow)),
- huh.NewOption("Allow for this session", string(PermissionAllowForSession)),
- huh.NewOption("Deny", string(PermissionDeny)),
- ).
- Title("Select an action")
+ content, err := generator()
+ if err != nil {
+ return fmt.Sprintf("Error formatting diff: %v", err)
+ }
- // Apply theme
- theme := styles.HuhTheme()
+ c.diffCache[key] = content
- // Setup form width and height
- form := huh.NewForm(huh.NewGroup(selectOption)).
- WithShowHelp(false).
- WithTheme(theme).
- WithShowErrors(false)
+ return content
+}
- // Focus the form for immediate interaction
- selectOption.Focus()
+// Helper to get or set cached markdown content
+func (c *permissionDialogCmp) GetOrSetMarkdown(key string, generator func() (string, error)) string {
+ if cached, ok := c.markdownCache[key]; ok {
+ return cached
+ }
- return &permissionDialogCmp{
- permission: permission,
- form: form,
- selectOption: selectOption,
+ content, err := generator()
+ if err != nil {
+ return fmt.Sprintf("Error rendering markdown: %v", err)
}
-}
-// NewPermissionDialogCmd creates a new permission dialog command
-func NewPermissionDialogCmd(permission permission.PermissionRequest) tea.Cmd {
- permDialog := newPermissionDialogCmp(permission)
-
- // Create the dialog layout
- dialogPane := layout.NewSinglePane(
- permDialog.(*permissionDialogCmp),
- layout.WithSinglePaneBordered(true),
- layout.WithSinglePaneFocusable(true),
- layout.WithSinglePaneActiveColor(styles.Warning),
- layout.WithSinglePaneBorderText(map[layout.BorderPosition]string{
- layout.TopMiddleBorder: " Permission Required ",
- }),
- )
+ c.markdownCache[key] = content
- // Focus the dialog
- dialogPane.Focus()
- widthRatio := 0.7
- heightRatio := 0.6
- minWidth := 100
- minHeight := 30
+ return content
+}
- // Make the dialog size more appropriate for different tools
- switch permission.ToolName {
- case tools.BashToolName:
- // For bash commands, use a more compact dialog
- widthRatio = 0.7
- heightRatio = 0.4 // Reduced from 0.5
- minWidth = 100
- minHeight = 20 // Reduced from 30
+func NewPermissionDialogCmp() PermissionDialogCmp {
+ // Create viewport for content
+ contentViewport := viewport.New(0, 0)
+
+ return &permissionDialogCmp{
+ contentViewPort: contentViewport,
+ selectedOption: 0, // Default to "Allow"
+ diffCache: make(map[string]string),
+ markdownCache: make(map[string]string),
}
- // Return the dialog command
- return util.CmdHandler(core.DialogMsg{
- Content: dialogPane,
- WidthRatio: widthRatio,
- HeightRatio: heightRatio,
- MinWidth: minWidth,
- MinHeight: minHeight,
- })
}
@@ -1,28 +1,58 @@
package dialog
import (
+ "strings"
+
"github.com/charmbracelet/bubbles/key"
tea "github.com/charmbracelet/bubbletea"
- "github.com/kujtimiihoxha/termai/internal/tui/components/core"
+ "github.com/charmbracelet/lipgloss"
"github.com/kujtimiihoxha/termai/internal/tui/layout"
"github.com/kujtimiihoxha/termai/internal/tui/styles"
"github.com/kujtimiihoxha/termai/internal/tui/util"
-
- "github.com/charmbracelet/huh"
)
const question = "Are you sure you want to quit?"
+type CloseQuitMsg struct{}
+
type QuitDialog interface {
tea.Model
- layout.Sizeable
layout.Bindings
}
type quitDialogCmp struct {
- form *huh.Form
- width int
- height int
+ selectedNo bool
+}
+
+type helpMapping struct {
+ LeftRight key.Binding
+ EnterSpace key.Binding
+ Yes key.Binding
+ No key.Binding
+ Tab key.Binding
+}
+
+var helpKeys = helpMapping{
+ LeftRight: key.NewBinding(
+ key.WithKeys("left", "right"),
+ key.WithHelp("←/→", "switch options"),
+ ),
+ EnterSpace: key.NewBinding(
+ key.WithKeys("enter", " "),
+ key.WithHelp("enter/space", "confirm"),
+ ),
+ Yes: key.NewBinding(
+ key.WithKeys("y", "Y"),
+ key.WithHelp("y/Y", "yes"),
+ ),
+ No: key.NewBinding(
+ key.WithKeys("n", "N"),
+ key.WithHelp("n/N", "no"),
+ ),
+ Tab: key.NewBinding(
+ key.WithKeys("tab"),
+ key.WithHelp("tab", "switch options"),
+ ),
}
func (q *quitDialogCmp) Init() tea.Cmd {
@@ -30,77 +60,73 @@ func (q *quitDialogCmp) Init() tea.Cmd {
}
func (q *quitDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
- var cmds []tea.Cmd
- form, cmd := q.form.Update(msg)
- if f, ok := form.(*huh.Form); ok {
- q.form = f
- cmds = append(cmds, cmd)
- }
-
- if q.form.State == huh.StateCompleted {
- v := q.form.GetBool("quit")
- if v {
+ switch msg := msg.(type) {
+ case tea.KeyMsg:
+ switch {
+ case key.Matches(msg, helpKeys.LeftRight) || key.Matches(msg, helpKeys.Tab):
+ q.selectedNo = !q.selectedNo
+ return q, nil
+ case key.Matches(msg, helpKeys.EnterSpace):
+ if !q.selectedNo {
+ return q, tea.Quit
+ }
+ return q, util.CmdHandler(CloseQuitMsg{})
+ case key.Matches(msg, helpKeys.Yes):
return q, tea.Quit
+ case key.Matches(msg, helpKeys.No):
+ return q, util.CmdHandler(CloseQuitMsg{})
}
- cmds = append(cmds, util.CmdHandler(core.DialogCloseMsg{}))
}
-
- return q, tea.Batch(cmds...)
+ return q, nil
}
func (q *quitDialogCmp) View() string {
- return q.form.View()
-}
+ yesStyle := styles.BaseStyle
+ noStyle := styles.BaseStyle
+ spacerStyle := styles.BaseStyle.Background(styles.Background)
+
+ if q.selectedNo {
+ noStyle = noStyle.Background(styles.PrimaryColor).Foreground(styles.Background)
+ yesStyle = yesStyle.Background(styles.Background).Foreground(styles.PrimaryColor)
+ } else {
+ yesStyle = yesStyle.Background(styles.PrimaryColor).Foreground(styles.Background)
+ noStyle = noStyle.Background(styles.Background).Foreground(styles.PrimaryColor)
+ }
-func (q *quitDialogCmp) GetSize() (int, int) {
- return q.width, q.height
-}
+ yesButton := yesStyle.Padding(0, 1).Render("Yes")
+ noButton := noStyle.Padding(0, 1).Render("No")
+
+ buttons := lipgloss.JoinHorizontal(lipgloss.Left, yesButton, spacerStyle.Render(" "), noButton)
+
+ width := lipgloss.Width(question)
+ remainingWidth := width - lipgloss.Width(buttons)
+ if remainingWidth > 0 {
+ buttons = spacerStyle.Render(strings.Repeat(" ", remainingWidth)) + buttons
+ }
-func (q *quitDialogCmp) SetSize(width int, height int) {
- q.width = width
- q.height = height
- q.form = q.form.WithWidth(width).WithHeight(height)
+ content := styles.BaseStyle.Render(
+ lipgloss.JoinVertical(
+ lipgloss.Center,
+ question,
+ "",
+ buttons,
+ ),
+ )
+
+ return styles.BaseStyle.Padding(1, 2).
+ Border(lipgloss.RoundedBorder()).
+ BorderBackground(styles.Background).
+ BorderForeground(styles.ForgroundDim).
+ Width(lipgloss.Width(content) + 4).
+ Render(content)
}
func (q *quitDialogCmp) BindingKeys() []key.Binding {
- return q.form.KeyBinds()
+ return layout.KeyMapToSlice(helpKeys)
}
-func newQuitDialogCmp() QuitDialog {
- confirm := huh.NewConfirm().
- Title(question).
- Affirmative("Yes!").
- Key("quit").
- Negative("No.")
-
- theme := styles.HuhTheme()
- theme.Focused.FocusedButton = theme.Focused.FocusedButton.Background(styles.Warning)
- theme.Blurred.FocusedButton = theme.Blurred.FocusedButton.Background(styles.Warning)
- form := huh.NewForm(huh.NewGroup(confirm)).
- WithShowHelp(false).
- WithWidth(0).
- WithHeight(0).
- WithTheme(theme).
- WithShowErrors(false)
- confirm.Focus()
+func NewQuitCmp() QuitDialog {
return &quitDialogCmp{
- form: form,
+ selectedNo: true,
}
}
-
-func NewQuitDialogCmd() tea.Cmd {
- content := layout.NewSinglePane(
- newQuitDialogCmp().(*quitDialogCmp),
- layout.WithSinglePaneBordered(true),
- layout.WithSinglePaneFocusable(true),
- layout.WithSinglePaneActiveColor(styles.Warning),
- )
- content.Focus()
- return util.CmdHandler(core.DialogMsg{
- Content: content,
- WidthRatio: 0.2,
- HeightRatio: 0.1,
- MinWidth: 40,
- MinHeight: 5,
- })
-}
@@ -16,10 +16,8 @@ import (
type DetailComponent interface {
tea.Model
- layout.Focusable
layout.Sizeable
layout.Bindings
- layout.Bordered
}
type detailCmp struct {
@@ -16,22 +16,14 @@ import (
type TableComponent interface {
tea.Model
- layout.Focusable
layout.Sizeable
layout.Bindings
- layout.Bordered
}
type tableCmp struct {
table table.Model
}
-func (i *tableCmp) BorderText() map[layout.BorderPosition]string {
- return map[layout.BorderPosition]string{
- layout.TopLeftBorder: "Logs",
- }
-}
-
type selectedLogMsg logging.LogMessage
func (i *tableCmp) Init() tea.Cmd {
@@ -74,20 +66,6 @@ func (i *tableCmp) View() string {
return i.table.View()
}
-func (i *tableCmp) Blur() tea.Cmd {
- i.table.Blur()
- return nil
-}
-
-func (i *tableCmp) Focus() tea.Cmd {
- i.table.Focus()
- return nil
-}
-
-func (i *tableCmp) IsFocused() bool {
- return i.table.Focused()
-}
-
func (i *tableCmp) GetSize() (int, int) {
return i.table.Width(), i.table.Height()
}
@@ -1,201 +0,0 @@
-package repl
-
-import (
- "strings"
-
- "github.com/charmbracelet/bubbles/key"
- tea "github.com/charmbracelet/bubbletea"
- "github.com/charmbracelet/lipgloss"
- "github.com/kujtimiihoxha/termai/internal/app"
- "github.com/kujtimiihoxha/termai/internal/tui/layout"
- "github.com/kujtimiihoxha/termai/internal/tui/styles"
- "github.com/kujtimiihoxha/termai/internal/tui/util"
- "github.com/kujtimiihoxha/vimtea"
- "golang.org/x/net/context"
-)
-
-type EditorCmp interface {
- tea.Model
- layout.Focusable
- layout.Sizeable
- layout.Bordered
- layout.Bindings
-}
-
-type editorCmp struct {
- app *app.App
- editor vimtea.Editor
- editorMode vimtea.EditorMode
- sessionID string
- focused bool
- width int
- height int
- cancelMessage context.CancelFunc
-}
-
-type editorKeyMap struct {
- SendMessage key.Binding
- SendMessageI key.Binding
- CancelMessage key.Binding
- InsertMode key.Binding
- NormaMode key.Binding
- VisualMode key.Binding
- VisualLineMode key.Binding
-}
-
-var editorKeyMapValue = editorKeyMap{
- SendMessage: key.NewBinding(
- key.WithKeys("enter"),
- key.WithHelp("enter", "send message normal mode"),
- ),
- SendMessageI: key.NewBinding(
- key.WithKeys("ctrl+s"),
- key.WithHelp("ctrl+s", "send message insert mode"),
- ),
- CancelMessage: key.NewBinding(
- key.WithKeys("ctrl+x"),
- key.WithHelp("ctrl+x", "cancel current message"),
- ),
- InsertMode: key.NewBinding(
- key.WithKeys("i"),
- key.WithHelp("i", "insert mode"),
- ),
- NormaMode: key.NewBinding(
- key.WithKeys("esc"),
- key.WithHelp("esc", "normal mode"),
- ),
- VisualMode: key.NewBinding(
- key.WithKeys("v"),
- key.WithHelp("v", "visual mode"),
- ),
- VisualLineMode: key.NewBinding(
- key.WithKeys("V"),
- key.WithHelp("V", "visual line mode"),
- ),
-}
-
-func (m *editorCmp) Init() tea.Cmd {
- return m.editor.Init()
-}
-
-func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
- switch msg := msg.(type) {
- case vimtea.EditorModeMsg:
- m.editorMode = msg.Mode
- case SelectedSessionMsg:
- if msg.SessionID != m.sessionID {
- m.sessionID = msg.SessionID
- }
- }
- if m.IsFocused() {
- switch msg := msg.(type) {
- case tea.KeyMsg:
- switch {
- case key.Matches(msg, editorKeyMapValue.SendMessage):
- if m.editorMode == vimtea.ModeNormal {
- return m, m.Send()
- }
- case key.Matches(msg, editorKeyMapValue.SendMessageI):
- if m.editorMode == vimtea.ModeInsert {
- return m, m.Send()
- }
- case key.Matches(msg, editorKeyMapValue.CancelMessage):
- return m, m.Cancel()
- }
- }
- u, cmd := m.editor.Update(msg)
- m.editor = u.(vimtea.Editor)
- return m, cmd
- }
- return m, nil
-}
-
-func (m *editorCmp) Blur() tea.Cmd {
- m.focused = false
- return nil
-}
-
-func (m *editorCmp) BorderText() map[layout.BorderPosition]string {
- title := "New Message"
- if m.focused {
- title = lipgloss.NewStyle().Foreground(styles.Primary).Render(title)
- }
- return map[layout.BorderPosition]string{
- layout.BottomLeftBorder: title,
- }
-}
-
-func (m *editorCmp) Focus() tea.Cmd {
- m.focused = true
- return m.editor.Tick()
-}
-
-func (m *editorCmp) GetSize() (int, int) {
- return m.width, m.height
-}
-
-func (m *editorCmp) IsFocused() bool {
- return m.focused
-}
-
-func (m *editorCmp) SetSize(width int, height int) {
- m.width = width
- m.height = height
- m.editor.SetSize(width, height)
-}
-
-func (m *editorCmp) Cancel() tea.Cmd {
- if m.cancelMessage == nil {
- return util.ReportWarn("No message to cancel")
- }
-
- m.cancelMessage()
- m.cancelMessage = nil
- return util.ReportWarn("Message cancelled")
-}
-
-func (m *editorCmp) Send() tea.Cmd {
- if m.cancelMessage != nil {
- return util.ReportWarn("Assistant is still working on the previous message")
- }
-
- messages, err := m.app.Messages.List(context.Background(), m.sessionID)
- if err != nil {
- return util.ReportError(err)
- }
- if hasUnfinishedMessages(messages) {
- return util.ReportWarn("Assistant is still working on the previous message")
- }
-
- content := strings.Join(m.editor.GetBuffer().Lines(), "\n")
- if len(content) == 0 {
- return util.ReportWarn("Message is empty")
- }
- ctx, cancel := context.WithCancel(context.Background())
- m.cancelMessage = cancel
- go func() {
- defer cancel()
- m.app.CoderAgent.Generate(ctx, m.sessionID, content)
- m.cancelMessage = nil
- }()
-
- return m.editor.Reset()
-}
-
-func (m *editorCmp) View() string {
- return m.editor.View()
-}
-
-func (m *editorCmp) BindingKeys() []key.Binding {
- return layout.KeyMapToSlice(editorKeyMapValue)
-}
-
-func NewEditorCmp(app *app.App) EditorCmp {
- editor := vimtea.NewEditor(
- vimtea.WithFileName("message.md"),
- )
- return &editorCmp{
- app: app,
- editor: editor,
- }
-}
@@ -1,513 +0,0 @@
-package repl
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "sort"
- "strings"
-
- "github.com/charmbracelet/bubbles/key"
- "github.com/charmbracelet/bubbles/viewport"
- tea "github.com/charmbracelet/bubbletea"
- "github.com/charmbracelet/glamour"
- "github.com/charmbracelet/lipgloss"
- "github.com/kujtimiihoxha/termai/internal/app"
- "github.com/kujtimiihoxha/termai/internal/llm/agent"
- "github.com/kujtimiihoxha/termai/internal/lsp/protocol"
- "github.com/kujtimiihoxha/termai/internal/message"
- "github.com/kujtimiihoxha/termai/internal/pubsub"
- "github.com/kujtimiihoxha/termai/internal/session"
- "github.com/kujtimiihoxha/termai/internal/tui/layout"
- "github.com/kujtimiihoxha/termai/internal/tui/styles"
-)
-
-type MessagesCmp interface {
- tea.Model
- layout.Focusable
- layout.Bordered
- layout.Sizeable
- layout.Bindings
-}
-
-type messagesCmp struct {
- app *app.App
- messages []message.Message
- selectedMsgIdx int // Index of the selected message
- session session.Session
- viewport viewport.Model
- mdRenderer *glamour.TermRenderer
- width int
- height int
- focused bool
- cachedView string
-}
-
-func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
- switch msg := msg.(type) {
- case pubsub.Event[message.Message]:
- if msg.Type == pubsub.CreatedEvent {
- if msg.Payload.SessionID == m.session.ID {
- m.messages = append(m.messages, msg.Payload)
- m.renderView()
- m.viewport.GotoBottom()
- }
- for _, v := range m.messages {
- for _, c := range v.ToolCalls() {
- // the message is being added to the session of a tool called
- if c.ID == msg.Payload.SessionID {
- m.renderView()
- m.viewport.GotoBottom()
- }
- }
- }
- } else if msg.Type == pubsub.UpdatedEvent && msg.Payload.SessionID == m.session.ID {
- for i, v := range m.messages {
- if v.ID == msg.Payload.ID {
- m.messages[i] = msg.Payload
- m.renderView()
- if i == len(m.messages)-1 {
- m.viewport.GotoBottom()
- }
- break
- }
- }
- }
- case pubsub.Event[session.Session]:
- if msg.Type == pubsub.UpdatedEvent && m.session.ID == msg.Payload.ID {
- m.session = msg.Payload
- }
- case SelectedSessionMsg:
- m.session, _ = m.app.Sessions.Get(context.Background(), msg.SessionID)
- m.messages, _ = m.app.Messages.List(context.Background(), m.session.ID)
- m.renderView()
- m.viewport.GotoBottom()
- }
- if m.focused {
- u, cmd := m.viewport.Update(msg)
- m.viewport = u
- return m, cmd
- }
- return m, nil
-}
-
-func borderColor(role message.MessageRole) lipgloss.TerminalColor {
- switch role {
- case message.Assistant:
- return styles.Mauve
- case message.User:
- return styles.Rosewater
- }
- return styles.Blue
-}
-
-func borderText(msgRole message.MessageRole, currentMessage int) map[layout.BorderPosition]string {
- role := ""
- icon := ""
- switch msgRole {
- case message.Assistant:
- role = "Assistant"
- icon = styles.BotIcon
- case message.User:
- role = "User"
- icon = styles.UserIcon
- }
- return map[layout.BorderPosition]string{
- layout.TopLeftBorder: lipgloss.NewStyle().
- Padding(0, 1).
- Bold(true).
- Foreground(styles.Crust).
- Background(borderColor(msgRole)).
- Render(fmt.Sprintf("%s %s ", role, icon)),
- layout.TopRightBorder: lipgloss.NewStyle().
- Padding(0, 1).
- Bold(true).
- Foreground(styles.Crust).
- Background(borderColor(msgRole)).
- Render(fmt.Sprintf("#%d ", currentMessage)),
- }
-}
-
-func hasUnfinishedMessages(messages []message.Message) bool {
- if len(messages) == 0 {
- return false
- }
- for _, msg := range messages {
- if !msg.IsFinished() {
- return true
- }
- }
- return false
-}
-
-func (m *messagesCmp) renderMessageWithToolCall(content string, tools []message.ToolCall, futureMessages []message.Message) string {
- allParts := []string{content}
-
- leftPaddingValue := 4
- connectorStyle := lipgloss.NewStyle().
- Foreground(styles.Peach).
- Bold(true)
-
- toolCallStyle := lipgloss.NewStyle().
- Border(lipgloss.RoundedBorder()).
- BorderForeground(styles.Peach).
- Width(m.width-leftPaddingValue-5).
- Padding(0, 1)
-
- toolResultStyle := lipgloss.NewStyle().
- Border(lipgloss.RoundedBorder()).
- BorderForeground(styles.Green).
- Width(m.width-leftPaddingValue-5).
- Padding(0, 1)
-
- leftPadding := lipgloss.NewStyle().Padding(0, 0, 0, leftPaddingValue)
-
- runningStyle := lipgloss.NewStyle().
- Foreground(styles.Peach).
- Bold(true)
-
- renderTool := func(toolCall message.ToolCall) string {
- toolHeader := lipgloss.NewStyle().
- Bold(true).
- Foreground(styles.Blue).
- Render(fmt.Sprintf("%s %s", styles.ToolIcon, toolCall.Name))
-
- var paramLines []string
- var args map[string]interface{}
- var paramOrder []string
-
- json.Unmarshal([]byte(toolCall.Input), &args)
-
- for key := range args {
- paramOrder = append(paramOrder, key)
- }
- sort.Strings(paramOrder)
-
- for _, name := range paramOrder {
- value := args[name]
- paramName := lipgloss.NewStyle().
- Foreground(styles.Peach).
- Bold(true).
- Render(name)
-
- truncate := m.width - leftPaddingValue*2 - 10
- if len(fmt.Sprintf("%v", value)) > truncate {
- value = fmt.Sprintf("%v", value)[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)")
- }
- paramValue := fmt.Sprintf("%v", value)
- paramLines = append(paramLines, fmt.Sprintf(" %s: %s", paramName, paramValue))
- }
-
- paramBlock := lipgloss.JoinVertical(lipgloss.Left, paramLines...)
-
- toolContent := lipgloss.JoinVertical(lipgloss.Left, toolHeader, paramBlock)
- return toolCallStyle.Render(toolContent)
- }
-
- findToolResult := func(toolCallID string, messages []message.Message) *message.ToolResult {
- for _, msg := range messages {
- if msg.Role == message.Tool {
- for _, result := range msg.ToolResults() {
- if result.ToolCallID == toolCallID {
- return &result
- }
- }
- }
- }
- return nil
- }
-
- renderToolResult := func(result message.ToolResult) string {
- resultHeader := lipgloss.NewStyle().
- Bold(true).
- Foreground(styles.Green).
- Render(fmt.Sprintf("%s Result", styles.CheckIcon))
-
- // Use the same style for both header and border if it's an error
- borderColor := styles.Green
- if result.IsError {
- resultHeader = lipgloss.NewStyle().
- Bold(true).
- Foreground(styles.Red).
- Render(fmt.Sprintf("%s Error", styles.ErrorIcon))
- borderColor = styles.Red
- }
-
- truncate := 200
- content := result.Content
- if len(content) > truncate {
- content = content[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)")
- }
-
- resultContent := lipgloss.JoinVertical(lipgloss.Left, resultHeader, content)
- return toolResultStyle.BorderForeground(borderColor).Render(resultContent)
- }
-
- connector := connectorStyle.Render("└─> Tool Calls:")
- allParts = append(allParts, connector)
-
- for _, toolCall := range tools {
- toolOutput := renderTool(toolCall)
- allParts = append(allParts, leftPadding.Render(toolOutput))
-
- result := findToolResult(toolCall.ID, futureMessages)
- if result != nil {
-
- resultOutput := renderToolResult(*result)
- allParts = append(allParts, leftPadding.Render(resultOutput))
-
- } else if toolCall.Name == agent.AgentToolName {
-
- runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon))
- allParts = append(allParts, leftPadding.Render(runningIndicator))
- taskSessionMessages, _ := m.app.Messages.List(context.Background(), toolCall.ID)
- for _, msg := range taskSessionMessages {
- if msg.Role == message.Assistant {
- for _, toolCall := range msg.ToolCalls() {
- toolHeader := lipgloss.NewStyle().
- Bold(true).
- Foreground(styles.Blue).
- Render(fmt.Sprintf("%s %s", styles.ToolIcon, toolCall.Name))
-
- var paramLines []string
- var args map[string]interface{}
- var paramOrder []string
-
- json.Unmarshal([]byte(toolCall.Input), &args)
-
- for key := range args {
- paramOrder = append(paramOrder, key)
- }
- sort.Strings(paramOrder)
-
- for _, name := range paramOrder {
- value := args[name]
- paramName := lipgloss.NewStyle().
- Foreground(styles.Peach).
- Bold(true).
- Render(name)
-
- truncate := 50
- if len(fmt.Sprintf("%v", value)) > truncate {
- value = fmt.Sprintf("%v", value)[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)")
- }
- paramValue := fmt.Sprintf("%v", value)
- paramLines = append(paramLines, fmt.Sprintf(" %s: %s", paramName, paramValue))
- }
-
- paramBlock := lipgloss.JoinVertical(lipgloss.Left, paramLines...)
- toolContent := lipgloss.JoinVertical(lipgloss.Left, toolHeader, paramBlock)
- toolOutput := toolCallStyle.BorderForeground(styles.Teal).MaxWidth(m.width - leftPaddingValue*2 - 2).Render(toolContent)
- allParts = append(allParts, lipgloss.NewStyle().Padding(0, 0, 0, leftPaddingValue*2).Render(toolOutput))
- }
- }
- }
-
- } else {
- runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon))
- allParts = append(allParts, " "+runningIndicator)
- }
- }
-
- for _, msg := range futureMessages {
- if msg.Content().String() != "" || msg.FinishReason() == "canceled" {
- break
- }
-
- for _, toolCall := range msg.ToolCalls() {
- toolOutput := renderTool(toolCall)
- allParts = append(allParts, " "+strings.ReplaceAll(toolOutput, "\n", "\n "))
-
- result := findToolResult(toolCall.ID, futureMessages)
- if result != nil {
- resultOutput := renderToolResult(*result)
- allParts = append(allParts, " "+strings.ReplaceAll(resultOutput, "\n", "\n "))
- } else {
- runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon))
- allParts = append(allParts, " "+runningIndicator)
- }
- }
- }
-
- return lipgloss.JoinVertical(lipgloss.Left, allParts...)
-}
-
-func (m *messagesCmp) renderView() {
- stringMessages := make([]string, 0)
- r, _ := glamour.NewTermRenderer(
- glamour.WithStyles(styles.CatppuccinMarkdownStyle()),
- glamour.WithWordWrap(m.width-20),
- glamour.WithEmoji(),
- )
- textStyle := lipgloss.NewStyle().Width(m.width - 4)
- currentMessage := 1
- displayedMsgCount := 0 // Track the actual displayed messages count
-
- prevMessageWasUser := false
- for inx, msg := range m.messages {
- content := msg.Content().String()
- if content != "" || prevMessageWasUser || msg.FinishReason() == "canceled" {
- if msg.ReasoningContent().String() != "" && content == "" {
- content = msg.ReasoningContent().String()
- } else if content == "" {
- content = "..."
- }
- if msg.FinishReason() == "canceled" {
- content, _ = r.Render(content)
- content += lipgloss.NewStyle().Padding(1, 0, 0, 1).Foreground(styles.Error).Render(styles.ErrorIcon + " Canceled")
- } else {
- content, _ = r.Render(content)
- }
-
- isSelected := inx == m.selectedMsgIdx
-
- border := lipgloss.DoubleBorder()
- activeColor := borderColor(msg.Role)
-
- if isSelected {
- activeColor = styles.Primary // Use primary color for selected message
- }
-
- content = layout.Borderize(
- textStyle.Render(content),
- layout.BorderOptions{
- InactiveBorder: border,
- ActiveBorder: border,
- ActiveColor: activeColor,
- InactiveColor: borderColor(msg.Role),
- EmbeddedText: borderText(msg.Role, currentMessage),
- },
- )
- if len(msg.ToolCalls()) > 0 {
- content = m.renderMessageWithToolCall(content, msg.ToolCalls(), m.messages[inx+1:])
- }
- stringMessages = append(stringMessages, content)
- currentMessage++
- displayedMsgCount++
- }
- if msg.Role == message.User && msg.Content().String() != "" {
- prevMessageWasUser = true
- } else {
- prevMessageWasUser = false
- }
- }
- m.viewport.SetContent(lipgloss.JoinVertical(lipgloss.Top, stringMessages...))
-}
-
-func (m *messagesCmp) View() string {
- return lipgloss.NewStyle().Padding(1).Render(m.viewport.View())
-}
-
-func (m *messagesCmp) BindingKeys() []key.Binding {
- keys := layout.KeyMapToSlice(m.viewport.KeyMap)
-
- return keys
-}
-
-func (m *messagesCmp) Blur() tea.Cmd {
- m.focused = false
- return nil
-}
-
-func (m *messagesCmp) projectDiagnostics() string {
- errorDiagnostics := []protocol.Diagnostic{}
- warnDiagnostics := []protocol.Diagnostic{}
- hintDiagnostics := []protocol.Diagnostic{}
- infoDiagnostics := []protocol.Diagnostic{}
- for _, client := range m.app.LSPClients {
- for _, d := range client.GetDiagnostics() {
- for _, diag := range d {
- switch diag.Severity {
- case protocol.SeverityError:
- errorDiagnostics = append(errorDiagnostics, diag)
- case protocol.SeverityWarning:
- warnDiagnostics = append(warnDiagnostics, diag)
- case protocol.SeverityHint:
- hintDiagnostics = append(hintDiagnostics, diag)
- case protocol.SeverityInformation:
- infoDiagnostics = append(infoDiagnostics, diag)
- }
- }
- }
- }
-
- if len(errorDiagnostics) == 0 && len(warnDiagnostics) == 0 && len(hintDiagnostics) == 0 && len(infoDiagnostics) == 0 {
- return "No diagnostics"
- }
-
- diagnostics := []string{}
-
- if len(errorDiagnostics) > 0 {
- errStr := lipgloss.NewStyle().Foreground(styles.Error).Render(fmt.Sprintf("%s %d", styles.ErrorIcon, len(errorDiagnostics)))
- diagnostics = append(diagnostics, errStr)
- }
- if len(warnDiagnostics) > 0 {
- warnStr := lipgloss.NewStyle().Foreground(styles.Warning).Render(fmt.Sprintf("%s %d", styles.WarningIcon, len(warnDiagnostics)))
- diagnostics = append(diagnostics, warnStr)
- }
- if len(hintDiagnostics) > 0 {
- hintStr := lipgloss.NewStyle().Foreground(styles.Text).Render(fmt.Sprintf("%s %d", styles.HintIcon, len(hintDiagnostics)))
- diagnostics = append(diagnostics, hintStr)
- }
- if len(infoDiagnostics) > 0 {
- infoStr := lipgloss.NewStyle().Foreground(styles.Peach).Render(fmt.Sprintf("%s %d", styles.InfoIcon, len(infoDiagnostics)))
- diagnostics = append(diagnostics, infoStr)
- }
-
- return strings.Join(diagnostics, " ")
-}
-
-func (m *messagesCmp) BorderText() map[layout.BorderPosition]string {
- title := m.session.Title
- titleWidth := m.width / 2
- if len(title) > titleWidth {
- title = title[:titleWidth] + "..."
- }
- if m.focused {
- title = lipgloss.NewStyle().Foreground(styles.Primary).Render(title)
- }
- borderTest := map[layout.BorderPosition]string{
- layout.TopLeftBorder: title,
- layout.BottomRightBorder: m.projectDiagnostics(),
- }
- if hasUnfinishedMessages(m.messages) {
- borderTest[layout.BottomLeftBorder] = lipgloss.NewStyle().Foreground(styles.Peach).Render("Thinking...")
- } else {
- borderTest[layout.BottomLeftBorder] = lipgloss.NewStyle().Foreground(styles.Text).Render("Sleeping " + styles.SleepIcon + " ")
- }
-
- return borderTest
-}
-
-func (m *messagesCmp) Focus() tea.Cmd {
- m.focused = true
- return nil
-}
-
-func (m *messagesCmp) GetSize() (int, int) {
- return m.width, m.height
-}
-
-func (m *messagesCmp) IsFocused() bool {
- return m.focused
-}
-
-func (m *messagesCmp) SetSize(width int, height int) {
- m.width = width
- m.height = height
- m.viewport.Width = width - 2 // padding
- m.viewport.Height = height - 2 // padding
- m.renderView()
-}
-
-func (m *messagesCmp) Init() tea.Cmd {
- return nil
-}
-
-func NewMessagesCmp(app *app.App) MessagesCmp {
- return &messagesCmp{
- app: app,
- messages: []message.Message{},
- viewport: viewport.New(0, 0),
- }
-}
@@ -1,249 +0,0 @@
-package repl
-
-import (
- "context"
- "fmt"
- "strings"
-
- "github.com/charmbracelet/bubbles/key"
- "github.com/charmbracelet/bubbles/list"
- tea "github.com/charmbracelet/bubbletea"
- "github.com/charmbracelet/lipgloss"
- "github.com/kujtimiihoxha/termai/internal/app"
- "github.com/kujtimiihoxha/termai/internal/pubsub"
- "github.com/kujtimiihoxha/termai/internal/session"
- "github.com/kujtimiihoxha/termai/internal/tui/layout"
- "github.com/kujtimiihoxha/termai/internal/tui/styles"
- "github.com/kujtimiihoxha/termai/internal/tui/util"
-)
-
-type SessionsCmp interface {
- tea.Model
- layout.Sizeable
- layout.Focusable
- layout.Bordered
- layout.Bindings
-}
-type sessionsCmp struct {
- app *app.App
- list list.Model
- focused bool
-}
-
-type listItem struct {
- id, title, desc string
-}
-
-func (i listItem) Title() string { return i.title }
-func (i listItem) Description() string { return i.desc }
-func (i listItem) FilterValue() string { return i.title }
-
-type InsertSessionsMsg struct {
- sessions []session.Session
-}
-
-type SelectedSessionMsg struct {
- SessionID string
-}
-
-type sessionsKeyMap struct {
- Select key.Binding
-}
-
-var sessionKeyMapValue = sessionsKeyMap{
- Select: key.NewBinding(
- key.WithKeys("enter", " "),
- key.WithHelp("enter/space", "select session"),
- ),
-}
-
-func (i *sessionsCmp) Init() tea.Cmd {
- existing, err := i.app.Sessions.List(context.Background())
- if err != nil {
- return util.ReportError(err)
- }
- if len(existing) == 0 || existing[0].MessageCount > 0 {
- newSession, err := i.app.Sessions.Create(
- context.Background(),
- "New Session",
- )
- if err != nil {
- return util.ReportError(err)
- }
- existing = append([]session.Session{newSession}, existing...)
- }
- return tea.Batch(
- util.CmdHandler(InsertSessionsMsg{existing}),
- util.CmdHandler(SelectedSessionMsg{existing[0].ID}),
- )
-}
-
-func (i *sessionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
- switch msg := msg.(type) {
- case InsertSessionsMsg:
- items := make([]list.Item, len(msg.sessions))
- for i, s := range msg.sessions {
- items[i] = listItem{
- id: s.ID,
- title: s.Title,
- desc: formatTokensAndCost(s.PromptTokens+s.CompletionTokens, s.Cost),
- }
- }
- return i, i.list.SetItems(items)
- case pubsub.Event[session.Session]:
- if msg.Type == pubsub.CreatedEvent && msg.Payload.ParentSessionID == "" {
- // Check if the session is already in the list
- items := i.list.Items()
- for _, item := range items {
- s := item.(listItem)
- if s.id == msg.Payload.ID {
- return i, nil
- }
- }
- // insert the new session at the top of the list
- items = append([]list.Item{listItem{
- id: msg.Payload.ID,
- title: msg.Payload.Title,
- desc: formatTokensAndCost(msg.Payload.PromptTokens+msg.Payload.CompletionTokens, msg.Payload.Cost),
- }}, items...)
- return i, i.list.SetItems(items)
- } else if msg.Type == pubsub.UpdatedEvent {
- // update the session in the list
- items := i.list.Items()
- for idx, item := range items {
- s := item.(listItem)
- if s.id == msg.Payload.ID {
- s.title = msg.Payload.Title
- s.desc = formatTokensAndCost(msg.Payload.PromptTokens+msg.Payload.CompletionTokens, msg.Payload.Cost)
- items[idx] = s
- break
- }
- }
- return i, i.list.SetItems(items)
- }
-
- case tea.KeyMsg:
- switch {
- case key.Matches(msg, sessionKeyMapValue.Select):
- selected := i.list.SelectedItem()
- if selected == nil {
- return i, nil
- }
- return i, util.CmdHandler(SelectedSessionMsg{selected.(listItem).id})
- }
- }
- if i.focused {
- u, cmd := i.list.Update(msg)
- i.list = u
- return i, cmd
- }
- return i, nil
-}
-
-func (i *sessionsCmp) View() string {
- return i.list.View()
-}
-
-func (i *sessionsCmp) Blur() tea.Cmd {
- i.focused = false
- return nil
-}
-
-func (i *sessionsCmp) Focus() tea.Cmd {
- i.focused = true
- return nil
-}
-
-func (i *sessionsCmp) GetSize() (int, int) {
- return i.list.Width(), i.list.Height()
-}
-
-func (i *sessionsCmp) IsFocused() bool {
- return i.focused
-}
-
-func (i *sessionsCmp) SetSize(width int, height int) {
- i.list.SetSize(width, height)
-}
-
-func (i *sessionsCmp) BorderText() map[layout.BorderPosition]string {
- totalCount := len(i.list.Items())
- itemsPerPage := i.list.Paginator.PerPage
- currentPage := i.list.Paginator.Page
-
- current := min(currentPage*itemsPerPage+itemsPerPage, totalCount)
-
- pageInfo := fmt.Sprintf(
- "%d-%d of %d",
- currentPage*itemsPerPage+1,
- current,
- totalCount,
- )
-
- title := "Sessions"
- if i.focused {
- title = lipgloss.NewStyle().Foreground(styles.Primary).Render(title)
- }
- return map[layout.BorderPosition]string{
- layout.TopMiddleBorder: title,
- layout.BottomMiddleBorder: pageInfo,
- }
-}
-
-func (i *sessionsCmp) BindingKeys() []key.Binding {
- return append(layout.KeyMapToSlice(i.list.KeyMap), sessionKeyMapValue.Select)
-}
-
-func formatTokensAndCost(tokens int64, cost float64) string {
- // Format tokens in human-readable format (e.g., 110K, 1.2M)
- var formattedTokens string
- switch {
- case tokens >= 1_000_000:
- formattedTokens = fmt.Sprintf("%.1fM", float64(tokens)/1_000_000)
- case tokens >= 1_000:
- formattedTokens = fmt.Sprintf("%.1fK", float64(tokens)/1_000)
- default:
- formattedTokens = fmt.Sprintf("%d", tokens)
- }
-
- // Remove .0 suffix if present
- if strings.HasSuffix(formattedTokens, ".0K") {
- formattedTokens = strings.Replace(formattedTokens, ".0K", "K", 1)
- }
- if strings.HasSuffix(formattedTokens, ".0M") {
- formattedTokens = strings.Replace(formattedTokens, ".0M", "M", 1)
- }
-
- // Format cost with $ symbol and 2 decimal places
- formattedCost := fmt.Sprintf("$%.2f", cost)
-
- return fmt.Sprintf("Tokens: %s, Cost: %s", formattedTokens, formattedCost)
-}
-
-func NewSessionsCmp(app *app.App) SessionsCmp {
- listDelegate := list.NewDefaultDelegate()
- defaultItemStyle := list.NewDefaultItemStyles()
- defaultItemStyle.SelectedTitle = defaultItemStyle.SelectedTitle.BorderForeground(styles.Secondary).Foreground(styles.Primary)
- defaultItemStyle.SelectedDesc = defaultItemStyle.SelectedDesc.BorderForeground(styles.Secondary).Foreground(styles.Primary)
-
- defaultStyle := list.DefaultStyles()
- defaultStyle.FilterPrompt = defaultStyle.FilterPrompt.Foreground(styles.Secondary)
- defaultStyle.FilterCursor = defaultStyle.FilterCursor.Foreground(styles.Flamingo)
-
- listDelegate.Styles = defaultItemStyle
-
- listComponent := list.New([]list.Item{}, listDelegate, 0, 0)
- listComponent.FilterInput.PromptStyle = defaultStyle.FilterPrompt
- listComponent.FilterInput.Cursor.Style = defaultStyle.FilterCursor
- listComponent.SetShowTitle(false)
- listComponent.SetShowPagination(false)
- listComponent.SetShowHelp(false)
- listComponent.SetShowStatusBar(false)
- listComponent.DisableQuitKeybindings()
-
- return &sessionsCmp{
- app: app,
- list: listComponent,
- focused: false,
- }
-}
@@ -5,6 +5,7 @@ import (
"strings"
"github.com/charmbracelet/lipgloss"
+ "github.com/kujtimiihoxha/termai/internal/tui/styles"
"github.com/kujtimiihoxha/termai/internal/tui/util"
"github.com/mattn/go-runewidth"
"github.com/muesli/ansi"
@@ -45,13 +46,15 @@ func PlaceOverlay(
if shadow {
var shadowbg string = ""
shadowchar := lipgloss.NewStyle().
- Foreground(lipgloss.Color("#333333")).
+ Background(styles.BackgroundDarker).
+ Foreground(styles.Background).
Render("░")
+ bgchar := styles.BaseStyle.Render(" ")
for i := 0; i <= fgHeight; i++ {
if i == 0 {
- shadowbg += " " + strings.Repeat(" ", fgWidth) + "\n"
+ shadowbg += bgchar + strings.Repeat(bgchar, fgWidth) + "\n"
} else {
- shadowbg += " " + strings.Repeat(shadowchar, fgWidth) + "\n"
+ shadowbg += bgchar + strings.Repeat(shadowchar, fgWidth) + "\n"
}
}
@@ -159,8 +162,6 @@ func max(a, b int) int {
return b
}
-
-
type whitespace struct {
style termenv.Style
chars string
@@ -10,6 +10,7 @@ import (
type SplitPaneLayout interface {
tea.Model
Sizeable
+ Bindings
SetLeftPanel(panel Container)
SetRightPanel(panel Container)
SetBottomPanel(panel Container)
@@ -37,7 +37,6 @@ var keyMap = ChatKeyMap{
}
func (p *chatPage) Init() tea.Cmd {
- // TODO: remove
cmds := []tea.Cmd{
p.layout.Init(),
}
@@ -48,9 +47,7 @@ func (p *chatPage) Init() tea.Cmd {
cmd := p.setSidebar()
cmds = append(cmds, util.CmdHandler(chat.SessionSelectedMsg(p.session)), cmd)
}
- return tea.Batch(
- cmds...,
- )
+ return tea.Batch(cmds...)
}
func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
@@ -68,6 +65,13 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
p.session = session.Session{}
p.clearSidebar()
return p, util.CmdHandler(chat.SessionClearedMsg{})
+ case key.Matches(msg, keyMap.Cancel):
+ if p.session.ID != "" {
+ // Cancel the current session's generation process
+ // This allows users to interrupt long-running operations
+ p.app.CoderAgent.Cancel(p.session.ID)
+ return p, nil
+ }
}
}
u, cmd := p.layout.Update(msg)
@@ -80,7 +84,7 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
func (p *chatPage) setSidebar() tea.Cmd {
sidebarContainer := layout.NewContainer(
- chat.NewSidebarCmp(p.session),
+ chat.NewSidebarCmp(p.session, p.app.History),
layout.WithPadding(1, 1, 1, 1),
)
p.layout.SetRightPanel(sidebarContainer)
@@ -111,14 +115,28 @@ func (p *chatPage) sendMessage(text string) tea.Cmd {
cmds = append(cmds, util.CmdHandler(chat.SessionSelectedMsg(session)))
}
- p.app.CoderAgent.Generate(context.Background(), p.session.ID, text)
+ p.app.CoderAgent.Run(context.Background(), p.session.ID, text)
return tea.Batch(cmds...)
}
+func (p *chatPage) SetSize(width, height int) {
+ p.layout.SetSize(width, height)
+}
+
+func (p *chatPage) GetSize() (int, int) {
+ return p.layout.GetSize()
+}
+
func (p *chatPage) View() string {
return p.layout.View()
}
+func (p *chatPage) BindingKeys() []key.Binding {
+ bindings := layout.KeyMapToSlice(keyMap)
+ bindings = append(bindings, p.layout.BindingKeys()...)
+ return bindings
+}
+
func NewChatPage(app *app.App) tea.Model {
messagesContainer := layout.NewContainer(
chat.NewMessagesCmp(app),
@@ -126,7 +144,7 @@ func NewChatPage(app *app.App) tea.Model {
)
editorContainer := layout.NewContainer(
- chat.NewEditorCmp(),
+ chat.NewEditorCmp(app),
layout.WithBorder(true, false, false, false),
)
return &chatPage{
@@ -1,308 +0,0 @@
-package page
-
-import (
- "fmt"
- "os"
- "path/filepath"
- "strconv"
-
- "github.com/charmbracelet/bubbles/key"
- tea "github.com/charmbracelet/bubbletea"
- "github.com/charmbracelet/huh"
- "github.com/charmbracelet/lipgloss"
- "github.com/kujtimiihoxha/termai/internal/llm/models"
- "github.com/kujtimiihoxha/termai/internal/tui/layout"
- "github.com/kujtimiihoxha/termai/internal/tui/styles"
- "github.com/kujtimiihoxha/termai/internal/tui/util"
- "github.com/spf13/viper"
-)
-
-var InitPage PageID = "init"
-
-type configSaved struct{}
-
-type initPage struct {
- form *huh.Form
- width int
- height int
- saved bool
- errorMsg string
- statusMsg string
- modelOpts []huh.Option[string]
- bigModel string
- smallModel string
- openAIKey string
- anthropicKey string
- groqKey string
- maxTokens string
- dataDir string
- agent string
-}
-
-func (i *initPage) Init() tea.Cmd {
- return i.form.Init()
-}
-
-func (i *initPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
- var cmds []tea.Cmd
-
- switch msg := msg.(type) {
- case tea.WindowSizeMsg:
- i.width = msg.Width - 4 // Account for border
- i.height = msg.Height - 4
- i.form = i.form.WithWidth(i.width).WithHeight(i.height)
- return i, nil
-
- case configSaved:
- i.saved = true
- i.statusMsg = "Configuration saved successfully. Press any key to continue."
- return i, nil
- }
-
- if i.saved {
- switch msg.(type) {
- case tea.KeyMsg:
- return i, util.CmdHandler(PageChangeMsg{ID: ReplPage})
- }
- return i, nil
- }
-
- // Process the form
- form, cmd := i.form.Update(msg)
- if f, ok := form.(*huh.Form); ok {
- i.form = f
- cmds = append(cmds, cmd)
- }
-
- if i.form.State == huh.StateCompleted {
- // Save configuration to file
- configPath := filepath.Join(os.Getenv("HOME"), ".termai.yaml")
- maxTokens, _ := strconv.Atoi(i.maxTokens)
- config := map[string]any{
- "models": map[string]string{
- "big": i.bigModel,
- "small": i.smallModel,
- },
- "providers": map[string]any{
- "openai": map[string]string{
- "key": i.openAIKey,
- },
- "anthropic": map[string]string{
- "key": i.anthropicKey,
- },
- "groq": map[string]string{
- "key": i.groqKey,
- },
- "common": map[string]int{
- "max_tokens": maxTokens,
- },
- },
- "data": map[string]string{
- "dir": i.dataDir,
- },
- "agents": map[string]string{
- "default": i.agent,
- },
- "log": map[string]string{
- "level": "info",
- },
- }
-
- // Write config to viper
- for k, v := range config {
- viper.Set(k, v)
- }
-
- // Save configuration
- err := viper.WriteConfigAs(configPath)
- if err != nil {
- i.errorMsg = fmt.Sprintf("Failed to save configuration: %s", err)
- return i, nil
- }
-
- // Return to main page
- return i, util.CmdHandler(configSaved{})
- }
-
- return i, tea.Batch(cmds...)
-}
-
-func (i *initPage) View() string {
- if i.saved {
- return lipgloss.NewStyle().
- Width(i.width).
- Height(i.height).
- Align(lipgloss.Center, lipgloss.Center).
- Render(lipgloss.JoinVertical(
- lipgloss.Center,
- lipgloss.NewStyle().Foreground(styles.Green).Render("✓ Configuration Saved"),
- "",
- lipgloss.NewStyle().Foreground(styles.Blue).Render(i.statusMsg),
- ))
- }
-
- view := i.form.View()
- if i.errorMsg != "" {
- errorBox := lipgloss.NewStyle().
- Padding(1).
- Border(lipgloss.RoundedBorder()).
- BorderForeground(styles.Red).
- Width(i.width - 4).
- Render(i.errorMsg)
- view = lipgloss.JoinVertical(lipgloss.Left, errorBox, view)
- }
- return view
-}
-
-func (i *initPage) GetSize() (int, int) {
- return i.width, i.height
-}
-
-func (i *initPage) SetSize(width int, height int) {
- i.width = width
- i.height = height
- i.form = i.form.WithWidth(width).WithHeight(height)
-}
-
-func (i *initPage) BindingKeys() []key.Binding {
- if i.saved {
- return []key.Binding{
- key.NewBinding(
- key.WithKeys("enter", "space", "esc"),
- key.WithHelp("any key", "continue"),
- ),
- }
- }
- return i.form.KeyBinds()
-}
-
-func NewInitPage() tea.Model {
- // Create model options
- var modelOpts []huh.Option[string]
- for id, model := range models.SupportedModels {
- modelOpts = append(modelOpts, huh.NewOption(model.Name, string(id)))
- }
-
- // Create agent options
- agentOpts := []huh.Option[string]{
- huh.NewOption("Coder", "coder"),
- huh.NewOption("Assistant", "assistant"),
- }
-
- // Init page with form
- initModel := &initPage{
- modelOpts: modelOpts,
- bigModel: string(models.Claude37Sonnet),
- smallModel: string(models.Claude37Sonnet),
- maxTokens: "4000",
- dataDir: ".termai",
- agent: "coder",
- }
-
- // API Keys group
- apiKeysGroup := huh.NewGroup(
- huh.NewNote().
- Title("API Keys").
- Description("You need to provide at least one API key to use termai"),
-
- huh.NewInput().
- Title("OpenAI API Key").
- Placeholder("sk-...").
- Key("openai_key").
- Value(&initModel.openAIKey),
-
- huh.NewInput().
- Title("Anthropic API Key").
- Placeholder("sk-ant-...").
- Key("anthropic_key").
- Value(&initModel.anthropicKey),
-
- huh.NewInput().
- Title("Groq API Key").
- Placeholder("gsk_...").
- Key("groq_key").
- Value(&initModel.groqKey),
- )
-
- // Model configuration group
- modelsGroup := huh.NewGroup(
- huh.NewNote().
- Title("Model Configuration").
- Description("Select which models to use"),
-
- huh.NewSelect[string]().
- Title("Big Model").
- Options(modelOpts...).
- Key("big_model").
- Value(&initModel.bigModel),
-
- huh.NewSelect[string]().
- Title("Small Model").
- Options(modelOpts...).
- Key("small_model").
- Value(&initModel.smallModel),
-
- huh.NewInput().
- Title("Max Tokens").
- Placeholder("4000").
- Key("max_tokens").
- CharLimit(5).
- Validate(func(s string) error {
- var n int
- _, err := fmt.Sscanf(s, "%d", &n)
- if err != nil || n <= 0 {
- return fmt.Errorf("must be a positive number")
- }
- initModel.maxTokens = s
- return nil
- }).
- Value(&initModel.maxTokens),
- )
-
- // General settings group
- generalGroup := huh.NewGroup(
- huh.NewNote().
- Title("General Settings").
- Description("Configure general termai settings"),
-
- huh.NewInput().
- Title("Data Directory").
- Placeholder(".termai").
- Key("data_dir").
- Value(&initModel.dataDir),
-
- huh.NewSelect[string]().
- Title("Default Agent").
- Options(agentOpts...).
- Key("agent").
- Value(&initModel.agent),
-
- huh.NewConfirm().
- Title("Save Configuration").
- Affirmative("Save").
- Negative("Cancel"),
- )
-
- // Create form with theme
- form := huh.NewForm(
- apiKeysGroup,
- modelsGroup,
- generalGroup,
- ).WithTheme(styles.HuhTheme()).
- WithShowHelp(true).
- WithShowErrors(true)
-
- // Set the form in the model
- initModel.form = form
-
- return layout.NewSinglePane(
- initModel,
- layout.WithSinglePaneFocusable(true),
- layout.WithSinglePaneBordered(true),
- layout.WithSinglePaneBorderText(
- map[layout.BorderPosition]string{
- layout.TopMiddleBorder: "Welcome to termai - Initial Setup",
- },
- ),
- )
-}
@@ -8,6 +8,23 @@ import (
var LogsPage PageID = "logs"
+type logsPage struct {
+ table logs.TableComponent
+ details logs.DetailComponent
+}
+
+func (p *logsPage) Init() tea.Cmd {
+ return nil
+}
+
+func (p *logsPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
+ return p, nil
+}
+
+func (p *logsPage) View() string {
+ return p.table.View() + "\n" + p.details.View()
+}
+
func NewLogsPage() tea.Model {
return layout.NewBentoLayout(
layout.BentoPanes{
@@ -1,21 +0,0 @@
-package page
-
-import (
- tea "github.com/charmbracelet/bubbletea"
- "github.com/kujtimiihoxha/termai/internal/app"
- "github.com/kujtimiihoxha/termai/internal/tui/components/repl"
- "github.com/kujtimiihoxha/termai/internal/tui/layout"
-)
-
-var ReplPage PageID = "repl"
-
-func NewReplPage(app *app.App) tea.Model {
- return layout.NewBentoLayout(
- layout.BentoPanes{
- layout.BentoLeftPane: repl.NewSessionsCmp(app),
- layout.BentoRightTopPane: repl.NewMessagesCmp(app),
- layout.BentoRightBottomPane: repl.NewEditorCmp(app),
- },
- layout.WithBentoLayoutCurrentPane(layout.BentoRightBottomPane),
- )
-}
@@ -1,8 +1,6 @@
package tui
import (
- "context"
-
"github.com/charmbracelet/bubbles/key"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
@@ -12,47 +10,41 @@ import (
"github.com/kujtimiihoxha/termai/internal/pubsub"
"github.com/kujtimiihoxha/termai/internal/tui/components/core"
"github.com/kujtimiihoxha/termai/internal/tui/components/dialog"
- "github.com/kujtimiihoxha/termai/internal/tui/components/repl"
"github.com/kujtimiihoxha/termai/internal/tui/layout"
"github.com/kujtimiihoxha/termai/internal/tui/page"
"github.com/kujtimiihoxha/termai/internal/tui/util"
- "github.com/kujtimiihoxha/vimtea"
)
type keyMap struct {
- Logs key.Binding
- Return key.Binding
- Back key.Binding
- Quit key.Binding
- Help key.Binding
+ Logs key.Binding
+ Quit key.Binding
+ Help key.Binding
}
var keys = keyMap{
Logs: key.NewBinding(
- key.WithKeys("L"),
- key.WithHelp("L", "logs"),
- ),
- Return: key.NewBinding(
- key.WithKeys("esc"),
- key.WithHelp("esc", "close"),
- ),
- Back: key.NewBinding(
- key.WithKeys("backspace"),
- key.WithHelp("backspace", "back"),
+ key.WithKeys("ctrl+l"),
+ key.WithHelp("ctrl+L", "logs"),
),
+
Quit: key.NewBinding(
- key.WithKeys("ctrl+c", "q"),
- key.WithHelp("ctrl+c/q", "quit"),
+ key.WithKeys("ctrl+c"),
+ key.WithHelp("ctrl+c", "quit"),
),
Help: key.NewBinding(
- key.WithKeys("?"),
- key.WithHelp("?", "toggle help"),
+ key.WithKeys("ctrl+_"),
+ key.WithHelp("ctrl+?", "toggle help"),
),
}
-var replKeyMap = key.NewBinding(
- key.WithKeys("N"),
- key.WithHelp("N", "new session"),
+var returnKey = key.NewBinding(
+ key.WithKeys("esc"),
+ key.WithHelp("esc", "close"),
+)
+
+var logsKeyReturnKey = key.NewBinding(
+ key.WithKeys("backspace"),
+ key.WithHelp("backspace", "go back"),
)
type appModel struct {
@@ -62,18 +54,30 @@ type appModel struct {
pages map[page.PageID]tea.Model
loadedPages map[page.PageID]bool
status tea.Model
- help core.HelpCmp
- dialog core.DialogCmp
app *app.App
- dialogVisible bool
- editorMode vimtea.EditorMode
- showHelp bool
+
+ showPermissions bool
+ permissions dialog.PermissionDialogCmp
+
+ showHelp bool
+ help dialog.HelpCmp
+
+ showQuit bool
+ quit dialog.QuitDialog
}
func (a appModel) Init() tea.Cmd {
+ var cmds []tea.Cmd
cmd := a.pages[a.currentPage].Init()
a.loadedPages[a.currentPage] = true
- return cmd
+ cmds = append(cmds, cmd)
+ cmd = a.status.Init()
+ cmds = append(cmds, cmd)
+ cmd = a.quit.Init()
+ cmds = append(cmds, cmd)
+ cmd = a.help.Init()
+ cmds = append(cmds, cmd)
+ return tea.Batch(cmds...)
}
func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
@@ -81,22 +85,20 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
switch msg := msg.(type) {
case tea.WindowSizeMsg:
- var cmds []tea.Cmd
msg.Height -= 1 // Make space for the status bar
a.width, a.height = msg.Width, msg.Height
a.status, _ = a.status.Update(msg)
-
- uh, _ := a.help.Update(msg)
- a.help = uh.(core.HelpCmp)
-
- p, cmd := a.pages[a.currentPage].Update(msg)
+ a.pages[a.currentPage], cmd = a.pages[a.currentPage].Update(msg)
cmds = append(cmds, cmd)
- a.pages[a.currentPage] = p
- d, cmd := a.dialog.Update(msg)
- cmds = append(cmds, cmd)
- a.dialog = d.(core.DialogCmp)
+ prm, permCmd := a.permissions.Update(msg)
+ a.permissions = prm.(dialog.PermissionDialogCmp)
+ cmds = append(cmds, permCmd)
+
+ help, helpCmd := a.help.Update(msg)
+ a.help = help.(dialog.HelpCmp)
+ cmds = append(cmds, helpCmd)
return a, tea.Batch(cmds...)
@@ -141,7 +143,9 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// Permission
case pubsub.Event[permission.PermissionRequest]:
- return a, dialog.NewPermissionDialogCmd(msg.Payload)
+ a.showPermissions = true
+ a.permissions.SetPermissions(msg.Payload)
+ return a, nil
case dialog.PermissionResponseMsg:
switch msg.Action {
case dialog.PermissionAllow:
@@ -151,91 +155,71 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case dialog.PermissionDeny:
a.app.Permissions.Deny(msg.Permission)
}
-
- // Dialog
- case core.DialogMsg:
- d, cmd := a.dialog.Update(msg)
- a.dialog = d.(core.DialogCmp)
- a.dialogVisible = true
- return a, cmd
- case core.DialogCloseMsg:
- d, cmd := a.dialog.Update(msg)
- a.dialog = d.(core.DialogCmp)
- a.dialogVisible = false
- return a, cmd
-
- // Editor
- case vimtea.EditorModeMsg:
- a.editorMode = msg.Mode
+ a.showPermissions = false
+ return a, nil
case page.PageChangeMsg:
return a, a.moveToPage(msg.ID)
+
+ case dialog.CloseQuitMsg:
+ a.showQuit = false
+ return a, nil
+
case tea.KeyMsg:
- if a.editorMode == vimtea.ModeNormal {
- switch {
- case key.Matches(msg, keys.Quit):
- return a, dialog.NewQuitDialogCmd()
- case key.Matches(msg, keys.Back):
- if a.previousPage != "" {
- return a, a.moveToPage(a.previousPage)
- }
- case key.Matches(msg, keys.Return):
- if a.showHelp {
- a.ToggleHelp()
- return a, nil
- }
- case key.Matches(msg, replKeyMap):
- if a.currentPage == page.ReplPage {
- sessions, err := a.app.Sessions.List(context.Background())
- if err != nil {
- return a, util.CmdHandler(util.ReportError(err))
- }
- lastSession := sessions[0]
- if lastSession.MessageCount == 0 {
- return a, util.CmdHandler(repl.SelectedSessionMsg{SessionID: lastSession.ID})
- }
- s, err := a.app.Sessions.Create(context.Background(), "New Session")
- if err != nil {
- return a, util.CmdHandler(util.ReportError(err))
- }
- return a, util.CmdHandler(repl.SelectedSessionMsg{SessionID: s.ID})
- }
- // case key.Matches(msg, keys.Logs):
- // return a, a.moveToPage(page.LogsPage)
- case msg.String() == "O":
- return a, a.moveToPage(page.ReplPage)
- case key.Matches(msg, keys.Help):
- a.ToggleHelp()
+ switch {
+ case key.Matches(msg, keys.Quit):
+ a.showQuit = !a.showQuit
+ if a.showHelp {
+ a.showHelp = false
+ }
+ return a, nil
+ case key.Matches(msg, logsKeyReturnKey):
+ if a.currentPage == page.LogsPage {
+ return a, a.moveToPage(page.ChatPage)
+ }
+ case key.Matches(msg, returnKey):
+ if a.showQuit {
+ a.showQuit = !a.showQuit
+ return a, nil
+ }
+ if a.showHelp {
+ a.showHelp = !a.showHelp
+ return a, nil
+ }
+ case key.Matches(msg, keys.Logs):
+ return a, a.moveToPage(page.LogsPage)
+ case key.Matches(msg, keys.Help):
+ if a.showQuit {
return a, nil
}
+ a.showHelp = !a.showHelp
+ return a, nil
}
}
- if a.dialogVisible {
- d, cmd := a.dialog.Update(msg)
- a.dialog = d.(core.DialogCmp)
- cmds = append(cmds, cmd)
- return a, tea.Batch(cmds...)
+ if a.showQuit {
+ q, quitCmd := a.quit.Update(msg)
+ a.quit = q.(dialog.QuitDialog)
+ cmds = append(cmds, quitCmd)
+ // Only block key messages send all other messages down
+ if _, ok := msg.(tea.KeyMsg); ok {
+ return a, tea.Batch(cmds...)
+ }
+ }
+ if a.showPermissions {
+ d, permissionsCmd := a.permissions.Update(msg)
+ a.permissions = d.(dialog.PermissionDialogCmp)
+ cmds = append(cmds, permissionsCmd)
+ // Only block key messages send all other messages down
+ if _, ok := msg.(tea.KeyMsg); ok {
+ return a, tea.Batch(cmds...)
+ }
}
a.pages[a.currentPage], cmd = a.pages[a.currentPage].Update(msg)
cmds = append(cmds, cmd)
return a, tea.Batch(cmds...)
}
-func (a *appModel) ToggleHelp() {
- if a.showHelp {
- a.showHelp = false
- a.height += a.help.Height()
- } else {
- a.showHelp = true
- a.height -= a.help.Height()
- }
-
- if sizable, ok := a.pages[a.currentPage].(layout.Sizeable); ok {
- sizable.SetSize(a.width, a.height)
- }
-}
-
func (a *appModel) moveToPage(pageID page.PageID) tea.Cmd {
var cmd tea.Cmd
if _, ok := a.loadedPages[pageID]; !ok {
@@ -256,27 +240,55 @@ func (a appModel) View() string {
a.pages[a.currentPage].View(),
}
+ components = append(components, a.status.View())
+
+ appView := lipgloss.JoinVertical(lipgloss.Top, components...)
+
+ if a.showPermissions {
+ overlay := a.permissions.View()
+ row := lipgloss.Height(appView) / 2
+ row -= lipgloss.Height(overlay) / 2
+ col := lipgloss.Width(appView) / 2
+ col -= lipgloss.Width(overlay) / 2
+ appView = layout.PlaceOverlay(
+ col,
+ row,
+ overlay,
+ appView,
+ true,
+ )
+ }
+
if a.showHelp {
bindings := layout.KeyMapToSlice(keys)
if p, ok := a.pages[a.currentPage].(layout.Bindings); ok {
bindings = append(bindings, p.BindingKeys()...)
}
- if a.dialogVisible {
- bindings = append(bindings, a.dialog.BindingKeys()...)
+ if a.showPermissions {
+ bindings = append(bindings, a.permissions.BindingKeys()...)
}
- if a.currentPage == page.ReplPage {
- bindings = append(bindings, replKeyMap)
+ if a.currentPage == page.LogsPage {
+ bindings = append(bindings, logsKeyReturnKey)
}
- a.help.SetBindings(bindings)
- components = append(components, a.help.View())
- }
- components = append(components, a.status.View())
+ a.help.SetBindings(bindings)
- appView := lipgloss.JoinVertical(lipgloss.Top, components...)
+ overlay := a.help.View()
+ row := lipgloss.Height(appView) / 2
+ row -= lipgloss.Height(overlay) / 2
+ col := lipgloss.Width(appView) / 2
+ col -= lipgloss.Width(overlay) / 2
+ appView = layout.PlaceOverlay(
+ col,
+ row,
+ overlay,
+ appView,
+ true,
+ )
+ }
- if a.dialogVisible {
- overlay := a.dialog.View()
+ if a.showQuit {
+ overlay := a.quit.View()
row := lipgloss.Height(appView) / 2
row -= lipgloss.Height(overlay) / 2
col := lipgloss.Width(appView) / 2
@@ -289,30 +301,23 @@ func (a appModel) View() string {
true,
)
}
+
return appView
}
func New(app *app.App) tea.Model {
- // homedir, _ := os.UserHomeDir()
- // configPath := filepath.Join(homedir, ".termai.yaml")
- //
startPage := page.ChatPage
- // if _, err := os.Stat(configPath); os.IsNotExist(err) {
- // startPage = page.InitPage
- // }
-
return &appModel{
currentPage: startPage,
loadedPages: make(map[page.PageID]bool),
- status: core.NewStatusCmp(),
- help: core.NewHelpCmp(),
- dialog: core.NewDialogCmp(),
+ status: core.NewStatusCmp(app.LSPClients),
+ help: dialog.NewHelpCmp(),
+ quit: dialog.NewQuitCmp(),
+ permissions: dialog.NewPermissionDialogCmp(),
app: app,
pages: map[page.PageID]tea.Model{
page.ChatPage: page.NewChatPage(app),
page.LogsPage: page.NewLogsPage(),
- page.InitPage: page.NewInitPage(),
- page.ReplPage: page.NewReplPage(app),
},
}
}
@@ -2,8 +2,15 @@ package main
import (
"github.com/kujtimiihoxha/termai/cmd"
+ "github.com/kujtimiihoxha/termai/internal/logging"
)
func main() {
+ // Set up panic recovery for the main function
+ defer logging.RecoverPanic("main", func() {
+ // Perform any necessary cleanup before exit
+ logging.ErrorPersist("Application terminated due to unhandled panic")
+ })
+
cmd.Execute()
}