diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..a5abefe --- /dev/null +++ b/auth.go @@ -0,0 +1,246 @@ +package webutility + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "errors" + "net/http" + "strings" + "time" + + "github.com/dgrijalva/jwt-go" +) + +const OneDay = time.Hour * 24 +const OneWeek = OneDay * 7 +const saltSize = 32 + +var appName = "webutility" +var secret = "webutility" + +type Role struct { + Name string `json:"name"` + ID int64 `json:"id"` +} + +// TokenClaims are JWT token claims. +type TokenClaims struct { + // extending a struct + jwt.StandardClaims + + // custom claims + Token string `json:"access_token"` + TokenType string `json:"token_type"` + Username string `json:"username"` + Role string `json:"role"` + RoleID int64 `json:"role_id"` + ExpiresIn int64 `json:"expires_in"` +} + +func InitJWT(appName, secret string) { + appName = appName + secret = secret +} + +// ValidateCredentials hashes pass and salt and returns comparison result with resultHash +func ValidateCredentials(pass, salt, resultHash string) (bool, error) { + hash, _, err := CreateHash(pass, salt) + if err != nil { + return false, err + } + res := hash == resultHash + return res, nil +} + +// CreateHash hashes str using SHA256. +// If the presalt parameter is not provided CreateHash will generate new salt string. +// Returns hash and salt strings or an error if it fails. +func CreateHash(str, presalt string) (hash, salt string, err error) { + // chech if message is presalted + if presalt == "" { + salt, err = randomSalt() + if err != nil { + return "", "", err + } + } else { + salt = presalt + } + + // convert strings to raw byte slices + rawstr := []byte(str) + rawsalt, err := hex.DecodeString(salt) + if err != nil { + return "", "", err + } + + rawdata := make([]byte, len(rawstr)+len(rawsalt)) + rawdata = append(rawdata, rawstr...) + rawdata = append(rawdata, rawsalt...) + + // hash message + salt + hasher := sha256.New() + hasher.Write(rawdata) + rawhash := hasher.Sum(nil) + + hash = hex.EncodeToString(rawhash) + return hash, salt, nil +} + +// CreateAuthToken returns JWT token with encoded username, role, expiration date and issuer claims. +// It returns an error if it fails. +func CreateAuthToken(username string, role Role) (TokenClaims, error) { + t0 := (time.Now()).Unix() + t1 := (time.Now().Add(OneWeek)).Unix() + claims := TokenClaims{ + TokenType: "Bearer", + Username: username, + Role: role.Name, + RoleID: role.ID, + ExpiresIn: t1 - t0, + } + // initialize jwt.StandardClaims fields (anonymous struct) + claims.IssuedAt = t0 + claims.ExpiresAt = t1 + claims.Issuer = appName + + jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + token, err := jwtToken.SignedString([]byte(secret)) + if err != nil { + return TokenClaims{}, err + } + claims.Token = token + return claims, nil +} + +// RefreshAuthToken returns new JWT token with sprolongs JWT token's expiration date for one week. +// It returns new JWT token or an error if it fails. +func RefreshAuthToken(tok string) (TokenClaims, error) { + token, err := jwt.ParseWithClaims(tok, &TokenClaims{}, secretFunc) + if err != nil { + if validation, ok := err.(*jwt.ValidationError); ok { + // don't return error if token is expired + // just extend it + if !(validation.Errors&jwt.ValidationErrorExpired != 0) { + return TokenClaims{}, err + } + } else { + return TokenClaims{}, err + } + } + + // type assertion + claims, ok := token.Claims.(*TokenClaims) + if !ok { + return TokenClaims{}, errors.New("token is not valid") + } + + // extend token expiration date + return CreateAuthToken(claims.Username, Role{claims.Role, claims.RoleID}) +} + +// RbacCheck returns true if user that made HTTP request is authorized to +// access the resource it is targeting. +// It exctracts user's role from the JWT token located in Authorization header of +// http.Request and then compares it with the list of supplied roles and returns +// true if there's a match, if "*" is provided or if the authRoles is nil. +// Otherwise it returns false. +func RbacCheck(req *http.Request, authRoles []string) bool { + if authRoles == nil { + return true + } + + // validate token and check expiration date + claims, err := GetTokenClaims(req) + if err != nil { + return false + } + // check if token has expired + if claims.ExpiresAt < (time.Now()).Unix() { + return false + } + + // check if role extracted from token matches + // any of the provided (allowed) ones + for _, r := range authRoles { + if claims.Role == r || r == "*" { + return true + } + } + + return false +} + +// AuthCheck returns token claims and boolean value based on user's rights to access resource specified in req. +// It exctracts user's role from the JWT token located in Authorization header of +// HTTP request and then compares it with the list of supplied (authorized); +// it returns true if there's a match, if "*" is provided or if the authRoles is nil. +func AuthCheck(req *http.Request, authRoles []string) (*TokenClaims, bool) { + if authRoles == nil { + return nil, true + } + + // validate token and check expiration date + claims, err := GetTokenClaims(req) + if err != nil { + return claims, false + } + // check if token has expired + if claims.ExpiresAt < (time.Now()).Unix() { + return claims, false + } + + // check if role extracted from token matches + // any of the provided (allowed) ones + for _, r := range authRoles { + if claims.Role == r || r == "*" { + return claims, true + } + } + + return claims, false +} + +// GetTokenClaims extracts JWT claims from Authorization header of the request. +// Returns token claims or an error. +func GetTokenClaims(req *http.Request) (*TokenClaims, error) { + // check for and strip 'Bearer' prefix + var tokstr string + authHead := req.Header.Get("Authorization") + if ok := strings.HasPrefix(authHead, "Bearer "); ok { + tokstr = strings.TrimPrefix(authHead, "Bearer ") + } else { + return &TokenClaims{}, errors.New("authorization header in incomplete") + } + + token, err := jwt.ParseWithClaims(tokstr, &TokenClaims{}, secretFunc) + if err != nil { + return &TokenClaims{}, err + } + + // type assertion + claims, ok := token.Claims.(*TokenClaims) + if !ok || !token.Valid { + return &TokenClaims{}, errors.New("token is not valid") + } + + return claims, nil +} + +// randomSalt returns a string of random characters of 'saltSize' length. +func randomSalt() (s string, err error) { + rawsalt := make([]byte, saltSize) + + _, err = rand.Read(rawsalt) + if err != nil { + return "", err + } + + s = hex.EncodeToString(rawsalt) + return s, nil +} + +// secretFunc returns byte slice of API secret keyword. +func secretFunc(token *jwt.Token) (interface{}, error) { + return []byte(secret), nil +} diff --git a/auth_utility.go b/auth_utility.go deleted file mode 100644 index 387392d..0000000 --- a/auth_utility.go +++ /dev/null @@ -1,246 +0,0 @@ -package webutility - -import ( - "crypto/rand" - "crypto/sha256" - "encoding/hex" - "errors" - "net/http" - "strings" - "time" - - "github.com/dgrijalva/jwt-go" -) - -const OneDay = time.Hour * 24 -const OneWeek = OneDay * 7 -const saltSize = 32 - -var appName = "webutility" -var secret = "webutility" - -type Role struct { - Name string `json:"name"` - ID int64 `json:"id"` -} - -// TokenClaims are JWT token claims. -type TokenClaims struct { - // extending a struct - jwt.StandardClaims - - // custom claims - Token string `json:"access_token"` - TokenType string `json:"token_type"` - Username string `json:"username"` - Role string `json:"role"` - RoleID int64 `json:"role_id"` - ExpiresIn int64 `json:"expires_in"` -} - -func InitJWT(appName, secret string) { - appName = appName - secret = secret -} - -// ValidateCredentials hashes pass and salt and returns comparison result with resultHash -func ValidateCredentials(pass, salt, resultHash string) (bool, error) { - hash, _, err := CreateHash(pass, salt) - if err != nil { - return false, err - } - res := hash == resultHash - return res, nil -} - -// CreateHash hashes str using SHA256. -// If the presalt parameter is not provided CreateHash will generate new salt string. -// Returns hash and salt strings or an error if it fails. -func CreateHash(str, presalt string) (hash, salt string, err error) { - // chech if message is presalted - if presalt == "" { - salt, err = randomSalt() - if err != nil { - return "", "", err - } - } else { - salt = presalt - } - - // convert strings to raw byte slices - rawstr := []byte(str) - rawsalt, err := hex.DecodeString(salt) - if err != nil { - return "", "", err - } - - rawdata := make([]byte, len(rawstr)+len(rawsalt)) - rawdata = append(rawdata, rawstr...) - rawdata = append(rawdata, rawsalt...) - - // hash message + salt - hasher := sha256.New() - hasher.Write(rawdata) - rawhash := hasher.Sum(nil) - - hash = hex.EncodeToString(rawhash) - return hash, salt, nil -} - -// CreateAuthToken returns JWT token with encoded username, role, expiration date and issuer claims. -// It returns an error if it fails. -func CreateAuthToken(username string, role Role) (TokenClaims, error) { - t0 := (time.Now()).Unix() - t1 := (time.Now().Add(OneWeek)).Unix() - claims := TokenClaims{ - TokenType: "Bearer", - Username: username, - Role: role.Name, - RoleID: role.ID, - ExpiresIn: t1 - t0, - } - // initialize jwt.StandardClaims fields (anonymous struct) - claims.IssuedAt = t0 - claims.ExpiresAt = t1 - claims.Issuer = appName - - jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - token, err := jwtToken.SignedString([]byte(secret)) - if err != nil { - return TokenClaims{}, err - } - claims.Token = token - return claims, nil -} - -// RefreshAuthToken returns new JWT token with sprolongs JWT token's expiration date for one week. -// It returns new JWT token or an error if it fails. -func RefreshAuthToken(tok string) (TokenClaims, error) { - token, err := jwt.ParseWithClaims(tok, &TokenClaims{}, secretFunc) - if err != nil { - if validation, ok := err.(*jwt.ValidationError); ok { - // don't return error if token is expired - // just extend it - if !(validation.Errors&jwt.ValidationErrorExpired != 0) { - return TokenClaims{}, err - } - } else { - return TokenClaims{}, err - } - } - - // type assertion - claims, ok := token.Claims.(*TokenClaims) - if !ok { - return TokenClaims{}, errors.New("token is not valid") - } - - // extend token expiration date - return CreateAuthToken(claims.Username, Role{claims.Role, claims.RoleID}) -} - -// RbacCheck returns true if user that made HTTP request is authorized to -// access the resource it is targeting. -// It exctracts user's role from the JWT token located in Authorization header of -// http.Request and then compares it with the list of supplied roles and returns -// true if there's a match, if "*" is provided or if the authRoles is nil. -// Otherwise it returns false. -func RbacCheck(req *http.Request, authRoles []string) bool { - if authRoles == nil { - return true - } - - // validate token and check expiration date - claims, err := GetTokenClaims(req) - if err != nil { - return false - } - // check if token has expired - if claims.ExpiresAt < (time.Now()).Unix() { - return false - } - - // check if role extracted from token matches - // any of the provided (allowed) ones - for _, r := range authRoles { - if claims.Role == r || r == "*" { - return true - } - } - - return false -} - -// ProcessRBAC returns token claims and boolean value based on user's rights to access resource specified in req. -// It exctracts user's role from the JWT token located in Authorization header of -// HTTP request and then compares it with the list of supplied (authorized); -// it returns true if there's a match, if "*" is provided or if the authRoles is nil. -func ProcessRBAC(req *http.Request, authRoles []string) (*TokenClaims, bool) { - if authRoles == nil { - return nil, true - } - - // validate token and check expiration date - claims, err := GetTokenClaims(req) - if err != nil { - return claims, false - } - // check if token has expired - if claims.ExpiresAt < (time.Now()).Unix() { - return claims, false - } - - // check if role extracted from token matches - // any of the provided (allowed) ones - for _, r := range authRoles { - if claims.Role == r || r == "*" { - return claims, true - } - } - - return claims, false -} - -// GetTokenClaims extracts JWT claims from Authorization header of the request. -// Returns token claims or an error. -func GetTokenClaims(req *http.Request) (*TokenClaims, error) { - // check for and strip 'Bearer' prefix - var tokstr string - authHead := req.Header.Get("Authorization") - if ok := strings.HasPrefix(authHead, "Bearer "); ok { - tokstr = strings.TrimPrefix(authHead, "Bearer ") - } else { - return &TokenClaims{}, errors.New("authorization header in incomplete") - } - - token, err := jwt.ParseWithClaims(tokstr, &TokenClaims{}, secretFunc) - if err != nil { - return &TokenClaims{}, err - } - - // type assertion - claims, ok := token.Claims.(*TokenClaims) - if !ok || !token.Valid { - return &TokenClaims{}, errors.New("token is not valid") - } - - return claims, nil -} - -// randomSalt returns a string of random characters of 'saltSize' length. -func randomSalt() (s string, err error) { - rawsalt := make([]byte, saltSize) - - _, err = rand.Read(rawsalt) - if err != nil { - return "", err - } - - s = hex.EncodeToString(rawsalt) - return s, nil -} - -// secretFunc returns byte slice of API secret keyword. -func secretFunc(token *jwt.Token) (interface{}, error) { - return []byte(secret), nil -} diff --git a/format.go b/format.go new file mode 100644 index 0000000..b40dbf3 --- /dev/null +++ b/format.go @@ -0,0 +1,50 @@ +package webutility + +import ( + "fmt" + "time" +) + +// UnixToDate converts given Unix time to local time in format and returns result: +// YYYY-MM-DD hh:mm:ss +zzzz UTC +func UnixToDate(unix int64) time.Time { + return time.Unix(unix, 0) +} + +// DateToUnix converts given date in Unix timestamp. +func DateToUnix(date interface{}) int64 { + if date != nil { + t, ok := date.(time.Time) + if !ok { + return 0 + } + return t.Unix() + + } + return 0 +} + +// EqualQuotes encapsulates given string in SQL 'equal' statement and returns result. +// Example: "hello" -> " = 'hello'" +func EqualQuotes(stmt string) string { + if stmt != "" { + stmt = fmt.Sprintf(" = '%s'", stmt) + } + return stmt +} + +func EqualString(stmt string) string { + if stmt != "" { + stmt = fmt.Sprintf(" = %s", stmt) + } + return stmt +} + +// LikeQuotes encapsulates given string in SQL 'like' statement and returns result. +// Example: "hello" -> " LIKE UPPER('%hello%')" +func LikeQuotes(stmt string) string { + if stmt != "" { + stmt = fmt.Sprintf(" LIKE UPPER('%s%s%s')", "%", stmt, "%") + } + return stmt +} diff --git a/format_utility.go b/format_utility.go deleted file mode 100644 index b40dbf3..0000000 --- a/format_utility.go +++ /dev/null @@ -1,50 +0,0 @@ -package webutility - -import ( - "fmt" - "time" -) - -// UnixToDate converts given Unix time to local time in format and returns result: -// YYYY-MM-DD hh:mm:ss +zzzz UTC -func UnixToDate(unix int64) time.Time { - return time.Unix(unix, 0) -} - -// DateToUnix converts given date in Unix timestamp. -func DateToUnix(date interface{}) int64 { - if date != nil { - t, ok := date.(time.Time) - if !ok { - return 0 - } - return t.Unix() - - } - return 0 -} - -// EqualQuotes encapsulates given string in SQL 'equal' statement and returns result. -// Example: "hello" -> " = 'hello'" -func EqualQuotes(stmt string) string { - if stmt != "" { - stmt = fmt.Sprintf(" = '%s'", stmt) - } - return stmt -} - -func EqualString(stmt string) string { - if stmt != "" { - stmt = fmt.Sprintf(" = %s", stmt) - } - return stmt -} - -// LikeQuotes encapsulates given string in SQL 'like' statement and returns result. -// Example: "hello" -> " LIKE UPPER('%hello%')" -func LikeQuotes(stmt string) string { - if stmt != "" { - stmt = fmt.Sprintf(" LIKE UPPER('%s%s%s')", "%", stmt, "%") - } - return stmt -} diff --git a/http.go b/http.go new file mode 100644 index 0000000..07a79ab --- /dev/null +++ b/http.go @@ -0,0 +1,165 @@ +package webutility + +import ( + "encoding/json" + "net/http" +) + +type webError struct { + Request string `json:"request"` + Error string `json:"error"` +} + +// NotFoundHandler writes HTTP error 404 to w. +func NotFoundHandler(w http.ResponseWriter, req *http.Request) { + SetDefaultHeaders(w) + if req.Method == "OPTIONS" { + return + } + NotFound(w, req, "Not found") +} + +// SetDefaultHeaders set's default headers for an HTTP response. +func SetDefaultHeaders(w http.ResponseWriter) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "POST, GET, PUT, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", `Accept, Content-Type, + Content-Length, Accept-Encoding, X-CSRF-Token, Authorization`) + w.Header().Set("Content-Type", "application/json; charset=utf-8") +} + +func ReqLocale(req *http.Request, dflt string) string { + loc := req.FormValue("locale") + if loc == "" { + return dflt + } + return loc +} + +// 2xx +func Success(w http.ResponseWriter, payload *Payload, code int) { + w.WriteHeader(code) + if payload != nil { + json.NewEncoder(w).Encode(*payload) + } +} + +// 200 +func OK(w http.ResponseWriter, payload *Payload) { + Success(w, payload, http.StatusOK) +} + +// 201 +func Created(w http.ResponseWriter, payload *Payload) { + Success(w, payload, http.StatusCreated) +} + +// 4xx; 5xx +func Error(w http.ResponseWriter, r *http.Request, code int, err string) { + werr := webError{Error: err, Request: r.Method + " " + r.RequestURI} + w.WriteHeader(code) + json.NewEncoder(w).Encode(werr) +} + +// 400 +func BadRequest(w http.ResponseWriter, r *http.Request, err string) { + Error(w, r, http.StatusBadRequest, err) +} + +// 404 +func NotFound(w http.ResponseWriter, r *http.Request, err string) { + Error(w, r, http.StatusNotFound, err) +} + +// 401 +func Unauthorized(w http.ResponseWriter, r *http.Request, err string) { + Error(w, r, http.StatusUnauthorized, err) +} + +// 403 +func Forbidden(w http.ResponseWriter, r *http.Request, err string) { + Error(w, r, http.StatusForbidden, err) +} + +// 403 +func Conflict(w http.ResponseWriter, r *http.Request, err string) { + Error(w, r, http.StatusConflict, err) +} + +// 500 +func InternalServerError(w http.ResponseWriter, r *http.Request, err string) { + Error(w, r, http.StatusInternalServerError, err) +} + +/// +/// Old API +/// + +const ( + templateHttpErr500_EN = "An internal server error has occurred." + templateHttpErr500_RS = "Došlo je do greške na serveru." + templateHttpErr400_EN = "Bad request." + templateHttpErr400_RS = "Neispravan zahtev." + templateHttpErr404_EN = "Resource not found." + templateHttpErr404_RS = "Resurs nije pronadjen." + templateHttpErr401_EN = "Unauthorized request." + templateHttpErr401_RS = "Neautorizovan zahtev." +) + +type httpError struct { + Error []HttpErrorDesc `json:"error"` + Request string `json:"request"` +} + +type HttpErrorDesc struct { + Lang string `json:"lang"` + Desc string `json:"description"` +} + +// DeliverPayload encodes payload as JSON to w. +func DeliverPayload(w http.ResponseWriter, payload Payload) { + // Don't write status OK in the headers here. Leave it up for the caller. + // E.g. Status 201. + json.NewEncoder(w).Encode(payload) + payload.Data = nil +} + +// ErrorResponse writes HTTP error to w. +func ErrorResponse(w http.ResponseWriter, r *http.Request, code int, desc []HttpErrorDesc) { + err := httpError{desc, r.Method + " " + r.RequestURI} + w.WriteHeader(code) + json.NewEncoder(w).Encode(err) +} + +// NotFoundResponse writes HTTP error 404 to w. +func NotFoundResponse(w http.ResponseWriter, req *http.Request) { + ErrorResponse(w, req, http.StatusNotFound, []HttpErrorDesc{ + {"en", templateHttpErr404_EN}, + {"rs", templateHttpErr404_RS}, + }) +} + +// BadRequestResponse writes HTTP error 400 to w. +func BadRequestResponse(w http.ResponseWriter, req *http.Request) { + ErrorResponse(w, req, http.StatusBadRequest, []HttpErrorDesc{ + {"en", templateHttpErr400_EN}, + {"rs", templateHttpErr400_RS}, + }) +} + +// InternalSeverErrorResponse writes HTTP error 500 to w. +func InternalServerErrorResponse(w http.ResponseWriter, req *http.Request) { + ErrorResponse(w, req, http.StatusInternalServerError, []HttpErrorDesc{ + {"en", templateHttpErr500_EN}, + {"rs", templateHttpErr500_RS}, + }) +} + +// UnauthorizedError writes HTTP error 401 to w. +func UnauthorizedResponse(w http.ResponseWriter, req *http.Request) { + w.Header().Set("WWW-Authenticate", "Bearer") + ErrorResponse(w, req, http.StatusUnauthorized, []HttpErrorDesc{ + {"en", templateHttpErr401_EN}, + {"rs", templateHttpErr401_RS}, + }) +} diff --git a/http_utility.go b/http_utility.go deleted file mode 100644 index 07a79ab..0000000 --- a/http_utility.go +++ /dev/null @@ -1,165 +0,0 @@ -package webutility - -import ( - "encoding/json" - "net/http" -) - -type webError struct { - Request string `json:"request"` - Error string `json:"error"` -} - -// NotFoundHandler writes HTTP error 404 to w. -func NotFoundHandler(w http.ResponseWriter, req *http.Request) { - SetDefaultHeaders(w) - if req.Method == "OPTIONS" { - return - } - NotFound(w, req, "Not found") -} - -// SetDefaultHeaders set's default headers for an HTTP response. -func SetDefaultHeaders(w http.ResponseWriter) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "POST, GET, PUT, DELETE, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", `Accept, Content-Type, - Content-Length, Accept-Encoding, X-CSRF-Token, Authorization`) - w.Header().Set("Content-Type", "application/json; charset=utf-8") -} - -func ReqLocale(req *http.Request, dflt string) string { - loc := req.FormValue("locale") - if loc == "" { - return dflt - } - return loc -} - -// 2xx -func Success(w http.ResponseWriter, payload *Payload, code int) { - w.WriteHeader(code) - if payload != nil { - json.NewEncoder(w).Encode(*payload) - } -} - -// 200 -func OK(w http.ResponseWriter, payload *Payload) { - Success(w, payload, http.StatusOK) -} - -// 201 -func Created(w http.ResponseWriter, payload *Payload) { - Success(w, payload, http.StatusCreated) -} - -// 4xx; 5xx -func Error(w http.ResponseWriter, r *http.Request, code int, err string) { - werr := webError{Error: err, Request: r.Method + " " + r.RequestURI} - w.WriteHeader(code) - json.NewEncoder(w).Encode(werr) -} - -// 400 -func BadRequest(w http.ResponseWriter, r *http.Request, err string) { - Error(w, r, http.StatusBadRequest, err) -} - -// 404 -func NotFound(w http.ResponseWriter, r *http.Request, err string) { - Error(w, r, http.StatusNotFound, err) -} - -// 401 -func Unauthorized(w http.ResponseWriter, r *http.Request, err string) { - Error(w, r, http.StatusUnauthorized, err) -} - -// 403 -func Forbidden(w http.ResponseWriter, r *http.Request, err string) { - Error(w, r, http.StatusForbidden, err) -} - -// 403 -func Conflict(w http.ResponseWriter, r *http.Request, err string) { - Error(w, r, http.StatusConflict, err) -} - -// 500 -func InternalServerError(w http.ResponseWriter, r *http.Request, err string) { - Error(w, r, http.StatusInternalServerError, err) -} - -/// -/// Old API -/// - -const ( - templateHttpErr500_EN = "An internal server error has occurred." - templateHttpErr500_RS = "Došlo je do greške na serveru." - templateHttpErr400_EN = "Bad request." - templateHttpErr400_RS = "Neispravan zahtev." - templateHttpErr404_EN = "Resource not found." - templateHttpErr404_RS = "Resurs nije pronadjen." - templateHttpErr401_EN = "Unauthorized request." - templateHttpErr401_RS = "Neautorizovan zahtev." -) - -type httpError struct { - Error []HttpErrorDesc `json:"error"` - Request string `json:"request"` -} - -type HttpErrorDesc struct { - Lang string `json:"lang"` - Desc string `json:"description"` -} - -// DeliverPayload encodes payload as JSON to w. -func DeliverPayload(w http.ResponseWriter, payload Payload) { - // Don't write status OK in the headers here. Leave it up for the caller. - // E.g. Status 201. - json.NewEncoder(w).Encode(payload) - payload.Data = nil -} - -// ErrorResponse writes HTTP error to w. -func ErrorResponse(w http.ResponseWriter, r *http.Request, code int, desc []HttpErrorDesc) { - err := httpError{desc, r.Method + " " + r.RequestURI} - w.WriteHeader(code) - json.NewEncoder(w).Encode(err) -} - -// NotFoundResponse writes HTTP error 404 to w. -func NotFoundResponse(w http.ResponseWriter, req *http.Request) { - ErrorResponse(w, req, http.StatusNotFound, []HttpErrorDesc{ - {"en", templateHttpErr404_EN}, - {"rs", templateHttpErr404_RS}, - }) -} - -// BadRequestResponse writes HTTP error 400 to w. -func BadRequestResponse(w http.ResponseWriter, req *http.Request) { - ErrorResponse(w, req, http.StatusBadRequest, []HttpErrorDesc{ - {"en", templateHttpErr400_EN}, - {"rs", templateHttpErr400_RS}, - }) -} - -// InternalSeverErrorResponse writes HTTP error 500 to w. -func InternalServerErrorResponse(w http.ResponseWriter, req *http.Request) { - ErrorResponse(w, req, http.StatusInternalServerError, []HttpErrorDesc{ - {"en", templateHttpErr500_EN}, - {"rs", templateHttpErr500_RS}, - }) -} - -// UnauthorizedError writes HTTP error 401 to w. -func UnauthorizedResponse(w http.ResponseWriter, req *http.Request) { - w.Header().Set("WWW-Authenticate", "Bearer") - ErrorResponse(w, req, http.StatusUnauthorized, []HttpErrorDesc{ - {"en", templateHttpErr401_EN}, - {"rs", templateHttpErr401_RS}, - }) -} diff --git a/json.go b/json.go new file mode 100644 index 0000000..f9b88f0 --- /dev/null +++ b/json.go @@ -0,0 +1,339 @@ +package webutility + +import ( + "database/sql" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "sync" + "time" + + "git.to-net.rs/marko.tikvic/gologger" +) + +var ( + mu = &sync.Mutex{} + metadata = make(map[string]Payload) + updateQue = make(map[string][]byte) + + metadataDB *sql.DB + activeProject string + + inited bool + driver string +) + +var logger *gologger.Logger + +func init() { + var err error + logger, err = gologger.New("webutility", gologger.MaxLogSize100KB) + if err != nil { + fmt.Printf("webutility: %s\n", err.Error()) + } +} + +type LangMap map[string]map[string]string + +type Field struct { + Parameter string `json:"param"` + Type string `json:"type"` + Visible bool `json:"visible"` + Editable bool `json:"editable"` +} + +type CorrelationField struct { + Result string `json:"result"` + Elements []string `json:"elements"` + Type string `json:"type"` +} + +type Translation struct { + Language string `json:"language"` + FieldsLabels map[string]string `json:"fieldsLabels"` +} + +type Payload struct { + Method string `json:"method"` + Params map[string]string `json:"params"` + Lang []Translation `json:"lang"` + Fields []Field `json:"fields"` + Correlations []CorrelationField `json:"correlationFields"` + IdField string `json:"idField"` + + // Data holds JSON payload. + // It can't be used for itteration. + Data interface{} `json:"data"` +} + +// NewPayload returs a payload sceleton for entity described with etype. +func NewPayload(r *http.Request, etype string) Payload { + pload := metadata[etype] + pload.Method = r.Method + " " + r.RequestURI + return pload +} + +// DecodeJSON decodes JSON data from r to v. +// Returns an error if it fails. +func DecodeJSON(r io.Reader, v interface{}) error { + return json.NewDecoder(r).Decode(v) +} + +// InitPayloadsMetadata loads all payloads' information into 'metadata' variable. +func InitPayloadsMetadata(drv string, db *sql.DB, project string) error { + if drv != "ora" && drv != "mysql" { + return errors.New("driver not supported") + } + driver = drv + metadataDB = db + activeProject = project + + mu.Lock() + defer mu.Unlock() + err := initMetadata(project) + if err != nil { + return err + } + inited = true + + return nil +} + +func EnableHotloading(interval int) { + if interval > 0 { + go hotload(interval) + } +} + +func GetMetadataForAllEntities() map[string]Payload { + return metadata +} + +func GetMetadataForEntity(t string) (Payload, bool) { + p, ok := metadata[t] + return p, ok +} + +func QueEntityModelUpdate(entityType string, v interface{}) { + updateQue[entityType], _ = json.Marshal(v) +} + +func UpdateEntityModels(command string) (total, upd, add int, err error) { + if command != "force" && command != "missing" { + return total, 0, 0, errors.New("webutility: unknown command: " + command) + } + + if !inited { + return 0, 0, 0, errors.New("webutility: metadata not initialized but update was tried.") + } + + total = len(updateQue) + + toUpdate := make([]string, 0) + toAdd := make([]string, 0) + + for k, _ := range updateQue { + if _, exists := metadata[k]; exists { + if command == "force" { + toUpdate = append(toUpdate, k) + } + } else { + toAdd = append(toAdd, k) + } + } + + var uStmt *sql.Stmt + if driver == "ora" { + uStmt, err = metadataDB.Prepare("update entities set entity_model = :1 where entity_type = :2") + if err != nil { + return + } + } else if driver == "mysql" { + uStmt, err = metadataDB.Prepare("update entities set entity_model = ? where entity_type = ?") + if err != nil { + return + } + } + for _, k := range toUpdate { + //fmt.Printf("Updating: %s\n", k) + //fmt.Printf("New model: %s\n", updateQue[k]) + _, err = uStmt.Exec(string(updateQue[k]), k) + if err != nil { + logger.Log("webutility: %v\n", err) + return + } + upd++ + } + + blankPayload, _ := json.Marshal(Payload{}) + var iStmt *sql.Stmt + if driver == "ora" { + iStmt, err = metadataDB.Prepare("insert into entities(projekat, metadata, entity_type, entity_model) values(:1, :2, :3, :4)") + if err != nil { + return + } + } else if driver == "mysql" { + iStmt, err = metadataDB.Prepare("insert into entities(projekat, metadata, entity_type, entity_model) values(?, ?, ?, ?)") + if err != nil { + return + } + } + for _, k := range toAdd { + _, err = iStmt.Exec(activeProject, string(blankPayload), k, string(updateQue[k])) + if err != nil { + logger.Log("webutility: %v\n", err) + return + } + metadata[k] = Payload{} + add++ + } + + return total, upd, add, nil +} + +func initMetadata(project string) error { + rows, err := metadataDB.Query(`select + entity_type, + metadata + from entities + where projekat = ` + fmt.Sprintf("'%s'", project)) + if err != nil { + return err + } + defer rows.Close() + + count := 0 + success := 0 + if len(metadata) > 0 { + metadata = nil + } + metadata = make(map[string]Payload) + for rows.Next() { + var name, load string + rows.Scan(&name, &load) + + p := Payload{} + err := json.Unmarshal([]byte(load), &p) + if err != nil { + logger.Log("webutility: couldn't init: '%s' metadata: %s:\n%s\n", name, err.Error(), load) + } else { + success++ + metadata[name] = p + } + count++ + } + perc := float32(success) / float32(count) * 100.0 + logger.Log("webutility: loaded %d/%d (%.1f%%) entities\n", success, count, perc) + + return nil +} + +func hotload(n int) { + entityScan := make(map[string]int64) + firstCheck := true + for { + time.Sleep(time.Duration(n) * time.Second) + rows, err := metadataDB.Query(`select + ora_rowscn, + entity_type + from entities where projekat = ` + fmt.Sprintf("'%s'", activeProject)) + if err != nil { + logger.Log("webutility: hotload failed: %v\n", err) + time.Sleep(time.Duration(n) * time.Second) + continue + } + + var toRefresh []string + for rows.Next() { + var scanID int64 + var entity string + rows.Scan(&scanID, &entity) + oldID, ok := entityScan[entity] + if !ok || oldID != scanID { + entityScan[entity] = scanID + toRefresh = append(toRefresh, entity) + } + } + rows.Close() + + if rows.Err() != nil { + logger.Log("webutility: hotload rset error: %v\n", rows.Err()) + time.Sleep(time.Duration(n) * time.Second) + continue + } + + if len(toRefresh) > 0 && !firstCheck { + mu.Lock() + refreshMetadata(toRefresh) + mu.Unlock() + } + if firstCheck { + firstCheck = false + } + } +} + +func refreshMetadata(entities []string) { + for _, e := range entities { + fmt.Printf("refreshing %s\n", e) + rows, err := metadataDB.Query(`select + metadata + from entities + where projekat = ` + fmt.Sprintf("'%s'", activeProject) + + ` and entity_type = ` + fmt.Sprintf("'%s'", e)) + + if err != nil { + logger.Log("webutility: refresh: prep: %v\n", err) + rows.Close() + continue + } + + for rows.Next() { + var load string + rows.Scan(&load) + p := Payload{} + err := json.Unmarshal([]byte(load), &p) + if err != nil { + logger.Log("webutility: couldn't refresh: '%s' metadata: %s\n%s\n", e, err.Error(), load) + } else { + metadata[e] = p + } + } + rows.Close() + } +} + +/* +func ModifyMetadataForEntity(entityType string, p *Payload) error { + md, err := json.Marshal(*p) + if err != nil { + return err + } + + mu.Lock() + defer mu.Unlock() + _, err = metadataDB.PrepAndExe(`update entities set + metadata = :1 + where projekat = :2 + and entity_type = :3`, + string(md), + activeProject, + entityType) + if err != nil { + return err + } + return nil +} + +func DeleteEntityModel(entityType string) error { + _, err := metadataDB.PrepAndExe("delete from entities where entity_type = :1", entityType) + if err == nil { + mu.Lock() + delete(metadata, entityType) + mu.Unlock() + } + return err +} +*/ diff --git a/json_utility.go b/json_utility.go deleted file mode 100644 index 7af4abf..0000000 --- a/json_utility.go +++ /dev/null @@ -1,339 +0,0 @@ -package webutility - -import ( - "database/sql" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "sync" - "time" - - "git.to-net.rs/marko.tikvic/gologger" -) - -var ( - mu = &sync.Mutex{} - metadata = make(map[string]Payload) - updateQue = make(map[string][]byte) - - metadataDB *sql.DB - activeProject string - - inited bool - driver string -) - -var logger *gologger.Logger - -func init() { - var err error - logger, err = gologger.New("webutility", gologger.MaxLogSize100KB) - if err != nil { - fmt.Printf("webutility: %s\n", err.Error()) - } -} - -type LangMap map[string]map[string]string - -type Field struct { - Parameter string `json:"param"` - Type string `json:"type"` - Visible bool `json:"visible"` - Editable bool `json:"editable"` -} - -type CorrelationField struct { - Result string `json:"result"` - Elements []string `json:"elements"` - Type string `json:"type"` -} - -type Translation struct { - Language string `json:"language"` - FieldsLabels map[string]string `json:"fieldsLabels"` -} - -type Payload struct { - Method string `json:"method"` - Params map[string]string `json:"params"` - Lang []Translation `json:"lang"` - Fields []Field `json:"fields"` - Correlations []CorrelationField `json:"correlationFields"` - IdField string `json:"idField"` - - // Data holds JSON payload. - // It can't be used for itteration. - Data interface{} `json:"data"` -} - -// NewPayload returs a payload sceleton for entity described with etype. -func NewPayload(r *http.Request, etype string) Payload { - pload := metadata[etype] - pload.Method = r.Method + " " + r.RequestURI - return pload -} - -// DecodeJSON decodes JSON data from r to v. -// Returns an error if it fails. -func DecodeJSON(r io.Reader, v interface{}) error { - return json.NewDecoder(r).Decode(v) -} - -// InitPayloadsMetadata loads all payloads' information into 'metadata' variable. -func InitPayloadsMetadata(drv string, db *sql.DB, project string) error { - if drv != "ora" && drv != "mysql" { - return errors.New("driver not supported") - } - driver = drv - metadataDB = db - activeProject = project - - mu.Lock() - defer mu.Unlock() - err := initMetadata(project) - if err != nil { - return err - } - inited = true - - return nil -} - -func EnableHotloading(interval int) { - if interval > 0 { - go hotload(interval) - } -} - -func GetMetadataForAllEntities() map[string]Payload { - return metadata -} - -func GetMetadataForEntity(t string) (Payload, bool) { - p, ok := metadata[t] - return p, ok -} - -func QueEntityModelUpdate(entityType string, v interface{}) { - updateQue[entityType], _ = json.Marshal(v) -} - -func initMetadata(project string) error { - rows, err := metadataDB.Query(`select - entity_type, - metadata - from entities - where projekat = ` + fmt.Sprintf("'%s'", project)) - if err != nil { - return err - } - defer rows.Close() - - count := 0 - success := 0 - if len(metadata) > 0 { - metadata = nil - } - metadata = make(map[string]Payload) - for rows.Next() { - var name, load string - rows.Scan(&name, &load) - - p := Payload{} - err := json.Unmarshal([]byte(load), &p) - if err != nil { - logger.Log("webutility: couldn't init: '%s' metadata: %s\n%s\n", name, err.Error(), load) - } else { - success++ - metadata[name] = p - } - count++ - } - perc := float32(success/count) * 100.0 - logger.Log("webutility: loaded %d/%d (%.1f%%) entities\n", success, count, perc) - - return nil -} - -func hotload(n int) { - entityScan := make(map[string]int64) - firstCheck := true - for { - time.Sleep(time.Duration(n) * time.Second) - rows, err := metadataDB.Query(`select - ora_rowscn, - entity_type - from entities where projekat = ` + fmt.Sprintf("'%s'", activeProject)) - if err != nil { - logger.Log("webutility: hotload failed: %v\n", err) - time.Sleep(time.Duration(n) * time.Second) - continue - } - - var toRefresh []string - for rows.Next() { - var scanID int64 - var entity string - rows.Scan(&scanID, &entity) - oldID, ok := entityScan[entity] - if !ok || oldID != scanID { - entityScan[entity] = scanID - toRefresh = append(toRefresh, entity) - } - } - rows.Close() - - if rows.Err() != nil { - logger.Log("webutility: hotload rset error: %v\n", rows.Err()) - time.Sleep(time.Duration(n) * time.Second) - continue - } - - if len(toRefresh) > 0 && !firstCheck { - mu.Lock() - refreshMetadata(toRefresh) - mu.Unlock() - } - if firstCheck { - firstCheck = false - } - } -} - -func refreshMetadata(entities []string) { - for _, e := range entities { - fmt.Printf("refreshing %s\n", e) - rows, err := metadataDB.Query(`select - metadata - from entities - where projekat = ` + fmt.Sprintf("'%s'", activeProject) + - ` and entity_type = ` + fmt.Sprintf("'%s'", e)) - - if err != nil { - logger.Log("webutility: refresh: prep: %v\n", err) - rows.Close() - continue - } - - for rows.Next() { - var load string - rows.Scan(&load) - p := Payload{} - err := json.Unmarshal([]byte(load), &p) - if err != nil { - logger.Log("webutility: couldn't refresh: '%s' metadata: %s\n%s\n", e, err.Error(), load) - } else { - metadata[e] = p - } - } - rows.Close() - } -} - -func UpdateEntityModels(command string) (total, upd, add int, err error) { - if command != "force" && command != "missing" { - return total, 0, 0, errors.New("webutility: unknown command: " + command) - } - - if !inited { - return 0, 0, 0, errors.New("webutility: metadata not initialized but update was tried.") - } - - total = len(updateQue) - - toUpdate := make([]string, 0) - toAdd := make([]string, 0) - - for k, _ := range updateQue { - if _, exists := metadata[k]; exists { - if command == "force" { - toUpdate = append(toUpdate, k) - } - } else { - toAdd = append(toAdd, k) - } - } - - var uStmt *sql.Stmt - if driver == "ora" { - uStmt, err = metadataDB.Prepare("update entities set entity_model = :1 where entity_type = :2") - if err != nil { - return - } - } else if driver == "mysql" { - uStmt, err = metadataDB.Prepare("update entities set entity_model = ? where entity_type = ?") - if err != nil { - return - } - } - for _, k := range toUpdate { - //fmt.Printf("Updating: %s\n", k) - //fmt.Printf("New model: %s\n", updateQue[k]) - _, err = uStmt.Exec(string(updateQue[k]), k) - if err != nil { - logger.Log("webutility: %v\n", err) - return - } - upd++ - } - - blankPayload, _ := json.Marshal(Payload{}) - var iStmt *sql.Stmt - if driver == "ora" { - iStmt, err = metadataDB.Prepare("INSERT INTO ENTITIES(PROJEKAT, METADATA, ENTITY_TYPE, ENTITY_MODEL) VALUES(:1, :2, :3, :4)") - if err != nil { - return - } - } else if driver == "mysql" { - iStmt, err = metadataDB.Prepare("INSERT INTO ENTITIES(PROJEKAT, METADATA, ENTITY_TYPE, ENTITY_MODEL) VALUES(?, ?, ?, ?)") - if err != nil { - return - } - } - for _, k := range toAdd { - _, err = iStmt.Exec(activeProject, string(blankPayload), k, string(updateQue[k])) - if err != nil { - logger.Log("webutility: %v\n", err) - return - } - metadata[k] = Payload{} - add++ - } - - return total, upd, add, nil -} - -/* -func ModifyMetadataForEntity(entityType string, p *Payload) error { - md, err := json.Marshal(*p) - if err != nil { - return err - } - - mu.Lock() - defer mu.Unlock() - _, err = metadataDB.PrepAndExe(`update entities set - metadata = :1 - where projekat = :2 - and entity_type = :3`, - string(md), - activeProject, - entityType) - if err != nil { - return err - } - return nil -} - -func DeleteEntityModel(entityType string) error { - _, err := metadataDB.PrepAndExe("delete from entities where entity_type = :1", entityType) - if err == nil { - mu.Lock() - delete(metadata, entityType) - mu.Unlock() - } - return err -} -*/ diff --git a/locale_utility.go b/locale_utility.go deleted file mode 100644 index 298f4b9..0000000 --- a/locale_utility.go +++ /dev/null @@ -1,66 +0,0 @@ -package webutility - -import ( - "encoding/json" - "errors" - "io/ioutil" -) - -type Dictionary struct { - locales map[string]map[string]string - supported []string - defaultLocale string -} - -func NewDictionary() Dictionary { - return Dictionary{ - locales: map[string]map[string]string{}, - } -} - -func (d *Dictionary) AddLocale(loc, filePath string) error { - file, err := ioutil.ReadFile(filePath) - if err != nil { - return err - } - - var data interface{} - err = json.Unmarshal(file, &data) - if err != nil { - return err - } - - l := map[string]string{} - for k, v := range data.(map[string]interface{}) { - l[k] = v.(string) - } - d.locales[loc] = l - d.supported = append(d.supported, loc) - - return nil -} - -func (d *Dictionary) Translate(loc, key string) string { - return d.locales[loc][key] -} - -func (d *Dictionary) HasLocale(loc string) bool { - for _, v := range d.supported { - if v == loc { - return true - } - } - return false -} - -func (d *Dictionary) SetDefaultLocale(loc string) error { - if !d.HasLocale(loc) { - return errors.New("dictionary does not contain translations for " + loc) - } - d.defaultLocale = loc - return nil -} - -func (d *Dictionary) GetDefaultLocale() string { - return d.defaultLocale -} diff --git a/localization.go b/localization.go new file mode 100644 index 0000000..298f4b9 --- /dev/null +++ b/localization.go @@ -0,0 +1,66 @@ +package webutility + +import ( + "encoding/json" + "errors" + "io/ioutil" +) + +type Dictionary struct { + locales map[string]map[string]string + supported []string + defaultLocale string +} + +func NewDictionary() Dictionary { + return Dictionary{ + locales: map[string]map[string]string{}, + } +} + +func (d *Dictionary) AddLocale(loc, filePath string) error { + file, err := ioutil.ReadFile(filePath) + if err != nil { + return err + } + + var data interface{} + err = json.Unmarshal(file, &data) + if err != nil { + return err + } + + l := map[string]string{} + for k, v := range data.(map[string]interface{}) { + l[k] = v.(string) + } + d.locales[loc] = l + d.supported = append(d.supported, loc) + + return nil +} + +func (d *Dictionary) Translate(loc, key string) string { + return d.locales[loc][key] +} + +func (d *Dictionary) HasLocale(loc string) bool { + for _, v := range d.supported { + if v == loc { + return true + } + } + return false +} + +func (d *Dictionary) SetDefaultLocale(loc string) error { + if !d.HasLocale(loc) { + return errors.New("dictionary does not contain translations for " + loc) + } + d.defaultLocale = loc + return nil +} + +func (d *Dictionary) GetDefaultLocale() string { + return d.defaultLocale +}