Commit fbf92700f4c7403d1086f3e97373a686081cfe34
1 parent
2b62f61cc8
Exists in
master
renamed package variables
Showing
1 changed file
with
8 additions
and
8 deletions
Show diff stats
auth.go
1 | package webutility | 1 | package webutility |
2 | 2 | ||
3 | import ( | 3 | import ( |
4 | "crypto/rand" | 4 | "crypto/rand" |
5 | "crypto/sha256" | 5 | "crypto/sha256" |
6 | "encoding/hex" | 6 | "encoding/hex" |
7 | "errors" | 7 | "errors" |
8 | "net/http" | 8 | "net/http" |
9 | "strings" | 9 | "strings" |
10 | "time" | 10 | "time" |
11 | 11 | ||
12 | "github.com/dgrijalva/jwt-go" | 12 | "github.com/dgrijalva/jwt-go" |
13 | ) | 13 | ) |
14 | 14 | ||
15 | var appName = "webutility" | 15 | var _issuer = "webutility" |
16 | var secret = "webutility" | 16 | var _secret = "webutility" |
17 | 17 | ||
18 | // TokenClaims are JWT token claims. | 18 | // TokenClaims are JWT token claims. |
19 | type TokenClaims struct { | 19 | type TokenClaims struct { |
20 | // extending a struct | 20 | // extending a struct |
21 | jwt.StandardClaims | 21 | jwt.StandardClaims |
22 | 22 | ||
23 | // custom claims | 23 | // custom claims |
24 | Token string `json:"access_token"` | 24 | Token string `json:"access_token"` |
25 | TokenType string `json:"token_type"` | 25 | TokenType string `json:"token_type"` |
26 | Username string `json:"username"` | 26 | Username string `json:"username"` |
27 | RoleName string `json:"role"` | 27 | RoleName string `json:"role"` |
28 | RoleID int64 `json:"role_id"` | 28 | RoleID int64 `json:"role_id"` |
29 | ExpiresIn int64 `json:"expires_in"` | 29 | ExpiresIn int64 `json:"expires_in"` |
30 | } | 30 | } |
31 | 31 | ||
32 | func InitJWT(appName, secret string) { | 32 | func InitJWT(issuer, secret string) { |
33 | appName = appName | 33 | _issuer = issuer |
34 | secret = secret | 34 | _secret = secret |
35 | } | 35 | } |
36 | 36 | ||
37 | // ValidateHash hashes pass and salt and returns comparison result with resultHash | 37 | // ValidateHash hashes pass and salt and returns comparison result with resultHash |
38 | func ValidateHash(pass, salt, resultHash string) (bool, error) { | 38 | func ValidateHash(pass, salt, resultHash string) (bool, error) { |
39 | hash, _, err := CreateHash(pass, salt) | 39 | hash, _, err := CreateHash(pass, salt) |
40 | if err != nil { | 40 | if err != nil { |
41 | return false, err | 41 | return false, err |
42 | } | 42 | } |
43 | res := hash == resultHash | 43 | res := hash == resultHash |
44 | return res, nil | 44 | return res, nil |
45 | } | 45 | } |
46 | 46 | ||
47 | // CreateHash hashes str using SHA256. | 47 | // CreateHash hashes str using SHA256. |
48 | // If the presalt parameter is not provided CreateHash will generate new salt string. | 48 | // If the presalt parameter is not provided CreateHash will generate new salt string. |
49 | // Returns hash and salt strings or an error if it fails. | 49 | // Returns hash and salt strings or an error if it fails. |
50 | func CreateHash(str, presalt string) (hash, salt string, err error) { | 50 | func CreateHash(str, presalt string) (hash, salt string, err error) { |
51 | // chech if message is presalted | 51 | // chech if message is presalted |
52 | if presalt == "" { | 52 | if presalt == "" { |
53 | salt, err = randomSalt() | 53 | salt, err = randomSalt() |
54 | if err != nil { | 54 | if err != nil { |
55 | return "", "", err | 55 | return "", "", err |
56 | } | 56 | } |
57 | } else { | 57 | } else { |
58 | salt = presalt | 58 | salt = presalt |
59 | } | 59 | } |
60 | 60 | ||
61 | // convert strings to raw byte slices | 61 | // convert strings to raw byte slices |
62 | rawstr := []byte(str) | 62 | rawstr := []byte(str) |
63 | rawsalt, err := hex.DecodeString(salt) | 63 | rawsalt, err := hex.DecodeString(salt) |
64 | if err != nil { | 64 | if err != nil { |
65 | return "", "", err | 65 | return "", "", err |
66 | } | 66 | } |
67 | 67 | ||
68 | rawdata := make([]byte, len(rawstr)+len(rawsalt)) | 68 | rawdata := make([]byte, len(rawstr)+len(rawsalt)) |
69 | rawdata = append(rawdata, rawstr...) | 69 | rawdata = append(rawdata, rawstr...) |
70 | rawdata = append(rawdata, rawsalt...) | 70 | rawdata = append(rawdata, rawsalt...) |
71 | 71 | ||
72 | // hash message + salt | 72 | // hash message + salt |
73 | hasher := sha256.New() | 73 | hasher := sha256.New() |
74 | hasher.Write(rawdata) | 74 | hasher.Write(rawdata) |
75 | rawhash := hasher.Sum(nil) | 75 | rawhash := hasher.Sum(nil) |
76 | 76 | ||
77 | hash = hex.EncodeToString(rawhash) | 77 | hash = hex.EncodeToString(rawhash) |
78 | return hash, salt, nil | 78 | return hash, salt, nil |
79 | } | 79 | } |
80 | 80 | ||
81 | // CreateAuthToken returns JWT token with encoded username, role, expiration date and issuer claims. | 81 | // CreateAuthToken returns JWT token with encoded username, role, expiration date and issuer claims. |
82 | // It returns an error if it fails. | 82 | // It returns an error if it fails. |
83 | func CreateAuthToken(username string, roleName string, roleID int64) (TokenClaims, error) { | 83 | func CreateAuthToken(username string, roleName string, roleID int64) (TokenClaims, error) { |
84 | t0 := (time.Now()).Unix() | 84 | t0 := (time.Now()).Unix() |
85 | t1 := (time.Now().Add(time.Hour * 24 * 7)).Unix() | 85 | t1 := (time.Now().Add(time.Hour * 24 * 7)).Unix() |
86 | claims := TokenClaims{ | 86 | claims := TokenClaims{ |
87 | TokenType: "Bearer", | 87 | TokenType: "Bearer", |
88 | Username: username, | 88 | Username: username, |
89 | RoleName: roleName, | 89 | RoleName: roleName, |
90 | RoleID: roleID, | 90 | RoleID: roleID, |
91 | ExpiresIn: t1 - t0, | 91 | ExpiresIn: t1 - t0, |
92 | } | 92 | } |
93 | // initialize jwt.StandardClaims fields (anonymous struct) | 93 | // initialize jwt.StandardClaims fields (anonymous struct) |
94 | claims.IssuedAt = t0 | 94 | claims.IssuedAt = t0 |
95 | claims.ExpiresAt = t1 | 95 | claims.ExpiresAt = t1 |
96 | claims.Issuer = appName | 96 | claims.Issuer = _issuer |
97 | 97 | ||
98 | jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) | 98 | jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) |
99 | token, err := jwtToken.SignedString([]byte(secret)) | 99 | token, err := jwtToken.SignedString([]byte(_secret)) |
100 | if err != nil { | 100 | if err != nil { |
101 | return TokenClaims{}, err | 101 | return TokenClaims{}, err |
102 | } | 102 | } |
103 | claims.Token = token | 103 | claims.Token = token |
104 | return claims, nil | 104 | return claims, nil |
105 | } | 105 | } |
106 | 106 | ||
107 | // RefreshAuthToken returns new JWT token with same claims contained in tok but with prolonged expiration date. | 107 | // RefreshAuthToken returns new JWT token with same claims contained in tok but with prolonged expiration date. |
108 | // It returns an error if it fails. | 108 | // It returns an error if it fails. |
109 | func RefreshAuthToken(tok string) (TokenClaims, error) { | 109 | func RefreshAuthToken(tok string) (TokenClaims, error) { |
110 | token, err := jwt.ParseWithClaims(tok, &TokenClaims{}, secretFunc) | 110 | token, err := jwt.ParseWithClaims(tok, &TokenClaims{}, secretFunc) |
111 | if err != nil { | 111 | if err != nil { |
112 | if validation, ok := err.(*jwt.ValidationError); ok { | 112 | if validation, ok := err.(*jwt.ValidationError); ok { |
113 | // don't return error if token is expired | 113 | // don't return error if token is expired |
114 | // just extend it | 114 | // just extend it |
115 | if !(validation.Errors&jwt.ValidationErrorExpired != 0) { | 115 | if !(validation.Errors&jwt.ValidationErrorExpired != 0) { |
116 | return TokenClaims{}, err | 116 | return TokenClaims{}, err |
117 | } | 117 | } |
118 | } else { | 118 | } else { |
119 | return TokenClaims{}, err | 119 | return TokenClaims{}, err |
120 | } | 120 | } |
121 | } | 121 | } |
122 | 122 | ||
123 | // type assertion | 123 | // type assertion |
124 | claims, ok := token.Claims.(*TokenClaims) | 124 | claims, ok := token.Claims.(*TokenClaims) |
125 | if !ok { | 125 | if !ok { |
126 | return TokenClaims{}, errors.New("token is not valid") | 126 | return TokenClaims{}, errors.New("token is not valid") |
127 | } | 127 | } |
128 | 128 | ||
129 | // extend token expiration date | 129 | // extend token expiration date |
130 | return CreateAuthToken(claims.Username, claims.RoleName, claims.RoleID) | 130 | return CreateAuthToken(claims.Username, claims.RoleName, claims.RoleID) |
131 | } | 131 | } |
132 | 132 | ||
133 | func AuthCheck(req *http.Request, roles string) (*TokenClaims, error) { | 133 | func AuthCheck(req *http.Request, roles string) (*TokenClaims, error) { |
134 | // validate token and check expiration date | 134 | // validate token and check expiration date |
135 | claims, err := GetTokenClaims(req) | 135 | claims, err := GetTokenClaims(req) |
136 | if err != nil { | 136 | if err != nil { |
137 | return claims, err | 137 | return claims, err |
138 | } | 138 | } |
139 | 139 | ||
140 | if roles == "" { | 140 | if roles == "" { |
141 | return claims, nil | 141 | return claims, nil |
142 | } | 142 | } |
143 | 143 | ||
144 | // check if token has expired | 144 | // check if token has expired |
145 | if claims.ExpiresAt < (time.Now()).Unix() { | 145 | if claims.ExpiresAt < (time.Now()).Unix() { |
146 | return claims, errors.New("token has expired") | 146 | return claims, errors.New("token has expired") |
147 | } | 147 | } |
148 | 148 | ||
149 | if roles == "*" { | 149 | if roles == "*" { |
150 | return claims, nil | 150 | return claims, nil |
151 | } | 151 | } |
152 | 152 | ||
153 | parts := strings.Split(roles, ",") | 153 | parts := strings.Split(roles, ",") |
154 | for i, _ := range parts { | 154 | for i, _ := range parts { |
155 | r := strings.Trim(parts[i], " ") | 155 | r := strings.Trim(parts[i], " ") |
156 | if claims.RoleName == r { | 156 | if claims.RoleName == r { |
157 | return claims, nil | 157 | return claims, nil |
158 | } | 158 | } |
159 | } | 159 | } |
160 | 160 | ||
161 | return claims, nil | 161 | return claims, nil |
162 | } | 162 | } |
163 | 163 | ||
164 | // GetTokenClaims extracts JWT claims from Authorization header of req. | 164 | // GetTokenClaims extracts JWT claims from Authorization header of req. |
165 | // Returns token claims or an error. | 165 | // Returns token claims or an error. |
166 | func GetTokenClaims(req *http.Request) (*TokenClaims, error) { | 166 | func GetTokenClaims(req *http.Request) (*TokenClaims, error) { |
167 | // check for and strip 'Bearer' prefix | 167 | // check for and strip 'Bearer' prefix |
168 | var tokstr string | 168 | var tokstr string |
169 | authHead := req.Header.Get("Authorization") | 169 | authHead := req.Header.Get("Authorization") |
170 | if ok := strings.HasPrefix(authHead, "Bearer "); ok { | 170 | if ok := strings.HasPrefix(authHead, "Bearer "); ok { |
171 | tokstr = strings.TrimPrefix(authHead, "Bearer ") | 171 | tokstr = strings.TrimPrefix(authHead, "Bearer ") |
172 | } else { | 172 | } else { |
173 | return &TokenClaims{}, errors.New("authorization header in incomplete") | 173 | return &TokenClaims{}, errors.New("authorization header in incomplete") |
174 | } | 174 | } |
175 | 175 | ||
176 | token, err := jwt.ParseWithClaims(tokstr, &TokenClaims{}, secretFunc) | 176 | token, err := jwt.ParseWithClaims(tokstr, &TokenClaims{}, secretFunc) |
177 | if err != nil { | 177 | if err != nil { |
178 | return &TokenClaims{}, err | 178 | return &TokenClaims{}, err |
179 | } | 179 | } |
180 | 180 | ||
181 | // type assertion | 181 | // type assertion |
182 | claims, ok := token.Claims.(*TokenClaims) | 182 | claims, ok := token.Claims.(*TokenClaims) |
183 | if !ok || !token.Valid { | 183 | if !ok || !token.Valid { |
184 | return &TokenClaims{}, errors.New("token is not valid") | 184 | return &TokenClaims{}, errors.New("token is not valid") |
185 | } | 185 | } |
186 | 186 | ||
187 | return claims, nil | 187 | return claims, nil |
188 | } | 188 | } |
189 | 189 | ||
190 | // randomSalt returns a string of 32 random characters. | 190 | // randomSalt returns a string of 32 random characters. |
191 | const saltSize = 32 | 191 | const saltSize = 32 |
192 | 192 | ||
193 | func randomSalt() (s string, err error) { | 193 | func randomSalt() (s string, err error) { |
194 | rawsalt := make([]byte, saltSize) | 194 | rawsalt := make([]byte, saltSize) |
195 | 195 | ||
196 | _, err = rand.Read(rawsalt) | 196 | _, err = rand.Read(rawsalt) |
197 | if err != nil { | 197 | if err != nil { |
198 | return "", err | 198 | return "", err |
199 | } | 199 | } |
200 | 200 | ||
201 | s = hex.EncodeToString(rawsalt) | 201 | s = hex.EncodeToString(rawsalt) |
202 | return s, nil | 202 | return s, nil |
203 | } | 203 | } |
204 | 204 | ||
205 | // secretFunc returns byte slice of API secret keyword. | 205 | // secretFunc returns byte slice of API secret keyword. |
206 | func secretFunc(token *jwt.Token) (interface{}, error) { | 206 | func secretFunc(token *jwt.Token) (interface{}, error) { |
207 | return []byte(secret), nil | 207 | return []byte(_secret), nil |
208 | } | 208 | } |
209 | 209 |