126 lines
3.2 KiB
Go
126 lines
3.2 KiB
Go
package db
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"net/url"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
_ "github.com/microsoft/go-mssqldb"
|
|
)
|
|
|
|
var MssqlDB *sql.DB
|
|
|
|
func envInt(name string, fallback int) int {
|
|
raw := strings.TrimSpace(os.Getenv(name))
|
|
if raw == "" {
|
|
return fallback
|
|
}
|
|
value, err := strconv.Atoi(raw)
|
|
if err != nil || value <= 0 {
|
|
return fallback
|
|
}
|
|
return value
|
|
}
|
|
|
|
func ensureTimeoutValue(current string, desired int) string {
|
|
cur, err := strconv.Atoi(strings.TrimSpace(current))
|
|
if err == nil && cur >= desired {
|
|
return strings.TrimSpace(current)
|
|
}
|
|
return strconv.Itoa(desired)
|
|
}
|
|
|
|
func ensureMSSQLTimeouts(connString string, connectionTimeoutSec int, dialTimeoutSec int) string {
|
|
raw := strings.TrimSpace(connString)
|
|
if raw == "" {
|
|
return raw
|
|
}
|
|
|
|
if strings.HasPrefix(strings.ToLower(raw), "sqlserver://") {
|
|
u, err := url.Parse(raw)
|
|
if err != nil {
|
|
return raw
|
|
}
|
|
q := u.Query()
|
|
q.Set("connection timeout", ensureTimeoutValue(q.Get("connection timeout"), connectionTimeoutSec))
|
|
q.Set("dial timeout", ensureTimeoutValue(q.Get("dial timeout"), dialTimeoutSec))
|
|
u.RawQuery = q.Encode()
|
|
return u.String()
|
|
}
|
|
|
|
parts := strings.Split(raw, ";")
|
|
foundConnectionTimeout := false
|
|
foundDialTimeout := false
|
|
|
|
for i, part := range parts {
|
|
part = strings.TrimSpace(part)
|
|
if part == "" {
|
|
continue
|
|
}
|
|
|
|
eq := strings.Index(part, "=")
|
|
if eq <= 0 {
|
|
continue
|
|
}
|
|
|
|
key := strings.ToLower(strings.TrimSpace(part[:eq]))
|
|
value := strings.TrimSpace(part[eq+1:])
|
|
|
|
switch key {
|
|
case "connection timeout":
|
|
foundConnectionTimeout = true
|
|
parts[i] = "connection timeout=" + ensureTimeoutValue(value, connectionTimeoutSec)
|
|
case "dial timeout":
|
|
foundDialTimeout = true
|
|
parts[i] = "dial timeout=" + ensureTimeoutValue(value, dialTimeoutSec)
|
|
}
|
|
}
|
|
|
|
if !foundConnectionTimeout {
|
|
parts = append(parts, "connection timeout="+strconv.Itoa(connectionTimeoutSec))
|
|
}
|
|
if !foundDialTimeout {
|
|
parts = append(parts, "dial timeout="+strconv.Itoa(dialTimeoutSec))
|
|
}
|
|
|
|
return strings.Join(parts, ";")
|
|
}
|
|
|
|
// ConnectMSSQL initializes the MSSQL connection from environment.
|
|
func ConnectMSSQL() error {
|
|
connString := strings.TrimSpace(os.Getenv("MSSQL_CONN"))
|
|
if connString == "" {
|
|
return fmt.Errorf("MSSQL_CONN tanimli degil")
|
|
}
|
|
|
|
connectionTimeoutSec := envInt("MSSQL_CONNECTION_TIMEOUT_SEC", 120)
|
|
dialTimeoutSec := envInt("MSSQL_DIAL_TIMEOUT_SEC", connectionTimeoutSec)
|
|
connString = ensureMSSQLTimeouts(connString, connectionTimeoutSec, dialTimeoutSec)
|
|
|
|
var err error
|
|
MssqlDB, err = sql.Open("sqlserver", connString)
|
|
if err != nil {
|
|
return fmt.Errorf("MSSQL baglanti hatasi: %w", err)
|
|
}
|
|
|
|
MssqlDB.SetMaxOpenConns(envInt("MSSQL_MAX_OPEN_CONNS", 40))
|
|
MssqlDB.SetMaxIdleConns(envInt("MSSQL_MAX_IDLE_CONNS", 40))
|
|
MssqlDB.SetConnMaxLifetime(time.Duration(envInt("MSSQL_CONN_MAX_LIFETIME_MIN", 30)) * time.Minute)
|
|
MssqlDB.SetConnMaxIdleTime(time.Duration(envInt("MSSQL_CONN_MAX_IDLE_MIN", 10)) * time.Minute)
|
|
|
|
if err = MssqlDB.Ping(); err != nil {
|
|
return fmt.Errorf("MSSQL erisilemiyor: %w", err)
|
|
}
|
|
|
|
fmt.Printf("MSSQL baglantisi basarili (connection timeout=%ds, dial timeout=%ds)\n", connectionTimeoutSec, dialTimeoutSec)
|
|
return nil
|
|
}
|
|
|
|
func GetDB() *sql.DB {
|
|
return MssqlDB
|
|
}
|