1
21
22 package internal
23
24 import (
25 "errors"
26 "net/url"
27
28 "io"
29 "strconv"
30 "time"
31
32 "golang.org/x/net/html"
33 goauth "golang.org/x/oauth2"
34 )
35
36 func ParseFormPostResponse(redirectURL string, resp io.ReadCloser) (authorizationCode, stateFromServer, iDToken string, token goauth.Token, customParameters url.Values, rFC6749Error map[string]string, err error) {
37 token = goauth.Token{}
38 rFC6749Error = map[string]string{}
39 customParameters = url.Values{}
40
41 doc, err := html.Parse(resp)
42 if err != nil {
43 return "", "", "", token, customParameters, rFC6749Error, err
44 }
45
46
47 body := findBody(doc.FirstChild.FirstChild)
48 if body.Data != "body" {
49 return "", "", "", token, customParameters, rFC6749Error, errors.New("Malformed html")
50 }
51
52 htmlEvent := body.Attr[0].Key
53 if htmlEvent != "onload" {
54 return "", "", "", token, customParameters, rFC6749Error, errors.New("onload event is missing")
55 }
56
57 onLoadFunc := body.Attr[0].Val
58 if onLoadFunc != "javascript:document.forms[0].submit()" {
59 return "", "", "", token, customParameters, rFC6749Error, errors.New("onload function is missing")
60 }
61
62 form := getNextNoneTextNode(body.FirstChild)
63 if form.Data != "form" {
64 return "", "", "", token, customParameters, rFC6749Error, errors.New("html form is missing")
65 }
66
67 for _, attr := range form.Attr {
68 if attr.Key == "method" {
69 if attr.Val != "post" {
70 return "", "", "", token, customParameters, rFC6749Error, errors.New("html form post method is missing")
71 }
72 } else {
73 if attr.Val != redirectURL {
74 return "", "", "", token, customParameters, rFC6749Error, errors.New("html form post url is wrong")
75 }
76 }
77 }
78
79 for node := getNextNoneTextNode(form.FirstChild); node != nil; node = getNextNoneTextNode(node.NextSibling) {
80 var k, v string
81 for _, attr := range node.Attr {
82 if attr.Key == "name" {
83 k = attr.Val
84 } else if attr.Key == "value" {
85 v = attr.Val
86 }
87
88 }
89
90 switch k {
91 case "state":
92 stateFromServer = v
93 case "code":
94 authorizationCode = v
95 case "expires_in":
96 expires, err := strconv.Atoi(v)
97 if err != nil {
98 return "", "", "", token, customParameters, rFC6749Error, err
99 }
100 token.Expiry = time.Now().UTC().Add(time.Duration(expires) * time.Second)
101 case "access_token":
102 token.AccessToken = v
103 case "token_type":
104 token.TokenType = v
105 case "refresh_token":
106 token.RefreshToken = v
107 case "error":
108 rFC6749Error["ErrorField"] = v
109 case "error_hint":
110 rFC6749Error["HintField"] = v
111 case "error_description":
112 rFC6749Error["DescriptionField"] = v
113 case "id_token":
114 iDToken = v
115 default:
116 customParameters.Add(k, v)
117 }
118 }
119
120 return
121 }
122
123 func getNextNoneTextNode(node *html.Node) *html.Node {
124 nextNode := node.NextSibling
125 if nextNode != nil && nextNode.Type == html.TextNode {
126 nextNode = getNextNoneTextNode(node.NextSibling)
127 }
128
129 return nextNode
130 }
131
132 func findBody(node *html.Node) *html.Node {
133 if node != nil {
134 if node.Data == "body" {
135 return node
136 }
137 return findBody(node.NextSibling)
138 }
139
140 return nil
141 }
142
View as plain text