//go:build maria
// +build maria

package main

/*
#cgo pkg-config: libmariadb
#include <mysql.h>
#include <stdlib.h>

// Helper functions to access row data
static const char* get_field(MYSQL_ROW row, unsigned long idx) {
    return row[idx];
}

static unsigned long get_length(unsigned long *lengths, unsigned long idx) {
    return lengths[idx];
}

// Wrapper for mysql_optionsv since CGO has trouble with variadic functions
static int my_mysql_optionsv(MYSQL *mysql, enum mysql_option option, const void *arg) {
    return mysql_optionsv(mysql, option, arg);
}
*/
import "C"
import (
	"fmt"
	"sync"
	"unsafe"

	"jobarranger2/src/libs/golibs/database"
)

type mariaDatabase struct {
	config   *database.DBConfig
	connPool []database.DBConnection
}

func NewDB(config *database.DBConfig) (database.Database, error) {
	var conns []database.DBConnection
	var conn database.DBConnection
	var err error

	for i := 0; i < config.MaxConCount(); i++ {
		conn = &mariaDBConnection{}

		err = conn.Reconnect(config)
		if err != nil {
			return nil, err
		}

		conns = append(conns, conn)
	}

	return &mariaDatabase{
		config:   config,
		connPool: conns,
	}, nil
}

func (db mariaDatabase) Close() {
	for i := 0; i < db.GetPoolSize(); i++ {
		db.connPool[i].Close()
	}
}

func (db mariaDatabase) IsConnClosed(index int) bool {
	if db.GetPoolSize() <= 0 {
		return false
	}

	if index < 0 && index < len(db.connPool) {
		return false
	}

	return db.connPool[index].IsClosed()
}

func (db mariaDatabase) ConnReconnect(index int) error {
	if db.GetPoolSize() <= 0 {
		return database.ErrNoDBConn
	}

	if index < 0 && index < len(db.connPool) {
		return database.ErrIndexOutOfBound
	}

	return db.connPool[index].Reconnect(db.config)
}

func (db mariaDatabase) GetPoolSize() int {
	return len(db.connPool)
}

func (db mariaDatabase) GetDBConfig() *database.DBConfig {
	return db.config
}

func (db mariaDatabase) GetConnFromPool(index int) (database.DBConnection, error) {
	if db.GetPoolSize() <= 0 {
		return nil, database.ErrNoDBConn
	}

	if index < 0 && index < len(db.connPool) {
		return nil, database.ErrIndexOutOfBound
	}

	conn, err := db.connPool[index].StartSession()

	return conn, err
}

func (db *mariaDatabase) GetConn() (database.DBConnection, error) {
	if len(db.connPool) > 0 {
		return database.GetConnFromPool(db) // blocks indefinitely until a connection becomes available.
	}

	conn := &mariaDBConnection{}
	err := conn.Reconnect(db.config)
	if err != nil {
		return nil, err
	}

	return conn, nil
}

type mariaDBConnection struct {
	conn            *C.MYSQL
	isInTransaction bool
	dbErrCode       string
	dbErrMessage    string
	hasSession      bool
	mu              sync.Mutex
}

func newMysqlErr(customMessage, errorMessage, errorCode string) error {
	return fmt.Errorf("%s, mysql_err_msg: %s, mysql_err_code: %s", customMessage, errorMessage, errorCode)
}

func (dbConn *mariaDBConnection) IsClosed() bool {
	return dbConn.conn == nil
}

func (dbConn *mariaDBConnection) StartSession() (database.DBConnection, error) {
	if !dbConn.mu.TryLock() {
		return nil, database.ErrDBConnLocked
	}

	dbConn.hasSession = true

	return dbConn, nil
}

func (dbConn *mariaDBConnection) EndSession() error {
	dbConn.mu.TryLock() // to prevent panic error on unlocking unlocked mutex
	defer func() {
		dbConn.hasSession = false // prevents calling db function after EndSession
		dbConn.mu.Unlock()
	}()

	if dbConn.conn != nil && C.mysql_reset_connection(dbConn.conn) != 0 {
		dbConn.dbErrCode = fmt.Sprintf("%d", C.mysql_errno(dbConn.conn))
		dbConn.dbErrMessage = C.GoString(C.mysql_error(dbConn.conn))

		return newMysqlErr("mysql_reset_connection failed", dbConn.DBErrMessage(), dbConn.DBErrCode())
	}

	return nil
}

