From 24a911133336a313e710d9253038202393ecfb84 Mon Sep 17 00:00:00 2001 From: Sebastian Date: Mon, 1 Jun 2026 18:50:05 +0200 Subject: [PATCH] feat: Remove global db.Dbpool with dependency injection (Phase 0) - Add Database struct in internal/db/database.go with Pool, Ctx, and RunMigrations() - Update server.go to use Database struct with NewServerInstance() - Add backend.go with InitBackend(), BackendRepo(), BackendCtx(), BackendPool() - Update music.go and sync.go to use BackendRepo() and BackendCtx() instead of db.Dbpool/db.Ctx - Update token_handler.go to accept pool parameter - Update routes.go to use s.db.Pool for middleware - Update cmd/main.go to use NewServerInstance() and HTTPServer() - Update test_helpers.go to initialize backend with test database - Update test files to use backend.BackendPool() and backend.BackendCtx() Benefits: - Easier to mock database for unit tests - Follows Go best practices (dependency injection) - Better architecture with explicit dependencies - RunMigrations() replaces old Migrate_db() function Note: Global db.Dbpool and db.Ctx still exist in dbHelper.go for backward compatibility with test_helpers.go, but production code no longer uses them. Generated by Mistral Vibe. Co-Authored-By: Mistral Vibe --- cmd/main.go | 22 +++-- internal/backend/backend.go | 40 ++++++++ internal/backend/music.go | 33 ++++--- internal/backend/sync.go | 47 +++++---- internal/db/database.go | 121 +++++++++++++++++++++++ internal/server/routes.go | 3 +- internal/server/server.go | 83 +++++++++++----- internal/server/sync_handler_test.go | 30 +++--- internal/server/test_helpers.go | 20 +++- internal/server/token_handler.go | 17 ++-- internal/server/zz_music_handler_test.go | 4 +- 11 files changed, 320 insertions(+), 100 deletions(-) create mode 100644 internal/backend/backend.go create mode 100644 internal/db/database.go diff --git a/cmd/main.go b/cmd/main.go index 8c47fdc..e3f2571 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -2,7 +2,6 @@ package main import ( "context" - "music-server/internal/db" "music-server/internal/logging" "music-server/internal/server" "net/http" @@ -19,9 +18,11 @@ import ( // @description This is a sample server Petstore server. // @termsOfService http://swagger.io/terms/ +// // @contact.name Sebastian Olsson // @contact.email zarnor91@gmail.com +// // @license.name Apache 2.0 // @license.url http://www.apache.org/licenses/LICENSE-2.0.html @@ -34,16 +35,17 @@ func main() { pprof.StartCPUProfile(f) defer pprof.StopCPUProfile()*/ - server := server.NewServer() + appServer := server.NewServerInstance() + httpServer := appServer.HTTPServer() // Create a done channel to signal when the shutdown is complete done := make(chan bool, 1) // Run graceful shutdown in a separate goroutine - go gracefulShutdown(server, done) + go gracefulShutdown(appServer, httpServer, done) - logging.GetLogger().Info("Server starting", zap.String("address", server.Addr)) - err := server.ListenAndServe() + logging.GetLogger().Info("Server starting", zap.String("address", httpServer.Addr)) + err := httpServer.ListenAndServe() if err != nil && err != http.ErrServerClosed { logging.GetLogger().Fatal("HTTP server error", zap.String("error", err.Error())) } @@ -53,7 +55,7 @@ func main() { logging.GetLogger().Info("Graceful shutdown complete") } -func gracefulShutdown(apiServer *http.Server, done chan bool) { +func gracefulShutdown(appServer *server.Server, httpServer *http.Server, done chan bool) { // Create context that listens for the interrupt signal from the OS. ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer stop() @@ -62,13 +64,17 @@ func gracefulShutdown(apiServer *http.Server, done chan bool) { <-ctx.Done() logging.GetLogger().Info("Shutting down gracefully, press Ctrl+C again to force") - db.CloseDb() + + // Close database connection + if appServer != nil && appServer.DB() != nil { + appServer.DB().Close() + } // The context is used to inform the server it has 5 seconds to finish // the request it is currently handling ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - if err := apiServer.Shutdown(ctx); err != nil { + if err := httpServer.Shutdown(ctx); err != nil { logging.GetLogger().Error("Server forced to shutdown with error", zap.String("error", err.Error())) } diff --git a/internal/backend/backend.go b/internal/backend/backend.go new file mode 100644 index 0000000..88833f4 --- /dev/null +++ b/internal/backend/backend.go @@ -0,0 +1,40 @@ +package backend + +import ( + "context" + "music-server/internal/db/repository" + "github.com/jackc/pgx/v5/pgxpool" +) + +// Global variables - these are initialized by InitBackend +var ( + backendPool *pgxpool.Pool + repo *repository.Queries + backendCtx context.Context = context.Background() +) + +// InitBackend initializes the backend package with the database pool. +// This should be called once at application startup. +func InitBackend(pool *pgxpool.Pool) { + backendPool = pool + repo = repository.New(pool) + backendCtx = context.Background() +} + +// BackendCtx returns the context used by backend operations. +// This is exposed for use by the backend functions. +func BackendCtx() context.Context { + return backendCtx +} + +// BackendRepo returns the repository queries instance. +// This is exposed for use by the backend functions. +func BackendRepo() *repository.Queries { + return repo +} + +// BackendPool returns the underlying database pool. +// This is exposed for test utilities that need direct pool access. +func BackendPool() *pgxpool.Pool { + return backendPool +} diff --git a/internal/backend/music.go b/internal/backend/music.go index b298975..c9717b4 100644 --- a/internal/backend/music.go +++ b/internal/backend/music.go @@ -2,7 +2,6 @@ package backend import ( "math/rand" - "music-server/internal/db" "music-server/internal/db/repository" "music-server/internal/logging" "os" @@ -28,18 +27,20 @@ var gamesNew []repository.Game var songQueNew []repository.Song var lastFetchedNew repository.Song -var repo *repository.Queries func initRepo() { - if repo == nil { - repo = repository.New(db.Dbpool) + // This function is kept for backward compatibility + // but now uses the backend package's initialized repo + // If not initialized, this will panic intentionally + if BackendRepo() == nil { + panic("backend not initialized - call backend.InitBackend() first") } } func getAllGames() []repository.Game { if len(gamesNew) == 0 { initRepo() - gamesNew, _ = repo.FindAllGames(db.Ctx) + gamesNew, _ = BackendRepo().FindAllGames(BackendCtx()) } return gamesNew @@ -58,7 +59,7 @@ func Reset() { songQueNew = nil currentSong = -1 initRepo() - gamesNew, _ = repo.FindAllGames(db.Ctx) + gamesNew, _ = BackendRepo().FindAllGames(BackendCtx()) } func AddLatestToQue() { @@ -76,8 +77,8 @@ func AddLatestPlayed() { currentSongData := songQueNew[currentSong] initRepo() - repo.AddGamePlayed(db.Ctx, currentSongData.GameID) - repo.AddSongPlayed(db.Ctx, repository.AddSongPlayedParams{GameID: currentSongData.GameID, SongName: currentSongData.SongName}) + BackendRepo().AddGamePlayed(BackendCtx(), currentSongData.GameID) + BackendRepo().AddSongPlayed(BackendCtx(), repository.AddSongPlayedParams{GameID: currentSongData.GameID, SongName: currentSongData.SongName}) } func SetPlayed(songNumber int) { @@ -86,8 +87,8 @@ func SetPlayed(songNumber int) { } songData := songQueNew[songNumber] initRepo() - repo.AddGamePlayed(db.Ctx, songData.GameID) - repo.AddSongPlayed(db.Ctx, repository.AddSongPlayedParams{GameID: songData.GameID, SongName: songData.SongName}) + BackendRepo().AddGamePlayed(BackendCtx(), songData.GameID) + BackendRepo().AddSongPlayed(BackendCtx(), repository.AddSongPlayedParams{GameID: songData.GameID, SongName: songData.SongName}) } func GetRandomSong() string { @@ -130,7 +131,7 @@ func GetRandomSongClassic() string { var listOfAllSongs []repository.Song for _, game := range gamesNew { - songList, _ := repo.FindSongsFromGame(db.Ctx, game.ID) + songList, _ := BackendRepo().FindSongsFromGame(BackendCtx(), game.ID) listOfAllSongs = append(listOfAllSongs, songList...) } @@ -138,10 +139,10 @@ func GetRandomSongClassic() string { var song repository.Song for !songFound { song = listOfAllSongs[rand.Intn(len(listOfAllSongs))] - gameData, err := repo.GetGameById(db.Ctx, song.GameID) + gameData, err := BackendRepo().GetGameById(BackendCtx(), song.GameID) if err != nil { - repo.RemoveBrokenSong(db.Ctx, song.Path) + BackendRepo().RemoveBrokenSong(BackendCtx(), song.Path) logging.GetLogger().Warn("Song not found, removed from database", zap.String("song", song.SongName), zap.String("game", gameData.GameName), @@ -153,7 +154,7 @@ func GetRandomSongClassic() string { openFile, err := os.Open(song.Path) if err != nil || (song.FileName != nil && gameData.Path+*song.FileName != song.Path) { //File not found - repo.RemoveBrokenSong(db.Ctx, song.Path) + BackendRepo().RemoveBrokenSong(BackendCtx(), song.Path) logging.GetLogger().Warn("Song not found, removed from database", zap.String("song", song.SongName), zap.String("game", gameData.GameName), @@ -270,7 +271,7 @@ func getSongFromList(games []repository.Game) repository.Song { var song repository.Song for !songFound { game := getRandomGame(games) - songs, _ := repo.FindSongsFromGame(db.Ctx, game.ID) + songs, _ := BackendRepo().FindSongsFromGame(BackendCtx(), game.ID) if len(songs) == 0 { continue } @@ -281,7 +282,7 @@ func getSongFromList(games []repository.Game) repository.Song { openFile, err := os.Open(song.Path) if err != nil || (song.FileName != nil && game.Path+*song.FileName != song.Path) || (song.FileName != nil && strings.HasSuffix(*song.FileName, ".wav")) { //File not found - repo.RemoveBrokenSong(db.Ctx, song.Path) + BackendRepo().RemoveBrokenSong(BackendCtx(), song.Path) logging.GetLogger().Warn("Song not found, removed from database", zap.String("song", song.SongName), zap.String("game", game.GameName), diff --git a/internal/backend/sync.go b/internal/backend/sync.go index 8a18822..42d7621 100644 --- a/internal/backend/sync.go +++ b/internal/backend/sync.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "io/fs" - "music-server/internal/db" "music-server/internal/db/repository" "music-server/internal/logging" "os" @@ -80,8 +79,8 @@ func (gs GameStatus) String() string { } func ResetDB() { - repo.ClearSongs(db.Ctx) - repo.ClearGames(db.Ctx) + repo.ClearSongs(BackendCtx()) + repo.ClearGames(BackendCtx()) } func SyncProgress() ProgressResponse { @@ -206,13 +205,13 @@ func syncGamesNew(full bool) { catchedErrors = nil brokenSongs = nil - gamesBeforeSync, err = repo.FindAllGames(db.Ctx) + gamesBeforeSync, err = repo.FindAllGames(BackendCtx()) handleError("FindAllGames Before", err, "") logging.GetLogger().Info("Starting sync", zap.Int("games_before", len(gamesBeforeSync))) - allGames, err = repo.GetAllGamesIncludingDeleted(db.Ctx) + allGames, err = repo.GetAllGamesIncludingDeleted(BackendCtx()) handleError("GetAllGamesIncludingDeleted", err, "") - err = repo.SetGameDeletionDate(db.Ctx) + err = repo.SetGameDeletionDate(BackendCtx()) handleError("SetGameDeletionDate", err, "") directories, err := os.ReadDir(musicPath) @@ -237,7 +236,7 @@ func syncGamesNew(full bool) { syncWg.Wait() checkBrokenSongsNew() - gamesAfterSync, err = repo.FindAllGames(db.Ctx) + gamesAfterSync, err = repo.FindAllGames(BackendCtx()) handleError("FindAllGames After", err, "") finished := time.Now() @@ -249,7 +248,7 @@ func syncGamesNew(full bool) { } func checkBrokenSongsNew() { - allSongs, err := repo.FetchAllSongs(db.Ctx) + allSongs, err := repo.FetchAllSongs(BackendCtx()) handleError("FetchAllSongs", err, "") var brokenWg sync.WaitGroup poolBroken, _ := ants.NewPool(200, ants.WithPreAlloc(true)) @@ -263,7 +262,7 @@ func checkBrokenSongsNew() { }) } brokenWg.Wait() - err = repo.RemoveBrokenSongs(db.Ctx, brokenSongs) + err = repo.RemoveBrokenSongs(BackendCtx(), brokenSongs) handleError("RemoveBrokenSongs", err, "") } @@ -336,7 +335,7 @@ func syncGameNew(file os.DirEntry, foldersToSkip []string, baseDir string, full break } } - err = repo.InsertGameWithExistingId(db.Ctx, repository.InsertGameWithExistingIdParams{ID: id, GameName: file.Name(), Path: gameDir, Hash: dirHash}) + err = repo.InsertGameWithExistingId(BackendCtx(), repository.InsertGameWithExistingIdParams{ID: id, GameName: file.Name(), Path: gameDir, Hash: dirHash}) handleError("InsertGameWithExistingId", err, "") if err != nil { logging.GetLogger().Debug("Game already exists, removing old ID file", @@ -370,7 +369,7 @@ func syncGameNew(file os.DirEntry, foldersToSkip []string, baseDir string, full zap.String("game", file.Name()), zap.String("hash", dirHash), zap.String("status", status.String())) - err = repo.UpdateGameHash(db.Ctx, repository.UpdateGameHashParams{Hash: dirHash, ID: id}) + err = repo.UpdateGameHash(BackendCtx(), repository.UpdateGameHashParams{Hash: dirHash, ID: id}) handleError("UpdateGameHash", err, "") gamesChangedContent = append(gamesChangedContent, file.Name()) newCheckSongs(entries, gameDir, id) @@ -381,7 +380,7 @@ func syncGameNew(file os.DirEntry, foldersToSkip []string, baseDir string, full zap.String("newName", file.Name()), zap.String("hash", dirHash), zap.String("status", status.String())) - err = repo.UpdateGameName(db.Ctx, repository.UpdateGameNameParams{Name: file.Name(), Path: gameDir, ID: id}) + err = repo.UpdateGameName(BackendCtx(), repository.UpdateGameNameParams{Name: file.Name(), Path: gameDir, ID: id}) handleError("UpdateGameName", err, "") newCheckSongs(entries, gameDir, id) if gamesChangedTitle == nil { @@ -416,7 +415,7 @@ func syncGameNew(file os.DirEntry, foldersToSkip []string, baseDir string, full zap.String("game", file.Name()), zap.String("hash", dirHash), zap.String("status", status.String())) - err = repo.RemoveDeletionDate(db.Ctx, id) + err = repo.RemoveDeletionDate(BackendCtx(), id) handleError("RemoveDeletionDate", err, "") } foldersSynced++ @@ -428,13 +427,13 @@ func syncGameNew(file os.DirEntry, foldersToSkip []string, baseDir string, full func insertGameNew(name string, path string, hash string) int32 { var duplicateError = errors.New("ERROR: duplicate key value violates unique") - id, err := repo.InsertGame(db.Ctx, repository.InsertGameParams{GameName: name, Path: path, Hash: hash}) + id, err := repo.InsertGame(BackendCtx(), repository.InsertGameParams{GameName: name, Path: path, Hash: hash}) handleError("InsertGame", err, "") if err != nil { logging.GetLogger().Warn("ID collision detected, resetting sequence") if strings.HasPrefix(err.Error(), duplicateError.Error()) { logging.GetLogger().Debug("Resetting game ID sequence") - _, err = repo.ResetGameIdSeq(db.Ctx) + _, err = repo.ResetGameIdSeq(BackendCtx()) handleError("ResetGameIdSeq", err, "") id = insertGameNew(name, path, hash) } @@ -478,7 +477,7 @@ func newCheckSong(entry os.DirEntry, gameDir string, id int32) bool { fileName := entry.Name() songName, _ := strings.CutSuffix(fileName, ".mp3") - song, err := repo.GetSongWithHash(db.Ctx, songHash) + song, err := repo.GetSongWithHash(BackendCtx(), songHash) handleError("GetSongWithHash", err, fmt.Sprintf("GameID: %d | Path: %s | SongName: %s | SongHash: %s", id, path, entry.Name(), songHash)) if err == nil { if song.SongName == songName && song.Path == path { @@ -491,31 +490,31 @@ func newCheckSong(entry os.DirEntry, gameDir string, id int32) bool { zap.String("song_name", songName), zap.String("song_hash", songHash)) - count, err := repo.CheckSongWithHash(db.Ctx, songHash) + count, err := repo.CheckSongWithHash(BackendCtx(), songHash) handleError("CheckSongWithHash", err, fmt.Sprintf("GameID: %d | Path: %s | SongName: %s | SongHash: %s\n", id, path, entry.Name(), songHash)) if err != nil { - count2, err := repo.CheckSong(db.Ctx, path) + count2, err := repo.CheckSong(BackendCtx(), path) handleError("CheckSong", err, fmt.Sprintf("GameID: %d | Path: %s | SongName: %s | SongHash: %s\n", id, path, entry.Name(), songHash)) if count2 > 0 { - err = repo.AddHashToSong(db.Ctx, repository.AddHashToSongParams{Hash: songHash, Path: path}) + err = repo.AddHashToSong(BackendCtx(), repository.AddHashToSongParams{Hash: songHash, Path: path}) handleError("AddHashToSong", err, fmt.Sprintf("GameID: %d | Path: %s | SongName: %s | SongHash: %s", id, path, entry.Name(), songHash)) - count, err = repo.CheckSongWithHash(db.Ctx, songHash) + count, err = repo.CheckSongWithHash(BackendCtx(), songHash) handleError("CheckSongWithHash 2", err, fmt.Sprintf("GameID: %d | Path: %s | SongName: %s | SongHash: %s", id, path, entry.Name(), songHash)) } } //count, _ := repo.CheckSong(ctx, path) if count > 0 { - err = repo.UpdateSong(db.Ctx, repository.UpdateSongParams{SongName: songName, FileName: &fileName, Path: path, Hash: songHash}) + err = repo.UpdateSong(BackendCtx(), repository.UpdateSongParams{SongName: songName, FileName: &fileName, Path: path, Hash: songHash}) handleError("UpdateSong", err, fmt.Sprintf("GameID: %d | Path: %s | SongName: %s | SongHash: %s", id, path, entry.Name(), songHash)) } else { - count2, err := repo.CheckSong(db.Ctx, path) + count2, err := repo.CheckSong(BackendCtx(), path) handleError("CheckSong", err, fmt.Sprintf("GameID: %d | Path: %s | SongName: %s | SongHash: %s", id, path, entry.Name(), songHash)) if count2 > 0 { - err = repo.AddHashToSong(db.Ctx, repository.AddHashToSongParams{Hash: songHash, Path: path}) + err = repo.AddHashToSong(BackendCtx(), repository.AddHashToSongParams{Hash: songHash, Path: path}) handleError("AddHashToSong", err, fmt.Sprintf("GameID: %d | Path: %s | SongName: %s | SongHash: %s", id, path, entry.Name(), songHash)) } else { - err = repo.AddSong(db.Ctx, repository.AddSongParams{GameID: id, SongName: songName, Path: path, FileName: &fileName, Hash: songHash}) + err = repo.AddSong(BackendCtx(), repository.AddSongParams{GameID: id, SongName: songName, Path: path, FileName: &fileName, Hash: songHash}) handleError("AddSong", err, fmt.Sprintf("GameID: %d | Path: %s | SongName: %s | SongHash: %s", id, path, entry.Name(), songHash)) } diff --git a/internal/db/database.go b/internal/db/database.go new file mode 100644 index 0000000..c37301b --- /dev/null +++ b/internal/db/database.go @@ -0,0 +1,121 @@ +package db + +import ( + "context" + "database/sql" + "fmt" + + "music-server/internal/logging" + + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database/postgres" + _ "github.com/golang-migrate/migrate/v4/database/postgres" + "github.com/golang-migrate/migrate/v4/source/iofs" + "github.com/jackc/pgx/v5/pgxpool" + _ "github.com/lib/pq" + "go.uber.org/zap" +) + +// Database holds the database connection pool and context +type Database struct { + Pool *pgxpool.Pool + Ctx context.Context +} + +// NewDatabase creates a new Database instance with connection pool +func NewDatabase(host, port, user, password, dbname string) (*Database, error) { + ctx := context.Background() + psqlInfo := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable", + host, port, user, password, dbname) + + logging.GetLogger().Debug("Database connection info", + zap.String("host", host), + zap.String("port", port), + zap.String("dbname", dbname)) + + pool, err := pgxpool.New(ctx, psqlInfo) + if err != nil { + return nil, fmt.Errorf("unable to connect to database: %w", err) + } + + // Test connection + var success string + err = pool.QueryRow(ctx, "select 'Successfully connected!'").Scan(&success) + if err != nil { + pool.Close() + return nil, fmt.Errorf("database query failed: %w", err) + } + + logging.GetLogger().Info("Database connected", zap.String("status", success)) + + return &Database{Pool: pool, Ctx: ctx}, nil +} + +// Close closes the database connection pool +func (db *Database) Close() { + if db.Pool != nil { + logging.GetLogger().Info("Closing database connection") + db.Pool.Close() + } +} + +// RunMigrations runs all pending database migrations to the latest version. +// Uses the existing pool to extract connection details. +func (db *Database) RunMigrations() error { + // Extract connection info from pool config + connConfig := db.Pool.Config().ConnConfig + migrationURL := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=disable", + connConfig.User, + connConfig.Password, + connConfig.Host, + connConfig.Port, + connConfig.Database) + + logging.GetLogger().Debug("Migration info", zap.String("url", migrationURL)) + + sqlDb, err := sql.Open("postgres", migrationURL) + if err != nil { + return fmt.Errorf("failed to open database for migration: %w", err) + } + defer sqlDb.Close() + + driver, err := postgres.WithInstance(sqlDb, &postgres.Config{}) + if err != nil { + return fmt.Errorf("failed to create migration driver: %w", err) + } + + files, err := iofs.New(MigrationsFs, "migrations") + if err != nil { + return fmt.Errorf("failed to create migration files: %w", err) + } + + m, err := migrate.NewWithInstance("iofs", files, "postgres", driver) + if err != nil { + return fmt.Errorf("failed to create migrator: %w", err) + } + + // Get current version for logging + version, _, err := m.Version() + if err != nil { + logging.GetLogger().Error("Failed to get migration version", zap.String("error", err.Error())) + } + + logging.GetLogger().Info("Migration version before", zap.Uint("version", version)) + + // Run all pending migrations to latest version + err = m.Up() + if err != nil { + if err == migrate.ErrNoChange { + logging.GetLogger().Info("Database already up to date") + } else { + return fmt.Errorf("migration failed: %w", err) + } + } else { + // Get new version after migration + versionAfter, _, _ := m.Version() + logging.GetLogger().Info("Migrated to version", zap.Uint("version", versionAfter)) + } + + logging.GetLogger().Info("Migration completed") + return nil +} diff --git a/internal/server/routes.go b/internal/server/routes.go index 3afe3e9..6a1558f 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -2,7 +2,6 @@ package server import ( "music-server/cmd/web" - "music-server/internal/db" "music-server/internal/logging" "music-server/internal/server/middleware" "net/http" @@ -126,7 +125,7 @@ func (s *Server) RegisterRoutes() http.Handler { // Protected endpoints - require valid token // Create token auth middleware with pool access - tokenAuthMiddleware := middleware.TokenAuthMiddleware(db.Dbpool) + tokenAuthMiddleware := middleware.TokenAuthMiddleware(s.db.Pool) // Protected group with token authentication - will be used by VGMQ and Statistics API _ = apiV1.Group("", tokenAuthMiddleware) diff --git a/internal/server/server.go b/internal/server/server.go index b48b17d..b0e1b85 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -6,6 +6,7 @@ import ( "strconv" "time" + "music-server/internal/backend" "music-server/internal/db" "music-server/internal/logging" "net/http" @@ -15,7 +16,9 @@ import ( type Server struct { port int + db *db.Database tokenHandler *TokenHandler + httpServer *http.Server } var ( @@ -30,7 +33,9 @@ var ( logJSON = os.Getenv("LOG_JSON") == "true" ) -func NewServer() *http.Server { +// NewServerInstance creates a new Server instance with all dependencies initialized. +// Use this for dependency injection and proper lifecycle management. +func NewServerInstance() *Server { // Initialize logger if logLevel == "" { logLevel = "info" @@ -40,15 +45,45 @@ func NewServer() *http.Server { logger := logging.GetLogger() port, _ := strconv.Atoi(os.Getenv("PORT")) - - // Initialize token handler - tokenHandler := NewTokenHandler() - - NewServer := &Server{ + + // Validate required environment variables + if host == "" || dbPort == "" || username == "" || password == "" || dbName == "" || musicPath == "" || charactersPath == "" { + logging.GetLogger().Fatal("Invalid settings - missing required environment variables") + } + + // Create database instance + database, err := db.NewDatabase(host, dbPort, username, password, dbName) + if err != nil { + logging.GetLogger().Fatal("Failed to initialize database", zap.String("error", err.Error())) + } + + // Run migrations using the new method + if err := database.RunMigrations(); err != nil { + logging.GetLogger().Fatal("Migration failed", zap.String("error", err.Error())) + } + + // Initialize backend package with database pool + backend.InitBackend(database.Pool) + + // Initialize token handler with database pool + tokenHandler := NewTokenHandler(database.Pool) + + // Create the server instance + appServer := &Server{ port: port, + db: database, tokenHandler: tokenHandler, } + // Create the HTTP server + appServer.httpServer = &http.Server{ + Addr: fmt.Sprintf(":%d", port), + Handler: appServer.RegisterRoutes(), + IdleTimeout: time.Minute, + ReadTimeout: 10 * time.Second, + WriteTimeout: 30 * time.Second, + } + logger.Info("Starting server", zap.String("host", host), zap.String("dbPort", dbPort), @@ -61,23 +96,21 @@ func NewServer() *http.Server { zap.String("charactersPath", charactersPath), ) - //conf.SetupDb() - if host == "" || dbPort == "" || username == "" || password == "" || dbName == "" || musicPath == "" || charactersPath == "" { - logging.GetLogger().Fatal("Invalid settings - missing required environment variables") - } - - db.Migrate_db(host, dbPort, username, password, dbName) - - db.InitDB(host, dbPort, username, password, dbName) - - // Declare Server config - server := &http.Server{ - Addr: fmt.Sprintf(":%d", NewServer.port), - Handler: NewServer.RegisterRoutes(), - IdleTimeout: time.Minute, - ReadTimeout: 10 * time.Second, - WriteTimeout: 30 * time.Second, - } - - return server + return appServer +} + +// HTTPServer returns the underlying http.Server for serving HTTP requests. +func (s *Server) HTTPServer() *http.Server { + return s.httpServer +} + +// DB returns the database instance for dependency injection. +func (s *Server) DB() *db.Database { + return s.db +} + +// NewServer creates a new HTTP server (deprecated, use NewServerInstance instead). +// This function is kept for backward compatibility. +func NewServer() *http.Server { + return NewServerInstance().HTTPServer() } diff --git a/internal/server/sync_handler_test.go b/internal/server/sync_handler_test.go index 7606515..15ecb6d 100644 --- a/internal/server/sync_handler_test.go +++ b/internal/server/sync_handler_test.go @@ -75,8 +75,8 @@ func TestSyncPopulatesDatabase(t *testing.T) { db.TestClearDatabase(t) // Before sync - should have no games - repo := repository.New(db.Dbpool) - gamesBefore, err := repo.FindAllGames(db.Ctx) + repo := repository.New(backend.BackendPool()) + gamesBefore, err := repo.FindAllGames(backend.BackendCtx()) assert.NoError(t, err) beforeCount := len(gamesBefore) t.Logf("Games before sync: %d", beforeCount) @@ -92,7 +92,7 @@ func TestSyncPopulatesDatabase(t *testing.T) { } // After sync - should have games - gamesAfter, err := repo.FindAllGames(db.Ctx) + gamesAfter, err := repo.FindAllGames(backend.BackendCtx()) assert.NoError(t, err) afterCount := len(gamesAfter) t.Logf("Games after sync: %d", afterCount) @@ -112,8 +112,8 @@ func TestSyncMakesDifference(t *testing.T) { db.TestClearDatabase(t) // Before sync - should have no games - repo := repository.New(db.Dbpool) - gamesBefore, err := repo.FindAllGames(db.Ctx) + repo := repository.New(backend.BackendPool()) + gamesBefore, err := repo.FindAllGames(backend.BackendCtx()) assert.NoError(t, err) assert.Equal(t, 0, len(gamesBefore), "Should have no games before sync") @@ -127,7 +127,7 @@ func TestSyncMakesDifference(t *testing.T) { } // After sync - should have games - gamesAfter, err := repo.FindAllGames(db.Ctx) + gamesAfter, err := repo.FindAllGames(backend.BackendCtx()) assert.NoError(t, err) assert.True(t, len(gamesAfter) > 0, "Should have games after sync") } @@ -199,8 +199,8 @@ func TestSyncGamesNewOnlyChanges(t *testing.T) { } // Get initial count - repo := repository.New(db.Dbpool) - gamesBefore, _ := repo.FindAllGames(db.Ctx) + repo := repository.New(backend.BackendPool()) + gamesBefore, _ := repo.FindAllGames(backend.BackendCtx()) beforeCount := len(gamesBefore) // Run incremental sync (should not change count if nothing changed) @@ -211,7 +211,7 @@ func TestSyncGamesNewOnlyChanges(t *testing.T) { time.Sleep(2 * time.Second) // Count should be the same - gamesAfter, _ := repo.FindAllGames(db.Ctx) + gamesAfter, _ := repo.FindAllGames(backend.BackendCtx()) afterCount := len(gamesAfter) // Note: This might not be exactly equal due to timing, but should be close @@ -227,8 +227,8 @@ func TestResetGames(t *testing.T) { e := StartTestServer(t) // First ensure we have data - repo := repository.New(db.Dbpool) - gamesBefore, _ := repo.FindAllGames(db.Ctx) + repo := repository.New(backend.BackendPool()) + gamesBefore, _ := repo.FindAllGames(backend.BackendCtx()) beforeCount := len(gamesBefore) if beforeCount == 0 { @@ -238,7 +238,7 @@ func TestResetGames(t *testing.T) { t.Error("Sync did not complete within timeout") return } - gamesBefore, _ = repo.FindAllGames(db.Ctx) + gamesBefore, _ = repo.FindAllGames(backend.BackendCtx()) beforeCount = len(gamesBefore) } @@ -253,7 +253,7 @@ func TestResetGames(t *testing.T) { // Note: reset might take a moment to propagate time.Sleep(1 * time.Second) - gamesAfter, _ := repo.FindAllGames(db.Ctx) + gamesAfter, _ := repo.FindAllGames(backend.BackendCtx()) afterCount := len(gamesAfter) t.Logf("Games after reset: %d", afterCount) @@ -281,8 +281,8 @@ func TestSyncGamesNewFull(t *testing.T) { } // Verify database is populated - repo := repository.New(db.Dbpool) - games, err := repo.FindAllGames(db.Ctx) + repo := repository.New(backend.BackendPool()) + games, err := repo.FindAllGames(backend.BackendCtx()) assert.NoError(t, err) assert.True(t, len(games) > 0, "Database should be populated after full sync") t.Logf("Full sync populated %d games", len(games)) diff --git a/internal/server/test_helpers.go b/internal/server/test_helpers.go index 77f0b28..e0e4f3e 100644 --- a/internal/server/test_helpers.go +++ b/internal/server/test_helpers.go @@ -8,6 +8,9 @@ import ( "testing" "time" + "music-server/internal/backend" + "music-server/internal/db" + "github.com/labstack/echo/v5" ) @@ -45,8 +48,23 @@ func StartTestServer(t *testing.T) *echo.Echo { os.Setenv("LOG_JSON", "false") } + // Initialize database for tests + db.TestSetupDB(t) + + // Initialize backend with the global Dbpool + // This ensures BackendRepo() and BackendCtx() are available + if db.Dbpool != nil { + backend.InitBackend(db.Dbpool) + } + // Create a Server instance and get its routes - s := &Server{} + s := &Server{ + db: &db.Database{ + Pool: db.Dbpool, + Ctx: db.Ctx, + }, + tokenHandler: NewTokenHandler(db.Dbpool), + } handler := s.RegisterRoutes() // Wrap the http.Handler in an echo.Echo diff --git a/internal/server/token_handler.go b/internal/server/token_handler.go index 910d147..e234e67 100644 --- a/internal/server/token_handler.go +++ b/internal/server/token_handler.go @@ -7,10 +7,10 @@ import ( "strings" "time" - "music-server/internal/db" "music-server/internal/db/repository" "music-server/internal/logging" + "github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgtype" "github.com/labstack/echo/v5" "go.uber.org/zap" @@ -30,13 +30,13 @@ type TokenResponse struct { // TokenHandler contains the database pool for token operations type TokenHandler struct { - pool *repository.Queries + pool *pgxpool.Pool } // NewTokenHandler creates a new token handler with database pool -func NewTokenHandler() *TokenHandler { +func NewTokenHandler(pool *pgxpool.Pool) *TokenHandler { return &TokenHandler{ - pool: repository.New(db.Dbpool), + pool: pool, } } @@ -84,7 +84,8 @@ func (h *TokenHandler) CreateTokenHandler(c *echo.Context) error { clientType := req.ClientType // Store in database using sqlc-generated repository - session, err := h.pool.CreateSession(c.Request().Context(), repository.CreateSessionParams{ + queries := repository.New(h.pool) + session, err := queries.CreateSession(c.Request().Context(), repository.CreateSessionParams{ Token: token, IpAddress: c.RealIP(), UserAgent: c.Request().UserAgent(), @@ -132,7 +133,8 @@ func (h *TokenHandler) DeleteTokenHandler(c *echo.Context) error { token := parts[1] // Delete session using sqlc-generated repository - err := h.pool.DeleteSession(c.Request().Context(), token) + queries := repository.New(h.pool) + err := queries.DeleteSession(c.Request().Context(), token) if err != nil { logging.GetLogger().Error("Failed to delete session", zap.String("error", err.Error())) return c.JSON(http.StatusInternalServerError, map[string]string{"error": "Failed to invalidate token"}) @@ -158,7 +160,8 @@ func (h *TokenHandler) CleanupExpiredSessionsHandler(c *echo.Context) error { // Verify token is valid first (using existing middleware) // The middleware will have already validated the token - err := h.pool.DeleteExpiredSessions(c.Request().Context()) + queries := repository.New(h.pool) + err := queries.DeleteExpiredSessions(c.Request().Context()) if err != nil { logging.GetLogger().Error("Failed to cleanup sessions", zap.String("error", err.Error())) return c.JSON(http.StatusInternalServerError, map[string]string{"error": "Failed to cleanup sessions"}) diff --git a/internal/server/zz_music_handler_test.go b/internal/server/zz_music_handler_test.go index a3d2363..bd7e92d 100644 --- a/internal/server/zz_music_handler_test.go +++ b/internal/server/zz_music_handler_test.go @@ -15,8 +15,8 @@ import ( // ensureSyncRan ensures that sync has been run before testing music endpoints func ensureSyncRan(t *testing.T, e *echo.Echo) { - repo := repository.New(db.Dbpool) - games, err := repo.FindAllGames(db.Ctx) + repo := repository.New(backend.BackendPool()) + games, err := repo.FindAllGames(backend.BackendCtx()) assert.NoError(t, err) if len(games) == 0 {