diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index 041cff490a59f1de51505e833cc7ee7866aa7644..d2ff6454e9a7135ee9404ef665495772b90ba86c 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -387,6 +387,7 @@ func createMCPSession(ctx context.Context, name string, m config.MCPConfig, reso session, err := client.Connect(mcpCtx, transport, nil) if err != nil { + err = maybeStdioErr(err, transport) updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0) slog.Error("error starting mcp client", "error", err, "name", name) cancel() @@ -398,6 +399,27 @@ func createMCPSession(ctx context.Context, name string, m config.MCPConfig, reso return session, nil } +// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail +// to parse, and the cli will then close it, causing the EOF error. +// so, if we got an EOF err, and the transport is STDIO, we try to exec it +// again with a timeout and collect the output so we can add details to the +// error. +// this happens particularly when starting things with npx, e.g. if node can't +// be found or some other error like that. +func maybeStdioErr(err error, transport mcp.Transport) error { + if !errors.Is(err, io.EOF) { + return err + } + ct, ok := transport.(*mcp.CommandTransport) + if !ok { + return err + } + if err2 := stdioMCPCheck(ct.Command); err2 != nil { + err = errors.Join(err, err2) + } + return err +} + func maybeTimeoutErr(err error, timeout time.Duration) error { if errors.Is(err, context.Canceled) { return fmt.Errorf("timed out after %s", timeout) @@ -465,3 +487,15 @@ func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error func mcpTimeout(m config.MCPConfig) time.Duration { return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second } + +func stdioMCPCheck(old *exec.Cmd) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + cmd := exec.CommandContext(ctx, old.Path, old.Args...) + cmd.Env = old.Env + out, err := cmd.CombinedOutput() + if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) { + return nil + } + return fmt.Errorf("%w: %s", err, string(out)) +}