bedrock.go

 1package anthropic
 2
 3import (
 4	"bytes"
 5	"fmt"
 6	"io"
 7	"net/http"
 8	"net/url"
 9
10	"github.com/tidwall/gjson"
11	"github.com/tidwall/sjson"
12
13	"github.com/anthropics/anthropic-sdk-go/bedrock"
14	"github.com/anthropics/anthropic-sdk-go/option"
15)
16
17func bedrockMiddleware(bearerToken string) option.Middleware {
18	return func(r *http.Request, next option.MiddlewareNext) (res *http.Response, err error) {
19		var body []byte
20		if r.Body != nil {
21			body, err = io.ReadAll(r.Body)
22			if err != nil {
23				return nil, err
24			}
25			_ = r.Body.Close()
26
27			if !gjson.GetBytes(body, "anthropic_version").Exists() {
28				body, _ = sjson.SetBytes(body, "anthropic_version", bedrock.DefaultVersion)
29			}
30
31			if r.Method == http.MethodPost && bedrock.DefaultEndpoints[r.URL.Path] {
32				model := gjson.GetBytes(body, "model").String()
33				stream := gjson.GetBytes(body, "stream").Bool()
34
35				body, _ = sjson.DeleteBytes(body, "model")
36				body, _ = sjson.DeleteBytes(body, "stream")
37
38				var method string
39				if stream {
40					method = "invoke-with-response-stream"
41				} else {
42					method = "invoke"
43				}
44
45				r.URL.Path = fmt.Sprintf("/model/%s/%s", model, method)
46				r.URL.RawPath = fmt.Sprintf("/model/%s/%s", url.QueryEscape(model), method)
47			}
48
49			reader := bytes.NewReader(body)
50			r.Body = io.NopCloser(reader)
51			r.GetBody = func() (io.ReadCloser, error) {
52				_, err := reader.Seek(0, 0)
53				return io.NopCloser(reader), err
54			}
55			r.ContentLength = int64(len(body))
56		}
57
58		r.Header.Set("Authorization", "Bearer "+bearerToken)
59
60		return next(r)
61	}
62}