1
2
3
4
5
6
7
8
9
10
11
12
13
14 package sns
15
16 import (
17 "context"
18 "fmt"
19 "net/http"
20 "strings"
21 "unicode/utf8"
22
23 "github.com/aws/aws-sdk-go/aws"
24 "github.com/aws/aws-sdk-go/aws/awserr"
25 "github.com/aws/aws-sdk-go/aws/credentials"
26 "github.com/aws/aws-sdk-go/aws/credentials/stscreds"
27 "github.com/aws/aws-sdk-go/aws/session"
28 "github.com/aws/aws-sdk-go/service/sns"
29 "github.com/go-kit/log"
30 "github.com/go-kit/log/level"
31 commoncfg "github.com/prometheus/common/config"
32
33 "github.com/prometheus/alertmanager/config"
34 "github.com/prometheus/alertmanager/notify"
35 "github.com/prometheus/alertmanager/template"
36 "github.com/prometheus/alertmanager/types"
37 )
38
39
40 type Notifier struct {
41 conf *config.SNSConfig
42 tmpl *template.Template
43 logger log.Logger
44 client *http.Client
45 retrier *notify.Retrier
46 }
47
48
49 func New(c *config.SNSConfig, t *template.Template, l log.Logger, httpOpts ...commoncfg.HTTPClientOption) (*Notifier, error) {
50 client, err := commoncfg.NewClientFromConfig(*c.HTTPConfig, "sns", httpOpts...)
51 if err != nil {
52 return nil, err
53 }
54 return &Notifier{
55 conf: c,
56 tmpl: t,
57 logger: l,
58 client: client,
59 retrier: ¬ify.Retrier{},
60 }, nil
61 }
62
63 func (n *Notifier) Notify(ctx context.Context, alert ...*types.Alert) (bool, error) {
64 var (
65 err error
66 data = notify.GetTemplateData(ctx, n.tmpl, alert, n.logger)
67 tmpl = notify.TmplText(n.tmpl, data, &err)
68 )
69
70 client, err := n.createSNSClient(tmpl)
71 if err != nil {
72 if e, ok := err.(awserr.RequestFailure); ok {
73 return n.retrier.Check(e.StatusCode(), strings.NewReader(e.Message()))
74 }
75 return true, err
76 }
77
78 publishInput, err := n.createPublishInput(ctx, tmpl)
79 if err != nil {
80 return true, err
81 }
82
83 publishOutput, err := client.Publish(publishInput)
84 if err != nil {
85 if e, ok := err.(awserr.RequestFailure); ok {
86 return n.retrier.Check(e.StatusCode(), strings.NewReader(e.Message()))
87 }
88 return true, err
89 }
90
91 level.Debug(n.logger).Log("msg", "SNS message successfully published", "message_id", publishOutput.MessageId, "sequence number", publishOutput.SequenceNumber)
92
93 return false, nil
94 }
95
96 func (n *Notifier) createSNSClient(tmpl func(string) string) (*sns.SNS, error) {
97 var creds *credentials.Credentials
98
99 if n.conf.Sigv4.AccessKey != "" && n.conf.Sigv4.SecretKey != "" {
100 creds = credentials.NewStaticCredentials(n.conf.Sigv4.AccessKey, string(n.conf.Sigv4.SecretKey), "")
101 }
102 sess, err := session.NewSessionWithOptions(session.Options{
103 Config: aws.Config{
104 Region: aws.String(n.conf.Sigv4.Region),
105 Endpoint: aws.String(tmpl(n.conf.APIUrl)),
106 },
107 Profile: n.conf.Sigv4.Profile,
108 })
109 if err != nil {
110 return nil, err
111 }
112
113 if n.conf.Sigv4.RoleARN != "" {
114 var stsSess *session.Session
115 if n.conf.APIUrl == "" {
116 stsSess = sess
117 } else {
118
119 stsSess, err = session.NewSessionWithOptions(session.Options{
120 Config: aws.Config{
121 Region: aws.String(n.conf.Sigv4.Region),
122 Credentials: creds,
123 },
124 Profile: n.conf.Sigv4.Profile,
125 })
126 if err != nil {
127 return nil, err
128 }
129 }
130 creds = stscreds.NewCredentials(stsSess, n.conf.Sigv4.RoleARN)
131 }
132
133 client := sns.New(sess, &aws.Config{Credentials: creds, HTTPClient: n.client})
134
135 if aws.StringValue(sess.Config.Region) == "" {
136 return nil, fmt.Errorf("region not configured in sns.sigv4.region or in default credentials chain")
137 }
138 return client, nil
139 }
140
141 func (n *Notifier) createPublishInput(ctx context.Context, tmpl func(string) string) (*sns.PublishInput, error) {
142 publishInput := &sns.PublishInput{}
143 messageAttributes := n.createMessageAttributes(tmpl)
144
145 messageSizeLimit := 256 * 1024
146 if n.conf.TopicARN != "" {
147 topicARN := tmpl(n.conf.TopicARN)
148 publishInput.SetTopicArn(topicARN)
149
150 if strings.HasSuffix(topicARN, ".fifo") {
151
152 key, err := notify.ExtractGroupKey(ctx)
153 if err != nil {
154 return nil, err
155 }
156 publishInput.SetMessageDeduplicationId(key.Hash())
157 publishInput.SetMessageGroupId(key.Hash())
158 }
159 }
160 if n.conf.PhoneNumber != "" {
161 publishInput.SetPhoneNumber(tmpl(n.conf.PhoneNumber))
162
163 messageSizeLimit = 1600
164 }
165 if n.conf.TargetARN != "" {
166 publishInput.SetTargetArn(tmpl(n.conf.TargetARN))
167 }
168
169 messageToSend, isTrunc, err := validateAndTruncateMessage(tmpl(n.conf.Message), messageSizeLimit)
170 if err != nil {
171 return nil, err
172 }
173 if isTrunc {
174
175 messageAttributes["truncated"] = &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String("true")}
176 }
177
178 publishInput.SetMessage(messageToSend)
179 publishInput.SetMessageAttributes(messageAttributes)
180
181 if n.conf.Subject != "" {
182 publishInput.SetSubject(tmpl(n.conf.Subject))
183 }
184
185 return publishInput, nil
186 }
187
188 func validateAndTruncateMessage(message string, maxMessageSizeInBytes int) (string, bool, error) {
189 if !utf8.ValidString(message) {
190 return "", false, fmt.Errorf("non utf8 encoded message string")
191 }
192 if len(message) <= maxMessageSizeInBytes {
193 return message, false, nil
194 }
195
196 truncated := make([]byte, maxMessageSizeInBytes)
197 copy(truncated, message)
198 return string(truncated), true, nil
199 }
200
201 func (n *Notifier) createMessageAttributes(tmpl func(string) string) map[string]*sns.MessageAttributeValue {
202
203 attributes := make(map[string]*sns.MessageAttributeValue, len(n.conf.Attributes))
204 for k, v := range n.conf.Attributes {
205 attributes[tmpl(k)] = &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String(tmpl(v))}
206 }
207 return attributes
208 }
209
View as plain text