Commit 65d214f47d49a6a5637598d8fac7523f5fb40fa8

Authored by Marko Tikvić
1 parent 32a277faa6
Exists in master

improved middleware: more control over content type

1 package webutility 1 package webutility
2 2
3 import ( 3 import (
4 "crypto/rand" 4 "crypto/rand"
5 "crypto/sha256" 5 "crypto/sha256"
6 "encoding/hex" 6 "encoding/hex"
7 "errors" 7 "errors"
8 "net/http" 8 "net/http"
9 "strings" 9 "strings"
10 "time" 10 "time"
11 11
12 "github.com/dgrijalva/jwt-go" 12 "github.com/dgrijalva/jwt-go"
13 ) 13 )
14 14
15 var _issuer = "webutility" 15 var _issuer = "webutility"
16 var _secret = "webutility" 16 var _secret = "webutility"
17 17
18 // TokenClaims are JWT token claims. 18 // TokenClaims are JWT token claims.
19 type TokenClaims struct { 19 type TokenClaims struct {
20 // extending a struct 20 // extending a struct
21 jwt.StandardClaims 21 jwt.StandardClaims
22 22
23 // custom claims 23 // custom claims
24 Token string `json:"access_token"` 24 Token string `json:"access_token"`
25 TokenType string `json:"token_type"` 25 TokenType string `json:"token_type"`
26 Username string `json:"username"` 26 Username string `json:"username"`
27 RoleName string `json:"role"` 27 RoleName string `json:"role"`
28 RoleID int64 `json:"role_id"` 28 RoleID int64 `json:"role_id"`
29 ExpiresIn int64 `json:"expires_in"` 29 ExpiresIn int64 `json:"expires_in"`
30 } 30 }
31 31
32 // InitJWT ... 32 // InitJWT ...
33 func InitJWT(issuer, secret string) { 33 func InitJWT(issuer, secret string) {
34 _issuer = issuer 34 _issuer = issuer
35 _secret = secret 35 _secret = secret
36 } 36 }
37 37
38 // ValidateHash hashes pass and salt and returns comparison result with resultHash 38 // ValidateHash hashes pass and salt and returns comparison result with resultHash
39 func ValidateHash(pass, salt, resultHash string) (bool, error) { 39 func ValidateHash(pass, salt, resultHash string) (bool, error) {
40 hash, _, err := CreateHash(pass, salt) 40 hash, _, err := CreateHash(pass, salt)
41 if err != nil { 41 if err != nil {
42 return false, err 42 return false, err
43 } 43 }
44 res := hash == resultHash 44 res := hash == resultHash
45 return res, nil 45 return res, nil
46 } 46 }
47 47
48 // CreateHash hashes str using SHA256. 48 // CreateHash hashes str using SHA256.
49 // If the presalt parameter is not provided CreateHash will generate new salt string. 49 // If the presalt parameter is not provided CreateHash will generate new salt string.
50 // Returns hash and salt strings or an error if it fails. 50 // Returns hash and salt strings or an error if it fails.
51 func CreateHash(str, presalt string) (hash, salt string, err error) { 51 func CreateHash(str, presalt string) (hash, salt string, err error) {
52 // chech if message is presalted 52 // chech if message is presalted
53 if presalt == "" { 53 if presalt == "" {
54 salt, err = randomSalt() 54 salt, err = randomSalt()
55 if err != nil { 55 if err != nil {
56 return "", "", err 56 return "", "", err
57 } 57 }
58 } else { 58 } else {
59 salt = presalt 59 salt = presalt
60 } 60 }
61 61
62 // convert strings to raw byte slices 62 // convert strings to raw byte slices
63 rawstr := []byte(str) 63 rawstr := []byte(str)
64 rawsalt, err := hex.DecodeString(salt) 64 rawsalt, err := hex.DecodeString(salt)
65 if err != nil { 65 if err != nil {
66 return "", "", err 66 return "", "", err
67 } 67 }
68 68
69 rawdata := make([]byte, len(rawstr)+len(rawsalt)) 69 rawdata := make([]byte, len(rawstr)+len(rawsalt))
70 rawdata = append(rawdata, rawstr...) 70 rawdata = append(rawdata, rawstr...)
71 rawdata = append(rawdata, rawsalt...) 71 rawdata = append(rawdata, rawsalt...)
72 72
73 // hash message + salt 73 // hash message + salt
74 hasher := sha256.New() 74 hasher := sha256.New()
75 hasher.Write(rawdata) 75 hasher.Write(rawdata)
76 rawhash := hasher.Sum(nil) 76 rawhash := hasher.Sum(nil)
77 77
78 hash = hex.EncodeToString(rawhash) 78 hash = hex.EncodeToString(rawhash)
79 return hash, salt, nil 79 return hash, salt, nil
80 } 80 }
81 81
82 // CreateAuthToken returns JWT token with encoded username, role, expiration date and issuer claims. 82 // CreateAuthToken returns JWT token with encoded username, role, expiration date and issuer claims.
83 // It returns an error if it fails. 83 // It returns an error if it fails.
84 func CreateAuthToken(username string, roleName string, roleID int64) (TokenClaims, error) { 84 func CreateAuthToken(username string, roleName string, roleID int64) (TokenClaims, error) {
85 t0 := (time.Now()).Unix() 85 t0 := (time.Now()).Unix()
86 t1 := (time.Now().Add(time.Hour * 24 * 7)).Unix() 86 t1 := (time.Now().Add(time.Hour * 24 * 7)).Unix()
87 claims := TokenClaims{ 87 claims := TokenClaims{
88 TokenType: "Bearer", 88 TokenType: "Bearer",
89 Username: username, 89 Username: username,
90 RoleName: roleName, 90 RoleName: roleName,
91 RoleID: roleID, 91 RoleID: roleID,
92 ExpiresIn: t1 - t0, 92 ExpiresIn: t1 - t0,
93 } 93 }
94 // initialize jwt.StandardClaims fields (anonymous struct) 94 // initialize jwt.StandardClaims fields (anonymous struct)
95 claims.IssuedAt = t0 95 claims.IssuedAt = t0
96 claims.ExpiresAt = t1 96 claims.ExpiresAt = t1
97 claims.Issuer = _issuer 97 claims.Issuer = _issuer
98 98
99 jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) 99 jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
100 token, err := jwtToken.SignedString([]byte(_secret)) 100 token, err := jwtToken.SignedString([]byte(_secret))
101 if err != nil { 101 if err != nil {
102 return TokenClaims{}, err 102 return TokenClaims{}, err
103 } 103 }
104 claims.Token = token 104 claims.Token = token
105 return claims, nil 105 return claims, nil
106 } 106 }
107 107
108 // RefreshAuthToken returns new JWT token with same claims contained in tok but with prolonged expiration date. 108 // RefreshAuthToken returns new JWT token with same claims contained in tok but with prolonged expiration date.
109 // It returns an error if it fails. 109 // It returns an error if it fails.
110 func RefreshAuthToken(tok string) (TokenClaims, error) { 110 func RefreshAuthToken(tok string) (TokenClaims, error) {
111 token, err := jwt.ParseWithClaims(tok, &TokenClaims{}, secretFunc) 111 token, err := jwt.ParseWithClaims(tok, &TokenClaims{}, secretFunc)
112 if err != nil { 112 if err != nil {
113 if validation, ok := err.(*jwt.ValidationError); ok { 113 if validation, ok := err.(*jwt.ValidationError); ok {
114 // don't return error if token is expired, just extend it 114 // don't return error if token is expired, just extend it
115 if !(validation.Errors&jwt.ValidationErrorExpired != 0) { 115 if !(validation.Errors&jwt.ValidationErrorExpired != 0) {
116 return TokenClaims{}, err 116 return TokenClaims{}, err
117 } 117 }
118 } else { 118 } else {
119 return TokenClaims{}, err 119 return TokenClaims{}, err
120 } 120 }
121 } 121 }
122 122
123 // type assertion 123 // type assertion
124 claims, ok := token.Claims.(*TokenClaims) 124 claims, ok := token.Claims.(*TokenClaims)
125 if !ok { 125 if !ok {
126 return TokenClaims{}, errors.New("token is not valid") 126 return TokenClaims{}, errors.New("token is not valid")
127 } 127 }
128 128
129 // extend token expiration date 129 // extend token expiration date
130 return CreateAuthToken(claims.Username, claims.RoleName, claims.RoleID) 130 return CreateAuthToken(claims.Username, claims.RoleName, claims.RoleID)
131 } 131 }
132 132
133 // AuthCheck ... 133 // AuthCheck ...
134 func AuthCheck(req *http.Request, roles string) (*TokenClaims, error) { 134 func AuthCheck(req *http.Request, roles string) (*TokenClaims, error) {
135 // validate token and check expiration date 135 // validate token and check expiration date
136 claims, err := GetTokenClaims(req) 136 claims, err := GetTokenClaims(req)
137 if err != nil { 137 if err != nil {
138 return claims, err 138 return claims, err
139 } 139 }
140 140
141 if roles == "" { 141 if roles == "" {
142 return claims, nil 142 return claims, nil
143 } 143 }
144 144
145 // check if token has expired 145 // check if token has expired
146 if claims.ExpiresAt < (time.Now()).Unix() { 146 if claims.ExpiresAt < (time.Now()).Unix() {
147 return claims, errors.New("token has expired") 147 return claims, errors.New("token has expired")
148 } 148 }
149 149
150 if roles == "*" { 150 if roles == "*" {
151 return claims, nil 151 return claims, nil
152 } 152 }
153 153
154 parts := strings.Split(roles, ",") 154 parts := strings.Split(roles, ",")
155 for i := range parts { 155 for i := range parts {
156 r := strings.Trim(parts[i], " ") 156 r := strings.Trim(parts[i], " ")
157 if claims.RoleName == r { 157 if claims.RoleName == r {
158 return claims, nil 158 return claims, nil
159 } 159 }
160 } 160 }
161 161
162 return claims, errors.New("unauthorized role access") 162 return claims, errors.New("unauthorized role access")
163 } 163 }
164 164
165 // GetTokenClaims extracts JWT claims from Authorization header of req. 165 // GetTokenClaims extracts JWT claims from Authorization header of req.
166 // Returns token claims or an error. 166 // Returns token claims or an error.
167 func GetTokenClaims(req *http.Request) (*TokenClaims, error) { 167 func GetTokenClaims(req *http.Request) (*TokenClaims, error) {
168 // check for and strip 'Bearer' prefix 168 // check for and strip 'Bearer' prefix
169 var tokstr string 169 var tokstr string
170 authHead := req.Header.Get("Authorization") 170 authHead := req.Header.Get("Authorization")
171 if ok := strings.HasPrefix(authHead, "Bearer "); ok { 171 if ok := strings.HasPrefix(authHead, "Bearer "); ok {
172 tokstr = strings.TrimPrefix(authHead, "Bearer ") 172 tokstr = strings.TrimPrefix(authHead, "Bearer ")
173 } else { 173 } else {
174 return &TokenClaims{}, errors.New("authorization header in incomplete") 174 return &TokenClaims{}, errors.New("authorization header is incomplete")
175 } 175 }
176 176
177 token, err := jwt.ParseWithClaims(tokstr, &TokenClaims{}, secretFunc) 177 token, err := jwt.ParseWithClaims(tokstr, &TokenClaims{}, secretFunc)
178 if err != nil { 178 if err != nil {
179 return &TokenClaims{}, err 179 return &TokenClaims{}, err
180 } 180 }
181 181
182 // type assertion 182 // type assertion
183 claims, ok := token.Claims.(*TokenClaims) 183 claims, ok := token.Claims.(*TokenClaims)
184 if !ok || !token.Valid { 184 if !ok || !token.Valid {
185 return &TokenClaims{}, errors.New("token is not valid") 185 return &TokenClaims{}, errors.New("token is not valid")
186 } 186 }
187 187
188 return claims, nil 188 return claims, nil
189 } 189 }
190 190
191 // randomSalt returns a string of 32 random characters. 191 // randomSalt returns a string of 32 random characters.
192 func randomSalt() (s string, err error) { 192 func randomSalt() (s string, err error) {
193 const saltSize = 32 193 const saltSize = 32
194 194
195 rawsalt := make([]byte, saltSize) 195 rawsalt := make([]byte, saltSize)
196 196
197 _, err = rand.Read(rawsalt) 197 _, err = rand.Read(rawsalt)
198 if err != nil { 198 if err != nil {
199 return "", err 199 return "", err
200 } 200 }
201 201
202 s = hex.EncodeToString(rawsalt) 202 s = hex.EncodeToString(rawsalt)
203 return s, nil 203 return s, nil
204 } 204 }
205 205
206 // secretFunc returns byte slice of API secret keyword. 206 // secretFunc returns byte slice of API secret keyword.
207 func secretFunc(token *jwt.Token) (interface{}, error) { 207 func secretFunc(token *jwt.Token) (interface{}, error) {
208 return []byte(_secret), nil 208 return []byte(_secret), nil
209 } 209 }
210 210
1 package webutility 1 package webutility
2 2
3 import ( 3 import (
4 "encoding/json" 4 "encoding/json"
5 "fmt" 5 "fmt"
6 "net/http" 6 "net/http"
7 ) 7 )
8 8
9 // StatusRecorder ... 9 // StatusRecorder ...
10 type StatusRecorder struct { 10 type StatusRecorder struct {
11 writer http.ResponseWriter 11 writer http.ResponseWriter
12 status int 12 status int
13 size int 13 size int
14 } 14 }
15 15
16 // NewStatusRecorder ... 16 // NewStatusRecorder ...
17 func NewStatusRecorder(w http.ResponseWriter) *StatusRecorder { 17 func NewStatusRecorder(w http.ResponseWriter) *StatusRecorder {
18 return &StatusRecorder{ 18 return &StatusRecorder{
19 writer: w, 19 writer: w,
20 status: 0, 20 status: 0,
21 size: 0, 21 size: 0,
22 } 22 }
23 } 23 }
24 24
25 // WriteHeader is a wrapper http.ResponseWriter interface 25 // WriteHeader is a wrapper http.ResponseWriter interface
26 func (r *StatusRecorder) WriteHeader(code int) { 26 func (r *StatusRecorder) WriteHeader(code int) {
27 r.status = code 27 r.status = code
28 r.writer.WriteHeader(code) 28 r.writer.WriteHeader(code)
29 } 29 }
30 30
31 // Write is a wrapper for http.ResponseWriter interface 31 // Write is a wrapper for http.ResponseWriter interface
32 func (r *StatusRecorder) Write(in []byte) (int, error) { 32 func (r *StatusRecorder) Write(in []byte) (int, error) {
33 r.size = len(in) 33 r.size = len(in)
34 return r.writer.Write(in) 34 return r.writer.Write(in)
35 } 35 }
36 36
37 // Header is a wrapper for http.ResponseWriter interface 37 // Header is a wrapper for http.ResponseWriter interface
38 func (r *StatusRecorder) Header() http.Header { 38 func (r *StatusRecorder) Header() http.Header {
39 return r.writer.Header() 39 return r.writer.Header()
40 } 40 }
41 41
42 // Status ... 42 // Status ...
43 func (r *StatusRecorder) Status() int { 43 func (r *StatusRecorder) Status() int {
44 return r.status 44 return r.status
45 } 45 }
46 46
47 // Size ... 47 // Size ...
48 func (r *StatusRecorder) Size() int { 48 func (r *StatusRecorder) Size() int {
49 return r.size 49 return r.size
50 } 50 }
51 51
52 // NotFoundHandlerFunc writes HTTP error 404 to w. 52 // NotFoundHandlerFunc writes HTTP error 404 to w.
53 func NotFoundHandlerFunc(w http.ResponseWriter, req *http.Request) { 53 func NotFoundHandlerFunc(w http.ResponseWriter, req *http.Request) {
54 SetDefaultHeaders(w) 54 SetAccessControlHeaders(w)
55 if req.Method == "OPTIONS" { 55 SetContentType(w, "application/json")
56 return
57 }
58 NotFound(w, req, fmt.Sprintf("Resource you requested was not found: %s", req.URL.String())) 56 NotFound(w, req, fmt.Sprintf("Resource you requested was not found: %s", req.URL.String()))
59 } 57 }
60 58
61 // SetContentType ... 59 // SetContentType must be called before SetResponseStatus (w.WriteHeader) (?)
62 func SetContentType(w http.ResponseWriter, ctype string) { 60 func SetContentType(w http.ResponseWriter, ctype string) {
63 w.Header().Set("Content-Type", ctype) 61 w.Header().Set("Content-Type", ctype)
64 } 62 }
65 63
66 // SetResponseStatus ... 64 // SetResponseStatus ...
67 func SetResponseStatus(w http.ResponseWriter, status int) { 65 func SetResponseStatus(w http.ResponseWriter, status int) {
68 w.WriteHeader(status) 66 w.WriteHeader(status)
69 } 67 }
70 68
71 // WriteResponse ... 69 // WriteResponse ...
72 func WriteResponse(w http.ResponseWriter, content []byte) { 70 func WriteResponse(w http.ResponseWriter, content []byte) {
73 w.Write(content) 71 w.Write(content)
74 } 72 }
75 73
76 // SetDefaultHeaders set's default headers for an HTTP response. 74 // SetAccessControlHeaders set's default headers for an HTTP response.
77 func SetDefaultHeaders(w http.ResponseWriter) { 75 func SetAccessControlHeaders(w http.ResponseWriter) {
78 w.Header().Set("Access-Control-Allow-Origin", "*") 76 w.Header().Set("Access-Control-Allow-Origin", "*")
79 w.Header().Set("Access-Control-Allow-Methods", "POST, GET, PUT, DELETE, OPTIONS") 77 w.Header().Set("Access-Control-Allow-Methods", "POST, GET, PUT, DELETE, OPTIONS")
80 w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization") 78 w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")
81 SetContentType(w, "application/json; charset=utf-8")
82 } 79 }
83 80
84 // GetLocale ... 81 // GetLocale ...
85 func GetLocale(req *http.Request, dflt string) string { 82 func GetLocale(req *http.Request, dflt string) string {
86 loc := req.FormValue("locale") 83 loc := req.FormValue("locale")
87 if loc == "" { 84 if loc == "" {
88 return dflt 85 return dflt
89 } 86 }
90 return loc 87 return loc
91 } 88 }
92 89
93 // Success ... 90 // Success ...
94 func Success(w http.ResponseWriter, payload interface{}, code int) { 91 func Success(w http.ResponseWriter, payload interface{}, code int) {
95 w.WriteHeader(code) 92 w.WriteHeader(code)
96 if payload != nil { 93 if payload != nil {
97 json.NewEncoder(w).Encode(payload) 94 json.NewEncoder(w).Encode(payload)
98 } 95 }
99 } 96 }
100 97
101 // OK ... 98 // OK ...
102 func OK(w http.ResponseWriter, payload interface{}) { 99 func OK(w http.ResponseWriter, payload interface{}) {
100 SetContentType(w, "application/json")
103 Success(w, payload, http.StatusOK) 101 Success(w, payload, http.StatusOK)
104 } 102 }
105 103
106 // Created ... 104 // Created ...
107 func Created(w http.ResponseWriter, payload interface{}) { 105 func Created(w http.ResponseWriter, payload interface{}) {
106 SetContentType(w, "application/json")
108 Success(w, payload, http.StatusCreated) 107 Success(w, payload, http.StatusCreated)
109 } 108 }
110 109
111 type weberror struct { 110 type weberror struct {
112 Request string `json:"request"` 111 Request string `json:"request"`
113 Error string `json:"error"` 112 Error string `json:"error"`
114 } 113 }
115 114
116 // Error ... 115 // Error ...
117 func Error(w http.ResponseWriter, r *http.Request, code int, err string) { 116 func Error(w http.ResponseWriter, r *http.Request, code int, err string) {
118 werr := weberror{Error: err, Request: r.Method + " " + r.RequestURI} 117 werr := weberror{Error: err, Request: r.Method + " " + r.RequestURI}
119 w.WriteHeader(code) 118 w.WriteHeader(code)
120 json.NewEncoder(w).Encode(werr) 119 json.NewEncoder(w).Encode(werr)
121 } 120 }
122 121
123 // BadRequest ... 122 // BadRequest ...
124 func BadRequest(w http.ResponseWriter, r *http.Request, err string) { 123 func BadRequest(w http.ResponseWriter, r *http.Request, err string) {
124 SetContentType(w, "application/json")
125 Error(w, r, http.StatusBadRequest, err) 125 Error(w, r, http.StatusBadRequest, err)
126 } 126 }
127 127
128 // Unauthorized ... 128 // Unauthorized ...
129 func Unauthorized(w http.ResponseWriter, r *http.Request, err string) { 129 func Unauthorized(w http.ResponseWriter, r *http.Request, err string) {
130 SetContentType(w, "application/json")
130 Error(w, r, http.StatusUnauthorized, err) 131 Error(w, r, http.StatusUnauthorized, err)
131 } 132 }
132 133
133 // Forbidden ... 134 // Forbidden ...
134 func Forbidden(w http.ResponseWriter, r *http.Request, err string) { 135 func Forbidden(w http.ResponseWriter, r *http.Request, err string) {
136 SetContentType(w, "application/json")
135 Error(w, r, http.StatusForbidden, err) 137 Error(w, r, http.StatusForbidden, err)
136 } 138 }
137 139
138 // NotFound ... 140 // NotFound ...
139 func NotFound(w http.ResponseWriter, r *http.Request, err string) { 141 func NotFound(w http.ResponseWriter, r *http.Request, err string) {
142 SetContentType(w, "application/json")
140 Error(w, r, http.StatusNotFound, err) 143 Error(w, r, http.StatusNotFound, err)
141 } 144 }
142 145
143 // Conflict ... 146 // Conflict ...
144 func Conflict(w http.ResponseWriter, r *http.Request, err string) { 147 func Conflict(w http.ResponseWriter, r *http.Request, err string) {
148 SetContentType(w, "application/json")
145 Error(w, r, http.StatusConflict, err) 149 Error(w, r, http.StatusConflict, err)
146 } 150 }
147 151
148 // InternalServerError ... 152 // InternalServerError ...
149 func InternalServerError(w http.ResponseWriter, r *http.Request, err string) { 153 func InternalServerError(w http.ResponseWriter, r *http.Request, err string) {
154 SetContentType(w, "application/json")
middleware/main.go
1 package middleware 1 package middleware
2 2
3 import ( 3 import (
4 "net/http" 4 "net/http"
5 ) 5 )
6 6
7 func Headers(h http.HandlerFunc) http.HandlerFunc { 7 func Headers(h http.HandlerFunc) http.HandlerFunc {
8 return IgnoreOptionsRequests(ParseForm(h)) 8 return SetAccessControlHeaders(IgnoreOptionsRequests(ParseForm(h)))
9 } 9 }
10 10
11 func AuthOnly(roles string, h http.HandlerFunc) http.HandlerFunc { 11 func AuthUser(roles string, h http.HandlerFunc) http.HandlerFunc {
12 return IgnoreOptionsRequests(ParseForm(Auth(roles, h))) 12 return SetAccessControlHeaders(IgnoreOptionsRequests(ParseForm(Auth(roles, h))))
13 } 13 }
14 14
15 func Full(roles string, h http.HandlerFunc) http.HandlerFunc { 15 func AuthUserLogTraffic(roles string, h http.HandlerFunc) http.HandlerFunc {
16 return IgnoreOptionsRequests(ParseForm(LogTraffic(Auth(roles, h)))) 16 return SetAccessControlHeaders(IgnoreOptionsRequests(ParseForm(LogHTTP(Auth(roles, h)))))
17 } 17 }
18 18
19 func LogTraffic(h http.HandlerFunc) http.HandlerFunc { 19 func LogTraffic(h http.HandlerFunc) http.HandlerFunc {
20 return IgnoreOptionsRequests(ParseForm(LogRequestAndResponse(h))) 20 return SetAccessControlHeaders(IgnoreOptionsRequests(ParseForm(LogHTTP(h))))
21 } 21 }
22 22
middleware/middleware.go
1 package middleware 1 package middleware
2 2
3 import ( 3 import (
4 "net/http" 4 "net/http"
5 "time" 5 "time"
6 6
7 "git.to-net.rs/marko.tikvic/gologger" 7 "git.to-net.rs/marko.tikvic/gologger"
8 8
9 web "git.to-net.rs/marko.tikvic/webutility" 9 web "git.to-net.rs/marko.tikvic/webutility"
10 ) 10 )
11 11
12 var httpLogger *gologger.Logger 12 var httpLogger *gologger.Logger
13 13
14 func SetAccessControlHeaders(h http.HandlerFunc) http.HandlerFunc {
15 return func(w http.ResponseWriter, req *http.Request) {
16 web.SetAccessControlHeaders(w)
17
18 h(w, req)
19 }
20 }
21
14 // IgnoreOptionsRequests ... 22 // IgnoreOptionsRequests ...
15 func IgnoreOptionsRequests(h http.HandlerFunc) http.HandlerFunc { 23 func IgnoreOptionsRequests(h http.HandlerFunc) http.HandlerFunc {
16 return func(w http.ResponseWriter, req *http.Request) { 24 return func(w http.ResponseWriter, req *http.Request) {
17 web.SetDefaultHeaders(w)
18 if req.Method == http.MethodOptions { 25 if req.Method == http.MethodOptions {
19 return 26 return
20 } 27 }
28
21 h(w, req) 29 h(w, req)
22 } 30 }
23 } 31 }
24 32
25 // ParseForm ... 33 // ParseForm ...
26 func ParseForm(h http.HandlerFunc) http.HandlerFunc { 34 func ParseForm(h http.HandlerFunc) http.HandlerFunc {
27 return func(w http.ResponseWriter, req *http.Request) { 35 return func(w http.ResponseWriter, req *http.Request) {
28 err := req.ParseForm() 36 err := req.ParseForm()
29 if err != nil { 37 if err != nil {
30 web.BadRequest(w, req, err.Error()) 38 web.BadRequest(w, req, err.Error())
31 return 39 return
32 } 40 }
41
33 h(w, req) 42 h(w, req)
34 } 43 }
35 } 44 }
36 45
37 // ParseMultipartForm ... 46 // ParseMultipartForm ...
38 func ParseMultipartForm(h http.HandlerFunc) http.HandlerFunc { 47 func ParseMultipartForm(h http.HandlerFunc) http.HandlerFunc {
39 return func(w http.ResponseWriter, req *http.Request) { 48 return func(w http.ResponseWriter, req *http.Request) {
40 err := req.ParseMultipartForm(32 << 20) 49 err := req.ParseMultipartForm(32 << 20)
41 if err != nil { 50 if err != nil {
42 web.BadRequest(w, req, err.Error()) 51 web.BadRequest(w, req, err.Error())
43 return 52 return
44 } 53 }
54
45 h(w, req) 55 h(w, req)
46 } 56 }
47 } 57 }
48 58
49 // SetLogger ... 59 // SetLogger ...
50 func SetLogger(logger *gologger.Logger) { 60 func SetLogger(logger *gologger.Logger) {
51 httpLogger = logger 61 httpLogger = logger
52 } 62 }
53 63
54 // LogRequestAndResponse ... 64 // LogHTTP ...
55 func LogRequestAndResponse(h http.HandlerFunc) http.HandlerFunc { 65 func LogHTTP(h http.HandlerFunc) http.HandlerFunc {
56 return func(w http.ResponseWriter, req *http.Request) { 66 return func(w http.ResponseWriter, req *http.Request) {
57 if httpLogger != nil { 67 if httpLogger == nil {
58 t1 := time.Now() 68 h(w, req)
69 return
70 }
59 71
60 claims, _ := web.GetTokenClaims(req) 72 t1 := time.Now()
61 in := httpLogger.LogHTTPRequest(req, claims.Username)
62 73
63 rec := web.NewStatusRecorder(w) 74 claims, _ := web.GetTokenClaims(req)
75 in := httpLogger.LogHTTPRequest(req, claims.Username)
64 76
65 h(rec, req) 77 rec := web.NewStatusRecorder(w)
66 78
67 t2 := time.Now() 79 h(rec, req)
68 out := httpLogger.LogHTTPResponse(rec.Status(), t2.Sub(t1), rec.Size())
69 80
70 httpLogger.CombineHTTPLogs(in, out) 81 t2 := time.Now()
71 } else { 82 out := httpLogger.LogHTTPResponse(rec.Status(), t2.Sub(t1), rec.Size())
72 h(w, req) 83
73 } 84 httpLogger.CombineHTTPLogs(in, out)
74 } 85 }
75 } 86 }
76 87
77 // Auth ... 88 // Auth ...
78 func Auth(roles string, h http.HandlerFunc) http.HandlerFunc { 89 func Auth(roles string, h http.HandlerFunc) http.HandlerFunc {
79 return func(w http.ResponseWriter, req *http.Request) { 90 return func(w http.ResponseWriter, req *http.Request) {
80 if _, err := web.AuthCheck(req, roles); err != nil { 91 if _, err := web.AuthCheck(req, roles); err != nil {
81 web.Unauthorized(w, req, err.Error()) 92 web.Unauthorized(w, req, err.Error())
82 return 93 return
83 } 94 }
95
84 h(w, req) 96 h(w, req)