package postgres // From https://gitlab.ozon.dev/go/classroom-18/students/week-4-workshop/-/blob/master/internal/infra/postgres/tx.go import ( "context" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/opentracing/opentracing-go" ) // Tx транзакция. type Tx pgx.Tx type txKey struct{} func ctxWithTx(ctx context.Context, tx pgx.Tx) context.Context { return context.WithValue(ctx, txKey{}, tx) } func TxFromCtx(ctx context.Context) (pgx.Tx, bool) { tx, ok := ctx.Value(txKey{}).(pgx.Tx) return tx, ok } type TxManager struct { write *pgxpool.Pool read *pgxpool.Pool } func NewTxManager(write, read *pgxpool.Pool) *TxManager { return &TxManager{ write: write, read: read, } } // WithTransaction выполняет fn в транзакции с дефолтным уровнем изоляции. func (m *TxManager) WriteWithTransaction(ctx context.Context, fn func(ctx context.Context) error) (err error) { return m.withTx(ctx, m.write, pgx.TxOptions{}, fn) } func (m *TxManager) ReadWithTransaction(ctx context.Context, fn func(ctx context.Context) error) (err error) { return m.withTx(ctx, m.read, pgx.TxOptions{}, fn) } // WithTransaction выполняет fn в транзакции с уровнем изоляции RepeatableRead. func (m *TxManager) WriteWithRepeatableRead(ctx context.Context, fn func(ctx context.Context) error) (err error) { return m.withTx(ctx, m.write, pgx.TxOptions{IsoLevel: pgx.RepeatableRead}, fn) } func (m *TxManager) ReadWithRepeatableRead(ctx context.Context, fn func(ctx context.Context) error) (err error) { return m.withTx(ctx, m.read, pgx.TxOptions{IsoLevel: pgx.RepeatableRead}, fn) } // WithTx выполняет fn в транзакции. func (m *TxManager) withTx(ctx context.Context, pool *pgxpool.Pool, options pgx.TxOptions, fn func(ctx context.Context) error) (err error) { var span opentracing.Span span, ctx = opentracing.StartSpanFromContext(ctx, "Transaction") defer span.Finish() tx, err := pool.BeginTx(ctx, options) if err != nil { return } ctx = ctxWithTx(ctx, tx) defer func() { if p := recover(); p != nil { // a panic occurred, rollback and repanic _ = tx.Rollback(ctx) panic(p) } else if err != nil { // something went wrong, rollback _ = tx.Rollback(ctx) } else { // all good, commit err = tx.Commit(ctx) } }() err = fn(ctx) return }