diff --git a/cmd/http-server/main.go b/cmd/http-server/main.go index 5119184..09e2322 100644 --- a/cmd/http-server/main.go +++ b/cmd/http-server/main.go @@ -147,6 +147,6 @@ func main() { c.Redirect(301, "/swagger/index.html") }) - go tokens.StartTokens() + tokens.Init() r.Run("127.0.0.1:3000") } diff --git a/middleware/auth.go b/middleware/auth.go index 416f68d..daccc6b 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -2,10 +2,10 @@ package middleware import ( "errors" - "log" "net/http" "git.qowevisa.me/Qowevisa/fin-check-api/consts" + "git.qowevisa.me/Qowevisa/fin-check-api/db" "git.qowevisa.me/Qowevisa/fin-check-api/tokens" "git.qowevisa.me/Qowevisa/fin-check-api/types" "github.com/gin-gonic/gin" @@ -20,17 +20,13 @@ func AuthMiddleware() gin.HandlerFunc { c.Abort() return } - if !tokens.ValidateSessionToken(token) { + var session *db.Session + if validated, tmpSession := tokens.ValidateAndGetSessionToken(token); !validated { c.JSON(401, types.ErrorResponse{Message: "Invalid authorization cookie"}) c.Abort() return - } - session, err := tokens.GetSession(token) - if err != nil { - log.Printf("ERROR: tokens.GetSession: %v\n", err) - c.JSON(500, types.ErrorResponse{Message: "Server error"}) - c.Abort() - return + } else { + session = tmpSession } c.Set("UserID", session.UserID) diff --git a/tokens/dispatcher.go b/tokens/dispatcher.go deleted file mode 100644 index 9443735..0000000 --- a/tokens/dispatcher.go +++ /dev/null @@ -1,165 +0,0 @@ -package tokens - -import ( - "crypto/rand" - "encoding/base64" - "errors" - "log" - "strings" - "sync" - "time" -) - -type Token struct { - Id uint - Val string - LastActive time.Time -} - -var ( - ActiveDur = time.Duration(time.Hour) -) - -func (t Token) IsExpired() bool { - return time.Now().Sub(t.LastActive) >= ActiveDur -} - -type TokensMapMu struct { - Initialized bool - Tokmap map[uint]*Token - TokmapRev map[string]*Token - Mu sync.RWMutex -} - -var toks TokensMapMu - -// NOTE: should be launch with a goroutine -// NOTE: it cannot die -func StartTokens() { - if toks.Initialized { - return - } - toks.Tokmap = make(map[uint]*Token) - toks.TokmapRev = make(map[string]*Token) - toks.Initialized = true - for { - // - toks.Mu.Lock() - for id, token := range toks.Tokmap { - if token == nil { - log.Printf("DAFUQ: 001\n") - delete(toks.Tokmap, id) - continue - } - if token.IsExpired() { - val := token.Val - delete(toks.Tokmap, id) - delete(toks.TokmapRev, val) - } - } - toks.Mu.Unlock() - // - time.Sleep(time.Minute) - } -} - -var ( - ERROR_DONT_HAVE_TOKEN = errors.New("Don't have token for this user") - ERROR_ALREADY_HAVE_TOKEN = errors.New("Already have token") -) - -func GetToken(id uint) (*Token, error) { - toks.Mu.Lock() - defer toks.Mu.Unlock() - val, exists := toks.Tokmap[id] - if !exists { - return nil, ERROR_DONT_HAVE_TOKEN - } - val.LastActive = time.Now() - return val, nil -} - -func GetID(token string) (uint, error) { - toks.Mu.RLock() - val, exists := toks.TokmapRev[token] - toks.Mu.RUnlock() - if !exists { - return 0, ERROR_DONT_HAVE_TOKEN - } - return val.Id, nil -} - -func haveToken(id uint) bool { - toks.Mu.RLock() - _, exists := toks.Tokmap[id] - toks.Mu.RUnlock() - return exists -} - -func UpdateLastActive(id uint) error { - if !haveToken(id) { - return ERROR_DONT_HAVE_TOKEN - } - toks.Mu.Lock() - val := toks.Tokmap[id] - val.LastActive = time.Now() - toks.Tokmap[id] = val - toks.Mu.Unlock() - return nil -} - -func haveTokenVal(val string) bool { - toks.Mu.RLock() - _, exists := toks.TokmapRev[val] - toks.Mu.RUnlock() - return exists -} - -func generateRandomString(length int) string { - bytes := make([]byte, length) - if _, err := rand.Read(bytes); err != nil { - log.Printf("generateRandomString: %v", err) - } - return base64.URLEncoding.EncodeToString(bytes) -} - -func generateTokenVal() string { - for { - tok := generateRandomString(32) - trimedToken := strings.Trim(tok, "=") - if !haveTokenVal(trimedToken) { - return trimedToken - } - } -} - -func AddToken(id uint) (*Token, error) { - toks.Mu.RLock() - _, exists := toks.Tokmap[id] - toks.Mu.RUnlock() - if exists { - // return nil, ERROR_ALREADY_HAVE_TOKEN - } - val := generateTokenVal() - token := &Token{ - Id: id, - Val: val, - LastActive: time.Now(), - } - toks.Mu.Lock() - toks.Tokmap[id] = token - toks.TokmapRev[val] = token - toks.Mu.Unlock() - return token, nil -} - -func AmIAllowed(token string) bool { - toks.Mu.Lock() - defer toks.Mu.Unlock() - val, exists := toks.TokmapRev[token] - if !exists { - return false - } - val.LastActive = time.Now() - return true -} diff --git a/tokens/session_cache.go b/tokens/session_cache.go new file mode 100644 index 0000000..e066cb7 --- /dev/null +++ b/tokens/session_cache.go @@ -0,0 +1,95 @@ +package tokens + +import ( + "crypto/rand" + "encoding/base64" + "log" + "strings" + "sync" + "time" + + "git.qowevisa.me/Qowevisa/fin-check-api/db" +) + +type Token struct { + Id uint + Val string + LastActive time.Time +} + +type SessiomMapMu struct { + Initialized bool + SessionMap map[string]*db.Session + Mu sync.RWMutex +} + +var sessionCache SessiomMapMu + +// NOTE: should be launch with a goroutine +// NOTE: it cannot die +func Init() error { + sessionCache.SessionMap = make(map[string]*db.Session) + var dbSessions []*db.Session + if err := db.Connect().Find(&dbSessions).Error; err != nil { + return err + } + log.Printf("you what len(dbSessions) = %d", len(dbSessions)) + for _, dbSession := range dbSessions { + sessionCache.SessionMap[dbSession.ID] = dbSession + } + sessionCache.Initialized = true + return nil +} + +func (s *SessiomMapMu) HaveSession(sessionID string) bool { + s.Mu.RLock() + val, exists := s.SessionMap[sessionID] + log.Printf("HaveSession s.SessionMap[sessionID] = %v | %b\n", val, exists) + s.Mu.RUnlock() + return exists +} + +func (s *SessiomMapMu) AddSession(session *db.Session) { + s.Mu.Lock() + s.SessionMap[session.ID] = session + log.Printf("AddSession s.SessionMap[sessionID] = %v \n", session) + s.Mu.Unlock() +} + +func (s *SessiomMapMu) GetSession(sessionID string) *db.Session { + s.Mu.RLock() + val, exists := s.SessionMap[sessionID] + log.Printf("GetSession s.SessionMap[sessionID] = %v | %b\n", val, exists) + s.Mu.RUnlock() + if !exists { + return nil + } + return val +} + +func generateRandomString(length int) string { + bytes := make([]byte, length) + if _, err := rand.Read(bytes); err != nil { + log.Printf("generateRandomString: %v", err) + } + return base64.URLEncoding.EncodeToString(bytes) +} + +func generateTokenVal() string { + for { + tok := generateRandomString(32) + trimedToken := strings.Trim(tok, "=") + // TODO: do some thing so it can check if user will have the same token + return trimedToken + } +} + +func AddToken(id uint) (*Token, error) { + val := generateTokenVal() + token := &Token{ + Id: id, + Val: val, + LastActive: time.Now(), + } + return token, nil +} diff --git a/tokens/sessions.go b/tokens/sessions.go index 96a867a..7219503 100644 --- a/tokens/sessions.go +++ b/tokens/sessions.go @@ -3,13 +3,14 @@ package tokens import ( "crypto/sha256" "encoding/base64" + "errors" "log" "time" "git.qowevisa.me/Qowevisa/fin-check-api/db" ) -const SESSION_DURATION = 24 * time.Hour +const SESSION_DURATION = (24 * time.Hour) func CreateSessionFromToken(token string, userID uint) error { sessionID := getSessionIDFromToken(token) @@ -19,36 +20,37 @@ func CreateSessionFromToken(token string, userID uint) error { UserID: userID, ExpireAt: time.Now().Add(SESSION_DURATION), } + sessionCache.AddSession(session) if err := dbc.Create(session).Error; err != nil { return err } return nil } -func ValidateSessionToken(token string) bool { +func ValidateAndGetSessionToken(token string) (bool, *db.Session) { sessionID := getSessionIDFromToken(token) dbc := db.Connect() - session := &db.Session{} - if err := dbc.Find(session, db.Session{ID: sessionID}).Error; err != nil { - log.Printf("DBERROR: %v\n", err) - return false - } - if session.ID == "" { - return false + session := sessionCache.GetSession(sessionID) + if session == nil || session.ID == "" { + log.Printf("Internal error TOKENS.SESSIONS.ValidateSessionToken.1\n") + return false, nil } if session.ExpireAt.Unix() < time.Now().Unix() { dbc.Unscoped().Delete(session) - return false + return false, nil } - return session.ID != "" + return session.ID != "", session } +var ( + ERROR_SESSION_NOT_FOUND = errors.New("Can't find session with this token") +) + func GetSession(token string) (*db.Session, error) { sessionID := getSessionIDFromToken(token) - dbc := db.Connect() - session := &db.Session{} - if err := dbc.Find(session, db.Session{ID: sessionID}).Error; err != nil { - return nil, err + session := sessionCache.GetSession(sessionID) + if session == nil { + return nil, ERROR_SESSION_NOT_FOUND } return session, nil }