feat: Remove global db.Dbpool with dependency injection (Phase 0)
- Add Database struct in internal/db/database.go with Pool, Ctx, and RunMigrations() - Update server.go to use Database struct with NewServerInstance() - Add backend.go with InitBackend(), BackendRepo(), BackendCtx(), BackendPool() - Update music.go and sync.go to use BackendRepo() and BackendCtx() instead of db.Dbpool/db.Ctx - Update token_handler.go to accept pool parameter - Update routes.go to use s.db.Pool for middleware - Update cmd/main.go to use NewServerInstance() and HTTPServer() - Update test_helpers.go to initialize backend with test database - Update test files to use backend.BackendPool() and backend.BackendCtx() Benefits: - Easier to mock database for unit tests - Follows Go best practices (dependency injection) - Better architecture with explicit dependencies - RunMigrations() replaces old Migrate_db() function Note: Global db.Dbpool and db.Ctx still exist in dbHelper.go for backward compatibility with test_helpers.go, but production code no longer uses them. Generated by Mistral Vibe. Co-Authored-By: Mistral Vibe <vibe@mistral.ai>
This commit is contained in:
+14
-8
@@ -2,7 +2,6 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"music-server/internal/db"
|
||||
"music-server/internal/logging"
|
||||
"music-server/internal/server"
|
||||
"net/http"
|
||||
@@ -19,9 +18,11 @@ import (
|
||||
// @description This is a sample server Petstore server.
|
||||
// @termsOfService http://swagger.io/terms/
|
||||
|
||||
//
|
||||
// @contact.name Sebastian Olsson
|
||||
// @contact.email zarnor91@gmail.com
|
||||
|
||||
//
|
||||
// @license.name Apache 2.0
|
||||
// @license.url http://www.apache.org/licenses/LICENSE-2.0.html
|
||||
|
||||
@@ -34,16 +35,17 @@ func main() {
|
||||
pprof.StartCPUProfile(f)
|
||||
defer pprof.StopCPUProfile()*/
|
||||
|
||||
server := server.NewServer()
|
||||
appServer := server.NewServerInstance()
|
||||
httpServer := appServer.HTTPServer()
|
||||
|
||||
// Create a done channel to signal when the shutdown is complete
|
||||
done := make(chan bool, 1)
|
||||
|
||||
// Run graceful shutdown in a separate goroutine
|
||||
go gracefulShutdown(server, done)
|
||||
go gracefulShutdown(appServer, httpServer, done)
|
||||
|
||||
logging.GetLogger().Info("Server starting", zap.String("address", server.Addr))
|
||||
err := server.ListenAndServe()
|
||||
logging.GetLogger().Info("Server starting", zap.String("address", httpServer.Addr))
|
||||
err := httpServer.ListenAndServe()
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
logging.GetLogger().Fatal("HTTP server error", zap.String("error", err.Error()))
|
||||
}
|
||||
@@ -53,7 +55,7 @@ func main() {
|
||||
logging.GetLogger().Info("Graceful shutdown complete")
|
||||
}
|
||||
|
||||
func gracefulShutdown(apiServer *http.Server, done chan bool) {
|
||||
func gracefulShutdown(appServer *server.Server, httpServer *http.Server, done chan bool) {
|
||||
// Create context that listens for the interrupt signal from the OS.
|
||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||
defer stop()
|
||||
@@ -62,13 +64,17 @@ func gracefulShutdown(apiServer *http.Server, done chan bool) {
|
||||
<-ctx.Done()
|
||||
|
||||
logging.GetLogger().Info("Shutting down gracefully, press Ctrl+C again to force")
|
||||
db.CloseDb()
|
||||
|
||||
// Close database connection
|
||||
if appServer != nil && appServer.DB() != nil {
|
||||
appServer.DB().Close()
|
||||
}
|
||||
|
||||
// The context is used to inform the server it has 5 seconds to finish
|
||||
// the request it is currently handling
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := apiServer.Shutdown(ctx); err != nil {
|
||||
if err := httpServer.Shutdown(ctx); err != nil {
|
||||
logging.GetLogger().Error("Server forced to shutdown with error", zap.String("error", err.Error()))
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"music-server/internal/db/repository"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// Global variables - these are initialized by InitBackend
|
||||
var (
|
||||
backendPool *pgxpool.Pool
|
||||
repo *repository.Queries
|
||||
backendCtx context.Context = context.Background()
|
||||
)
|
||||
|
||||
// InitBackend initializes the backend package with the database pool.
|
||||
// This should be called once at application startup.
|
||||
func InitBackend(pool *pgxpool.Pool) {
|
||||
backendPool = pool
|
||||
repo = repository.New(pool)
|
||||
backendCtx = context.Background()
|
||||
}
|
||||
|
||||
// BackendCtx returns the context used by backend operations.
|
||||
// This is exposed for use by the backend functions.
|
||||
func BackendCtx() context.Context {
|
||||
return backendCtx
|
||||
}
|
||||
|
||||
// BackendRepo returns the repository queries instance.
|
||||
// This is exposed for use by the backend functions.
|
||||
func BackendRepo() *repository.Queries {
|
||||
return repo
|
||||
}
|
||||
|
||||
// BackendPool returns the underlying database pool.
|
||||
// This is exposed for test utilities that need direct pool access.
|
||||
func BackendPool() *pgxpool.Pool {
|
||||
return backendPool
|
||||
}
|
||||
+17
-16
@@ -2,7 +2,6 @@ package backend
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"music-server/internal/db"
|
||||
"music-server/internal/db/repository"
|
||||
"music-server/internal/logging"
|
||||
"os"
|
||||
@@ -28,18 +27,20 @@ var gamesNew []repository.Game
|
||||
var songQueNew []repository.Song
|
||||
|
||||
var lastFetchedNew repository.Song
|
||||
var repo *repository.Queries
|
||||
|
||||
func initRepo() {
|
||||
if repo == nil {
|
||||
repo = repository.New(db.Dbpool)
|
||||
// This function is kept for backward compatibility
|
||||
// but now uses the backend package's initialized repo
|
||||
// If not initialized, this will panic intentionally
|
||||
if BackendRepo() == nil {
|
||||
panic("backend not initialized - call backend.InitBackend() first")
|
||||
}
|
||||
}
|
||||
|
||||
func getAllGames() []repository.Game {
|
||||
if len(gamesNew) == 0 {
|
||||
initRepo()
|
||||
gamesNew, _ = repo.FindAllGames(db.Ctx)
|
||||
gamesNew, _ = BackendRepo().FindAllGames(BackendCtx())
|
||||
}
|
||||
return gamesNew
|
||||
|
||||
@@ -58,7 +59,7 @@ func Reset() {
|
||||
songQueNew = nil
|
||||
currentSong = -1
|
||||
initRepo()
|
||||
gamesNew, _ = repo.FindAllGames(db.Ctx)
|
||||
gamesNew, _ = BackendRepo().FindAllGames(BackendCtx())
|
||||
}
|
||||
|
||||
func AddLatestToQue() {
|
||||
@@ -76,8 +77,8 @@ func AddLatestPlayed() {
|
||||
currentSongData := songQueNew[currentSong]
|
||||
|
||||
initRepo()
|
||||
repo.AddGamePlayed(db.Ctx, currentSongData.GameID)
|
||||
repo.AddSongPlayed(db.Ctx, repository.AddSongPlayedParams{GameID: currentSongData.GameID, SongName: currentSongData.SongName})
|
||||
BackendRepo().AddGamePlayed(BackendCtx(), currentSongData.GameID)
|
||||
BackendRepo().AddSongPlayed(BackendCtx(), repository.AddSongPlayedParams{GameID: currentSongData.GameID, SongName: currentSongData.SongName})
|
||||
}
|
||||
|
||||
func SetPlayed(songNumber int) {
|
||||
@@ -86,8 +87,8 @@ func SetPlayed(songNumber int) {
|
||||
}
|
||||
songData := songQueNew[songNumber]
|
||||
initRepo()
|
||||
repo.AddGamePlayed(db.Ctx, songData.GameID)
|
||||
repo.AddSongPlayed(db.Ctx, repository.AddSongPlayedParams{GameID: songData.GameID, SongName: songData.SongName})
|
||||
BackendRepo().AddGamePlayed(BackendCtx(), songData.GameID)
|
||||
BackendRepo().AddSongPlayed(BackendCtx(), repository.AddSongPlayedParams{GameID: songData.GameID, SongName: songData.SongName})
|
||||
}
|
||||
|
||||
func GetRandomSong() string {
|
||||
@@ -130,7 +131,7 @@ func GetRandomSongClassic() string {
|
||||
|
||||
var listOfAllSongs []repository.Song
|
||||
for _, game := range gamesNew {
|
||||
songList, _ := repo.FindSongsFromGame(db.Ctx, game.ID)
|
||||
songList, _ := BackendRepo().FindSongsFromGame(BackendCtx(), game.ID)
|
||||
listOfAllSongs = append(listOfAllSongs, songList...)
|
||||
}
|
||||
|
||||
@@ -138,10 +139,10 @@ func GetRandomSongClassic() string {
|
||||
var song repository.Song
|
||||
for !songFound {
|
||||
song = listOfAllSongs[rand.Intn(len(listOfAllSongs))]
|
||||
gameData, err := repo.GetGameById(db.Ctx, song.GameID)
|
||||
gameData, err := BackendRepo().GetGameById(BackendCtx(), song.GameID)
|
||||
|
||||
if err != nil {
|
||||
repo.RemoveBrokenSong(db.Ctx, song.Path)
|
||||
BackendRepo().RemoveBrokenSong(BackendCtx(), song.Path)
|
||||
logging.GetLogger().Warn("Song not found, removed from database",
|
||||
zap.String("song", song.SongName),
|
||||
zap.String("game", gameData.GameName),
|
||||
@@ -153,7 +154,7 @@ func GetRandomSongClassic() string {
|
||||
openFile, err := os.Open(song.Path)
|
||||
if err != nil || (song.FileName != nil && gameData.Path+*song.FileName != song.Path) {
|
||||
//File not found
|
||||
repo.RemoveBrokenSong(db.Ctx, song.Path)
|
||||
BackendRepo().RemoveBrokenSong(BackendCtx(), song.Path)
|
||||
logging.GetLogger().Warn("Song not found, removed from database",
|
||||
zap.String("song", song.SongName),
|
||||
zap.String("game", gameData.GameName),
|
||||
@@ -270,7 +271,7 @@ func getSongFromList(games []repository.Game) repository.Song {
|
||||
var song repository.Song
|
||||
for !songFound {
|
||||
game := getRandomGame(games)
|
||||
songs, _ := repo.FindSongsFromGame(db.Ctx, game.ID)
|
||||
songs, _ := BackendRepo().FindSongsFromGame(BackendCtx(), game.ID)
|
||||
if len(songs) == 0 {
|
||||
continue
|
||||
}
|
||||
@@ -281,7 +282,7 @@ func getSongFromList(games []repository.Game) repository.Song {
|
||||
openFile, err := os.Open(song.Path)
|
||||
if err != nil || (song.FileName != nil && game.Path+*song.FileName != song.Path) || (song.FileName != nil && strings.HasSuffix(*song.FileName, ".wav")) {
|
||||
//File not found
|
||||
repo.RemoveBrokenSong(db.Ctx, song.Path)
|
||||
BackendRepo().RemoveBrokenSong(BackendCtx(), song.Path)
|
||||
logging.GetLogger().Warn("Song not found, removed from database",
|
||||
zap.String("song", song.SongName),
|
||||
zap.String("game", game.GameName),
|
||||
|
||||
+23
-24
@@ -7,7 +7,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"music-server/internal/db"
|
||||
"music-server/internal/db/repository"
|
||||
"music-server/internal/logging"
|
||||
"os"
|
||||
@@ -80,8 +79,8 @@ func (gs GameStatus) String() string {
|
||||
}
|
||||
|
||||
func ResetDB() {
|
||||
repo.ClearSongs(db.Ctx)
|
||||
repo.ClearGames(db.Ctx)
|
||||
repo.ClearSongs(BackendCtx())
|
||||
repo.ClearGames(BackendCtx())
|
||||
}
|
||||
|
||||
func SyncProgress() ProgressResponse {
|
||||
@@ -206,13 +205,13 @@ func syncGamesNew(full bool) {
|
||||
catchedErrors = nil
|
||||
brokenSongs = nil
|
||||
|
||||
gamesBeforeSync, err = repo.FindAllGames(db.Ctx)
|
||||
gamesBeforeSync, err = repo.FindAllGames(BackendCtx())
|
||||
handleError("FindAllGames Before", err, "")
|
||||
logging.GetLogger().Info("Starting sync", zap.Int("games_before", len(gamesBeforeSync)))
|
||||
|
||||
allGames, err = repo.GetAllGamesIncludingDeleted(db.Ctx)
|
||||
allGames, err = repo.GetAllGamesIncludingDeleted(BackendCtx())
|
||||
handleError("GetAllGamesIncludingDeleted", err, "")
|
||||
err = repo.SetGameDeletionDate(db.Ctx)
|
||||
err = repo.SetGameDeletionDate(BackendCtx())
|
||||
handleError("SetGameDeletionDate", err, "")
|
||||
|
||||
directories, err := os.ReadDir(musicPath)
|
||||
@@ -237,7 +236,7 @@ func syncGamesNew(full bool) {
|
||||
syncWg.Wait()
|
||||
checkBrokenSongsNew()
|
||||
|
||||
gamesAfterSync, err = repo.FindAllGames(db.Ctx)
|
||||
gamesAfterSync, err = repo.FindAllGames(BackendCtx())
|
||||
handleError("FindAllGames After", err, "")
|
||||
|
||||
finished := time.Now()
|
||||
@@ -249,7 +248,7 @@ func syncGamesNew(full bool) {
|
||||
}
|
||||
|
||||
func checkBrokenSongsNew() {
|
||||
allSongs, err := repo.FetchAllSongs(db.Ctx)
|
||||
allSongs, err := repo.FetchAllSongs(BackendCtx())
|
||||
handleError("FetchAllSongs", err, "")
|
||||
var brokenWg sync.WaitGroup
|
||||
poolBroken, _ := ants.NewPool(200, ants.WithPreAlloc(true))
|
||||
@@ -263,7 +262,7 @@ func checkBrokenSongsNew() {
|
||||
})
|
||||
}
|
||||
brokenWg.Wait()
|
||||
err = repo.RemoveBrokenSongs(db.Ctx, brokenSongs)
|
||||
err = repo.RemoveBrokenSongs(BackendCtx(), brokenSongs)
|
||||
handleError("RemoveBrokenSongs", err, "")
|
||||
}
|
||||
|
||||
@@ -336,7 +335,7 @@ func syncGameNew(file os.DirEntry, foldersToSkip []string, baseDir string, full
|
||||
break
|
||||
}
|
||||
}
|
||||
err = repo.InsertGameWithExistingId(db.Ctx, repository.InsertGameWithExistingIdParams{ID: id, GameName: file.Name(), Path: gameDir, Hash: dirHash})
|
||||
err = repo.InsertGameWithExistingId(BackendCtx(), repository.InsertGameWithExistingIdParams{ID: id, GameName: file.Name(), Path: gameDir, Hash: dirHash})
|
||||
handleError("InsertGameWithExistingId", err, "")
|
||||
if err != nil {
|
||||
logging.GetLogger().Debug("Game already exists, removing old ID file",
|
||||
@@ -370,7 +369,7 @@ func syncGameNew(file os.DirEntry, foldersToSkip []string, baseDir string, full
|
||||
zap.String("game", file.Name()),
|
||||
zap.String("hash", dirHash),
|
||||
zap.String("status", status.String()))
|
||||
err = repo.UpdateGameHash(db.Ctx, repository.UpdateGameHashParams{Hash: dirHash, ID: id})
|
||||
err = repo.UpdateGameHash(BackendCtx(), repository.UpdateGameHashParams{Hash: dirHash, ID: id})
|
||||
handleError("UpdateGameHash", err, "")
|
||||
gamesChangedContent = append(gamesChangedContent, file.Name())
|
||||
newCheckSongs(entries, gameDir, id)
|
||||
@@ -381,7 +380,7 @@ func syncGameNew(file os.DirEntry, foldersToSkip []string, baseDir string, full
|
||||
zap.String("newName", file.Name()),
|
||||
zap.String("hash", dirHash),
|
||||
zap.String("status", status.String()))
|
||||
err = repo.UpdateGameName(db.Ctx, repository.UpdateGameNameParams{Name: file.Name(), Path: gameDir, ID: id})
|
||||
err = repo.UpdateGameName(BackendCtx(), repository.UpdateGameNameParams{Name: file.Name(), Path: gameDir, ID: id})
|
||||
handleError("UpdateGameName", err, "")
|
||||
newCheckSongs(entries, gameDir, id)
|
||||
if gamesChangedTitle == nil {
|
||||
@@ -416,7 +415,7 @@ func syncGameNew(file os.DirEntry, foldersToSkip []string, baseDir string, full
|
||||
zap.String("game", file.Name()),
|
||||
zap.String("hash", dirHash),
|
||||
zap.String("status", status.String()))
|
||||
err = repo.RemoveDeletionDate(db.Ctx, id)
|
||||
err = repo.RemoveDeletionDate(BackendCtx(), id)
|
||||
handleError("RemoveDeletionDate", err, "")
|
||||
}
|
||||
foldersSynced++
|
||||
@@ -428,13 +427,13 @@ func syncGameNew(file os.DirEntry, foldersToSkip []string, baseDir string, full
|
||||
|
||||
func insertGameNew(name string, path string, hash string) int32 {
|
||||
var duplicateError = errors.New("ERROR: duplicate key value violates unique")
|
||||
id, err := repo.InsertGame(db.Ctx, repository.InsertGameParams{GameName: name, Path: path, Hash: hash})
|
||||
id, err := repo.InsertGame(BackendCtx(), repository.InsertGameParams{GameName: name, Path: path, Hash: hash})
|
||||
handleError("InsertGame", err, "")
|
||||
if err != nil {
|
||||
logging.GetLogger().Warn("ID collision detected, resetting sequence")
|
||||
if strings.HasPrefix(err.Error(), duplicateError.Error()) {
|
||||
logging.GetLogger().Debug("Resetting game ID sequence")
|
||||
_, err = repo.ResetGameIdSeq(db.Ctx)
|
||||
_, err = repo.ResetGameIdSeq(BackendCtx())
|
||||
handleError("ResetGameIdSeq", err, "")
|
||||
id = insertGameNew(name, path, hash)
|
||||
}
|
||||
@@ -478,7 +477,7 @@ func newCheckSong(entry os.DirEntry, gameDir string, id int32) bool {
|
||||
fileName := entry.Name()
|
||||
songName, _ := strings.CutSuffix(fileName, ".mp3")
|
||||
|
||||
song, err := repo.GetSongWithHash(db.Ctx, songHash)
|
||||
song, err := repo.GetSongWithHash(BackendCtx(), songHash)
|
||||
handleError("GetSongWithHash", err, fmt.Sprintf("GameID: %d | Path: %s | SongName: %s | SongHash: %s", id, path, entry.Name(), songHash))
|
||||
if err == nil {
|
||||
if song.SongName == songName && song.Path == path {
|
||||
@@ -491,31 +490,31 @@ func newCheckSong(entry os.DirEntry, gameDir string, id int32) bool {
|
||||
zap.String("song_name", songName),
|
||||
zap.String("song_hash", songHash))
|
||||
|
||||
count, err := repo.CheckSongWithHash(db.Ctx, songHash)
|
||||
count, err := repo.CheckSongWithHash(BackendCtx(), songHash)
|
||||
handleError("CheckSongWithHash", err, fmt.Sprintf("GameID: %d | Path: %s | SongName: %s | SongHash: %s\n", id, path, entry.Name(), songHash))
|
||||
if err != nil {
|
||||
count2, err := repo.CheckSong(db.Ctx, path)
|
||||
count2, err := repo.CheckSong(BackendCtx(), path)
|
||||
handleError("CheckSong", err, fmt.Sprintf("GameID: %d | Path: %s | SongName: %s | SongHash: %s\n", id, path, entry.Name(), songHash))
|
||||
if count2 > 0 {
|
||||
err = repo.AddHashToSong(db.Ctx, repository.AddHashToSongParams{Hash: songHash, Path: path})
|
||||
err = repo.AddHashToSong(BackendCtx(), repository.AddHashToSongParams{Hash: songHash, Path: path})
|
||||
handleError("AddHashToSong", err, fmt.Sprintf("GameID: %d | Path: %s | SongName: %s | SongHash: %s", id, path, entry.Name(), songHash))
|
||||
count, err = repo.CheckSongWithHash(db.Ctx, songHash)
|
||||
count, err = repo.CheckSongWithHash(BackendCtx(), songHash)
|
||||
handleError("CheckSongWithHash 2", err, fmt.Sprintf("GameID: %d | Path: %s | SongName: %s | SongHash: %s", id, path, entry.Name(), songHash))
|
||||
}
|
||||
}
|
||||
|
||||
//count, _ := repo.CheckSong(ctx, path)
|
||||
if count > 0 {
|
||||
err = repo.UpdateSong(db.Ctx, repository.UpdateSongParams{SongName: songName, FileName: &fileName, Path: path, Hash: songHash})
|
||||
err = repo.UpdateSong(BackendCtx(), repository.UpdateSongParams{SongName: songName, FileName: &fileName, Path: path, Hash: songHash})
|
||||
handleError("UpdateSong", err, fmt.Sprintf("GameID: %d | Path: %s | SongName: %s | SongHash: %s", id, path, entry.Name(), songHash))
|
||||
} else {
|
||||
count2, err := repo.CheckSong(db.Ctx, path)
|
||||
count2, err := repo.CheckSong(BackendCtx(), path)
|
||||
handleError("CheckSong", err, fmt.Sprintf("GameID: %d | Path: %s | SongName: %s | SongHash: %s", id, path, entry.Name(), songHash))
|
||||
if count2 > 0 {
|
||||
err = repo.AddHashToSong(db.Ctx, repository.AddHashToSongParams{Hash: songHash, Path: path})
|
||||
err = repo.AddHashToSong(BackendCtx(), repository.AddHashToSongParams{Hash: songHash, Path: path})
|
||||
handleError("AddHashToSong", err, fmt.Sprintf("GameID: %d | Path: %s | SongName: %s | SongHash: %s", id, path, entry.Name(), songHash))
|
||||
} else {
|
||||
err = repo.AddSong(db.Ctx, repository.AddSongParams{GameID: id, SongName: songName, Path: path, FileName: &fileName, Hash: songHash})
|
||||
err = repo.AddSong(BackendCtx(), repository.AddSongParams{GameID: id, SongName: songName, Path: path, FileName: &fileName, Hash: songHash})
|
||||
handleError("AddSong", err, fmt.Sprintf("GameID: %d | Path: %s | SongName: %s | SongHash: %s", id, path, entry.Name(), songHash))
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"music-server/internal/logging"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database/postgres"
|
||||
_ "github.com/golang-migrate/migrate/v4/database/postgres"
|
||||
"github.com/golang-migrate/migrate/v4/source/iofs"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
_ "github.com/lib/pq"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Database holds the database connection pool and context
|
||||
type Database struct {
|
||||
Pool *pgxpool.Pool
|
||||
Ctx context.Context
|
||||
}
|
||||
|
||||
// NewDatabase creates a new Database instance with connection pool
|
||||
func NewDatabase(host, port, user, password, dbname string) (*Database, error) {
|
||||
ctx := context.Background()
|
||||
psqlInfo := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable",
|
||||
host, port, user, password, dbname)
|
||||
|
||||
logging.GetLogger().Debug("Database connection info",
|
||||
zap.String("host", host),
|
||||
zap.String("port", port),
|
||||
zap.String("dbname", dbname))
|
||||
|
||||
pool, err := pgxpool.New(ctx, psqlInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to connect to database: %w", err)
|
||||
}
|
||||
|
||||
// Test connection
|
||||
var success string
|
||||
err = pool.QueryRow(ctx, "select 'Successfully connected!'").Scan(&success)
|
||||
if err != nil {
|
||||
pool.Close()
|
||||
return nil, fmt.Errorf("database query failed: %w", err)
|
||||
}
|
||||
|
||||
logging.GetLogger().Info("Database connected", zap.String("status", success))
|
||||
|
||||
return &Database{Pool: pool, Ctx: ctx}, nil
|
||||
}
|
||||
|
||||
// Close closes the database connection pool
|
||||
func (db *Database) Close() {
|
||||
if db.Pool != nil {
|
||||
logging.GetLogger().Info("Closing database connection")
|
||||
db.Pool.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// RunMigrations runs all pending database migrations to the latest version.
|
||||
// Uses the existing pool to extract connection details.
|
||||
func (db *Database) RunMigrations() error {
|
||||
// Extract connection info from pool config
|
||||
connConfig := db.Pool.Config().ConnConfig
|
||||
migrationURL := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=disable",
|
||||
connConfig.User,
|
||||
connConfig.Password,
|
||||
connConfig.Host,
|
||||
connConfig.Port,
|
||||
connConfig.Database)
|
||||
|
||||
logging.GetLogger().Debug("Migration info", zap.String("url", migrationURL))
|
||||
|
||||
sqlDb, err := sql.Open("postgres", migrationURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open database for migration: %w", err)
|
||||
}
|
||||
defer sqlDb.Close()
|
||||
|
||||
driver, err := postgres.WithInstance(sqlDb, &postgres.Config{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create migration driver: %w", err)
|
||||
}
|
||||
|
||||
files, err := iofs.New(MigrationsFs, "migrations")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create migration files: %w", err)
|
||||
}
|
||||
|
||||
m, err := migrate.NewWithInstance("iofs", files, "postgres", driver)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create migrator: %w", err)
|
||||
}
|
||||
|
||||
// Get current version for logging
|
||||
version, _, err := m.Version()
|
||||
if err != nil {
|
||||
logging.GetLogger().Error("Failed to get migration version", zap.String("error", err.Error()))
|
||||
}
|
||||
|
||||
logging.GetLogger().Info("Migration version before", zap.Uint("version", version))
|
||||
|
||||
// Run all pending migrations to latest version
|
||||
err = m.Up()
|
||||
if err != nil {
|
||||
if err == migrate.ErrNoChange {
|
||||
logging.GetLogger().Info("Database already up to date")
|
||||
} else {
|
||||
return fmt.Errorf("migration failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Get new version after migration
|
||||
versionAfter, _, _ := m.Version()
|
||||
logging.GetLogger().Info("Migrated to version", zap.Uint("version", versionAfter))
|
||||
}
|
||||
|
||||
logging.GetLogger().Info("Migration completed")
|
||||
return nil
|
||||
}
|
||||
@@ -2,7 +2,6 @@ package server
|
||||
|
||||
import (
|
||||
"music-server/cmd/web"
|
||||
"music-server/internal/db"
|
||||
"music-server/internal/logging"
|
||||
"music-server/internal/server/middleware"
|
||||
"net/http"
|
||||
@@ -121,7 +120,7 @@ func (s *Server) RegisterRoutes() http.Handler {
|
||||
|
||||
// Protected endpoints - require valid token
|
||||
// Create token auth middleware with pool access
|
||||
tokenAuthMiddleware := middleware.TokenAuthMiddleware(db.Dbpool)
|
||||
tokenAuthMiddleware := middleware.TokenAuthMiddleware(s.db.Pool)
|
||||
|
||||
// Protected group with token authentication - will be used by VGMQ and Statistics API
|
||||
_ = apiV1.Group("", tokenAuthMiddleware)
|
||||
|
||||
+52
-19
@@ -6,6 +6,7 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"music-server/internal/backend"
|
||||
"music-server/internal/db"
|
||||
"music-server/internal/logging"
|
||||
"net/http"
|
||||
@@ -15,7 +16,9 @@ import (
|
||||
|
||||
type Server struct {
|
||||
port int
|
||||
db *db.Database
|
||||
tokenHandler *TokenHandler
|
||||
httpServer *http.Server
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -30,7 +33,9 @@ var (
|
||||
logJSON = os.Getenv("LOG_JSON") == "true"
|
||||
)
|
||||
|
||||
func NewServer() *http.Server {
|
||||
// NewServerInstance creates a new Server instance with all dependencies initialized.
|
||||
// Use this for dependency injection and proper lifecycle management.
|
||||
func NewServerInstance() *Server {
|
||||
// Initialize logger
|
||||
if logLevel == "" {
|
||||
logLevel = "info"
|
||||
@@ -41,14 +46,44 @@ func NewServer() *http.Server {
|
||||
|
||||
port, _ := strconv.Atoi(os.Getenv("PORT"))
|
||||
|
||||
// Initialize token handler
|
||||
tokenHandler := NewTokenHandler()
|
||||
// Validate required environment variables
|
||||
if host == "" || dbPort == "" || username == "" || password == "" || dbName == "" || musicPath == "" || charactersPath == "" {
|
||||
logging.GetLogger().Fatal("Invalid settings - missing required environment variables")
|
||||
}
|
||||
|
||||
NewServer := &Server{
|
||||
// Create database instance
|
||||
database, err := db.NewDatabase(host, dbPort, username, password, dbName)
|
||||
if err != nil {
|
||||
logging.GetLogger().Fatal("Failed to initialize database", zap.String("error", err.Error()))
|
||||
}
|
||||
|
||||
// Run migrations using the new method
|
||||
if err := database.RunMigrations(); err != nil {
|
||||
logging.GetLogger().Fatal("Migration failed", zap.String("error", err.Error()))
|
||||
}
|
||||
|
||||
// Initialize backend package with database pool
|
||||
backend.InitBackend(database.Pool)
|
||||
|
||||
// Initialize token handler with database pool
|
||||
tokenHandler := NewTokenHandler(database.Pool)
|
||||
|
||||
// Create the server instance
|
||||
appServer := &Server{
|
||||
port: port,
|
||||
db: database,
|
||||
tokenHandler: tokenHandler,
|
||||
}
|
||||
|
||||
// Create the HTTP server
|
||||
appServer.httpServer = &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", port),
|
||||
Handler: appServer.RegisterRoutes(),
|
||||
IdleTimeout: time.Minute,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
logger.Info("Starting server",
|
||||
zap.String("host", host),
|
||||
zap.String("dbPort", dbPort),
|
||||
@@ -61,23 +96,21 @@ func NewServer() *http.Server {
|
||||
zap.String("charactersPath", charactersPath),
|
||||
)
|
||||
|
||||
//conf.SetupDb()
|
||||
if host == "" || dbPort == "" || username == "" || password == "" || dbName == "" || musicPath == "" || charactersPath == "" {
|
||||
logging.GetLogger().Fatal("Invalid settings - missing required environment variables")
|
||||
return appServer
|
||||
}
|
||||
|
||||
db.Migrate_db(host, dbPort, username, password, dbName)
|
||||
|
||||
db.InitDB(host, dbPort, username, password, dbName)
|
||||
|
||||
// Declare Server config
|
||||
server := &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", NewServer.port),
|
||||
Handler: NewServer.RegisterRoutes(),
|
||||
IdleTimeout: time.Minute,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
// HTTPServer returns the underlying http.Server for serving HTTP requests.
|
||||
func (s *Server) HTTPServer() *http.Server {
|
||||
return s.httpServer
|
||||
}
|
||||
|
||||
return server
|
||||
// DB returns the database instance for dependency injection.
|
||||
func (s *Server) DB() *db.Database {
|
||||
return s.db
|
||||
}
|
||||
|
||||
// NewServer creates a new HTTP server (deprecated, use NewServerInstance instead).
|
||||
// This function is kept for backward compatibility.
|
||||
func NewServer() *http.Server {
|
||||
return NewServerInstance().HTTPServer()
|
||||
}
|
||||
|
||||
@@ -75,8 +75,8 @@ func TestSyncPopulatesDatabase(t *testing.T) {
|
||||
db.TestClearDatabase(t)
|
||||
|
||||
// Before sync - should have no games
|
||||
repo := repository.New(db.Dbpool)
|
||||
gamesBefore, err := repo.FindAllGames(db.Ctx)
|
||||
repo := repository.New(backend.BackendPool())
|
||||
gamesBefore, err := repo.FindAllGames(backend.BackendCtx())
|
||||
assert.NoError(t, err)
|
||||
beforeCount := len(gamesBefore)
|
||||
t.Logf("Games before sync: %d", beforeCount)
|
||||
@@ -92,7 +92,7 @@ func TestSyncPopulatesDatabase(t *testing.T) {
|
||||
}
|
||||
|
||||
// After sync - should have games
|
||||
gamesAfter, err := repo.FindAllGames(db.Ctx)
|
||||
gamesAfter, err := repo.FindAllGames(backend.BackendCtx())
|
||||
assert.NoError(t, err)
|
||||
afterCount := len(gamesAfter)
|
||||
t.Logf("Games after sync: %d", afterCount)
|
||||
@@ -112,8 +112,8 @@ func TestSyncMakesDifference(t *testing.T) {
|
||||
db.TestClearDatabase(t)
|
||||
|
||||
// Before sync - should have no games
|
||||
repo := repository.New(db.Dbpool)
|
||||
gamesBefore, err := repo.FindAllGames(db.Ctx)
|
||||
repo := repository.New(backend.BackendPool())
|
||||
gamesBefore, err := repo.FindAllGames(backend.BackendCtx())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, len(gamesBefore), "Should have no games before sync")
|
||||
|
||||
@@ -127,7 +127,7 @@ func TestSyncMakesDifference(t *testing.T) {
|
||||
}
|
||||
|
||||
// After sync - should have games
|
||||
gamesAfter, err := repo.FindAllGames(db.Ctx)
|
||||
gamesAfter, err := repo.FindAllGames(backend.BackendCtx())
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, len(gamesAfter) > 0, "Should have games after sync")
|
||||
}
|
||||
@@ -199,8 +199,8 @@ func TestSyncGamesNewOnlyChanges(t *testing.T) {
|
||||
}
|
||||
|
||||
// Get initial count
|
||||
repo := repository.New(db.Dbpool)
|
||||
gamesBefore, _ := repo.FindAllGames(db.Ctx)
|
||||
repo := repository.New(backend.BackendPool())
|
||||
gamesBefore, _ := repo.FindAllGames(backend.BackendCtx())
|
||||
beforeCount := len(gamesBefore)
|
||||
|
||||
// Run incremental sync (should not change count if nothing changed)
|
||||
@@ -211,7 +211,7 @@ func TestSyncGamesNewOnlyChanges(t *testing.T) {
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Count should be the same
|
||||
gamesAfter, _ := repo.FindAllGames(db.Ctx)
|
||||
gamesAfter, _ := repo.FindAllGames(backend.BackendCtx())
|
||||
afterCount := len(gamesAfter)
|
||||
|
||||
// Note: This might not be exactly equal due to timing, but should be close
|
||||
@@ -227,8 +227,8 @@ func TestResetGames(t *testing.T) {
|
||||
e := StartTestServer(t)
|
||||
|
||||
// First ensure we have data
|
||||
repo := repository.New(db.Dbpool)
|
||||
gamesBefore, _ := repo.FindAllGames(db.Ctx)
|
||||
repo := repository.New(backend.BackendPool())
|
||||
gamesBefore, _ := repo.FindAllGames(backend.BackendCtx())
|
||||
beforeCount := len(gamesBefore)
|
||||
|
||||
if beforeCount == 0 {
|
||||
@@ -238,7 +238,7 @@ func TestResetGames(t *testing.T) {
|
||||
t.Error("Sync did not complete within timeout")
|
||||
return
|
||||
}
|
||||
gamesBefore, _ = repo.FindAllGames(db.Ctx)
|
||||
gamesBefore, _ = repo.FindAllGames(backend.BackendCtx())
|
||||
beforeCount = len(gamesBefore)
|
||||
}
|
||||
|
||||
@@ -253,7 +253,7 @@ func TestResetGames(t *testing.T) {
|
||||
// Note: reset might take a moment to propagate
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
gamesAfter, _ := repo.FindAllGames(db.Ctx)
|
||||
gamesAfter, _ := repo.FindAllGames(backend.BackendCtx())
|
||||
afterCount := len(gamesAfter)
|
||||
|
||||
t.Logf("Games after reset: %d", afterCount)
|
||||
@@ -281,8 +281,8 @@ func TestSyncGamesNewFull(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify database is populated
|
||||
repo := repository.New(db.Dbpool)
|
||||
games, err := repo.FindAllGames(db.Ctx)
|
||||
repo := repository.New(backend.BackendPool())
|
||||
games, err := repo.FindAllGames(backend.BackendCtx())
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, len(games) > 0, "Database should be populated after full sync")
|
||||
t.Logf("Full sync populated %d games", len(games))
|
||||
|
||||
@@ -8,6 +8,9 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"music-server/internal/backend"
|
||||
"music-server/internal/db"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
)
|
||||
|
||||
@@ -45,8 +48,23 @@ func StartTestServer(t *testing.T) *echo.Echo {
|
||||
os.Setenv("LOG_JSON", "false")
|
||||
}
|
||||
|
||||
// Initialize database for tests
|
||||
db.TestSetupDB(t)
|
||||
|
||||
// Initialize backend with the global Dbpool
|
||||
// This ensures BackendRepo() and BackendCtx() are available
|
||||
if db.Dbpool != nil {
|
||||
backend.InitBackend(db.Dbpool)
|
||||
}
|
||||
|
||||
// Create a Server instance and get its routes
|
||||
s := &Server{}
|
||||
s := &Server{
|
||||
db: &db.Database{
|
||||
Pool: db.Dbpool,
|
||||
Ctx: db.Ctx,
|
||||
},
|
||||
tokenHandler: NewTokenHandler(db.Dbpool),
|
||||
}
|
||||
handler := s.RegisterRoutes()
|
||||
|
||||
// Wrap the http.Handler in an echo.Echo
|
||||
|
||||
@@ -7,10 +7,10 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"music-server/internal/db"
|
||||
"music-server/internal/db/repository"
|
||||
"music-server/internal/logging"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/labstack/echo/v5"
|
||||
"go.uber.org/zap"
|
||||
@@ -30,13 +30,13 @@ type TokenResponse struct {
|
||||
|
||||
// TokenHandler contains the database pool for token operations
|
||||
type TokenHandler struct {
|
||||
pool *repository.Queries
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
// NewTokenHandler creates a new token handler with database pool
|
||||
func NewTokenHandler() *TokenHandler {
|
||||
func NewTokenHandler(pool *pgxpool.Pool) *TokenHandler {
|
||||
return &TokenHandler{
|
||||
pool: repository.New(db.Dbpool),
|
||||
pool: pool,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,7 +84,8 @@ func (h *TokenHandler) CreateTokenHandler(c *echo.Context) error {
|
||||
clientType := req.ClientType
|
||||
|
||||
// Store in database using sqlc-generated repository
|
||||
session, err := h.pool.CreateSession(c.Request().Context(), repository.CreateSessionParams{
|
||||
queries := repository.New(h.pool)
|
||||
session, err := queries.CreateSession(c.Request().Context(), repository.CreateSessionParams{
|
||||
Token: token,
|
||||
IpAddress: c.RealIP(),
|
||||
UserAgent: c.Request().UserAgent(),
|
||||
@@ -132,7 +133,8 @@ func (h *TokenHandler) DeleteTokenHandler(c *echo.Context) error {
|
||||
token := parts[1]
|
||||
|
||||
// Delete session using sqlc-generated repository
|
||||
err := h.pool.DeleteSession(c.Request().Context(), token)
|
||||
queries := repository.New(h.pool)
|
||||
err := queries.DeleteSession(c.Request().Context(), token)
|
||||
if err != nil {
|
||||
logging.GetLogger().Error("Failed to delete session", zap.String("error", err.Error()))
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "Failed to invalidate token"})
|
||||
@@ -158,7 +160,8 @@ func (h *TokenHandler) CleanupExpiredSessionsHandler(c *echo.Context) error {
|
||||
// Verify token is valid first (using existing middleware)
|
||||
// The middleware will have already validated the token
|
||||
|
||||
err := h.pool.DeleteExpiredSessions(c.Request().Context())
|
||||
queries := repository.New(h.pool)
|
||||
err := queries.DeleteExpiredSessions(c.Request().Context())
|
||||
if err != nil {
|
||||
logging.GetLogger().Error("Failed to cleanup sessions", zap.String("error", err.Error()))
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "Failed to cleanup sessions"})
|
||||
|
||||
@@ -15,8 +15,8 @@ import (
|
||||
|
||||
// ensureSyncRan ensures that sync has been run before testing music endpoints
|
||||
func ensureSyncRan(t *testing.T, e *echo.Echo) {
|
||||
repo := repository.New(db.Dbpool)
|
||||
games, err := repo.FindAllGames(db.Ctx)
|
||||
repo := repository.New(backend.BackendPool())
|
||||
games, err := repo.FindAllGames(backend.BackendCtx())
|
||||
assert.NoError(t, err)
|
||||
|
||||
if len(games) == 0 {
|
||||
|
||||
Reference in New Issue
Block a user