package sqlc import ( "context" "errors" "sort" "route256/comments/internal/domain/entity" "route256/comments/internal/domain/model" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/rs/zerolog/log" ) type commentsRepo struct { shard1 *pgxpool.Pool shard2 *pgxpool.Pool } func NewCommentsRepository(shard1, shard2 *pgxpool.Pool) *commentsRepo { return &commentsRepo{ shard1: shard1, shard2: shard2, } } func (r *commentsRepo) pickShard(sku int64) *pgxpool.Pool { if sku%2 == 0 { return r.shard1 } return r.shard2 } func (r *commentsRepo) GetCommentByID(ctx context.Context, id int64) (*Comment, error) { q1 := New(r.shard1) c, err := q1.GetCommentByID(ctx, id) switch { case err == nil: return c, nil case errors.Is(err, pgx.ErrNoRows): log.Trace().Msgf("comment with id %d not found in shard 1", id) default: return nil, err } q2 := New(r.shard2) c2, err2 := q2.GetCommentByID(ctx, id) switch { case err2 == nil: return c2, nil case errors.Is(err2, pgx.ErrNoRows): return nil, model.ErrCommentNotFound default: return nil, err2 } } func (r *commentsRepo) InsertComment(ctx context.Context, comment *entity.Comment) (*Comment, error) { shard := r.pickShard(comment.SKU) q := New(shard) req := &InsertCommentParams{ UserID: comment.UserID, Sku: comment.SKU, Text: comment.Text, } c, err := q.InsertComment(ctx, req) if err != nil { return nil, err } return c, nil } func (r *commentsRepo) ListCommentsBySku(ctx context.Context, sku int64) ([]*Comment, error) { shard := r.pickShard(sku) q := New(shard) list, err := q.ListCommentsBySku(ctx, sku) if err != nil { return nil, err } out := make([]*Comment, len(list)) copy(out, list) return out, nil } func (r *commentsRepo) ListCommentsByUser(ctx context.Context, userID int64) ([]*Comment, error) { q1 := New(r.shard1) l1, err1 := q1.ListCommentsByUser(ctx, userID) if err1 != nil { return nil, err1 } q2 := New(r.shard2) l2, err2 := q2.ListCommentsByUser(ctx, userID) if err2 != nil { return nil, err2 } merged := make([]*Comment, 0, len(l1)+len(l2)) merged = append(merged, l1...) merged = append(merged, l2...) sort.Slice(merged, func(i, j int) bool { if merged[i].CreatedAt.Time.Equal(merged[j].CreatedAt.Time) { return merged[i].UserID < merged[j].UserID } return merged[i].CreatedAt.Time.After(merged[j].CreatedAt.Time) }) return merged, nil } func (r *commentsRepo) UpdateComment(ctx context.Context, comment *entity.Comment) (*Comment, error) { req := &UpdateCommentParams{ ID: comment.ID, Text: comment.Text, } q1 := New(r.shard1) c, err := q1.UpdateComment(ctx, req) switch { case err == nil: return c, nil case errors.Is(err, pgx.ErrNoRows): log.Trace().Msgf("comment with id %d not found in shard 1", req.ID) default: return nil, err } q2 := New(r.shard2) c2, err2 := q2.UpdateComment(ctx, req) switch { case err2 == nil: return c2, nil case errors.Is(err2, pgx.ErrNoRows): return nil, model.ErrCommentNotFound default: return nil, err2 } }