fix(mcp): improve STDIO error handling (#1244)

Carlos Alexandro Becker created

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

Change summary

internal/llm/agent/mcp-tools.go | 34 ++++++++++++++++++++++++++++++++++
1 file changed, 34 insertions(+)

Detailed changes

internal/llm/agent/mcp-tools.go 🔗

@@ -387,6 +387,7 @@ func createMCPSession(ctx context.Context, name string, m config.MCPConfig, reso
 
 	session, err := client.Connect(mcpCtx, transport, nil)
 	if err != nil {
+		err = maybeStdioErr(err, transport)
 		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
 		slog.Error("error starting mcp client", "error", err, "name", name)
 		cancel()
@@ -398,6 +399,27 @@ func createMCPSession(ctx context.Context, name string, m config.MCPConfig, reso
 	return session, nil
 }
 
+// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
+// to parse, and the cli will then close it, causing the EOF error.
+// so, if we got an EOF err, and the transport is STDIO, we try to exec it
+// again with a timeout and collect the output so we can add details to the
+// error.
+// this happens particularly when starting things with npx, e.g. if node can't
+// be found or some other error like that.
+func maybeStdioErr(err error, transport mcp.Transport) error {
+	if !errors.Is(err, io.EOF) {
+		return err
+	}
+	ct, ok := transport.(*mcp.CommandTransport)
+	if !ok {
+		return err
+	}
+	if err2 := stdioMCPCheck(ct.Command); err2 != nil {
+		err = errors.Join(err, err2)
+	}
+	return err
+}
+
 func maybeTimeoutErr(err error, timeout time.Duration) error {
 	if errors.Is(err, context.Canceled) {
 		return fmt.Errorf("timed out after %s", timeout)
@@ -465,3 +487,15 @@ func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
 func mcpTimeout(m config.MCPConfig) time.Duration {
 	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
 }
+
+func stdioMCPCheck(old *exec.Cmd) error {
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
+	defer cancel()
+	cmd := exec.CommandContext(ctx, old.Path, old.Args...)
+	cmd.Env = old.Env
+	out, err := cmd.CombinedOutput()
+	if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
+		return nil
+	}
+	return fmt.Errorf("%w: %s", err, string(out))
+}