mirror of
https://github.com/3ybactuk/marketplace-go-service-project.git
synced 2025-10-30 14:03:45 +03:00
[hw-4] add postgres db
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"google.golang.org/grpc"
|
||||
@@ -16,11 +17,12 @@ import (
|
||||
"google.golang.org/grpc/reflection"
|
||||
|
||||
"route256/loms/internal/app/server"
|
||||
ordersRepository "route256/loms/internal/domain/repository/orders"
|
||||
stocksRepository "route256/loms/internal/domain/repository/stocks"
|
||||
ordersRepository "route256/loms/internal/domain/repository/orders/sqlc"
|
||||
stocksRepository "route256/loms/internal/domain/repository/stocks/sqlc"
|
||||
"route256/loms/internal/domain/service"
|
||||
"route256/loms/internal/infra/config"
|
||||
mw "route256/loms/internal/infra/grpc/middleware"
|
||||
"route256/loms/internal/infra/postgres"
|
||||
|
||||
pb "route256/pkg/api/loms/v1"
|
||||
)
|
||||
@@ -50,13 +52,16 @@ func NewApp(configPath string) (*App, error) {
|
||||
|
||||
log.WithLevel(zerolog.GlobalLevel()).Msgf("using logging level=`%s`", zerolog.GlobalLevel().String())
|
||||
|
||||
stockRepo, err := stocksRepository.NewInMemoryRepository(100)
|
||||
masterPool, replicaPool, err := getPostgresPools(c)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stocksRepository.NewInMemoryRepository: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
orderRepo := ordersRepository.NewInMemoryRepository(100)
|
||||
service := service.NewLomsService(orderRepo, stockRepo)
|
||||
stockRepo := stocksRepository.NewStockRepository(masterPool, replicaPool)
|
||||
orderRepo := ordersRepository.NewOrderRepository(masterPool)
|
||||
txManager := postgres.NewTxManager(masterPool, replicaPool)
|
||||
|
||||
service := service.NewLomsService(orderRepo, stockRepo, txManager)
|
||||
controller := server.NewServer(service)
|
||||
|
||||
app := &App{
|
||||
@@ -117,3 +122,34 @@ func (app *App) ListenAndServe() error {
|
||||
|
||||
return gwServer.ListenAndServe()
|
||||
}
|
||||
|
||||
func getPostgresPools(c *config.Config) (masterPool, replicaPool *pgxpool.Pool, err error) {
|
||||
masterConn := fmt.Sprintf(
|
||||
"postgresql://%s:%s@%s:%s/%s?sslmode=disable",
|
||||
c.DatabaseMaster.User,
|
||||
c.DatabaseMaster.Password,
|
||||
c.DatabaseMaster.Host,
|
||||
c.DatabaseMaster.Port,
|
||||
c.DatabaseMaster.DBName,
|
||||
)
|
||||
|
||||
replicaConn := fmt.Sprintf(
|
||||
"postgresql://%s:%s@%s:%s/%s?sslmode=disable",
|
||||
c.DatabaseReplica.User,
|
||||
c.DatabaseReplica.Password,
|
||||
c.DatabaseReplica.Host,
|
||||
c.DatabaseReplica.Port,
|
||||
c.DatabaseReplica.DBName,
|
||||
)
|
||||
|
||||
pools, err := postgres.NewPools(context.Background(), masterConn, replicaConn)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if len(pools) != 2 {
|
||||
return nil, nil, fmt.Errorf("got wrong number of pools when establishing postgres connection")
|
||||
}
|
||||
|
||||
return pools[0], pools[1], nil
|
||||
}
|
||||
|
||||
32
loms/internal/domain/repository/orders/sqlc/db.go
Normal file
32
loms/internal/domain/repository/orders/sqlc/db.go
Normal file
@@ -0,0 +1,32 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.29.0
|
||||
|
||||
package sqlc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
type DBTX interface {
|
||||
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
|
||||
Query(context.Context, string, ...interface{}) (pgx.Rows, error)
|
||||
QueryRow(context.Context, string, ...interface{}) pgx.Row
|
||||
}
|
||||
|
||||
func New(db DBTX) *Queries {
|
||||
return &Queries{db: db}
|
||||
}
|
||||
|
||||
type Queries struct {
|
||||
db DBTX
|
||||
}
|
||||
|
||||
func (q *Queries) WithTx(tx pgx.Tx) *Queries {
|
||||
return &Queries{
|
||||
db: tx,
|
||||
}
|
||||
}
|
||||
5
loms/internal/domain/repository/orders/sqlc/models.go
Normal file
5
loms/internal/domain/repository/orders/sqlc/models.go
Normal file
@@ -0,0 +1,5 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.29.0
|
||||
|
||||
package sqlc
|
||||
18
loms/internal/domain/repository/orders/sqlc/querier.go
Normal file
18
loms/internal/domain/repository/orders/sqlc/querier.go
Normal file
@@ -0,0 +1,18 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.29.0
|
||||
|
||||
package sqlc
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type Querier interface {
|
||||
OrderAddItem(ctx context.Context, arg *OrderAddItemParams) error
|
||||
OrderCreate(ctx context.Context, arg *OrderCreateParams) (int64, error)
|
||||
OrderGetByID(ctx context.Context, id int64) ([]*OrderGetByIDRow, error)
|
||||
OrderSetStatus(ctx context.Context, arg *OrderSetStatusParams) (int64, error)
|
||||
}
|
||||
|
||||
var _ Querier = (*Queries)(nil)
|
||||
25
loms/internal/domain/repository/orders/sqlc/query.sql
Normal file
25
loms/internal/domain/repository/orders/sqlc/query.sql
Normal file
@@ -0,0 +1,25 @@
|
||||
-- name: OrderCreate :one
|
||||
insert into orders (status_name, user_id)
|
||||
values ($1, $2)
|
||||
returning id;
|
||||
|
||||
-- name: OrderAddItem :exec
|
||||
insert into order_items (order_id, sku, count)
|
||||
select $1, $2, $3;
|
||||
|
||||
-- name: OrderSetStatus :execrows
|
||||
update orders
|
||||
set status_name = $2
|
||||
where id = $1;
|
||||
|
||||
-- name: OrderGetByID :many
|
||||
select
|
||||
o.id as order_id,
|
||||
o.status_name as status,
|
||||
o.user_id,
|
||||
oi.sku,
|
||||
oi.count
|
||||
from orders o
|
||||
left join order_items oi on oi.order_id = o.id
|
||||
where o.id = $1
|
||||
order by oi.id;
|
||||
110
loms/internal/domain/repository/orders/sqlc/query.sql.go
Normal file
110
loms/internal/domain/repository/orders/sqlc/query.sql.go
Normal file
@@ -0,0 +1,110 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.29.0
|
||||
// source: query.sql
|
||||
|
||||
package sqlc
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
const orderAddItem = `-- name: OrderAddItem :exec
|
||||
insert into order_items (order_id, sku, count)
|
||||
select $1, $2, $3
|
||||
`
|
||||
|
||||
type OrderAddItemParams struct {
|
||||
OrderID int64
|
||||
Sku int64
|
||||
Count int64
|
||||
}
|
||||
|
||||
func (q *Queries) OrderAddItem(ctx context.Context, arg *OrderAddItemParams) error {
|
||||
_, err := q.db.Exec(ctx, orderAddItem, arg.OrderID, arg.Sku, arg.Count)
|
||||
return err
|
||||
}
|
||||
|
||||
const orderCreate = `-- name: OrderCreate :one
|
||||
insert into orders (status_name, user_id)
|
||||
values ($1, $2)
|
||||
returning id
|
||||
`
|
||||
|
||||
type OrderCreateParams struct {
|
||||
StatusName string
|
||||
UserID int64
|
||||
}
|
||||
|
||||
func (q *Queries) OrderCreate(ctx context.Context, arg *OrderCreateParams) (int64, error) {
|
||||
row := q.db.QueryRow(ctx, orderCreate, arg.StatusName, arg.UserID)
|
||||
var id int64
|
||||
err := row.Scan(&id)
|
||||
return id, err
|
||||
}
|
||||
|
||||
const orderGetByID = `-- name: OrderGetByID :many
|
||||
select
|
||||
o.id as order_id,
|
||||
o.status_name as status,
|
||||
o.user_id,
|
||||
oi.sku,
|
||||
oi.count
|
||||
from orders o
|
||||
left join order_items oi on oi.order_id = o.id
|
||||
where o.id = $1
|
||||
order by oi.id
|
||||
`
|
||||
|
||||
type OrderGetByIDRow struct {
|
||||
OrderID int64
|
||||
Status string
|
||||
UserID int64
|
||||
Sku *int64
|
||||
Count *int64
|
||||
}
|
||||
|
||||
func (q *Queries) OrderGetByID(ctx context.Context, id int64) ([]*OrderGetByIDRow, error) {
|
||||
rows, err := q.db.Query(ctx, orderGetByID, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []*OrderGetByIDRow
|
||||
for rows.Next() {
|
||||
var i OrderGetByIDRow
|
||||
if err := rows.Scan(
|
||||
&i.OrderID,
|
||||
&i.Status,
|
||||
&i.UserID,
|
||||
&i.Sku,
|
||||
&i.Count,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, &i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const orderSetStatus = `-- name: OrderSetStatus :execrows
|
||||
update orders
|
||||
set status_name = $2
|
||||
where id = $1
|
||||
`
|
||||
|
||||
type OrderSetStatusParams struct {
|
||||
ID int64
|
||||
StatusName string
|
||||
}
|
||||
|
||||
func (q *Queries) OrderSetStatus(ctx context.Context, arg *OrderSetStatusParams) (int64, error) {
|
||||
result, err := q.db.Exec(ctx, orderSetStatus, arg.ID, arg.StatusName)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected(), nil
|
||||
}
|
||||
105
loms/internal/domain/repository/orders/sqlc/repository.go
Normal file
105
loms/internal/domain/repository/orders/sqlc/repository.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package sqlc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"route256/loms/internal/domain/entity"
|
||||
"route256/loms/internal/domain/model"
|
||||
"route256/loms/internal/domain/service"
|
||||
"route256/loms/internal/infra/postgres"
|
||||
)
|
||||
|
||||
type orderRepo struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewOrderRepository(pool *pgxpool.Pool) service.OrderRepository {
|
||||
return &orderRepo{
|
||||
pool: pool,
|
||||
}
|
||||
}
|
||||
|
||||
func (o *orderRepo) GetQuerier(ctx context.Context) *Queries {
|
||||
tx, ok := postgres.TxFromCtx(ctx)
|
||||
if ok {
|
||||
return New(tx)
|
||||
}
|
||||
|
||||
return New(o.pool)
|
||||
}
|
||||
|
||||
func (o *orderRepo) OrderCreate(ctx context.Context, order *entity.Order) (entity.ID, error) {
|
||||
querier := o.GetQuerier(ctx)
|
||||
|
||||
id, err := querier.OrderCreate(ctx, &OrderCreateParams{
|
||||
StatusName: order.Status,
|
||||
UserID: int64(order.UserID),
|
||||
})
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("querier.OrderCreate: %w", err)
|
||||
}
|
||||
|
||||
for _, item := range order.Items {
|
||||
if err := querier.OrderAddItem(ctx, &OrderAddItemParams{
|
||||
OrderID: id,
|
||||
Sku: int64(item.ID),
|
||||
Count: int64(item.Count),
|
||||
}); err != nil {
|
||||
return 0, fmt.Errorf("querier.OrderAddItem: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return entity.ID(id), nil
|
||||
}
|
||||
|
||||
func (o *orderRepo) OrderGetByID(ctx context.Context, orderID entity.ID) (*entity.Order, error) {
|
||||
querier := o.GetQuerier(ctx)
|
||||
|
||||
rows, err := querier.OrderGetByID(ctx, int64(orderID))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querier.OrderGetByID: %w", err)
|
||||
}
|
||||
|
||||
if len(rows) == 0 {
|
||||
return nil, model.ErrOrderNotFound
|
||||
}
|
||||
|
||||
items := make([]entity.OrderItem, len(rows))
|
||||
for i, row := range rows {
|
||||
items[i] = entity.OrderItem{
|
||||
ID: entity.Sku(*row.Sku),
|
||||
//nolint:gosec // will not overflow, uint32 is stored as int64
|
||||
Count: uint32(*row.Count),
|
||||
}
|
||||
}
|
||||
|
||||
order := &entity.Order{
|
||||
OrderID: orderID,
|
||||
Status: rows[0].Status,
|
||||
UserID: entity.ID(rows[0].UserID),
|
||||
Items: items,
|
||||
}
|
||||
|
||||
return order, nil
|
||||
}
|
||||
|
||||
func (o *orderRepo) OrderSetStatus(ctx context.Context, orderID entity.ID, newStatus string) error {
|
||||
querier := o.GetQuerier(ctx)
|
||||
|
||||
rows, err := querier.OrderSetStatus(ctx, &OrderSetStatusParams{
|
||||
ID: int64(orderID),
|
||||
StatusName: newStatus,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("querier.OrderSetStatus: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return model.ErrOrderNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
32
loms/internal/domain/repository/stocks/sqlc/db.go
Normal file
32
loms/internal/domain/repository/stocks/sqlc/db.go
Normal file
@@ -0,0 +1,32 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.29.0
|
||||
|
||||
package sqlc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
type DBTX interface {
|
||||
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
|
||||
Query(context.Context, string, ...interface{}) (pgx.Rows, error)
|
||||
QueryRow(context.Context, string, ...interface{}) pgx.Row
|
||||
}
|
||||
|
||||
func New(db DBTX) *Queries {
|
||||
return &Queries{db: db}
|
||||
}
|
||||
|
||||
type Queries struct {
|
||||
db DBTX
|
||||
}
|
||||
|
||||
func (q *Queries) WithTx(tx pgx.Tx) *Queries {
|
||||
return &Queries{
|
||||
db: tx,
|
||||
}
|
||||
}
|
||||
11
loms/internal/domain/repository/stocks/sqlc/models.go
Normal file
11
loms/internal/domain/repository/stocks/sqlc/models.go
Normal file
@@ -0,0 +1,11 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.29.0
|
||||
|
||||
package sqlc
|
||||
|
||||
type Stock struct {
|
||||
Sku int64
|
||||
TotalCount int64
|
||||
Reserved int64
|
||||
}
|
||||
18
loms/internal/domain/repository/stocks/sqlc/querier.go
Normal file
18
loms/internal/domain/repository/stocks/sqlc/querier.go
Normal file
@@ -0,0 +1,18 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.29.0
|
||||
|
||||
package sqlc
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type Querier interface {
|
||||
StockCancel(ctx context.Context, arg *StockCancelParams) (int64, error)
|
||||
StockGetByID(ctx context.Context, sku int64) (*Stock, error)
|
||||
StockReserve(ctx context.Context, arg *StockReserveParams) (int64, error)
|
||||
StockReserveRemove(ctx context.Context, arg *StockReserveRemoveParams) (int64, error)
|
||||
}
|
||||
|
||||
var _ Querier = (*Queries)(nil)
|
||||
24
loms/internal/domain/repository/stocks/sqlc/query.sql
Normal file
24
loms/internal/domain/repository/stocks/sqlc/query.sql
Normal file
@@ -0,0 +1,24 @@
|
||||
-- name: StockGetByID :one
|
||||
select sku, total_count, reserved
|
||||
from stocks
|
||||
where sku = $1
|
||||
limit 1;
|
||||
|
||||
-- name: StockReserve :execrows
|
||||
update stocks
|
||||
set reserved = reserved + $2
|
||||
where sku = $1
|
||||
and total_count >= reserved + $2;
|
||||
|
||||
-- name: StockReserveRemove :execrows
|
||||
update stocks
|
||||
set reserved = reserved - $2,
|
||||
total_count = total_count - $2
|
||||
where sku = $1
|
||||
and reserved >= $2 and total_count >= $2;
|
||||
|
||||
-- name: StockCancel :execrows
|
||||
update stocks
|
||||
set reserved = reserved - $2
|
||||
where sku = $1
|
||||
and reserved >= $2;
|
||||
85
loms/internal/domain/repository/stocks/sqlc/query.sql.go
Normal file
85
loms/internal/domain/repository/stocks/sqlc/query.sql.go
Normal file
@@ -0,0 +1,85 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.29.0
|
||||
// source: query.sql
|
||||
|
||||
package sqlc
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
const stockCancel = `-- name: StockCancel :execrows
|
||||
update stocks
|
||||
set reserved = reserved - $2
|
||||
where sku = $1
|
||||
and reserved >= $2
|
||||
`
|
||||
|
||||
type StockCancelParams struct {
|
||||
Sku int64
|
||||
Reserved int64
|
||||
}
|
||||
|
||||
func (q *Queries) StockCancel(ctx context.Context, arg *StockCancelParams) (int64, error) {
|
||||
result, err := q.db.Exec(ctx, stockCancel, arg.Sku, arg.Reserved)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected(), nil
|
||||
}
|
||||
|
||||
const stockGetByID = `-- name: StockGetByID :one
|
||||
select sku, total_count, reserved
|
||||
from stocks
|
||||
where sku = $1
|
||||
limit 1
|
||||
`
|
||||
|
||||
func (q *Queries) StockGetByID(ctx context.Context, sku int64) (*Stock, error) {
|
||||
row := q.db.QueryRow(ctx, stockGetByID, sku)
|
||||
var i Stock
|
||||
err := row.Scan(&i.Sku, &i.TotalCount, &i.Reserved)
|
||||
return &i, err
|
||||
}
|
||||
|
||||
const stockReserve = `-- name: StockReserve :execrows
|
||||
update stocks
|
||||
set reserved = reserved + $2
|
||||
where sku = $1
|
||||
and total_count >= reserved + $2
|
||||
`
|
||||
|
||||
type StockReserveParams struct {
|
||||
Sku int64
|
||||
Reserved int64
|
||||
}
|
||||
|
||||
func (q *Queries) StockReserve(ctx context.Context, arg *StockReserveParams) (int64, error) {
|
||||
result, err := q.db.Exec(ctx, stockReserve, arg.Sku, arg.Reserved)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected(), nil
|
||||
}
|
||||
|
||||
const stockReserveRemove = `-- name: StockReserveRemove :execrows
|
||||
update stocks
|
||||
set reserved = reserved - $2,
|
||||
total_count = total_count - $2
|
||||
where sku = $1
|
||||
and reserved >= $2 and total_count >= $2
|
||||
`
|
||||
|
||||
type StockReserveRemoveParams struct {
|
||||
Sku int64
|
||||
Reserved int64
|
||||
}
|
||||
|
||||
func (q *Queries) StockReserveRemove(ctx context.Context, arg *StockReserveRemoveParams) (int64, error) {
|
||||
result, err := q.db.Exec(ctx, stockReserveRemove, arg.Sku, arg.Reserved)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected(), nil
|
||||
}
|
||||
137
loms/internal/domain/repository/stocks/sqlc/repository.go
Normal file
137
loms/internal/domain/repository/stocks/sqlc/repository.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package sqlc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"route256/loms/internal/domain/entity"
|
||||
"route256/loms/internal/domain/model"
|
||||
"route256/loms/internal/domain/service"
|
||||
"route256/loms/internal/infra/postgres"
|
||||
"route256/loms/internal/infra/tools"
|
||||
)
|
||||
|
||||
type stockRepo struct {
|
||||
read *pgxpool.Pool
|
||||
write *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewStockRepository(write, read *pgxpool.Pool) service.StockRepository {
|
||||
return &stockRepo{
|
||||
read: read,
|
||||
write: write,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stockRepo) GetQuerier(ctx context.Context) *Queries {
|
||||
tx, ok := postgres.TxFromCtx(ctx)
|
||||
if ok {
|
||||
return New(tx)
|
||||
}
|
||||
|
||||
return New(s.write)
|
||||
}
|
||||
|
||||
func (s *stockRepo) StockReserve(ctx context.Context, stock *entity.Stock) error {
|
||||
querier := s.GetQuerier(ctx)
|
||||
|
||||
rows, err := querier.StockReserve(ctx, &StockReserveParams{
|
||||
Sku: int64(stock.Item.ID),
|
||||
Reserved: int64(stock.Reserved),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("querier.StockReserve: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return model.ErrNotEnoughStocks
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stockRepo) StockReserveRemove(ctx context.Context, stock *entity.Stock) error {
|
||||
querier := s.GetQuerier(ctx)
|
||||
|
||||
rows, err := querier.StockReserveRemove(ctx, &StockReserveRemoveParams{
|
||||
Sku: int64(stock.Item.ID),
|
||||
Reserved: int64(stock.Reserved),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("querier.StockReserveRemove: %w", err)
|
||||
}
|
||||
|
||||
if rows > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = querier.StockGetByID(ctx, int64(stock.Item.ID))
|
||||
switch {
|
||||
case errors.Is(err, pgx.ErrNoRows):
|
||||
return model.ErrUnknownStock
|
||||
case err != nil:
|
||||
return fmt.Errorf("querier.StockGetByID: %w", err)
|
||||
default:
|
||||
return model.ErrNotEnoughStocks
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stockRepo) StockCancel(ctx context.Context, stock *entity.Stock) error {
|
||||
querier := s.GetQuerier(ctx)
|
||||
|
||||
rows, err := querier.StockCancel(ctx, &StockCancelParams{
|
||||
Sku: int64(stock.Item.ID),
|
||||
Reserved: int64(stock.Reserved),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("querier.StockCancel: %w", err)
|
||||
}
|
||||
|
||||
if rows > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = querier.StockGetByID(ctx, int64(stock.Item.ID))
|
||||
switch {
|
||||
case errors.Is(err, pgx.ErrNoRows):
|
||||
return model.ErrUnknownStock
|
||||
case err != nil:
|
||||
return fmt.Errorf("querier.StockGetByID: %w", err)
|
||||
default:
|
||||
return model.ErrNotEnoughStocks
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stockRepo) StockGetByID(ctx context.Context, sku entity.Sku) (*entity.Stock, error) {
|
||||
querier := s.GetQuerier(ctx)
|
||||
|
||||
stock, err := querier.StockGetByID(ctx, int64(sku))
|
||||
switch {
|
||||
case errors.Is(err, pgx.ErrNoRows):
|
||||
return nil, model.ErrUnknownStock
|
||||
case err != nil:
|
||||
return nil, fmt.Errorf("querier.StockGetByID: %w", err)
|
||||
default:
|
||||
count, castErr := tools.SafeCastInt64ToUInt32(stock.TotalCount)
|
||||
if castErr != nil {
|
||||
return nil, castErr
|
||||
}
|
||||
|
||||
reserved, castErr := tools.SafeCastInt64ToUInt32(stock.Reserved)
|
||||
if castErr != nil {
|
||||
return nil, castErr
|
||||
}
|
||||
|
||||
return &entity.Stock{
|
||||
Item: entity.OrderItem{
|
||||
ID: entity.Sku(stock.Sku),
|
||||
Count: count,
|
||||
},
|
||||
Reserved: reserved,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,8 @@
|
||||
// Code generated by http://github.com/gojuno/minimock (v3.4.5). DO NOT EDIT.
|
||||
// Code generated by http://github.com/gojuno/minimock (v3.4.0). DO NOT EDIT.
|
||||
|
||||
package mock
|
||||
|
||||
//go:generate minimock -i route256/loms/internal/domain/loms/service.OrderRepository -o order_repository_mock.go -n OrderRepositoryMock -p mock
|
||||
//go:generate minimock -i route256/loms/internal/domain/service.OrderRepository -o order_repository_mock.go -n OrderRepositoryMock -p mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// Code generated by http://github.com/gojuno/minimock (v3.4.5). DO NOT EDIT.
|
||||
// Code generated by http://github.com/gojuno/minimock (v3.4.0). DO NOT EDIT.
|
||||
|
||||
package mock
|
||||
|
||||
//go:generate minimock -i route256/loms/internal/domain/loms/service.StockRepository -o stock_repository_mock.go -n StockRepositoryMock -p mock
|
||||
//go:generate minimock -i route256/loms/internal/domain/service.StockRepository -o stock_repository_mock.go -n StockRepositoryMock -p mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
@@ -28,15 +28,24 @@ type StockRepository interface {
|
||||
StockGetByID(ctx context.Context, sku entity.Sku) (*entity.Stock, error)
|
||||
}
|
||||
|
||||
type LomsService struct {
|
||||
orders OrderRepository
|
||||
stocks StockRepository
|
||||
type txManager interface {
|
||||
WriteWithTransaction(ctx context.Context, fn func(ctx context.Context) error) (err error)
|
||||
ReadWithTransaction(ctx context.Context, fn func(ctx context.Context) error) (err error)
|
||||
WriteWithRepeatableRead(ctx context.Context, fn func(ctx context.Context) error) (err error)
|
||||
ReadWithRepeatableRead(ctx context.Context, fn func(ctx context.Context) error) (err error)
|
||||
}
|
||||
|
||||
func NewLomsService(orderRepo OrderRepository, stockRepo StockRepository) *LomsService {
|
||||
type LomsService struct {
|
||||
orders OrderRepository
|
||||
stocks StockRepository
|
||||
txManager txManager
|
||||
}
|
||||
|
||||
func NewLomsService(orderRepo OrderRepository, stockRepo StockRepository, txManager txManager) *LomsService {
|
||||
return &LomsService{
|
||||
orders: orderRepo,
|
||||
stocks: stockRepo,
|
||||
orders: orderRepo,
|
||||
stocks: stockRepo,
|
||||
txManager: txManager,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,40 +86,44 @@ func (s *LomsService) OrderCreate(ctx context.Context, orderReq *pb.OrderCreateR
|
||||
return int(a.ID - b.ID)
|
||||
})
|
||||
|
||||
id, err := s.orders.OrderCreate(ctx, order)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("orders.OrderCreate: %w", err)
|
||||
}
|
||||
var (
|
||||
orderID entity.ID
|
||||
resErr error
|
||||
)
|
||||
|
||||
order.OrderID = id
|
||||
|
||||
commitedStocks := make([]*entity.Stock, 0, len(order.Items))
|
||||
for _, item := range order.Items {
|
||||
stock := &entity.Stock{
|
||||
Item: item,
|
||||
Reserved: item.Count,
|
||||
err := s.txManager.WriteWithTransaction(ctx, func(txCtx context.Context) error {
|
||||
id, err := s.orders.OrderCreate(txCtx, order)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
order.OrderID = id
|
||||
orderID = id
|
||||
|
||||
if err := s.stocks.StockReserve(ctx, stock); err != nil {
|
||||
s.rollbackStocks(ctx, commitedStocks)
|
||||
committed := make([]*entity.Stock, 0, len(order.Items))
|
||||
for _, it := range order.Items {
|
||||
st := &entity.Stock{Item: it, Reserved: it.Count}
|
||||
if err := s.stocks.StockReserve(txCtx, st); err != nil {
|
||||
s.rollbackStocks(txCtx, committed)
|
||||
|
||||
if statusErr := s.orders.OrderSetStatus(ctx, order.OrderID, pb.OrderStatus_ORDER_STATUS_FAILED.String()); statusErr != nil {
|
||||
log.Error().Err(statusErr).Msg("failed to update status on stock reserve fail")
|
||||
_ = s.orders.OrderSetStatus(txCtx, id,
|
||||
pb.OrderStatus_ORDER_STATUS_FAILED.String())
|
||||
|
||||
resErr = fmt.Errorf("stocks.StockReserve: %w", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("stocks.StockReserve: %w", err)
|
||||
committed = append(committed, st)
|
||||
}
|
||||
|
||||
commitedStocks = append(commitedStocks, stock)
|
||||
}
|
||||
|
||||
if err := s.orders.OrderSetStatus(ctx, order.OrderID, pb.OrderStatus_ORDER_STATUS_AWAITING_PAYMENT.String()); err != nil {
|
||||
s.rollbackStocks(ctx, commitedStocks)
|
||||
|
||||
return s.orders.OrderSetStatus(txCtx, id,
|
||||
pb.OrderStatus_ORDER_STATUS_AWAITING_PAYMENT.String())
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return order.OrderID, nil
|
||||
if resErr != nil {
|
||||
return 0, resErr
|
||||
}
|
||||
return orderID, nil
|
||||
}
|
||||
|
||||
func (s *LomsService) OrderInfo(ctx context.Context, orderID entity.ID) (*entity.Order, error) {
|
||||
@@ -122,54 +135,66 @@ func (s *LomsService) OrderInfo(ctx context.Context, orderID entity.ID) (*entity
|
||||
}
|
||||
|
||||
func (s *LomsService) OrderPay(ctx context.Context, orderID entity.ID) error {
|
||||
order, err := s.OrderInfo(ctx, orderID)
|
||||
if err != nil {
|
||||
return err
|
||||
if orderID <= 0 {
|
||||
return model.ErrInvalidInput
|
||||
}
|
||||
|
||||
switch order.Status {
|
||||
case pb.OrderStatus_ORDER_STATUS_PAYED.String():
|
||||
return nil
|
||||
case pb.OrderStatus_ORDER_STATUS_AWAITING_PAYMENT.String():
|
||||
for _, item := range order.Items {
|
||||
if err := s.stocks.StockReserveRemove(ctx, &entity.Stock{
|
||||
Item: item,
|
||||
Reserved: item.Count,
|
||||
}); err != nil {
|
||||
log.Error().Err(err).Msg("failed to free stock reservation")
|
||||
}
|
||||
return s.txManager.WriteWithTransaction(ctx, func(txCtx context.Context) error {
|
||||
order, err := s.orders.OrderGetByID(txCtx, orderID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.orders.OrderSetStatus(ctx, orderID, pb.OrderStatus_ORDER_STATUS_PAYED.String())
|
||||
default:
|
||||
return model.ErrOrderInvalidStatus
|
||||
}
|
||||
switch order.Status {
|
||||
case pb.OrderStatus_ORDER_STATUS_PAYED.String():
|
||||
return nil
|
||||
case pb.OrderStatus_ORDER_STATUS_AWAITING_PAYMENT.String():
|
||||
for _, it := range order.Items {
|
||||
if err := s.stocks.StockReserveRemove(txCtx, &entity.Stock{
|
||||
Item: it,
|
||||
Reserved: it.Count,
|
||||
}); err != nil {
|
||||
log.Error().Err(err).Msg("failed to free stock reservation")
|
||||
}
|
||||
}
|
||||
return s.orders.OrderSetStatus(txCtx, orderID,
|
||||
pb.OrderStatus_ORDER_STATUS_PAYED.String())
|
||||
default:
|
||||
return model.ErrOrderInvalidStatus
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *LomsService) OrderCancel(ctx context.Context, orderID entity.ID) error {
|
||||
order, err := s.OrderInfo(ctx, orderID)
|
||||
if err != nil {
|
||||
return err
|
||||
if orderID <= 0 {
|
||||
return model.ErrInvalidInput
|
||||
}
|
||||
|
||||
switch order.Status {
|
||||
case pb.OrderStatus_ORDER_STATUS_CANCELLED.String():
|
||||
return nil
|
||||
case pb.OrderStatus_ORDER_STATUS_FAILED.String(), pb.OrderStatus_ORDER_STATUS_PAYED.String():
|
||||
return model.ErrOrderInvalidStatus
|
||||
}
|
||||
|
||||
stocks := make([]*entity.Stock, len(order.Items))
|
||||
for i, item := range order.Items {
|
||||
stocks[i] = &entity.Stock{
|
||||
Item: item,
|
||||
Reserved: item.Count,
|
||||
return s.txManager.WriteWithTransaction(ctx, func(txCtx context.Context) error {
|
||||
order, err := s.orders.OrderGetByID(txCtx, orderID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
s.rollbackStocks(ctx, stocks)
|
||||
switch order.Status {
|
||||
case pb.OrderStatus_ORDER_STATUS_CANCELLED.String():
|
||||
return nil
|
||||
case pb.OrderStatus_ORDER_STATUS_FAILED.String(),
|
||||
pb.OrderStatus_ORDER_STATUS_PAYED.String():
|
||||
return model.ErrOrderInvalidStatus
|
||||
}
|
||||
|
||||
return s.orders.OrderSetStatus(ctx, orderID, pb.OrderStatus_ORDER_STATUS_CANCELLED.String())
|
||||
for _, it := range order.Items {
|
||||
if err := s.stocks.StockCancel(txCtx, &entity.Stock{
|
||||
Item: it,
|
||||
Reserved: it.Count,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return s.orders.OrderSetStatus(txCtx, orderID,
|
||||
pb.OrderStatus_ORDER_STATUS_CANCELLED.String())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *LomsService) StocksInfo(ctx context.Context, sku entity.Sku) (uint32, error) {
|
||||
@@ -182,5 +207,5 @@ func (s *LomsService) StocksInfo(ctx context.Context, sku entity.Sku) (uint32, e
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return stock.Item.Count, nil
|
||||
return stock.Item.Count - stock.Reserved, nil
|
||||
}
|
||||
|
||||
@@ -21,6 +21,24 @@ const (
|
||||
testSku = entity.Sku(199)
|
||||
)
|
||||
|
||||
type mockTxManager struct{}
|
||||
|
||||
func (t *mockTxManager) WriteWithTransaction(ctx context.Context, fn func(ctx context.Context) error) (err error) {
|
||||
return fn(ctx)
|
||||
}
|
||||
|
||||
func (t *mockTxManager) ReadWithTransaction(ctx context.Context, fn func(ctx context.Context) error) (err error) {
|
||||
return fn(ctx)
|
||||
}
|
||||
|
||||
func (t *mockTxManager) WriteWithRepeatableRead(ctx context.Context, fn func(ctx context.Context) error) (err error) {
|
||||
return fn(ctx)
|
||||
}
|
||||
|
||||
func (t *mockTxManager) ReadWithRepeatableRead(ctx context.Context, fn func(ctx context.Context) error) (err error) {
|
||||
return fn(ctx)
|
||||
}
|
||||
|
||||
func TestLomsService_OrderCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -130,8 +148,7 @@ func TestLomsService_OrderCreate(t *testing.T) {
|
||||
OrderCreateMock.Return(1, nil).
|
||||
OrderSetStatusMock.Return(errors.New("unexpected error")),
|
||||
stocks: mock.NewStockRepositoryMock(mc).
|
||||
StockReserveMock.Return(nil).
|
||||
StockCancelMock.Return(nil),
|
||||
StockReserveMock.Return(nil),
|
||||
},
|
||||
args: args{
|
||||
req: goodReq,
|
||||
@@ -145,7 +162,7 @@ func TestLomsService_OrderCreate(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewLomsService(tt.fields.orders, tt.fields.stocks)
|
||||
svc := NewLomsService(tt.fields.orders, tt.fields.stocks, &mockTxManager{})
|
||||
_, err := svc.OrderCreate(ctx, tt.args.req)
|
||||
tt.wantErr(t, err)
|
||||
})
|
||||
@@ -256,24 +273,46 @@ func TestLomsService_OrderPay(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewLomsService(tt.fields.orders, tt.fields.stocks)
|
||||
svc := NewLomsService(tt.fields.orders, tt.fields.stocks, &mockTxManager{})
|
||||
err := svc.OrderPay(ctx, tt.args.id)
|
||||
tt.wantErr(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLomsService_OrderInfo(t *testing.T) {
|
||||
func TestLomsService_OrderInfoBadInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewLomsService(
|
||||
nil,
|
||||
nil,
|
||||
&mockTxManager{},
|
||||
)
|
||||
|
||||
_, err := svc.OrderInfo(context.Background(), 0)
|
||||
require.ErrorIs(t, err, model.ErrInvalidInput)
|
||||
}
|
||||
|
||||
func TestLomsService_OrderInfoSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
order := &entity.Order{
|
||||
OrderID: 123,
|
||||
Status: "payed",
|
||||
UserID: 1337,
|
||||
Items: []entity.OrderItem{},
|
||||
}
|
||||
|
||||
mc := minimock.NewController(t)
|
||||
svc := NewLomsService(
|
||||
mock.NewOrderRepositoryMock(mc),
|
||||
mock.NewStockRepositoryMock(mc),
|
||||
mock.NewOrderRepositoryMock(mc).OrderGetByIDMock.Return(order, nil),
|
||||
nil,
|
||||
&mockTxManager{},
|
||||
)
|
||||
|
||||
err := svc.OrderPay(context.Background(), 0)
|
||||
require.ErrorIs(t, err, model.ErrInvalidInput)
|
||||
gotOrder, err := svc.OrderInfo(context.Background(), 123)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, order, gotOrder)
|
||||
}
|
||||
|
||||
func TestLomsService_OrderCancel(t *testing.T) {
|
||||
@@ -370,7 +409,7 @@ func TestLomsService_OrderCancel(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewLomsService(tt.fields.orders, tt.fields.stocks)
|
||||
svc := NewLomsService(tt.fields.orders, tt.fields.stocks, &mockTxManager{})
|
||||
err := svc.OrderCancel(ctx, tt.args.id)
|
||||
tt.wantErr(t, err)
|
||||
})
|
||||
@@ -437,7 +476,7 @@ func TestLomsService_StocksInfo(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewLomsService(nil, tt.fields.stocks)
|
||||
svc := NewLomsService(nil, tt.fields.stocks, &mockTxManager{})
|
||||
got, err := svc.StocksInfo(ctx, tt.args.sku)
|
||||
tt.wantErr(t, err)
|
||||
if err == nil {
|
||||
|
||||
56
loms/internal/infra/postgres/postgres.go
Normal file
56
loms/internal/infra/postgres/postgres.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// From https://gitlab.ozon.dev/go/classroom-18/students/week-4-workshop/-/blob/master/internal/infra/postgres/postgres.go
|
||||
|
||||
func NewPool(ctx context.Context, dsn string) (*pgxpool.Pool, error) {
|
||||
config, err := pgxpool.ParseConfig(dsn)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "pgxpool.ParseConfig")
|
||||
}
|
||||
|
||||
pool, err := pgxpool.NewWithConfig(ctx, config)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "pgxpool.NewWithConfig")
|
||||
}
|
||||
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
func NewPools(ctx context.Context, DSNs ...string) ([]*pgxpool.Pool, error) {
|
||||
pools := make([]*pgxpool.Pool, len(DSNs))
|
||||
|
||||
for i, dsn := range DSNs {
|
||||
cfg, err := pgxpool.ParseConfig(dsn)
|
||||
if err != nil {
|
||||
closeOpened(pools[:i])
|
||||
|
||||
return nil, errors.Wrap(err, "pgxpool.ParseConfig")
|
||||
}
|
||||
|
||||
pool, err := pgxpool.NewWithConfig(ctx, cfg)
|
||||
if err != nil {
|
||||
closeOpened(pools[:i])
|
||||
|
||||
return nil, errors.Wrap(err, "pgxpool.NewWithConfig")
|
||||
}
|
||||
|
||||
pools[i] = pool
|
||||
}
|
||||
|
||||
return pools, nil
|
||||
}
|
||||
|
||||
func closeOpened(pools []*pgxpool.Pool) {
|
||||
for _, p := range pools {
|
||||
if p != nil {
|
||||
p.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
87
loms/internal/infra/postgres/tx.go
Normal file
87
loms/internal/infra/postgres/tx.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package postgres
|
||||
|
||||
// From https://gitlab.ozon.dev/go/classroom-18/students/week-4-workshop/-/blob/master/internal/infra/postgres/tx.go
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/opentracing/opentracing-go"
|
||||
)
|
||||
|
||||
// Tx транзакция.
|
||||
type Tx pgx.Tx
|
||||
|
||||
type txKey struct{}
|
||||
|
||||
func ctxWithTx(ctx context.Context, tx pgx.Tx) context.Context {
|
||||
return context.WithValue(ctx, txKey{}, tx)
|
||||
}
|
||||
|
||||
func TxFromCtx(ctx context.Context) (pgx.Tx, bool) {
|
||||
tx, ok := ctx.Value(txKey{}).(pgx.Tx)
|
||||
|
||||
return tx, ok
|
||||
}
|
||||
|
||||
type TxManager struct {
|
||||
write *pgxpool.Pool
|
||||
read *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewTxManager(write, read *pgxpool.Pool) *TxManager {
|
||||
return &TxManager{
|
||||
write: write,
|
||||
read: read,
|
||||
}
|
||||
}
|
||||
|
||||
// WithTransaction выполняет fn в транзакции с дефолтным уровнем изоляции.
|
||||
func (m *TxManager) WriteWithTransaction(ctx context.Context, fn func(ctx context.Context) error) (err error) {
|
||||
return m.withTx(ctx, m.write, pgx.TxOptions{}, fn)
|
||||
}
|
||||
|
||||
func (m *TxManager) ReadWithTransaction(ctx context.Context, fn func(ctx context.Context) error) (err error) {
|
||||
return m.withTx(ctx, m.read, pgx.TxOptions{}, fn)
|
||||
}
|
||||
|
||||
// WithTransaction выполняет fn в транзакции с уровнем изоляции RepeatableRead.
|
||||
func (m *TxManager) WriteWithRepeatableRead(ctx context.Context, fn func(ctx context.Context) error) (err error) {
|
||||
return m.withTx(ctx, m.write, pgx.TxOptions{IsoLevel: pgx.RepeatableRead}, fn)
|
||||
}
|
||||
|
||||
func (m *TxManager) ReadWithRepeatableRead(ctx context.Context, fn func(ctx context.Context) error) (err error) {
|
||||
return m.withTx(ctx, m.read, pgx.TxOptions{IsoLevel: pgx.RepeatableRead}, fn)
|
||||
}
|
||||
|
||||
// WithTx выполняет fn в транзакции.
|
||||
func (m *TxManager) withTx(ctx context.Context, pool *pgxpool.Pool, options pgx.TxOptions, fn func(ctx context.Context) error) (err error) {
|
||||
var span opentracing.Span
|
||||
span, ctx = opentracing.StartSpanFromContext(ctx, "Transaction")
|
||||
defer span.Finish()
|
||||
|
||||
tx, err := pool.BeginTx(ctx, options)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ctx = ctxWithTx(ctx, tx)
|
||||
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
// a panic occurred, rollback and repanic
|
||||
_ = tx.Rollback(ctx)
|
||||
panic(p)
|
||||
} else if err != nil {
|
||||
// something went wrong, rollback
|
||||
_ = tx.Rollback(ctx)
|
||||
} else {
|
||||
// all good, commit
|
||||
err = tx.Commit(ctx)
|
||||
}
|
||||
}()
|
||||
|
||||
err = fn(ctx)
|
||||
|
||||
return
|
||||
}
|
||||
19
loms/internal/infra/tools/safecast.go
Normal file
19
loms/internal/infra/tools/safecast.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
)
|
||||
|
||||
func SafeCastInt64ToUInt32(num int64) (uint32, error) {
|
||||
if num < 0 {
|
||||
return 0, fmt.Errorf("tried casting signed negative number to unsigned number")
|
||||
}
|
||||
|
||||
if num > math.MaxUint32 {
|
||||
return 0, fmt.Errorf("tried casting larger number than uint32 can store")
|
||||
}
|
||||
|
||||
// the bounds are checked, and cast should be safe.
|
||||
return uint32(num), nil // #nosec G115
|
||||
}
|
||||
Reference in New Issue
Block a user