151 lines
4.3 KiB
Go
151 lines
4.3 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"sync"
|
|
"testing"
|
|
)
|
|
|
|
var (
|
|
testDBSetupOnce sync.Once
|
|
testDBHost string
|
|
testDBPort string
|
|
testDBUser string
|
|
testDBPassword string
|
|
testDBName string
|
|
// TestDatabase is the database instance for tests
|
|
TestDatabase *Database
|
|
)
|
|
|
|
// TestSetupDB initializes the test database using existing functions
|
|
// It creates the database if it doesn't exist and runs migrations
|
|
// Uses sync.Once to ensure it only runs once across all tests
|
|
func TestSetupDB(t *testing.T) {
|
|
host := os.Getenv("DB_HOST")
|
|
port := os.Getenv("DB_PORT")
|
|
user := os.Getenv("DB_USERNAME")
|
|
password := os.Getenv("DB_PASSWORD")
|
|
dbname := os.Getenv("DB_NAME")
|
|
|
|
if host == "" || port == "" || user == "" || password == "" || dbname == "" {
|
|
t.Skip("Test database environment variables not set")
|
|
}
|
|
|
|
// Store for TestTearDownDB
|
|
testDBHost = host
|
|
testDBPort = port
|
|
testDBUser = user
|
|
testDBPassword = password
|
|
testDBName = dbname
|
|
|
|
// Only run setup once
|
|
testDBSetupOnce.Do(func() {
|
|
// Create the database first (testuser is a superuser in the container)
|
|
createTestDatabase(host, port, dbname, user, password)
|
|
|
|
// Create database instance and run migrations
|
|
var err error
|
|
TestDatabase, err = NewDatabase(host, port, user, password, dbname)
|
|
if err != nil {
|
|
t.Fatalf("Failed to initialize test database: %v", err)
|
|
}
|
|
|
|
// Clean up any existing schema to ensure clean state
|
|
ctx := context.Background()
|
|
_, err = TestDatabase.Pool.Exec(ctx, "DROP SCHEMA IF EXISTS public CASCADE; CREATE SCHEMA public;")
|
|
if err != nil {
|
|
t.Logf("Warning: Could not clean schema: %v", err)
|
|
// Continue anyway, migrations might still work
|
|
}
|
|
|
|
// Run migrations
|
|
if err := TestDatabase.RunMigrations(); err != nil {
|
|
// Clean up on failure to prevent nil pointer issues in other tests
|
|
TestDatabase.Close()
|
|
TestDatabase = nil
|
|
t.Fatalf("Failed to run migrations: %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
// createTestDatabase creates the test database
|
|
// In the test container, POSTGRES_USER is created as a superuser
|
|
func createTestDatabase(host, port, dbname, user, password string) {
|
|
// Connect to the postgres database to create new database
|
|
connStr := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=postgres sslmode=disable", host, port, user, password)
|
|
db, err := sql.Open("postgres", connStr)
|
|
if err != nil {
|
|
log.Println("Warning: Could not connect to create test database:", err)
|
|
return
|
|
}
|
|
defer db.Close()
|
|
|
|
// Check if database exists
|
|
var dbExists int
|
|
err = db.QueryRow("SELECT 1 FROM pg_database WHERE datname = $1", dbname).Scan(&dbExists)
|
|
if err != nil && err != sql.ErrNoRows {
|
|
log.Println("Warning: Could not check if database exists:", err)
|
|
return
|
|
}
|
|
|
|
if dbExists == 0 {
|
|
// Create database
|
|
_, err = db.Exec("CREATE DATABASE " + dbname)
|
|
if err != nil {
|
|
log.Println("Warning: Could not create database:", err)
|
|
return
|
|
}
|
|
log.Println("Created test database:", dbname)
|
|
}
|
|
}
|
|
|
|
// TestTearDownDB closes the test database connection
|
|
// Note: We don't actually close the pool between tests to avoid
|
|
// "closed pool" errors when tests run sequentially
|
|
func TestTearDownDB(t *testing.T) {
|
|
// CloseDb() // Disabled to prevent pool closure between sequential tests
|
|
// Note: We also don't nil TestDatabase to allow reuse across tests
|
|
// if TestDatabase != nil {
|
|
// TestDatabase.Close()
|
|
// TestDatabase = nil
|
|
// }
|
|
}
|
|
|
|
// TestClearDatabase clears all data from the test database
|
|
// Useful for running tests with a clean slate
|
|
func TestClearDatabase(t *testing.T) {
|
|
if TestDatabase == nil || TestDatabase.Pool == nil {
|
|
t.Skip("Database not initialized")
|
|
}
|
|
|
|
// Clear all tables in reverse order to respect foreign keys
|
|
// Note: This assumes the tables exist and have the expected structure
|
|
// After migration 000005, game table was renamed to soundtrack
|
|
tables := []string{
|
|
"song_list",
|
|
"song",
|
|
"soundtrack",
|
|
"vgmq",
|
|
"sessions",
|
|
}
|
|
|
|
ctx := context.Background()
|
|
for _, table := range tables {
|
|
_, err := TestDatabase.Pool.Exec(ctx, "TRUNCATE TABLE "+table+" CASCADE")
|
|
if err != nil {
|
|
t.Logf("Failed to truncate table %s: %v", table, err)
|
|
}
|
|
}
|
|
|
|
// Reset sequences (renamed from game_id_seq to soundtrack_id_seq in migration 000005)
|
|
var seqErr error
|
|
_, seqErr = TestDatabase.Pool.Exec(ctx, "SELECT setval('soundtrack_id_seq', 1, false)")
|
|
if seqErr != nil {
|
|
t.Logf("Failed to reset soundtrack_id_seq: %v", seqErr)
|
|
}
|
|
}
|