Commit 0feac505905b9435723c3a058453e61a59cd7dc7
1 parent
31a4e13027
Exists in
master
renamed ValidateCredentials to ValidateHash
Showing
1 changed file
with
2 additions
and
2 deletions
Show diff stats
auth.go
1 | package webutility | 1 | package webutility |
2 | 2 | ||
3 | import ( | 3 | import ( |
4 | "crypto/rand" | 4 | "crypto/rand" |
5 | "crypto/sha256" | 5 | "crypto/sha256" |
6 | "encoding/hex" | 6 | "encoding/hex" |
7 | "errors" | 7 | "errors" |
8 | "net/http" | 8 | "net/http" |
9 | "strings" | 9 | "strings" |
10 | "time" | 10 | "time" |
11 | 11 | ||
12 | "github.com/dgrijalva/jwt-go" | 12 | "github.com/dgrijalva/jwt-go" |
13 | ) | 13 | ) |
14 | 14 | ||
15 | var appName = "webutility" | 15 | var appName = "webutility" |
16 | var secret = "webutility" | 16 | var secret = "webutility" |
17 | 17 | ||
18 | type Role struct { | 18 | type Role struct { |
19 | Name string `json:"name"` | 19 | Name string `json:"name"` |
20 | ID int64 `json:"id"` | 20 | ID int64 `json:"id"` |
21 | } | 21 | } |
22 | 22 | ||
23 | // TokenClaims are JWT token claims. | 23 | // TokenClaims are JWT token claims. |
24 | type TokenClaims struct { | 24 | type TokenClaims struct { |
25 | // extending a struct | 25 | // extending a struct |
26 | jwt.StandardClaims | 26 | jwt.StandardClaims |
27 | 27 | ||
28 | // custom claims | 28 | // custom claims |
29 | Token string `json:"access_token"` | 29 | Token string `json:"access_token"` |
30 | TokenType string `json:"token_type"` | 30 | TokenType string `json:"token_type"` |
31 | Username string `json:"username"` | 31 | Username string `json:"username"` |
32 | Role string `json:"role"` | 32 | Role string `json:"role"` |
33 | RoleID int64 `json:"role_id"` | 33 | RoleID int64 `json:"role_id"` |
34 | ExpiresIn int64 `json:"expires_in"` | 34 | ExpiresIn int64 `json:"expires_in"` |
35 | } | 35 | } |
36 | 36 | ||
37 | func InitJWT(appName, secret string) { | 37 | func InitJWT(appName, secret string) { |
38 | appName = appName | 38 | appName = appName |
39 | secret = secret | 39 | secret = secret |
40 | } | 40 | } |
41 | 41 | ||
42 | // ValidateCredentials hashes pass and salt and returns comparison result with resultHash | 42 | // ValidateHash hashes pass and salt and returns comparison result with resultHash |
43 | func ValidateCredentials(pass, salt, resultHash string) (bool, error) { | 43 | func ValidateHash(pass, salt, resultHash string) (bool, error) { |
44 | hash, _, err := CreateHash(pass, salt) | 44 | hash, _, err := CreateHash(pass, salt) |
45 | if err != nil { | 45 | if err != nil { |
46 | return false, err | 46 | return false, err |
47 | } | 47 | } |
48 | res := hash == resultHash | 48 | res := hash == resultHash |
49 | return res, nil | 49 | return res, nil |
50 | } | 50 | } |
51 | 51 | ||
52 | // CreateHash hashes str using SHA256. | 52 | // CreateHash hashes str using SHA256. |
53 | // If the presalt parameter is not provided CreateHash will generate new salt string. | 53 | // If the presalt parameter is not provided CreateHash will generate new salt string. |
54 | // Returns hash and salt strings or an error if it fails. | 54 | // Returns hash and salt strings or an error if it fails. |
55 | func CreateHash(str, presalt string) (hash, salt string, err error) { | 55 | func CreateHash(str, presalt string) (hash, salt string, err error) { |
56 | // chech if message is presalted | 56 | // chech if message is presalted |
57 | if presalt == "" { | 57 | if presalt == "" { |
58 | salt, err = randomSalt() | 58 | salt, err = randomSalt() |
59 | if err != nil { | 59 | if err != nil { |
60 | return "", "", err | 60 | return "", "", err |
61 | } | 61 | } |
62 | } else { | 62 | } else { |
63 | salt = presalt | 63 | salt = presalt |
64 | } | 64 | } |
65 | 65 | ||
66 | // convert strings to raw byte slices | 66 | // convert strings to raw byte slices |
67 | rawstr := []byte(str) | 67 | rawstr := []byte(str) |
68 | rawsalt, err := hex.DecodeString(salt) | 68 | rawsalt, err := hex.DecodeString(salt) |
69 | if err != nil { | 69 | if err != nil { |
70 | return "", "", err | 70 | return "", "", err |
71 | } | 71 | } |
72 | 72 | ||
73 | rawdata := make([]byte, len(rawstr)+len(rawsalt)) | 73 | rawdata := make([]byte, len(rawstr)+len(rawsalt)) |
74 | rawdata = append(rawdata, rawstr...) | 74 | rawdata = append(rawdata, rawstr...) |
75 | rawdata = append(rawdata, rawsalt...) | 75 | rawdata = append(rawdata, rawsalt...) |
76 | 76 | ||
77 | // hash message + salt | 77 | // hash message + salt |
78 | hasher := sha256.New() | 78 | hasher := sha256.New() |
79 | hasher.Write(rawdata) | 79 | hasher.Write(rawdata) |
80 | rawhash := hasher.Sum(nil) | 80 | rawhash := hasher.Sum(nil) |
81 | 81 | ||
82 | hash = hex.EncodeToString(rawhash) | 82 | hash = hex.EncodeToString(rawhash) |
83 | return hash, salt, nil | 83 | return hash, salt, nil |
84 | } | 84 | } |
85 | 85 | ||
86 | // CreateAuthToken returns JWT token with encoded username, role, expiration date and issuer claims. | 86 | // CreateAuthToken returns JWT token with encoded username, role, expiration date and issuer claims. |
87 | // It returns an error if it fails. | 87 | // It returns an error if it fails. |
88 | func CreateAuthToken(username string, role Role) (TokenClaims, error) { | 88 | func CreateAuthToken(username string, role Role) (TokenClaims, error) { |
89 | t0 := (time.Now()).Unix() | 89 | t0 := (time.Now()).Unix() |
90 | t1 := (time.Now().Add(time.Hour * 24 * 7)).Unix() | 90 | t1 := (time.Now().Add(time.Hour * 24 * 7)).Unix() |
91 | claims := TokenClaims{ | 91 | claims := TokenClaims{ |
92 | TokenType: "Bearer", | 92 | TokenType: "Bearer", |
93 | Username: username, | 93 | Username: username, |
94 | Role: role.Name, | 94 | Role: role.Name, |
95 | RoleID: role.ID, | 95 | RoleID: role.ID, |
96 | ExpiresIn: t1 - t0, | 96 | ExpiresIn: t1 - t0, |
97 | } | 97 | } |
98 | // initialize jwt.StandardClaims fields (anonymous struct) | 98 | // initialize jwt.StandardClaims fields (anonymous struct) |
99 | claims.IssuedAt = t0 | 99 | claims.IssuedAt = t0 |
100 | claims.ExpiresAt = t1 | 100 | claims.ExpiresAt = t1 |
101 | claims.Issuer = appName | 101 | claims.Issuer = appName |
102 | 102 | ||
103 | jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) | 103 | jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) |
104 | token, err := jwtToken.SignedString([]byte(secret)) | 104 | token, err := jwtToken.SignedString([]byte(secret)) |
105 | if err != nil { | 105 | if err != nil { |
106 | return TokenClaims{}, err | 106 | return TokenClaims{}, err |
107 | } | 107 | } |
108 | claims.Token = token | 108 | claims.Token = token |
109 | return claims, nil | 109 | return claims, nil |
110 | } | 110 | } |
111 | 111 | ||
112 | // RefreshAuthToken returns new JWT token with same claims contained in tok but with prolonged expiration date. | 112 | // RefreshAuthToken returns new JWT token with same claims contained in tok but with prolonged expiration date. |
113 | // It returns an error if it fails. | 113 | // It returns an error if it fails. |
114 | func RefreshAuthToken(tok string) (TokenClaims, error) { | 114 | func RefreshAuthToken(tok string) (TokenClaims, error) { |
115 | token, err := jwt.ParseWithClaims(tok, &TokenClaims{}, secretFunc) | 115 | token, err := jwt.ParseWithClaims(tok, &TokenClaims{}, secretFunc) |
116 | if err != nil { | 116 | if err != nil { |
117 | if validation, ok := err.(*jwt.ValidationError); ok { | 117 | if validation, ok := err.(*jwt.ValidationError); ok { |
118 | // don't return error if token is expired | 118 | // don't return error if token is expired |
119 | // just extend it | 119 | // just extend it |
120 | if !(validation.Errors&jwt.ValidationErrorExpired != 0) { | 120 | if !(validation.Errors&jwt.ValidationErrorExpired != 0) { |
121 | return TokenClaims{}, err | 121 | return TokenClaims{}, err |
122 | } | 122 | } |
123 | } else { | 123 | } else { |
124 | return TokenClaims{}, err | 124 | return TokenClaims{}, err |
125 | } | 125 | } |
126 | } | 126 | } |
127 | 127 | ||
128 | // type assertion | 128 | // type assertion |
129 | claims, ok := token.Claims.(*TokenClaims) | 129 | claims, ok := token.Claims.(*TokenClaims) |
130 | if !ok { | 130 | if !ok { |
131 | return TokenClaims{}, errors.New("token is not valid") | 131 | return TokenClaims{}, errors.New("token is not valid") |
132 | } | 132 | } |
133 | 133 | ||
134 | // extend token expiration date | 134 | // extend token expiration date |
135 | return CreateAuthToken(claims.Username, Role{claims.Role, claims.RoleID}) | 135 | return CreateAuthToken(claims.Username, Role{claims.Role, claims.RoleID}) |
136 | } | 136 | } |
137 | 137 | ||
138 | func AuthCheck(req *http.Request, roles string) (*TokenClaims, error) { | 138 | func AuthCheck(req *http.Request, roles string) (*TokenClaims, error) { |
139 | // validate token and check expiration date | 139 | // validate token and check expiration date |
140 | claims, err := GetTokenClaims(req) | 140 | claims, err := GetTokenClaims(req) |
141 | if err != nil { | 141 | if err != nil { |
142 | return claims, err | 142 | return claims, err |
143 | } | 143 | } |
144 | 144 | ||
145 | if roles == "" { | 145 | if roles == "" { |
146 | return claims, nil | 146 | return claims, nil |
147 | } | 147 | } |
148 | 148 | ||
149 | // check if token has expired | 149 | // check if token has expired |
150 | if claims.ExpiresAt < (time.Now()).Unix() { | 150 | if claims.ExpiresAt < (time.Now()).Unix() { |
151 | return claims, errors.New("token has expired") | 151 | return claims, errors.New("token has expired") |
152 | } | 152 | } |
153 | 153 | ||
154 | if roles == "*" { | 154 | if roles == "*" { |
155 | return claims, nil | 155 | return claims, nil |
156 | } | 156 | } |
157 | 157 | ||
158 | parts := strings.Split(roles, ",") | 158 | parts := strings.Split(roles, ",") |
159 | for i, _ := range parts { | 159 | for i, _ := range parts { |
160 | r := strings.Trim(parts[i], " ") | 160 | r := strings.Trim(parts[i], " ") |
161 | if claims.Role == r { | 161 | if claims.Role == r { |
162 | return claims, nil | 162 | return claims, nil |
163 | } | 163 | } |
164 | } | 164 | } |
165 | 165 | ||
166 | return claims, nil | 166 | return claims, nil |
167 | } | 167 | } |
168 | 168 | ||
169 | // GetTokenClaims extracts JWT claims from Authorization header of req. | 169 | // GetTokenClaims extracts JWT claims from Authorization header of req. |
170 | // Returns token claims or an error. | 170 | // Returns token claims or an error. |
171 | func GetTokenClaims(req *http.Request) (*TokenClaims, error) { | 171 | func GetTokenClaims(req *http.Request) (*TokenClaims, error) { |
172 | // check for and strip 'Bearer' prefix | 172 | // check for and strip 'Bearer' prefix |
173 | var tokstr string | 173 | var tokstr string |
174 | authHead := req.Header.Get("Authorization") | 174 | authHead := req.Header.Get("Authorization") |
175 | if ok := strings.HasPrefix(authHead, "Bearer "); ok { | 175 | if ok := strings.HasPrefix(authHead, "Bearer "); ok { |
176 | tokstr = strings.TrimPrefix(authHead, "Bearer ") | 176 | tokstr = strings.TrimPrefix(authHead, "Bearer ") |
177 | } else { | 177 | } else { |
178 | return &TokenClaims{}, errors.New("authorization header in incomplete") | 178 | return &TokenClaims{}, errors.New("authorization header in incomplete") |
179 | } | 179 | } |
180 | 180 | ||
181 | token, err := jwt.ParseWithClaims(tokstr, &TokenClaims{}, secretFunc) | 181 | token, err := jwt.ParseWithClaims(tokstr, &TokenClaims{}, secretFunc) |
182 | if err != nil { | 182 | if err != nil { |
183 | return &TokenClaims{}, err | 183 | return &TokenClaims{}, err |
184 | } | 184 | } |
185 | 185 | ||
186 | // type assertion | 186 | // type assertion |
187 | claims, ok := token.Claims.(*TokenClaims) | 187 | claims, ok := token.Claims.(*TokenClaims) |
188 | if !ok || !token.Valid { | 188 | if !ok || !token.Valid { |
189 | return &TokenClaims{}, errors.New("token is not valid") | 189 | return &TokenClaims{}, errors.New("token is not valid") |
190 | } | 190 | } |
191 | 191 | ||
192 | return claims, nil | 192 | return claims, nil |
193 | } | 193 | } |
194 | 194 | ||
195 | // randomSalt returns a string of 32 random characters. | 195 | // randomSalt returns a string of 32 random characters. |
196 | const saltSize = 32 | 196 | const saltSize = 32 |
197 | 197 | ||
198 | func randomSalt() (s string, err error) { | 198 | func randomSalt() (s string, err error) { |
199 | rawsalt := make([]byte, saltSize) | 199 | rawsalt := make([]byte, saltSize) |
200 | 200 | ||
201 | _, err = rand.Read(rawsalt) | 201 | _, err = rand.Read(rawsalt) |
202 | if err != nil { | 202 | if err != nil { |
203 | return "", err | 203 | return "", err |
204 | } | 204 | } |
205 | 205 | ||
206 | s = hex.EncodeToString(rawsalt) | 206 | s = hex.EncodeToString(rawsalt) |
207 | return s, nil | 207 | return s, nil |
208 | } | 208 | } |
209 | 209 | ||
210 | // secretFunc returns byte slice of API secret keyword. | 210 | // secretFunc returns byte slice of API secret keyword. |
211 | func secretFunc(token *jwt.Token) (interface{}, error) { | 211 | func secretFunc(token *jwt.Token) (interface{}, error) { |
212 | return []byte(secret), nil | 212 | return []byte(secret), nil |
213 | } | 213 | } |
214 | 214 |