// Reconnect() can be used in DB connection retry cases.
func (dbConn *mariaDBConnection) Reconnect(config *database.DBConfig) error {
	if config == nil {
		return database.ErrDBConfigNil
	}

	conn := C.mysql_init(nil)
	if conn == nil {
		panic("mysql_init failed")
	}

	// Convert Go strings to C strings
	host := C.CString(config.Hostname())
	user := C.CString(config.User())
	passwd := C.CString(config.Password())
	dbName := C.CString(config.DBName())
	dbSocket := C.CString(config.MysqlDBSocket())
	dbPort := C.uint(config.Port())

	// clean C pointers
	defer func() {
		C.free(unsafe.Pointer(host))
		C.free(unsafe.Pointer(user))
		C.free(unsafe.Pointer(passwd))
		C.free(unsafe.Pointer(dbName))
		C.free(unsafe.Pointer(dbSocket))
	}()

	if config.TLSMode() != "" {
		switch config.TLSMode() {
		case "required":
			enforce_tls := C.my_bool(1)
			if C.my_mysql_optionsv(conn, C.MYSQL_OPT_SSL_ENFORCE, unsafe.Pointer(&enforce_tls)) != 0 {
				return newMysqlErr("C.MYSQL_OPT_SSL_ENFORCE failed", "", "")
			}
		default:
			verify := C.my_bool(1)
			if C.my_mysql_optionsv(conn, C.MYSQL_OPT_SSL_VERIFY_SERVER_CERT, unsafe.Pointer(&verify)) != 0 {
				return newMysqlErr("C.MYSQL_OPT_SSL_VERIFY_SERVER_CERT failed", "", "")
			}
		}

		if path := config.TLSCaFile(); path != "" {
			filePath := C.CString(config.TLSCaFile())
			defer C.free(unsafe.Pointer(filePath))

			if C.my_mysql_optionsv(conn, C.MYSQL_OPT_SSL_CA, unsafe.Pointer(filePath)) != 0 {
				C.mysql_close(conn)

				return newMysqlErr("MYSQL_OPT_SSL_CA failed", "", "")
			}
		}
		if path := config.TLSCertFile(); path != "" {
			filePath := C.CString(config.TLSCertFile())
			defer C.free(unsafe.Pointer(filePath))

			if C.my_mysql_optionsv(conn, C.MYSQL_OPT_SSL_CERT, unsafe.Pointer(filePath)) != 0 {
				C.mysql_close(conn)

				return newMysqlErr("MYSQL_OPT_SSL_CERT failed", "", "")
			}
		}
		if path := config.TLSKeyFile(); path != "" {
			filePath := C.CString(config.TLSKeyFile())
			defer C.free(unsafe.Pointer(filePath))

			if C.my_mysql_optionsv(conn, C.MYSQL_OPT_SSL_KEY, unsafe.Pointer(filePath)) != 0 {
				C.mysql_close(conn)

				return newMysqlErr("MYSQL_OPT_SSL_KEY failed", "", "")
			}
		}
		if cipher := config.TLSCipher(); cipher != "" {
			cipher := C.CString(config.TLSCipher())
			defer C.free(unsafe.Pointer(cipher))

			if C.mysql_options(conn, C.MYSQL_OPT_SSL_CIPHER, unsafe.Pointer(cipher)) != 0 {
				C.mysql_close(conn)

				return newMysqlErr("MYSQL_OPT_SSL_KEY failed", "", "")
			}
		}
	}

	// Establish connection
	if C.mysql_real_connect(conn, host, user, passwd, dbName, dbPort, dbSocket, C.CLIENT_MULTI_STATEMENTS) == nil {
		dbConn.dbErrCode = fmt.Sprintf("%d", C.mysql_errno(conn))
		dbConn.dbErrMessage = C.GoString(C.mysql_error(conn))

		C.mysql_close(conn)
		return newMysqlErr("mysql_real_connect failed", dbConn.DBErrMessage(), dbConn.dbErrCode)
	}

	dbConn.conn = conn

	return nil
}

func (dbConn *mariaDBConnection) Close() {
	if dbConn.conn != nil {
		C.mysql_close(dbConn.conn)
		dbConn.conn = nil
	}
}

func (dbConn *mariaDBConnection) Begin() error {
	if dbConn.isInTransaction {
		return database.ErrDuplicatedDBTransaction
	}

	if _, err := dbConn.execute("begin;"); err != nil {
		return err
	}

	dbConn.isInTransaction = true

	return nil
}

func (dbConn *mariaDBConnection) Commit() error {
	if !dbConn.isInTransaction {
		return database.ErrNoDBTransaction
	}

	if _, err := dbConn.execute("commit;"); err != nil {
		return err
	}

	dbConn.isInTransaction = false

	return nil
}

func (dbConn *mariaDBConnection) Rollback() error {
	if !dbConn.isInTransaction {
		return database.ErrNoDBTransaction
	}

	if _, err := dbConn.execute("rollback;"); err != nil {
		return err
	}

	dbConn.isInTransaction = false

	return nil
}

func (dbConn *mariaDBConnection) IsAlive() bool {
	if dbConn.conn != nil {
		// Returns true if connection is alive (mysql_ping returns 0 on success)
		return C.mysql_ping(dbConn.conn) == 0
	}

	return false
}

