package middleware import ( "net/http" "strings" "time" "music-server/internal/db/repository" "music-server/internal/logging" "github.com/jackc/pgx/v5/pgxpool" "github.com/labstack/echo/v5" "go.uber.org/zap" ) // TokenAuthMiddleware returns an Echo middleware that validates session tokens func TokenAuthMiddleware(pool *pgxpool.Pool) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) error { // Extract token from Authorization header authHeader := c.Request().Header.Get("Authorization") if authHeader == "" { return c.JSON(http.StatusUnauthorized, map[string]string{"error": "Authorization header required"}) } // Bearer token format parts := strings.Split(authHeader, " ") if len(parts) != 2 || parts[0] != "Bearer" { return c.JSON(http.StatusUnauthorized, map[string]string{"error": "Invalid authorization format. Use: Bearer "}) } token := parts[1] queries := repository.New(pool) session, err := queries.GetSession(c.Request().Context(), token) if err != nil { logging.GetLogger().Warn("Invalid token attempt", zap.String("token", token), zap.String("ip", c.RealIP()), zap.String("error", err.Error()), ) return c.JSON(http.StatusUnauthorized, map[string]string{"error": "Invalid or expired token"}) } // Check if token is expired if time.Now().After(session.ExpiresAt.Time) { // Clean up expired session in background go func() { queries.DeleteSession(c.Request().Context(), token) }() return c.JSON(http.StatusUnauthorized, map[string]string{"error": "Token expired"}) } // Add session to request context for potential use by handlers c.Set("session", session) return next(c) } } } // TokenIPCheckMiddleware checks if the request IP matches the session IP func TokenIPCheckMiddleware(next echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) error { sessionVal := c.Get("session") if sessionVal == nil { return c.JSON(http.StatusUnauthorized, map[string]string{"error": "No session in context"}) } session := sessionVal.(repository.Session) if session.IpAddress != c.RealIP() { logging.GetLogger().Warn("Token IP mismatch", zap.String("token_ip", session.IpAddress), zap.String("request_ip", c.RealIP()), ) } return next(c) } }