diff --git a/internal/db/dbHelper.go b/internal/db/dbHelper.go index 8bcfc55..0d09e41 100644 --- a/internal/db/dbHelper.go +++ b/internal/db/dbHelper.go @@ -20,6 +20,7 @@ import ( "go.uber.org/zap" ) +// TODO: Remove these global variables once test_helpers.go is fully migrated to use Database struct var Dbpool *pgxpool.Pool var Ctx = context.Background() diff --git a/internal/db/test_helpers.go b/internal/db/test_helpers.go index ffb5475..2cdac70 100644 --- a/internal/db/test_helpers.go +++ b/internal/db/test_helpers.go @@ -1,6 +1,7 @@ package db import ( + "context" "database/sql" "fmt" "log" @@ -16,6 +17,8 @@ var ( testDBUser string testDBPassword string testDBName string + // TestDatabase is the database instance for tests + TestDatabase *Database ) // TestSetupDB initializes the test database using existing functions @@ -44,9 +47,17 @@ func TestSetupDB(t *testing.T) { // Create the database first (testuser is a superuser in the container) createTestDatabase(host, port, dbname, user, password) - // Now run migrations using the existing function - Migrate_db(host, port, user, password, dbname) - InitDB(host, port, user, password, dbname) + // Create database instance and run migrations + var err error + TestDatabase, err = NewDatabase(host, port, user, password, dbname) + if err != nil { + t.Fatalf("Failed to initialize test database: %v", err) + } + + // Run migrations + if err := TestDatabase.RunMigrations(); err != nil { + t.Fatalf("Failed to run migrations: %v", err) + } }) } @@ -86,12 +97,16 @@ func createTestDatabase(host, port, dbname, user, password string) { // "closed pool" errors when tests run sequentially func TestTearDownDB(t *testing.T) { // CloseDb() // Disabled to prevent pool closure between sequential tests + if TestDatabase != nil { + TestDatabase.Close() + TestDatabase = nil + } } // TestClearDatabase clears all data from the test database // Useful for running tests with a clean slate func TestClearDatabase(t *testing.T) { - if Dbpool == nil { + if TestDatabase == nil || TestDatabase.Pool == nil { t.Skip("Database not initialized") } @@ -103,15 +118,16 @@ func TestClearDatabase(t *testing.T) { "game", } + ctx := context.Background() for _, table := range tables { - _, err := Dbpool.Exec(Ctx, "TRUNCATE TABLE "+table+" CASCADE") + _, err := TestDatabase.Pool.Exec(ctx, "TRUNCATE TABLE "+table+" CASCADE") if err != nil { t.Logf("Failed to truncate table %s: %v", table, err) } } // Reset sequences - _, err := Dbpool.Exec(Ctx, "SELECT setval('game_id_seq', 1, false)") + _, err := TestDatabase.Pool.Exec(ctx, "SELECT setval('game_id_seq', 1, false)") if err != nil { t.Logf("Failed to reset game_id_seq: %v", err) } diff --git a/internal/server/test_helpers.go b/internal/server/test_helpers.go index e0e4f3e..bd30334 100644 --- a/internal/server/test_helpers.go +++ b/internal/server/test_helpers.go @@ -51,19 +51,16 @@ func StartTestServer(t *testing.T) *echo.Echo { // Initialize database for tests db.TestSetupDB(t) - // Initialize backend with the global Dbpool + // Initialize backend with test database pool // This ensures BackendRepo() and BackendCtx() are available - if db.Dbpool != nil { - backend.InitBackend(db.Dbpool) + if db.TestDatabase != nil && db.TestDatabase.Pool != nil { + backend.InitBackend(db.TestDatabase.Pool) } // Create a Server instance and get its routes s := &Server{ - db: &db.Database{ - Pool: db.Dbpool, - Ctx: db.Ctx, - }, - tokenHandler: NewTokenHandler(db.Dbpool), + db: db.TestDatabase, + tokenHandler: NewTokenHandler(db.TestDatabase.Pool), } handler := s.RegisterRoutes()