func (dbConn *mariaDBConnection) execute(format string, arg ...any) (int, error) {
	// prevent calling functions after EndSession()
	if !dbConn.hasSession {
		return 0, database.ErrNoDBSession
	}

	affectedCount := 0
	if dbConn.conn == nil {
		return 0, database.ErrDBConnNil
	}

	sqlC := C.CString(fmt.Sprintf(format, arg...))
	defer C.free(unsafe.Pointer(sqlC))

	if C.mysql_query(dbConn.conn, sqlC) != 0 {
		dbConn.dbErrCode = fmt.Sprintf("%d", C.mysql_errno(dbConn.conn))
		dbConn.dbErrMessage = C.GoString(C.mysql_error(dbConn.conn))

		return 0, newMysqlErr("mysql_query failed", dbConn.DBErrMessage(), dbConn.DBErrCode())
	}

	for {
		if C.mysql_field_count(dbConn.conn) != 0 {
			result := C.mysql_store_result(dbConn.conn)
			if result != nil {
				C.mysql_free_result(result)
			}
		} else {
			affectedCount += int(C.mysql_affected_rows(dbConn.conn))
		}

		status := C.mysql_next_result(dbConn.conn)
		// error on getting next result
		if status > 0 {
			dbConn.dbErrCode = fmt.Sprintf("%d", C.mysql_errno(dbConn.conn))
			dbConn.dbErrMessage = C.GoString(C.mysql_error(dbConn.conn))

			return affectedCount, newMysqlErr("mysql_next_result failed", dbConn.DBErrMessage(), dbConn.DBErrCode())
		}

		// no more result
		if status != 0 {
			break
		}
	}

	return affectedCount, nil
}

func (dbConn *mariaDBConnection) Execute(format string, arg ...any) (int, error) {
	if !dbConn.isInTransaction {
		return 0, database.ErrNoDBTransaction
	}

	return dbConn.execute(format, arg...)
}

func (dbConn *mariaDBConnection) Select(format string, arg ...any) (database.DBresult, error) {
	// prevent calling functions after EndSession()
	if !dbConn.hasSession {
		return nil, database.ErrNoDBSession
	}

	if dbConn.conn == nil {
		return nil, database.ErrDBConnNil
	}

	sqlC := C.CString(fmt.Sprintf(format, arg...))
	defer C.free(unsafe.Pointer(sqlC))

	if C.mysql_query(dbConn.conn, sqlC) != 0 {
		dbConn.dbErrCode = fmt.Sprintf("%d", C.mysql_errno(dbConn.conn))
		dbConn.dbErrMessage = C.GoString(C.mysql_error(dbConn.conn))

		return nil, newMysqlErr("mysql_query failed", dbConn.DBErrMessage(), dbConn.DBErrCode())
	}

	result := C.mysql_store_result(dbConn.conn)
	if result == nil {
		dbConn.dbErrCode = fmt.Sprintf("%d", C.mysql_errno(dbConn.conn))
		dbConn.dbErrMessage = C.GoString(C.mysql_error(dbConn.conn))

		return nil, newMysqlErr("mysql_store_result failed", dbConn.DBErrMessage(), dbConn.DBErrCode())
	}

	// create column names according to the number of field
	colNames := make([]string, int(C.mysql_num_fields(result)))

	// insert column names
	for idx := range colNames {
		colNames[idx] = C.GoString(C.mysql_fetch_field_direct(result, C.uint(idx)).name) // get column name by index
	}

	return &MysqlDBresult{
		mysqlResult:  result,
		columnNames:  colNames,
		totalRowNo:   uint64(C.mysql_num_rows(result)),
		currentRowNo: 0,
	}, nil
}

func (database *mariaDBConnection) DBErrCode() string {
	return database.dbErrCode
}

func (database *mariaDBConnection) DBErrMessage() string {
	return database.dbErrMessage
}

type MysqlDBresult struct {
	mysqlResult  *C.MYSQL_RES
	totalRowNo   uint64
	columnNames  []string
	currentRowNo uint64
	mu           sync.Mutex
}

// Fetchs mysqlResult row by row
func (dbResult *MysqlDBresult) Fetch() (map[string]string, error) {
	dbResult.mu.Lock()
	defer dbResult.mu.Unlock()

	row := C.mysql_fetch_row(dbResult.mysqlResult)
	if row == nil {
		return nil, database.ErrNoTableRows // No more rows
	}

	// get column values' length
	lengths := C.mysql_fetch_lengths(dbResult.mysqlResult)
	if lengths == nil {
		// mysql_fetch_lengths is nil only when there is no rows to fetch
		return nil, database.ErrNoTableRows
	}

	rowData := make(map[string]string)
	for idx, colName := range dbResult.columnNames {
		cField := C.get_field(row, C.ulong(idx))
		cLength := C.get_length(lengths, C.ulong(idx))

		if cField == nil {
			rowData[colName] = ""
		} else {
			rowData[colName] = C.GoStringN(cField, C.int(cLength))
		}
	}

	// Advance to next row
	dbResult.currentRowNo++

	return rowData, nil
}

func (dbResult *MysqlDBresult) HasNextRow() bool {
	dbResult.mu.Lock()
	defer dbResult.mu.Unlock()

	return dbResult.currentRowNo < dbResult.totalRowNo
}

func (dbResult *MysqlDBresult) Free() {
	dbResult.mu.Lock()
	defer dbResult.mu.Unlock()

	if dbResult.mysqlResult != nil {
		C.mysql_free_result(dbResult.mysqlResult)
		dbResult.mysqlResult = nil
	}
}
