root.go

  1package cmd
  2
  3import (
  4	"context"
  5	"fmt"
  6	"os"
  7	"sync"
  8	"time"
  9
 10	tea "github.com/charmbracelet/bubbletea"
 11	"github.com/kujtimiihoxha/termai/internal/app"
 12	"github.com/kujtimiihoxha/termai/internal/config"
 13	"github.com/kujtimiihoxha/termai/internal/db"
 14	"github.com/kujtimiihoxha/termai/internal/llm/agent"
 15	"github.com/kujtimiihoxha/termai/internal/logging"
 16	"github.com/kujtimiihoxha/termai/internal/pubsub"
 17	"github.com/kujtimiihoxha/termai/internal/tui"
 18	zone "github.com/lrstanley/bubblezone"
 19	"github.com/spf13/cobra"
 20)
 21
 22var rootCmd = &cobra.Command{
 23	Use:   "termai",
 24	Short: "A terminal ai assistant",
 25	Long:  `A terminal ai assistant`,
 26	RunE: func(cmd *cobra.Command, args []string) error {
 27		// If the help flag is set, show the help message
 28		if cmd.Flag("help").Changed {
 29			cmd.Help()
 30			return nil
 31		}
 32
 33		// Load the config
 34		debug, _ := cmd.Flags().GetBool("debug")
 35		cwd, _ := cmd.Flags().GetString("cwd")
 36		if cwd != "" {
 37			err := os.Chdir(cwd)
 38			if err != nil {
 39				return fmt.Errorf("failed to change directory: %v", err)
 40			}
 41		}
 42		if cwd == "" {
 43			c, err := os.Getwd()
 44			if err != nil {
 45				return fmt.Errorf("failed to get current working directory: %v", err)
 46			}
 47			cwd = c
 48		}
 49		_, err := config.Load(cwd, debug)
 50		if err != nil {
 51			return err
 52		}
 53
 54		// Connect DB, this will also run migrations
 55		conn, err := db.Connect()
 56		if err != nil {
 57			return err
 58		}
 59
 60		// Create main context for the application
 61		ctx, cancel := context.WithCancel(context.Background())
 62		defer cancel()
 63
 64		app, err := app.New(ctx, conn)
 65		if err != nil {
 66			logging.Error("Failed to create app: %v", err)
 67			return err
 68		}
 69
 70		// Set up the TUI
 71		zone.NewGlobal()
 72		program := tea.NewProgram(
 73			tui.New(app),
 74			tea.WithAltScreen(),
 75			tea.WithMouseCellMotion(),
 76		)
 77
 78		// Initialize MCP tools in the background
 79		initMCPTools(ctx, app)
 80
 81		// Setup the subscriptions, this will send services events to the TUI
 82		ch, cancelSubs := setupSubscriptions(app)
 83
 84		// Create a context for the TUI message handler
 85		tuiCtx, tuiCancel := context.WithCancel(ctx)
 86		var tuiWg sync.WaitGroup
 87		tuiWg.Add(1)
 88
 89		// Set up message handling for the TUI
 90		go func() {
 91			defer tuiWg.Done()
 92			defer func() {
 93				if r := recover(); r != nil {
 94					logging.Error("Panic in TUI message handling: %v", r)
 95					attemptTUIRecovery(program)
 96				}
 97			}()
 98
 99			for {
100				select {
101				case <-tuiCtx.Done():
102					logging.Info("TUI message handler shutting down")
103					return
104				case msg, ok := <-ch:
105					if !ok {
106						logging.Info("TUI message channel closed")
107						return
108					}
109					program.Send(msg)
110				}
111			}
112		}()
113
114		// Cleanup function for when the program exits
115		cleanup := func() {
116			// Shutdown the app
117			app.Shutdown()
118
119			// Cancel subscriptions first
120			cancelSubs()
121
122			// Then cancel TUI message handler
123			tuiCancel()
124
125			// Wait for TUI message handler to finish
126			tuiWg.Wait()
127
128			logging.Info("All goroutines cleaned up")
129		}
130
131		// Run the TUI
132		result, err := program.Run()
133		cleanup()
134
135		if err != nil {
136			logging.Error("TUI error: %v", err)
137			return fmt.Errorf("TUI error: %v", err)
138		}
139
140		logging.Info("TUI exited with result: %v", result)
141		return nil
142	},
143}
144
145// attemptTUIRecovery tries to recover the TUI after a panic
146func attemptTUIRecovery(program *tea.Program) {
147	logging.Info("Attempting to recover TUI after panic")
148
149	// We could try to restart the TUI or gracefully exit
150	// For now, we'll just quit the program to avoid further issues
151	program.Quit()
152}
153
154func initMCPTools(ctx context.Context, app *app.App) {
155	go func() {
156		defer func() {
157			if r := recover(); r != nil {
158				logging.Error("Panic in MCP goroutine: %v", r)
159			}
160		}()
161
162		// Create a context with timeout for the initial MCP tools fetch
163		ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second)
164		defer cancel()
165
166		// Set this up once with proper error handling
167		agent.GetMcpTools(ctxWithTimeout, app.Permissions)
168		logging.Info("MCP message handling goroutine exiting")
169	}()
170}
171
172func setupSubscriber[T any](
173	ctx context.Context,
174	wg *sync.WaitGroup,
175	name string,
176	subscriber func(context.Context) <-chan pubsub.Event[T],
177	outputCh chan<- tea.Msg,
178) {
179	wg.Add(1)
180	go func() {
181		defer wg.Done()
182		defer func() {
183			if r := recover(); r != nil {
184				logging.Error("Panic in %s subscription goroutine: %v", name, r)
185			}
186		}()
187
188		for {
189			select {
190			case event, ok := <-subscriber(ctx):
191				if !ok {
192					logging.Info("%s subscription channel closed", name)
193					return
194				}
195
196				// Convert generic event to tea.Msg if needed
197				var msg tea.Msg = event
198
199				// Non-blocking send with timeout to prevent deadlocks
200				select {
201				case outputCh <- msg:
202				case <-time.After(500 * time.Millisecond):
203					logging.Warn("%s message dropped due to slow consumer", name)
204				case <-ctx.Done():
205					logging.Info("%s subscription cancelled", name)
206					return
207				}
208			case <-ctx.Done():
209				logging.Info("%s subscription cancelled", name)
210				return
211			}
212		}
213	}()
214}
215
216func setupSubscriptions(app *app.App) (chan tea.Msg, func()) {
217	ch := make(chan tea.Msg, 100)
218	// Add a buffer to prevent blocking
219	wg := sync.WaitGroup{}
220	ctx, cancel := context.WithCancel(context.Background())
221	// Setup each subscription using the helper
222	setupSubscriber(ctx, &wg, "logging", logging.Subscribe, ch)
223	setupSubscriber(ctx, &wg, "sessions", app.Sessions.Subscribe, ch)
224	setupSubscriber(ctx, &wg, "messages", app.Messages.Subscribe, ch)
225	setupSubscriber(ctx, &wg, "permissions", app.Permissions.Subscribe, ch)
226
227	// Return channel and a cleanup function
228	cleanupFunc := func() {
229		logging.Info("Cancelling all subscriptions")
230		cancel() // Signal all goroutines to stop
231
232		// Wait with a timeout for all goroutines to complete
233		waitCh := make(chan struct{})
234		go func() {
235			wg.Wait()
236			close(waitCh)
237		}()
238
239		select {
240		case <-waitCh:
241			logging.Info("All subscription goroutines completed successfully")
242		case <-time.After(5 * time.Second):
243			logging.Warn("Timed out waiting for some subscription goroutines to complete")
244		}
245
246		close(ch) // Safe to close after all writers are done or timed out
247	}
248	return ch, cleanupFunc
249}
250
251func Execute() {
252	err := rootCmd.Execute()
253	if err != nil {
254		os.Exit(1)
255	}
256}
257
258func init() {
259	rootCmd.Flags().BoolP("help", "h", false, "Help")
260	rootCmd.Flags().BoolP("debug", "d", false, "Debug")
261	rootCmd.Flags().StringP("cwd", "c", "", "Current working directory")
262}