package db import ( "context" "database/sql" "fmt" "log" "os" "sync" "testing" ) var ( testDBSetupOnce sync.Once testDBHost string testDBPort string testDBUser string testDBPassword string testDBName string // TestDatabase is the database instance for tests TestDatabase *Database ) // TestSetupDB initializes the test database using existing functions // It creates the database if it doesn't exist and runs migrations // Uses sync.Once to ensure it only runs once across all tests func TestSetupDB(t *testing.T) { host := os.Getenv("DB_HOST") port := os.Getenv("DB_PORT") user := os.Getenv("DB_USERNAME") password := os.Getenv("DB_PASSWORD") dbname := os.Getenv("DB_NAME") if host == "" || port == "" || user == "" || password == "" || dbname == "" { t.Skip("Test database environment variables not set") } // Store for TestTearDownDB testDBHost = host testDBPort = port testDBUser = user testDBPassword = password testDBName = dbname // Only run setup once testDBSetupOnce.Do(func() { // Create the database first (testuser is a superuser in the container) createTestDatabase(host, port, dbname, user, password) // Create database instance and run migrations var err error TestDatabase, err = NewDatabase(host, port, user, password, dbname) if err != nil { t.Fatalf("Failed to initialize test database: %v", err) } // Run migrations if err := TestDatabase.RunMigrations(); err != nil { t.Fatalf("Failed to run migrations: %v", err) } }) } // createTestDatabase creates the test database // In the test container, POSTGRES_USER is created as a superuser func createTestDatabase(host, port, dbname, user, password string) { // Connect to the postgres database to create new database connStr := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=postgres sslmode=disable", host, port, user, password) db, err := sql.Open("postgres", connStr) if err != nil { log.Println("Warning: Could not connect to create test database:", err) return } defer db.Close() // Check if database exists var dbExists int err = db.QueryRow("SELECT 1 FROM pg_database WHERE datname = $1", dbname).Scan(&dbExists) if err != nil && err != sql.ErrNoRows { log.Println("Warning: Could not check if database exists:", err) return } if dbExists == 0 { // Create database _, err = db.Exec("CREATE DATABASE " + dbname) if err != nil { log.Println("Warning: Could not create database:", err) return } log.Println("Created test database:", dbname) } } // TestTearDownDB closes the test database connection // Note: We don't actually close the pool between tests to avoid // "closed pool" errors when tests run sequentially func TestTearDownDB(t *testing.T) { // CloseDb() // Disabled to prevent pool closure between sequential tests if TestDatabase != nil { TestDatabase.Close() TestDatabase = nil } } // TestClearDatabase clears all data from the test database // Useful for running tests with a clean slate func TestClearDatabase(t *testing.T) { if TestDatabase == nil || TestDatabase.Pool == nil { t.Skip("Database not initialized") } // Clear all tables in reverse order to respect foreign keys // Note: This assumes the tables exist and have the expected structure tables := []string{ "song_list", "song", "game", } ctx := context.Background() for _, table := range tables { _, err := TestDatabase.Pool.Exec(ctx, "TRUNCATE TABLE "+table+" CASCADE") if err != nil { t.Logf("Failed to truncate table %s: %v", table, err) } } // Reset sequences _, err := TestDatabase.Pool.Exec(ctx, "SELECT setval('game_id_seq', 1, false)") if err != nil { t.Logf("Failed to reset game_id_seq: %v", err) } }