Commit de25e1deb6dc5ea3eac3668d0f05500763dfb3dc
1 parent
0feac50590
Exists in
master
removed redundant role struct
Showing
1 changed file
with
6 additions
and
11 deletions
Show diff stats
auth.go
1 | package webutility | 1 | package webutility |
2 | 2 | ||
3 | import ( | 3 | import ( |
4 | "crypto/rand" | 4 | "crypto/rand" |
5 | "crypto/sha256" | 5 | "crypto/sha256" |
6 | "encoding/hex" | 6 | "encoding/hex" |
7 | "errors" | 7 | "errors" |
8 | "net/http" | 8 | "net/http" |
9 | "strings" | 9 | "strings" |
10 | "time" | 10 | "time" |
11 | 11 | ||
12 | "github.com/dgrijalva/jwt-go" | 12 | "github.com/dgrijalva/jwt-go" |
13 | ) | 13 | ) |
14 | 14 | ||
15 | var appName = "webutility" | 15 | var appName = "webutility" |
16 | var secret = "webutility" | 16 | var secret = "webutility" |
17 | 17 | ||
18 | type Role struct { | ||
19 | Name string `json:"name"` | ||
20 | ID int64 `json:"id"` | ||
21 | } | ||
22 | |||
23 | // TokenClaims are JWT token claims. | 18 | // TokenClaims are JWT token claims. |
24 | type TokenClaims struct { | 19 | type TokenClaims struct { |
25 | // extending a struct | 20 | // extending a struct |
26 | jwt.StandardClaims | 21 | jwt.StandardClaims |
27 | 22 | ||
28 | // custom claims | 23 | // custom claims |
29 | Token string `json:"access_token"` | 24 | Token string `json:"access_token"` |
30 | TokenType string `json:"token_type"` | 25 | TokenType string `json:"token_type"` |
31 | Username string `json:"username"` | 26 | Username string `json:"username"` |
32 | Role string `json:"role"` | 27 | RoleName string `json:"role"` |
33 | RoleID int64 `json:"role_id"` | 28 | RoleID int64 `json:"role_id"` |
34 | ExpiresIn int64 `json:"expires_in"` | 29 | ExpiresIn int64 `json:"expires_in"` |
35 | } | 30 | } |
36 | 31 | ||
37 | func InitJWT(appName, secret string) { | 32 | func InitJWT(appName, secret string) { |
38 | appName = appName | 33 | appName = appName |
39 | secret = secret | 34 | secret = secret |
40 | } | 35 | } |
41 | 36 | ||
42 | // ValidateHash hashes pass and salt and returns comparison result with resultHash | 37 | // ValidateHash hashes pass and salt and returns comparison result with resultHash |
43 | func ValidateHash(pass, salt, resultHash string) (bool, error) { | 38 | func ValidateHash(pass, salt, resultHash string) (bool, error) { |
44 | hash, _, err := CreateHash(pass, salt) | 39 | hash, _, err := CreateHash(pass, salt) |
45 | if err != nil { | 40 | if err != nil { |
46 | return false, err | 41 | return false, err |
47 | } | 42 | } |
48 | res := hash == resultHash | 43 | res := hash == resultHash |
49 | return res, nil | 44 | return res, nil |
50 | } | 45 | } |
51 | 46 | ||
52 | // CreateHash hashes str using SHA256. | 47 | // CreateHash hashes str using SHA256. |
53 | // If the presalt parameter is not provided CreateHash will generate new salt string. | 48 | // If the presalt parameter is not provided CreateHash will generate new salt string. |
54 | // Returns hash and salt strings or an error if it fails. | 49 | // Returns hash and salt strings or an error if it fails. |
55 | func CreateHash(str, presalt string) (hash, salt string, err error) { | 50 | func CreateHash(str, presalt string) (hash, salt string, err error) { |
56 | // chech if message is presalted | 51 | // chech if message is presalted |
57 | if presalt == "" { | 52 | if presalt == "" { |
58 | salt, err = randomSalt() | 53 | salt, err = randomSalt() |
59 | if err != nil { | 54 | if err != nil { |
60 | return "", "", err | 55 | return "", "", err |
61 | } | 56 | } |
62 | } else { | 57 | } else { |
63 | salt = presalt | 58 | salt = presalt |
64 | } | 59 | } |
65 | 60 | ||
66 | // convert strings to raw byte slices | 61 | // convert strings to raw byte slices |
67 | rawstr := []byte(str) | 62 | rawstr := []byte(str) |
68 | rawsalt, err := hex.DecodeString(salt) | 63 | rawsalt, err := hex.DecodeString(salt) |
69 | if err != nil { | 64 | if err != nil { |
70 | return "", "", err | 65 | return "", "", err |
71 | } | 66 | } |
72 | 67 | ||
73 | rawdata := make([]byte, len(rawstr)+len(rawsalt)) | 68 | rawdata := make([]byte, len(rawstr)+len(rawsalt)) |
74 | rawdata = append(rawdata, rawstr...) | 69 | rawdata = append(rawdata, rawstr...) |
75 | rawdata = append(rawdata, rawsalt...) | 70 | rawdata = append(rawdata, rawsalt...) |
76 | 71 | ||
77 | // hash message + salt | 72 | // hash message + salt |
78 | hasher := sha256.New() | 73 | hasher := sha256.New() |
79 | hasher.Write(rawdata) | 74 | hasher.Write(rawdata) |
80 | rawhash := hasher.Sum(nil) | 75 | rawhash := hasher.Sum(nil) |
81 | 76 | ||
82 | hash = hex.EncodeToString(rawhash) | 77 | hash = hex.EncodeToString(rawhash) |
83 | return hash, salt, nil | 78 | return hash, salt, nil |
84 | } | 79 | } |
85 | 80 | ||
86 | // CreateAuthToken returns JWT token with encoded username, role, expiration date and issuer claims. | 81 | // CreateAuthToken returns JWT token with encoded username, role, expiration date and issuer claims. |
87 | // It returns an error if it fails. | 82 | // It returns an error if it fails. |
88 | func CreateAuthToken(username string, role Role) (TokenClaims, error) { | 83 | func CreateAuthToken(username string, roleName string, roleID int64) (TokenClaims, error) { |
89 | t0 := (time.Now()).Unix() | 84 | t0 := (time.Now()).Unix() |
90 | t1 := (time.Now().Add(time.Hour * 24 * 7)).Unix() | 85 | t1 := (time.Now().Add(time.Hour * 24 * 7)).Unix() |
91 | claims := TokenClaims{ | 86 | claims := TokenClaims{ |
92 | TokenType: "Bearer", | 87 | TokenType: "Bearer", |
93 | Username: username, | 88 | Username: username, |
94 | Role: role.Name, | 89 | RoleName: roleName, |
95 | RoleID: role.ID, | 90 | RoleID: roleID, |
96 | ExpiresIn: t1 - t0, | 91 | ExpiresIn: t1 - t0, |
97 | } | 92 | } |
98 | // initialize jwt.StandardClaims fields (anonymous struct) | 93 | // initialize jwt.StandardClaims fields (anonymous struct) |
99 | claims.IssuedAt = t0 | 94 | claims.IssuedAt = t0 |
100 | claims.ExpiresAt = t1 | 95 | claims.ExpiresAt = t1 |
101 | claims.Issuer = appName | 96 | claims.Issuer = appName |
102 | 97 | ||
103 | jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) | 98 | jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) |
104 | token, err := jwtToken.SignedString([]byte(secret)) | 99 | token, err := jwtToken.SignedString([]byte(secret)) |
105 | if err != nil { | 100 | if err != nil { |
106 | return TokenClaims{}, err | 101 | return TokenClaims{}, err |
107 | } | 102 | } |
108 | claims.Token = token | 103 | claims.Token = token |
109 | return claims, nil | 104 | return claims, nil |
110 | } | 105 | } |
111 | 106 | ||
112 | // RefreshAuthToken returns new JWT token with same claims contained in tok but with prolonged expiration date. | 107 | // RefreshAuthToken returns new JWT token with same claims contained in tok but with prolonged expiration date. |
113 | // It returns an error if it fails. | 108 | // It returns an error if it fails. |
114 | func RefreshAuthToken(tok string) (TokenClaims, error) { | 109 | func RefreshAuthToken(tok string) (TokenClaims, error) { |
115 | token, err := jwt.ParseWithClaims(tok, &TokenClaims{}, secretFunc) | 110 | token, err := jwt.ParseWithClaims(tok, &TokenClaims{}, secretFunc) |
116 | if err != nil { | 111 | if err != nil { |
117 | if validation, ok := err.(*jwt.ValidationError); ok { | 112 | if validation, ok := err.(*jwt.ValidationError); ok { |
118 | // don't return error if token is expired | 113 | // don't return error if token is expired |
119 | // just extend it | 114 | // just extend it |
120 | if !(validation.Errors&jwt.ValidationErrorExpired != 0) { | 115 | if !(validation.Errors&jwt.ValidationErrorExpired != 0) { |
121 | return TokenClaims{}, err | 116 | return TokenClaims{}, err |
122 | } | 117 | } |
123 | } else { | 118 | } else { |
124 | return TokenClaims{}, err | 119 | return TokenClaims{}, err |
125 | } | 120 | } |
126 | } | 121 | } |
127 | 122 | ||
128 | // type assertion | 123 | // type assertion |
129 | claims, ok := token.Claims.(*TokenClaims) | 124 | claims, ok := token.Claims.(*TokenClaims) |
130 | if !ok { | 125 | if !ok { |
131 | return TokenClaims{}, errors.New("token is not valid") | 126 | return TokenClaims{}, errors.New("token is not valid") |
132 | } | 127 | } |
133 | 128 | ||
134 | // extend token expiration date | 129 | // extend token expiration date |
135 | return CreateAuthToken(claims.Username, Role{claims.Role, claims.RoleID}) | 130 | return CreateAuthToken(claims.Username, claims.RoleName, claims.RoleID) |
136 | } | 131 | } |
137 | 132 | ||
138 | func AuthCheck(req *http.Request, roles string) (*TokenClaims, error) { | 133 | func AuthCheck(req *http.Request, roles string) (*TokenClaims, error) { |
139 | // validate token and check expiration date | 134 | // validate token and check expiration date |
140 | claims, err := GetTokenClaims(req) | 135 | claims, err := GetTokenClaims(req) |
141 | if err != nil { | 136 | if err != nil { |
142 | return claims, err | 137 | return claims, err |
143 | } | 138 | } |
144 | 139 | ||
145 | if roles == "" { | 140 | if roles == "" { |
146 | return claims, nil | 141 | return claims, nil |
147 | } | 142 | } |
148 | 143 | ||
149 | // check if token has expired | 144 | // check if token has expired |
150 | if claims.ExpiresAt < (time.Now()).Unix() { | 145 | if claims.ExpiresAt < (time.Now()).Unix() { |
151 | return claims, errors.New("token has expired") | 146 | return claims, errors.New("token has expired") |
152 | } | 147 | } |
153 | 148 | ||
154 | if roles == "*" { | 149 | if roles == "*" { |
155 | return claims, nil | 150 | return claims, nil |
156 | } | 151 | } |
157 | 152 | ||
158 | parts := strings.Split(roles, ",") | 153 | parts := strings.Split(roles, ",") |
159 | for i, _ := range parts { | 154 | for i, _ := range parts { |
160 | r := strings.Trim(parts[i], " ") | 155 | r := strings.Trim(parts[i], " ") |
161 | if claims.Role == r { | 156 | if claims.RoleName == r { |
162 | return claims, nil | 157 | return claims, nil |
163 | } | 158 | } |
164 | } | 159 | } |
165 | 160 | ||
166 | return claims, nil | 161 | return claims, nil |
167 | } | 162 | } |
168 | 163 | ||
169 | // GetTokenClaims extracts JWT claims from Authorization header of req. | 164 | // GetTokenClaims extracts JWT claims from Authorization header of req. |
170 | // Returns token claims or an error. | 165 | // Returns token claims or an error. |
171 | func GetTokenClaims(req *http.Request) (*TokenClaims, error) { | 166 | func GetTokenClaims(req *http.Request) (*TokenClaims, error) { |
172 | // check for and strip 'Bearer' prefix | 167 | // check for and strip 'Bearer' prefix |
173 | var tokstr string | 168 | var tokstr string |
174 | authHead := req.Header.Get("Authorization") | 169 | authHead := req.Header.Get("Authorization") |
175 | if ok := strings.HasPrefix(authHead, "Bearer "); ok { | 170 | if ok := strings.HasPrefix(authHead, "Bearer "); ok { |
176 | tokstr = strings.TrimPrefix(authHead, "Bearer ") | 171 | tokstr = strings.TrimPrefix(authHead, "Bearer ") |
177 | } else { | 172 | } else { |
178 | return &TokenClaims{}, errors.New("authorization header in incomplete") | 173 | return &TokenClaims{}, errors.New("authorization header in incomplete") |
179 | } | 174 | } |
180 | 175 | ||
181 | token, err := jwt.ParseWithClaims(tokstr, &TokenClaims{}, secretFunc) | 176 | token, err := jwt.ParseWithClaims(tokstr, &TokenClaims{}, secretFunc) |
182 | if err != nil { | 177 | if err != nil { |
183 | return &TokenClaims{}, err | 178 | return &TokenClaims{}, err |
184 | } | 179 | } |
185 | 180 | ||
186 | // type assertion | 181 | // type assertion |
187 | claims, ok := token.Claims.(*TokenClaims) | 182 | claims, ok := token.Claims.(*TokenClaims) |
188 | if !ok || !token.Valid { | 183 | if !ok || !token.Valid { |
189 | return &TokenClaims{}, errors.New("token is not valid") | 184 | return &TokenClaims{}, errors.New("token is not valid") |
190 | } | 185 | } |
191 | 186 | ||
192 | return claims, nil | 187 | return claims, nil |
193 | } | 188 | } |
194 | 189 | ||
195 | // randomSalt returns a string of 32 random characters. | 190 | // randomSalt returns a string of 32 random characters. |
196 | const saltSize = 32 | 191 | const saltSize = 32 |
197 | 192 | ||
198 | func randomSalt() (s string, err error) { | 193 | func randomSalt() (s string, err error) { |
199 | rawsalt := make([]byte, saltSize) | 194 | rawsalt := make([]byte, saltSize) |
200 | 195 | ||
201 | _, err = rand.Read(rawsalt) | 196 | _, err = rand.Read(rawsalt) |
202 | if err != nil { | 197 | if err != nil { |
203 | return "", err | 198 | return "", err |
204 | } | 199 | } |
205 | 200 | ||
206 | s = hex.EncodeToString(rawsalt) | 201 | s = hex.EncodeToString(rawsalt) |
207 | return s, nil | 202 | return s, nil |
208 | } | 203 | } |
209 | 204 | ||
210 | // secretFunc returns byte slice of API secret keyword. | 205 | // secretFunc returns byte slice of API secret keyword. |
211 | func secretFunc(token *jwt.Token) (interface{}, error) { | 206 | func secretFunc(token *jwt.Token) (interface{}, error) { |
212 | return []byte(secret), nil | 207 | return []byte(secret), nil |
213 | } | 208 | } |
214 | 209 |