Commit 052f8a3a63bf8c52b72b5e20ffd0a5d22372e5ab
1 parent
61efd58cdb
Exists in
master
and in
1 other branch
token validation extended
Showing
1 changed file
with
10 additions
and
29 deletions
Show diff stats
auth_utility.go
1 | package webutility | 1 | package webutility |
2 | 2 | ||
3 | import ( | 3 | import ( |
4 | "crypto/rand" | 4 | "crypto/rand" |
5 | "crypto/sha256" | 5 | "crypto/sha256" |
6 | "encoding/hex" | 6 | "encoding/hex" |
7 | "errors" | 7 | "errors" |
8 | "net/http" | 8 | "net/http" |
9 | "strings" | 9 | "strings" |
10 | "time" | 10 | "time" |
11 | 11 | ||
12 | "github.com/dgrijalva/jwt-go" | 12 | "github.com/dgrijalva/jwt-go" |
13 | ) | 13 | ) |
14 | 14 | ||
15 | const OneDay = time.Hour * 24 | 15 | const OneDay = time.Hour * 24 |
16 | const OneWeek = OneDay * 7 | 16 | const OneWeek = OneDay * 7 |
17 | const saltSize = 32 | 17 | const saltSize = 32 |
18 | const appName = "korisnicki-centar" | 18 | const appName = "korisnicki-centar" |
19 | const secret = "korisnicki-centar-api" | 19 | const secret = "korisnicki-centar-api" |
20 | 20 | ||
21 | type Role struct { | 21 | type Role struct { |
22 | Name string `json:"name"` | 22 | Name string `json:"name"` |
23 | ID uint32 `json:"id"` | 23 | ID uint32 `json:"id"` |
24 | } | 24 | } |
25 | 25 | ||
26 | // TokenClaims are JWT token claims. | 26 | // TokenClaims are JWT token claims. |
27 | type TokenClaims struct { | 27 | type TokenClaims struct { |
28 | Token string `json:"access_token"` | 28 | Token string `json:"access_token"` |
29 | TokenType string `json:"token_type"` | 29 | TokenType string `json:"token_type"` |
30 | Username string `json:"username"` | 30 | Username string `json:"username"` |
31 | Role string `json:"role"` | 31 | Role string `json:"role"` |
32 | RoleID uint32 `json:"role_id"` | 32 | RoleID uint32 `json:"role_id"` |
33 | ExpiresIn int64 `json:"expires_in"` | 33 | ExpiresIn int64 `json:"expires_in"` |
34 | 34 | ||
35 | // extending a struct | 35 | // extending a struct |
36 | jwt.StandardClaims | 36 | jwt.StandardClaims |
37 | } | 37 | } |
38 | 38 | ||
39 | // CredentialsStruct is an instace of username/password values. | 39 | // CredentialsStruct is an instace of username/password values. |
40 | type CredentialsStruct struct { | 40 | type CredentialsStruct struct { |
41 | Username string `json:"username"` | 41 | Username string `json:"username"` |
42 | Password string `json:"password"` | 42 | Password string `json:"password"` |
43 | RoleID uint32 `json:"roleID"` | 43 | RoleID uint32 `json:"roleID"` |
44 | } | 44 | } |
45 | 45 | ||
46 | // ValidateCredentials hashes pass and salt and returns comparison result with resultHash | 46 | // ValidateCredentials hashes pass and salt and returns comparison result with resultHash |
47 | func ValidateCredentials(pass, salt, resultHash string) bool { | 47 | func ValidateCredentials(pass, salt, resultHash string) bool { |
48 | hash, _, err := CreateHash(pass, salt) | 48 | hash, _, err := CreateHash(pass, salt) |
49 | if err != nil { | 49 | if err != nil { |
50 | return false | 50 | return false |
51 | } | 51 | } |
52 | return hash == resultHash | 52 | return hash == resultHash |
53 | } | 53 | } |
54 | 54 | ||
55 | // CreateHash hashes str using SHA256. | 55 | // CreateHash hashes str using SHA256. |
56 | // If the presalt parameter is not provided CreateHash will generate new salt string. | 56 | // If the presalt parameter is not provided CreateHash will generate new salt string. |
57 | // Returns hash and salt strings or an error if it fails. | 57 | // Returns hash and salt strings or an error if it fails. |
58 | func CreateHash(str, presalt string) (hash, salt string, err error) { | 58 | func CreateHash(str, presalt string) (hash, salt string, err error) { |
59 | // chech if message is presalted | 59 | // chech if message is presalted |
60 | if presalt == "" { | 60 | if presalt == "" { |
61 | salt, err = randomSalt() | 61 | salt, err = randomSalt() |
62 | if err != nil { | 62 | if err != nil { |
63 | return "", "", err | 63 | return "", "", err |
64 | } | 64 | } |
65 | } else { | 65 | } else { |
66 | salt = presalt | 66 | salt = presalt |
67 | } | 67 | } |
68 | 68 | ||
69 | // convert strings to raw byte slices | 69 | // convert strings to raw byte slices |
70 | rawstr := []byte(str) | 70 | rawstr := []byte(str) |
71 | rawsalt, err := hex.DecodeString(salt) | 71 | rawsalt, err := hex.DecodeString(salt) |
72 | if err != nil { | 72 | if err != nil { |
73 | return "", "", err | 73 | return "", "", err |
74 | } | 74 | } |
75 | 75 | ||
76 | rawdata := make([]byte, len(rawstr)+len(rawsalt)) | 76 | rawdata := make([]byte, len(rawstr)+len(rawsalt)) |
77 | rawdata = append(rawdata, rawstr...) | 77 | rawdata = append(rawdata, rawstr...) |
78 | rawdata = append(rawdata, rawsalt...) | 78 | rawdata = append(rawdata, rawsalt...) |
79 | 79 | ||
80 | // hash message + salt | 80 | // hash message + salt |
81 | hasher := sha256.New() | 81 | hasher := sha256.New() |
82 | hasher.Write(rawdata) | 82 | hasher.Write(rawdata) |
83 | rawhash := hasher.Sum(nil) | 83 | rawhash := hasher.Sum(nil) |
84 | 84 | ||
85 | hash = hex.EncodeToString(rawhash) | 85 | hash = hex.EncodeToString(rawhash) |
86 | return hash, salt, nil | 86 | return hash, salt, nil |
87 | } | 87 | } |
88 | 88 | ||
89 | // CreateAuthToken returns JWT token with encoded username, role, expiration date and issuer claims. | 89 | // CreateAuthToken returns JWT token with encoded username, role, expiration date and issuer claims. |
90 | // It returns an error if it fails. | 90 | // It returns an error if it fails. |
91 | func CreateAuthToken(username string, role Role) (TokenClaims, error) { | 91 | func CreateAuthToken(username string, role Role) (TokenClaims, error) { |
92 | t0 := (time.Now()).Unix() | 92 | t0 := (time.Now()).Unix() |
93 | t1 := (time.Now().Add(OneWeek)).Unix() | 93 | t1 := (time.Now().Add(OneWeek)).Unix() |
94 | claims := TokenClaims{ | 94 | claims := TokenClaims{ |
95 | TokenType: "Bearer", | 95 | TokenType: "Bearer", |
96 | Username: username, | 96 | Username: username, |
97 | Role: role.Name, | 97 | Role: role.Name, |
98 | RoleID: role.ID, | 98 | RoleID: role.ID, |
99 | ExpiresIn: t1 - t0, | 99 | ExpiresIn: t1 - t0, |
100 | } | 100 | } |
101 | // initialize jwt.StandardClaims fields (anonymous struct) | 101 | // initialize jwt.StandardClaims fields (anonymous struct) |
102 | claims.IssuedAt = t0 | 102 | claims.IssuedAt = t0 |
103 | claims.ExpiresAt = t1 | 103 | claims.ExpiresAt = t1 |
104 | claims.Issuer = appName | 104 | claims.Issuer = appName |
105 | 105 | ||
106 | jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) | 106 | jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) |
107 | token, err := jwtToken.SignedString([]byte(secret)) | 107 | token, err := jwtToken.SignedString([]byte(secret)) |
108 | if err != nil { | 108 | if err != nil { |
109 | return TokenClaims{}, err | 109 | return TokenClaims{}, err |
110 | } | 110 | } |
111 | claims.Token = token | 111 | claims.Token = token |
112 | return claims, nil | 112 | return claims, nil |
113 | } | 113 | } |
114 | 114 | ||
115 | // RefreshAuthToken prolongs JWT token's expiration date for one week. | 115 | // RefreshAuthToken prolongs JWT token's expiration date for one week. |
116 | // It returns new JWT token or an error if it fails. | 116 | // It returns new JWT token or an error if it fails. |
117 | func RefreshAuthToken(req *http.Request) (TokenClaims, error) { | 117 | func RefreshAuthToken(req *http.Request) (TokenClaims, error) { |
118 | authHead := req.Header.Get("Authorization") | 118 | authHead := req.Header.Get("Authorization") |
119 | tokenstr := strings.TrimPrefix(authHead, "Bearer ") | 119 | tokenstr := strings.TrimPrefix(authHead, "Bearer ") |
120 | token, err := jwt.ParseWithClaims(tokenstr, &TokenClaims{}, secretFunc) | 120 | token, err := jwt.ParseWithClaims(tokenstr, &TokenClaims{}, secretFunc) |
121 | if err != nil { | 121 | if err != nil { |
122 | return TokenClaims{}, err | 122 | if validation, ok := err.(*jwt.ValidationError); ok { |
123 | // don't return error if token is expired | ||
124 | // just extend it | ||
125 | if !(validation.Errors&jwt.ValidationErrorExpired != 0) { | ||
126 | return TokenClaims{}, err | ||
127 | } | ||
128 | } else { | ||
129 | return TokenClaims{}, err | ||
130 | } | ||
123 | } | 131 | } |
124 | 132 | ||
125 | // type assertion | 133 | // type assertion |
126 | claims, ok := token.Claims.(*TokenClaims) | 134 | claims, ok := token.Claims.(*TokenClaims) |
127 | if !ok || !token.Valid { | 135 | if !ok { |
128 | return TokenClaims{}, errors.New("token is not valid") | 136 | return TokenClaims{}, errors.New("token is not valid") |
129 | } | 137 | } |
130 | 138 | ||
131 | // extend token expiration date | 139 | // extend token expiration date |
132 | return CreateAuthToken(claims.Username, Role{claims.Role, claims.RoleID}) | 140 | return CreateAuthToken(claims.Username, Role{claims.Role, claims.RoleID}) |
133 | } | 141 | } |
134 | 142 | ||
135 | // RbacCheck returns true if role that made HTTP request is authorized to | 143 | // RbacCheck returns true if role that made HTTP request is authorized to |
136 | // access the resource it is targeting. | 144 | // access the resource it is targeting. |
137 | // It exctracts user's role from the JWT token located in Authorization header of | 145 | // It exctracts user's role from the JWT token located in Authorization header of |
138 | // http.Request and then compares it with the list of supplied roles and returns | 146 | // http.Request and then compares it with the list of supplied roles and returns |
139 | // true if there's a match, if "*" is provided or if the authRoles is nil. | 147 | // true if there's a match, if "*" is provided or if the authRoles is nil. |
140 | // Otherwise it returns false. | 148 | // Otherwise it returns false. |
141 | func RbacCheck(req *http.Request, authRoles []string) bool { | 149 | func RbacCheck(req *http.Request, authRoles []string) bool { |
142 | if authRoles == nil { | 150 | if authRoles == nil { |
143 | return true | 151 | return true |
144 | } | 152 | } |
145 | 153 | ||
146 | // validate token and check expiration date | 154 | // validate token and check expiration date |
147 | claims, err := GetTokenClaims(req) | 155 | claims, err := GetTokenClaims(req) |
148 | if err != nil { | 156 | if err != nil { |
149 | return false | 157 | return false |
150 | } | 158 | } |
151 | // check if token has expired | 159 | // check if token has expired |
152 | if claims.ExpiresAt < (time.Now()).Unix() { | 160 | if claims.ExpiresAt < (time.Now()).Unix() { |
153 | return false | 161 | return false |
154 | } | 162 | } |
155 | 163 | ||
156 | // check if role extracted from token matches | 164 | // check if role extracted from token matches |
157 | // any of the provided (allowed) ones | 165 | // any of the provided (allowed) ones |
158 | for _, r := range authRoles { | 166 | for _, r := range authRoles { |
159 | if claims.Role == r || r == "*" { | 167 | if claims.Role == r || r == "*" { |
160 | return true | 168 | return true |
161 | } | 169 | } |
162 | } | 170 | } |
163 | 171 | ||
164 | return false | 172 | return false |
165 | } | 173 | } |
166 | 174 | ||
167 | // TODO | ||
168 | func AuthCheck(req *http.Request, authRoles []string) (*TokenClaims, error) { | ||
169 | if authRoles == nil { | ||
170 | return &TokenClaims{}, nil | ||
171 | } | ||
172 | |||
173 | // validate token and check expiration date | ||
174 | claims, err := GetTokenClaims(req) | ||
175 | if err != nil { | ||
176 | return &TokenClaims{}, err | ||
177 | } | ||
178 | // check if token has expired | ||
179 | if claims.ExpiresAt < (time.Now()).Unix() { | ||
180 | return &TokenClaims{}, errors.New("token has expired") | ||
181 | } | ||
182 | |||
183 | // check if role extracted from token matches | ||
184 | // any of the provided (allowed) ones | ||
185 | for _, r := range authRoles { | ||
186 | if claims.Role == r || r == "*" { | ||
187 | return claims, nil | ||
188 | } | ||
189 | } | ||
190 | |||
191 | return claims, errors.New("role is not authorized") | ||
192 | } | ||
193 | |||
194 | // GetTokenClaims extracts JWT claims from Authorization header of the request. | 175 | // GetTokenClaims extracts JWT claims from Authorization header of the request. |
195 | // Returns token claims or an error. | 176 | // Returns token claims or an error. |
196 | func GetTokenClaims(req *http.Request) (*TokenClaims, error) { | 177 | func GetTokenClaims(req *http.Request) (*TokenClaims, error) { |
197 | // check for and strip 'Bearer' prefix | 178 | // check for and strip 'Bearer' prefix |
198 | var tokstr string | 179 | var tokstr string |
199 | authHead := req.Header.Get("Authorization") | 180 | authHead := req.Header.Get("Authorization") |
200 | if ok := strings.HasPrefix(authHead, "Bearer "); ok { | 181 | if ok := strings.HasPrefix(authHead, "Bearer "); ok { |
201 | tokstr = strings.TrimPrefix(authHead, "Bearer ") | 182 | tokstr = strings.TrimPrefix(authHead, "Bearer ") |
202 | } else { | 183 | } else { |
203 | return &TokenClaims{}, errors.New("authorization header in incomplete") | 184 | return &TokenClaims{}, errors.New("authorization header in incomplete") |
204 | } | 185 | } |
205 | 186 | ||
206 | token, err := jwt.ParseWithClaims(tokstr, &TokenClaims{}, secretFunc) | 187 | token, err := jwt.ParseWithClaims(tokstr, &TokenClaims{}, secretFunc) |
207 | if err != nil { | 188 | if err != nil { |
208 | return &TokenClaims{}, err | 189 | return &TokenClaims{}, err |
209 | } | 190 | } |
210 | 191 | ||
211 | // type assertion | 192 | // type assertion |
212 | claims, ok := token.Claims.(*TokenClaims) | 193 | claims, ok := token.Claims.(*TokenClaims) |
213 | if !ok || !token.Valid { | 194 | if !ok || !token.Valid { |
214 | return &TokenClaims{}, errors.New("token is not valid") | 195 | return &TokenClaims{}, errors.New("token is not valid") |
215 | } | 196 | } |
216 | 197 | ||
217 | return claims, nil | 198 | return claims, nil |
218 | } | 199 | } |
219 | 200 | ||
220 | // randomSalt returns a string of random characters of 'saltSize' length. | 201 | // randomSalt returns a string of random characters of 'saltSize' length. |
221 | func randomSalt() (s string, err error) { | 202 | func randomSalt() (s string, err error) { |
222 | rawsalt := make([]byte, saltSize) | 203 | rawsalt := make([]byte, saltSize) |
223 | 204 | ||
224 | _, err = rand.Read(rawsalt) | 205 | _, err = rand.Read(rawsalt) |
225 | if err != nil { | 206 | if err != nil { |
226 | return "", err | 207 | return "", err |
227 | } | 208 | } |
228 | 209 | ||
229 | s = hex.EncodeToString(rawsalt) | 210 | s = hex.EncodeToString(rawsalt) |