[hw-4] add postgres db

This commit is contained in:
Никита Шубин
2025-06-26 12:08:46 +00:00
parent 3ebaad5558
commit 77ed9fcf85
46 changed files with 1582 additions and 369 deletions

View File

@@ -9,6 +9,7 @@ import (
"time"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"google.golang.org/grpc"
@@ -16,11 +17,12 @@ import (
"google.golang.org/grpc/reflection"
"route256/loms/internal/app/server"
ordersRepository "route256/loms/internal/domain/repository/orders"
stocksRepository "route256/loms/internal/domain/repository/stocks"
ordersRepository "route256/loms/internal/domain/repository/orders/sqlc"
stocksRepository "route256/loms/internal/domain/repository/stocks/sqlc"
"route256/loms/internal/domain/service"
"route256/loms/internal/infra/config"
mw "route256/loms/internal/infra/grpc/middleware"
"route256/loms/internal/infra/postgres"
pb "route256/pkg/api/loms/v1"
)
@@ -50,13 +52,16 @@ func NewApp(configPath string) (*App, error) {
log.WithLevel(zerolog.GlobalLevel()).Msgf("using logging level=`%s`", zerolog.GlobalLevel().String())
stockRepo, err := stocksRepository.NewInMemoryRepository(100)
masterPool, replicaPool, err := getPostgresPools(c)
if err != nil {
return nil, fmt.Errorf("stocksRepository.NewInMemoryRepository: %w", err)
return nil, err
}
orderRepo := ordersRepository.NewInMemoryRepository(100)
service := service.NewLomsService(orderRepo, stockRepo)
stockRepo := stocksRepository.NewStockRepository(masterPool, replicaPool)
orderRepo := ordersRepository.NewOrderRepository(masterPool)
txManager := postgres.NewTxManager(masterPool, replicaPool)
service := service.NewLomsService(orderRepo, stockRepo, txManager)
controller := server.NewServer(service)
app := &App{
@@ -117,3 +122,34 @@ func (app *App) ListenAndServe() error {
return gwServer.ListenAndServe()
}
func getPostgresPools(c *config.Config) (masterPool, replicaPool *pgxpool.Pool, err error) {
masterConn := fmt.Sprintf(
"postgresql://%s:%s@%s:%s/%s?sslmode=disable",
c.DatabaseMaster.User,
c.DatabaseMaster.Password,
c.DatabaseMaster.Host,
c.DatabaseMaster.Port,
c.DatabaseMaster.DBName,
)
replicaConn := fmt.Sprintf(
"postgresql://%s:%s@%s:%s/%s?sslmode=disable",
c.DatabaseReplica.User,
c.DatabaseReplica.Password,
c.DatabaseReplica.Host,
c.DatabaseReplica.Port,
c.DatabaseReplica.DBName,
)
pools, err := postgres.NewPools(context.Background(), masterConn, replicaConn)
if err != nil {
return nil, nil, err
}
if len(pools) != 2 {
return nil, nil, fmt.Errorf("got wrong number of pools when establishing postgres connection")
}
return pools[0], pools[1], nil
}

View File

@@ -0,0 +1,32 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.29.0
package sqlc
import (
"context"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)
type DBTX interface {
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
Query(context.Context, string, ...interface{}) (pgx.Rows, error)
QueryRow(context.Context, string, ...interface{}) pgx.Row
}
func New(db DBTX) *Queries {
return &Queries{db: db}
}
type Queries struct {
db DBTX
}
func (q *Queries) WithTx(tx pgx.Tx) *Queries {
return &Queries{
db: tx,
}
}

View File

@@ -0,0 +1,5 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.29.0
package sqlc

View File

@@ -0,0 +1,18 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.29.0
package sqlc
import (
"context"
)
type Querier interface {
OrderAddItem(ctx context.Context, arg *OrderAddItemParams) error
OrderCreate(ctx context.Context, arg *OrderCreateParams) (int64, error)
OrderGetByID(ctx context.Context, id int64) ([]*OrderGetByIDRow, error)
OrderSetStatus(ctx context.Context, arg *OrderSetStatusParams) (int64, error)
}
var _ Querier = (*Queries)(nil)

View File

@@ -0,0 +1,25 @@
-- name: OrderCreate :one
insert into orders (status_name, user_id)
values ($1, $2)
returning id;
-- name: OrderAddItem :exec
insert into order_items (order_id, sku, count)
select $1, $2, $3;
-- name: OrderSetStatus :execrows
update orders
set status_name = $2
where id = $1;
-- name: OrderGetByID :many
select
o.id as order_id,
o.status_name as status,
o.user_id,
oi.sku,
oi.count
from orders o
left join order_items oi on oi.order_id = o.id
where o.id = $1
order by oi.id;

View File

@@ -0,0 +1,110 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.29.0
// source: query.sql
package sqlc
import (
"context"
)
const orderAddItem = `-- name: OrderAddItem :exec
insert into order_items (order_id, sku, count)
select $1, $2, $3
`
type OrderAddItemParams struct {
OrderID int64
Sku int64
Count int64
}
func (q *Queries) OrderAddItem(ctx context.Context, arg *OrderAddItemParams) error {
_, err := q.db.Exec(ctx, orderAddItem, arg.OrderID, arg.Sku, arg.Count)
return err
}
const orderCreate = `-- name: OrderCreate :one
insert into orders (status_name, user_id)
values ($1, $2)
returning id
`
type OrderCreateParams struct {
StatusName string
UserID int64
}
func (q *Queries) OrderCreate(ctx context.Context, arg *OrderCreateParams) (int64, error) {
row := q.db.QueryRow(ctx, orderCreate, arg.StatusName, arg.UserID)
var id int64
err := row.Scan(&id)
return id, err
}
const orderGetByID = `-- name: OrderGetByID :many
select
o.id as order_id,
o.status_name as status,
o.user_id,
oi.sku,
oi.count
from orders o
left join order_items oi on oi.order_id = o.id
where o.id = $1
order by oi.id
`
type OrderGetByIDRow struct {
OrderID int64
Status string
UserID int64
Sku *int64
Count *int64
}
func (q *Queries) OrderGetByID(ctx context.Context, id int64) ([]*OrderGetByIDRow, error) {
rows, err := q.db.Query(ctx, orderGetByID, id)
if err != nil {
return nil, err
}
defer rows.Close()
var items []*OrderGetByIDRow
for rows.Next() {
var i OrderGetByIDRow
if err := rows.Scan(
&i.OrderID,
&i.Status,
&i.UserID,
&i.Sku,
&i.Count,
); err != nil {
return nil, err
}
items = append(items, &i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const orderSetStatus = `-- name: OrderSetStatus :execrows
update orders
set status_name = $2
where id = $1
`
type OrderSetStatusParams struct {
ID int64
StatusName string
}
func (q *Queries) OrderSetStatus(ctx context.Context, arg *OrderSetStatusParams) (int64, error) {
result, err := q.db.Exec(ctx, orderSetStatus, arg.ID, arg.StatusName)
if err != nil {
return 0, err
}
return result.RowsAffected(), nil
}

View File

@@ -0,0 +1,105 @@
package sqlc
import (
"context"
"fmt"
"github.com/jackc/pgx/v5/pgxpool"
"route256/loms/internal/domain/entity"
"route256/loms/internal/domain/model"
"route256/loms/internal/domain/service"
"route256/loms/internal/infra/postgres"
)
type orderRepo struct {
pool *pgxpool.Pool
}
func NewOrderRepository(pool *pgxpool.Pool) service.OrderRepository {
return &orderRepo{
pool: pool,
}
}
func (o *orderRepo) GetQuerier(ctx context.Context) *Queries {
tx, ok := postgres.TxFromCtx(ctx)
if ok {
return New(tx)
}
return New(o.pool)
}
func (o *orderRepo) OrderCreate(ctx context.Context, order *entity.Order) (entity.ID, error) {
querier := o.GetQuerier(ctx)
id, err := querier.OrderCreate(ctx, &OrderCreateParams{
StatusName: order.Status,
UserID: int64(order.UserID),
})
if err != nil {
return 0, fmt.Errorf("querier.OrderCreate: %w", err)
}
for _, item := range order.Items {
if err := querier.OrderAddItem(ctx, &OrderAddItemParams{
OrderID: id,
Sku: int64(item.ID),
Count: int64(item.Count),
}); err != nil {
return 0, fmt.Errorf("querier.OrderAddItem: %w", err)
}
}
return entity.ID(id), nil
}
func (o *orderRepo) OrderGetByID(ctx context.Context, orderID entity.ID) (*entity.Order, error) {
querier := o.GetQuerier(ctx)
rows, err := querier.OrderGetByID(ctx, int64(orderID))
if err != nil {
return nil, fmt.Errorf("querier.OrderGetByID: %w", err)
}
if len(rows) == 0 {
return nil, model.ErrOrderNotFound
}
items := make([]entity.OrderItem, len(rows))
for i, row := range rows {
items[i] = entity.OrderItem{
ID: entity.Sku(*row.Sku),
//nolint:gosec // will not overflow, uint32 is stored as int64
Count: uint32(*row.Count),
}
}
order := &entity.Order{
OrderID: orderID,
Status: rows[0].Status,
UserID: entity.ID(rows[0].UserID),
Items: items,
}
return order, nil
}
func (o *orderRepo) OrderSetStatus(ctx context.Context, orderID entity.ID, newStatus string) error {
querier := o.GetQuerier(ctx)
rows, err := querier.OrderSetStatus(ctx, &OrderSetStatusParams{
ID: int64(orderID),
StatusName: newStatus,
})
if err != nil {
return fmt.Errorf("querier.OrderSetStatus: %w", err)
}
if rows == 0 {
return model.ErrOrderNotFound
}
return nil
}

View File

@@ -0,0 +1,32 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.29.0
package sqlc
import (
"context"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)
type DBTX interface {
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
Query(context.Context, string, ...interface{}) (pgx.Rows, error)
QueryRow(context.Context, string, ...interface{}) pgx.Row
}
func New(db DBTX) *Queries {
return &Queries{db: db}
}
type Queries struct {
db DBTX
}
func (q *Queries) WithTx(tx pgx.Tx) *Queries {
return &Queries{
db: tx,
}
}

View File

@@ -0,0 +1,11 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.29.0
package sqlc
type Stock struct {
Sku int64
TotalCount int64
Reserved int64
}

View File

@@ -0,0 +1,18 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.29.0
package sqlc
import (
"context"
)
type Querier interface {
StockCancel(ctx context.Context, arg *StockCancelParams) (int64, error)
StockGetByID(ctx context.Context, sku int64) (*Stock, error)
StockReserve(ctx context.Context, arg *StockReserveParams) (int64, error)
StockReserveRemove(ctx context.Context, arg *StockReserveRemoveParams) (int64, error)
}
var _ Querier = (*Queries)(nil)

View File

@@ -0,0 +1,24 @@
-- name: StockGetByID :one
select sku, total_count, reserved
from stocks
where sku = $1
limit 1;
-- name: StockReserve :execrows
update stocks
set reserved = reserved + $2
where sku = $1
and total_count >= reserved + $2;
-- name: StockReserveRemove :execrows
update stocks
set reserved = reserved - $2,
total_count = total_count - $2
where sku = $1
and reserved >= $2 and total_count >= $2;
-- name: StockCancel :execrows
update stocks
set reserved = reserved - $2
where sku = $1
and reserved >= $2;

View File

@@ -0,0 +1,85 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.29.0
// source: query.sql
package sqlc
import (
"context"
)
const stockCancel = `-- name: StockCancel :execrows
update stocks
set reserved = reserved - $2
where sku = $1
and reserved >= $2
`
type StockCancelParams struct {
Sku int64
Reserved int64
}
func (q *Queries) StockCancel(ctx context.Context, arg *StockCancelParams) (int64, error) {
result, err := q.db.Exec(ctx, stockCancel, arg.Sku, arg.Reserved)
if err != nil {
return 0, err
}
return result.RowsAffected(), nil
}
const stockGetByID = `-- name: StockGetByID :one
select sku, total_count, reserved
from stocks
where sku = $1
limit 1
`
func (q *Queries) StockGetByID(ctx context.Context, sku int64) (*Stock, error) {
row := q.db.QueryRow(ctx, stockGetByID, sku)
var i Stock
err := row.Scan(&i.Sku, &i.TotalCount, &i.Reserved)
return &i, err
}
const stockReserve = `-- name: StockReserve :execrows
update stocks
set reserved = reserved + $2
where sku = $1
and total_count >= reserved + $2
`
type StockReserveParams struct {
Sku int64
Reserved int64
}
func (q *Queries) StockReserve(ctx context.Context, arg *StockReserveParams) (int64, error) {
result, err := q.db.Exec(ctx, stockReserve, arg.Sku, arg.Reserved)
if err != nil {
return 0, err
}
return result.RowsAffected(), nil
}
const stockReserveRemove = `-- name: StockReserveRemove :execrows
update stocks
set reserved = reserved - $2,
total_count = total_count - $2
where sku = $1
and reserved >= $2 and total_count >= $2
`
type StockReserveRemoveParams struct {
Sku int64
Reserved int64
}
func (q *Queries) StockReserveRemove(ctx context.Context, arg *StockReserveRemoveParams) (int64, error) {
result, err := q.db.Exec(ctx, stockReserveRemove, arg.Sku, arg.Reserved)
if err != nil {
return 0, err
}
return result.RowsAffected(), nil
}

View File

@@ -0,0 +1,137 @@
package sqlc
import (
"context"
"errors"
"fmt"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"route256/loms/internal/domain/entity"
"route256/loms/internal/domain/model"
"route256/loms/internal/domain/service"
"route256/loms/internal/infra/postgres"
"route256/loms/internal/infra/tools"
)
type stockRepo struct {
read *pgxpool.Pool
write *pgxpool.Pool
}
func NewStockRepository(write, read *pgxpool.Pool) service.StockRepository {
return &stockRepo{
read: read,
write: write,
}
}
func (s *stockRepo) GetQuerier(ctx context.Context) *Queries {
tx, ok := postgres.TxFromCtx(ctx)
if ok {
return New(tx)
}
return New(s.write)
}
func (s *stockRepo) StockReserve(ctx context.Context, stock *entity.Stock) error {
querier := s.GetQuerier(ctx)
rows, err := querier.StockReserve(ctx, &StockReserveParams{
Sku: int64(stock.Item.ID),
Reserved: int64(stock.Reserved),
})
if err != nil {
return fmt.Errorf("querier.StockReserve: %w", err)
}
if rows == 0 {
return model.ErrNotEnoughStocks
}
return nil
}
func (s *stockRepo) StockReserveRemove(ctx context.Context, stock *entity.Stock) error {
querier := s.GetQuerier(ctx)
rows, err := querier.StockReserveRemove(ctx, &StockReserveRemoveParams{
Sku: int64(stock.Item.ID),
Reserved: int64(stock.Reserved),
})
if err != nil {
return fmt.Errorf("querier.StockReserveRemove: %w", err)
}
if rows > 0 {
return nil
}
_, err = querier.StockGetByID(ctx, int64(stock.Item.ID))
switch {
case errors.Is(err, pgx.ErrNoRows):
return model.ErrUnknownStock
case err != nil:
return fmt.Errorf("querier.StockGetByID: %w", err)
default:
return model.ErrNotEnoughStocks
}
}
func (s *stockRepo) StockCancel(ctx context.Context, stock *entity.Stock) error {
querier := s.GetQuerier(ctx)
rows, err := querier.StockCancel(ctx, &StockCancelParams{
Sku: int64(stock.Item.ID),
Reserved: int64(stock.Reserved),
})
if err != nil {
return fmt.Errorf("querier.StockCancel: %w", err)
}
if rows > 0 {
return nil
}
_, err = querier.StockGetByID(ctx, int64(stock.Item.ID))
switch {
case errors.Is(err, pgx.ErrNoRows):
return model.ErrUnknownStock
case err != nil:
return fmt.Errorf("querier.StockGetByID: %w", err)
default:
return model.ErrNotEnoughStocks
}
}
func (s *stockRepo) StockGetByID(ctx context.Context, sku entity.Sku) (*entity.Stock, error) {
querier := s.GetQuerier(ctx)
stock, err := querier.StockGetByID(ctx, int64(sku))
switch {
case errors.Is(err, pgx.ErrNoRows):
return nil, model.ErrUnknownStock
case err != nil:
return nil, fmt.Errorf("querier.StockGetByID: %w", err)
default:
count, castErr := tools.SafeCastInt64ToUInt32(stock.TotalCount)
if castErr != nil {
return nil, castErr
}
reserved, castErr := tools.SafeCastInt64ToUInt32(stock.Reserved)
if castErr != nil {
return nil, castErr
}
return &entity.Stock{
Item: entity.OrderItem{
ID: entity.Sku(stock.Sku),
Count: count,
},
Reserved: reserved,
}, nil
}
}

View File

@@ -1,8 +1,8 @@
// Code generated by http://github.com/gojuno/minimock (v3.4.5). DO NOT EDIT.
// Code generated by http://github.com/gojuno/minimock (v3.4.0). DO NOT EDIT.
package mock
//go:generate minimock -i route256/loms/internal/domain/loms/service.OrderRepository -o order_repository_mock.go -n OrderRepositoryMock -p mock
//go:generate minimock -i route256/loms/internal/domain/service.OrderRepository -o order_repository_mock.go -n OrderRepositoryMock -p mock
import (
"context"

View File

@@ -1,8 +1,8 @@
// Code generated by http://github.com/gojuno/minimock (v3.4.5). DO NOT EDIT.
// Code generated by http://github.com/gojuno/minimock (v3.4.0). DO NOT EDIT.
package mock
//go:generate minimock -i route256/loms/internal/domain/loms/service.StockRepository -o stock_repository_mock.go -n StockRepositoryMock -p mock
//go:generate minimock -i route256/loms/internal/domain/service.StockRepository -o stock_repository_mock.go -n StockRepositoryMock -p mock
import (
"context"

View File

@@ -28,15 +28,24 @@ type StockRepository interface {
StockGetByID(ctx context.Context, sku entity.Sku) (*entity.Stock, error)
}
type LomsService struct {
orders OrderRepository
stocks StockRepository
type txManager interface {
WriteWithTransaction(ctx context.Context, fn func(ctx context.Context) error) (err error)
ReadWithTransaction(ctx context.Context, fn func(ctx context.Context) error) (err error)
WriteWithRepeatableRead(ctx context.Context, fn func(ctx context.Context) error) (err error)
ReadWithRepeatableRead(ctx context.Context, fn func(ctx context.Context) error) (err error)
}
func NewLomsService(orderRepo OrderRepository, stockRepo StockRepository) *LomsService {
type LomsService struct {
orders OrderRepository
stocks StockRepository
txManager txManager
}
func NewLomsService(orderRepo OrderRepository, stockRepo StockRepository, txManager txManager) *LomsService {
return &LomsService{
orders: orderRepo,
stocks: stockRepo,
orders: orderRepo,
stocks: stockRepo,
txManager: txManager,
}
}
@@ -77,40 +86,44 @@ func (s *LomsService) OrderCreate(ctx context.Context, orderReq *pb.OrderCreateR
return int(a.ID - b.ID)
})
id, err := s.orders.OrderCreate(ctx, order)
if err != nil {
return 0, fmt.Errorf("orders.OrderCreate: %w", err)
}
var (
orderID entity.ID
resErr error
)
order.OrderID = id
commitedStocks := make([]*entity.Stock, 0, len(order.Items))
for _, item := range order.Items {
stock := &entity.Stock{
Item: item,
Reserved: item.Count,
err := s.txManager.WriteWithTransaction(ctx, func(txCtx context.Context) error {
id, err := s.orders.OrderCreate(txCtx, order)
if err != nil {
return err
}
order.OrderID = id
orderID = id
if err := s.stocks.StockReserve(ctx, stock); err != nil {
s.rollbackStocks(ctx, commitedStocks)
committed := make([]*entity.Stock, 0, len(order.Items))
for _, it := range order.Items {
st := &entity.Stock{Item: it, Reserved: it.Count}
if err := s.stocks.StockReserve(txCtx, st); err != nil {
s.rollbackStocks(txCtx, committed)
if statusErr := s.orders.OrderSetStatus(ctx, order.OrderID, pb.OrderStatus_ORDER_STATUS_FAILED.String()); statusErr != nil {
log.Error().Err(statusErr).Msg("failed to update status on stock reserve fail")
_ = s.orders.OrderSetStatus(txCtx, id,
pb.OrderStatus_ORDER_STATUS_FAILED.String())
resErr = fmt.Errorf("stocks.StockReserve: %w", err)
return nil
}
return 0, fmt.Errorf("stocks.StockReserve: %w", err)
committed = append(committed, st)
}
commitedStocks = append(commitedStocks, stock)
}
if err := s.orders.OrderSetStatus(ctx, order.OrderID, pb.OrderStatus_ORDER_STATUS_AWAITING_PAYMENT.String()); err != nil {
s.rollbackStocks(ctx, commitedStocks)
return s.orders.OrderSetStatus(txCtx, id,
pb.OrderStatus_ORDER_STATUS_AWAITING_PAYMENT.String())
})
if err != nil {
return 0, err
}
return order.OrderID, nil
if resErr != nil {
return 0, resErr
}
return orderID, nil
}
func (s *LomsService) OrderInfo(ctx context.Context, orderID entity.ID) (*entity.Order, error) {
@@ -122,54 +135,66 @@ func (s *LomsService) OrderInfo(ctx context.Context, orderID entity.ID) (*entity
}
func (s *LomsService) OrderPay(ctx context.Context, orderID entity.ID) error {
order, err := s.OrderInfo(ctx, orderID)
if err != nil {
return err
if orderID <= 0 {
return model.ErrInvalidInput
}
switch order.Status {
case pb.OrderStatus_ORDER_STATUS_PAYED.String():
return nil
case pb.OrderStatus_ORDER_STATUS_AWAITING_PAYMENT.String():
for _, item := range order.Items {
if err := s.stocks.StockReserveRemove(ctx, &entity.Stock{
Item: item,
Reserved: item.Count,
}); err != nil {
log.Error().Err(err).Msg("failed to free stock reservation")
}
return s.txManager.WriteWithTransaction(ctx, func(txCtx context.Context) error {
order, err := s.orders.OrderGetByID(txCtx, orderID)
if err != nil {
return err
}
return s.orders.OrderSetStatus(ctx, orderID, pb.OrderStatus_ORDER_STATUS_PAYED.String())
default:
return model.ErrOrderInvalidStatus
}
switch order.Status {
case pb.OrderStatus_ORDER_STATUS_PAYED.String():
return nil
case pb.OrderStatus_ORDER_STATUS_AWAITING_PAYMENT.String():
for _, it := range order.Items {
if err := s.stocks.StockReserveRemove(txCtx, &entity.Stock{
Item: it,
Reserved: it.Count,
}); err != nil {
log.Error().Err(err).Msg("failed to free stock reservation")
}
}
return s.orders.OrderSetStatus(txCtx, orderID,
pb.OrderStatus_ORDER_STATUS_PAYED.String())
default:
return model.ErrOrderInvalidStatus
}
})
}
func (s *LomsService) OrderCancel(ctx context.Context, orderID entity.ID) error {
order, err := s.OrderInfo(ctx, orderID)
if err != nil {
return err
if orderID <= 0 {
return model.ErrInvalidInput
}
switch order.Status {
case pb.OrderStatus_ORDER_STATUS_CANCELLED.String():
return nil
case pb.OrderStatus_ORDER_STATUS_FAILED.String(), pb.OrderStatus_ORDER_STATUS_PAYED.String():
return model.ErrOrderInvalidStatus
}
stocks := make([]*entity.Stock, len(order.Items))
for i, item := range order.Items {
stocks[i] = &entity.Stock{
Item: item,
Reserved: item.Count,
return s.txManager.WriteWithTransaction(ctx, func(txCtx context.Context) error {
order, err := s.orders.OrderGetByID(txCtx, orderID)
if err != nil {
return err
}
}
s.rollbackStocks(ctx, stocks)
switch order.Status {
case pb.OrderStatus_ORDER_STATUS_CANCELLED.String():
return nil
case pb.OrderStatus_ORDER_STATUS_FAILED.String(),
pb.OrderStatus_ORDER_STATUS_PAYED.String():
return model.ErrOrderInvalidStatus
}
return s.orders.OrderSetStatus(ctx, orderID, pb.OrderStatus_ORDER_STATUS_CANCELLED.String())
for _, it := range order.Items {
if err := s.stocks.StockCancel(txCtx, &entity.Stock{
Item: it,
Reserved: it.Count,
}); err != nil {
return err
}
}
return s.orders.OrderSetStatus(txCtx, orderID,
pb.OrderStatus_ORDER_STATUS_CANCELLED.String())
})
}
func (s *LomsService) StocksInfo(ctx context.Context, sku entity.Sku) (uint32, error) {
@@ -182,5 +207,5 @@ func (s *LomsService) StocksInfo(ctx context.Context, sku entity.Sku) (uint32, e
return 0, err
}
return stock.Item.Count, nil
return stock.Item.Count - stock.Reserved, nil
}

View File

@@ -21,6 +21,24 @@ const (
testSku = entity.Sku(199)
)
type mockTxManager struct{}
func (t *mockTxManager) WriteWithTransaction(ctx context.Context, fn func(ctx context.Context) error) (err error) {
return fn(ctx)
}
func (t *mockTxManager) ReadWithTransaction(ctx context.Context, fn func(ctx context.Context) error) (err error) {
return fn(ctx)
}
func (t *mockTxManager) WriteWithRepeatableRead(ctx context.Context, fn func(ctx context.Context) error) (err error) {
return fn(ctx)
}
func (t *mockTxManager) ReadWithRepeatableRead(ctx context.Context, fn func(ctx context.Context) error) (err error) {
return fn(ctx)
}
func TestLomsService_OrderCreate(t *testing.T) {
t.Parallel()
@@ -130,8 +148,7 @@ func TestLomsService_OrderCreate(t *testing.T) {
OrderCreateMock.Return(1, nil).
OrderSetStatusMock.Return(errors.New("unexpected error")),
stocks: mock.NewStockRepositoryMock(mc).
StockReserveMock.Return(nil).
StockCancelMock.Return(nil),
StockReserveMock.Return(nil),
},
args: args{
req: goodReq,
@@ -145,7 +162,7 @@ func TestLomsService_OrderCreate(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
svc := NewLomsService(tt.fields.orders, tt.fields.stocks)
svc := NewLomsService(tt.fields.orders, tt.fields.stocks, &mockTxManager{})
_, err := svc.OrderCreate(ctx, tt.args.req)
tt.wantErr(t, err)
})
@@ -256,24 +273,46 @@ func TestLomsService_OrderPay(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
svc := NewLomsService(tt.fields.orders, tt.fields.stocks)
svc := NewLomsService(tt.fields.orders, tt.fields.stocks, &mockTxManager{})
err := svc.OrderPay(ctx, tt.args.id)
tt.wantErr(t, err)
})
}
}
func TestLomsService_OrderInfo(t *testing.T) {
func TestLomsService_OrderInfoBadInput(t *testing.T) {
t.Parallel()
svc := NewLomsService(
nil,
nil,
&mockTxManager{},
)
_, err := svc.OrderInfo(context.Background(), 0)
require.ErrorIs(t, err, model.ErrInvalidInput)
}
func TestLomsService_OrderInfoSuccess(t *testing.T) {
t.Parallel()
order := &entity.Order{
OrderID: 123,
Status: "payed",
UserID: 1337,
Items: []entity.OrderItem{},
}
mc := minimock.NewController(t)
svc := NewLomsService(
mock.NewOrderRepositoryMock(mc),
mock.NewStockRepositoryMock(mc),
mock.NewOrderRepositoryMock(mc).OrderGetByIDMock.Return(order, nil),
nil,
&mockTxManager{},
)
err := svc.OrderPay(context.Background(), 0)
require.ErrorIs(t, err, model.ErrInvalidInput)
gotOrder, err := svc.OrderInfo(context.Background(), 123)
require.NoError(t, err)
require.Equal(t, order, gotOrder)
}
func TestLomsService_OrderCancel(t *testing.T) {
@@ -370,7 +409,7 @@ func TestLomsService_OrderCancel(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
svc := NewLomsService(tt.fields.orders, tt.fields.stocks)
svc := NewLomsService(tt.fields.orders, tt.fields.stocks, &mockTxManager{})
err := svc.OrderCancel(ctx, tt.args.id)
tt.wantErr(t, err)
})
@@ -437,7 +476,7 @@ func TestLomsService_StocksInfo(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
svc := NewLomsService(nil, tt.fields.stocks)
svc := NewLomsService(nil, tt.fields.stocks, &mockTxManager{})
got, err := svc.StocksInfo(ctx, tt.args.sku)
tt.wantErr(t, err)
if err == nil {

View File

@@ -0,0 +1,56 @@
package postgres
import (
"context"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/pkg/errors"
)
// From https://gitlab.ozon.dev/go/classroom-18/students/week-4-workshop/-/blob/master/internal/infra/postgres/postgres.go
func NewPool(ctx context.Context, dsn string) (*pgxpool.Pool, error) {
config, err := pgxpool.ParseConfig(dsn)
if err != nil {
return nil, errors.Wrap(err, "pgxpool.ParseConfig")
}
pool, err := pgxpool.NewWithConfig(ctx, config)
if err != nil {
return nil, errors.Wrap(err, "pgxpool.NewWithConfig")
}
return pool, nil
}
func NewPools(ctx context.Context, DSNs ...string) ([]*pgxpool.Pool, error) {
pools := make([]*pgxpool.Pool, len(DSNs))
for i, dsn := range DSNs {
cfg, err := pgxpool.ParseConfig(dsn)
if err != nil {
closeOpened(pools[:i])
return nil, errors.Wrap(err, "pgxpool.ParseConfig")
}
pool, err := pgxpool.NewWithConfig(ctx, cfg)
if err != nil {
closeOpened(pools[:i])
return nil, errors.Wrap(err, "pgxpool.NewWithConfig")
}
pools[i] = pool
}
return pools, nil
}
func closeOpened(pools []*pgxpool.Pool) {
for _, p := range pools {
if p != nil {
p.Close()
}
}
}

View File

@@ -0,0 +1,87 @@
package postgres
// From https://gitlab.ozon.dev/go/classroom-18/students/week-4-workshop/-/blob/master/internal/infra/postgres/tx.go
import (
"context"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/opentracing/opentracing-go"
)
// Tx транзакция.
type Tx pgx.Tx
type txKey struct{}
func ctxWithTx(ctx context.Context, tx pgx.Tx) context.Context {
return context.WithValue(ctx, txKey{}, tx)
}
func TxFromCtx(ctx context.Context) (pgx.Tx, bool) {
tx, ok := ctx.Value(txKey{}).(pgx.Tx)
return tx, ok
}
type TxManager struct {
write *pgxpool.Pool
read *pgxpool.Pool
}
func NewTxManager(write, read *pgxpool.Pool) *TxManager {
return &TxManager{
write: write,
read: read,
}
}
// WithTransaction выполняет fn в транзакции с дефолтным уровнем изоляции.
func (m *TxManager) WriteWithTransaction(ctx context.Context, fn func(ctx context.Context) error) (err error) {
return m.withTx(ctx, m.write, pgx.TxOptions{}, fn)
}
func (m *TxManager) ReadWithTransaction(ctx context.Context, fn func(ctx context.Context) error) (err error) {
return m.withTx(ctx, m.read, pgx.TxOptions{}, fn)
}
// WithTransaction выполняет fn в транзакции с уровнем изоляции RepeatableRead.
func (m *TxManager) WriteWithRepeatableRead(ctx context.Context, fn func(ctx context.Context) error) (err error) {
return m.withTx(ctx, m.write, pgx.TxOptions{IsoLevel: pgx.RepeatableRead}, fn)
}
func (m *TxManager) ReadWithRepeatableRead(ctx context.Context, fn func(ctx context.Context) error) (err error) {
return m.withTx(ctx, m.read, pgx.TxOptions{IsoLevel: pgx.RepeatableRead}, fn)
}
// WithTx выполняет fn в транзакции.
func (m *TxManager) withTx(ctx context.Context, pool *pgxpool.Pool, options pgx.TxOptions, fn func(ctx context.Context) error) (err error) {
var span opentracing.Span
span, ctx = opentracing.StartSpanFromContext(ctx, "Transaction")
defer span.Finish()
tx, err := pool.BeginTx(ctx, options)
if err != nil {
return
}
ctx = ctxWithTx(ctx, tx)
defer func() {
if p := recover(); p != nil {
// a panic occurred, rollback and repanic
_ = tx.Rollback(ctx)
panic(p)
} else if err != nil {
// something went wrong, rollback
_ = tx.Rollback(ctx)
} else {
// all good, commit
err = tx.Commit(ctx)
}
}()
err = fn(ctx)
return
}

View File

@@ -0,0 +1,19 @@
package tools
import (
"fmt"
"math"
)
func SafeCastInt64ToUInt32(num int64) (uint32, error) {
if num < 0 {
return 0, fmt.Errorf("tried casting signed negative number to unsigned number")
}
if num > math.MaxUint32 {
return 0, fmt.Errorf("tried casting larger number than uint32 can store")
}
// the bounds are checked, and cast should be safe.
return uint32(num), nil // #nosec G115
}