1package anthropic
2
3import (
4 "bytes"
5 "fmt"
6 "io"
7 "net/http"
8 "net/url"
9
10 "github.com/tidwall/gjson"
11 "github.com/tidwall/sjson"
12
13 "github.com/anthropics/anthropic-sdk-go/bedrock"
14 "github.com/anthropics/anthropic-sdk-go/option"
15)
16
17func bedrockMiddleware(bearerToken string) option.Middleware {
18 return func(r *http.Request, next option.MiddlewareNext) (res *http.Response, err error) {
19 var body []byte
20 if r.Body != nil {
21 body, err = io.ReadAll(r.Body)
22 if err != nil {
23 return nil, err
24 }
25 _ = r.Body.Close()
26
27 if !gjson.GetBytes(body, "anthropic_version").Exists() {
28 body, _ = sjson.SetBytes(body, "anthropic_version", bedrock.DefaultVersion)
29 }
30
31 if r.Method == http.MethodPost && bedrock.DefaultEndpoints[r.URL.Path] {
32 model := gjson.GetBytes(body, "model").String()
33 stream := gjson.GetBytes(body, "stream").Bool()
34
35 body, _ = sjson.DeleteBytes(body, "model")
36 body, _ = sjson.DeleteBytes(body, "stream")
37
38 var method string
39 if stream {
40 method = "invoke-with-response-stream"
41 } else {
42 method = "invoke"
43 }
44
45 r.URL.Path = fmt.Sprintf("/model/%s/%s", model, method)
46 r.URL.RawPath = fmt.Sprintf("/model/%s/%s", url.QueryEscape(model), method)
47 }
48
49 reader := bytes.NewReader(body)
50 r.Body = io.NopCloser(reader)
51 r.GetBody = func() (io.ReadCloser, error) {
52 _, err := reader.Seek(0, 0)
53 return io.NopCloser(reader), err
54 }
55 r.ContentLength = int64(len(body))
56 }
57
58 r.Header.Set("Authorization", "Bearer "+bearerToken)
59
60 return next(r)
61 }
62}