pgtestproxy

package module
v0.1.2 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Oct 21, 2025 License: MIT Imports: 15 Imported by: 0

README

pgtestproxy

pgtestproxy is a golang library for triggering specific PostgreSQL errors in tests (it rewrites/skips messages on postgres wire protocol level), helping to test postgres error handling. It supports the pgx driver, but some low-level parts may not be supported at this moment - for example, sending Flush message in a custom pipeline.

Example Usage

// MIT No Attribution <https://spdx.org/licenses/MIT-0>

package pgtestproxy_test

import (
	"context"
	"errors"
	"fmt"
	"net"
	"os"
	"strconv"
	"sync/atomic"
	"syscall"
	"testing"
	"time"

	"codeberg.org/kerbrek/pgtestproxy"
	"github.com/jackc/pgx/v5/pgconn"
	"github.com/jackc/pgx/v5/pgxpool"
)

func TestQueries(t *testing.T) { //nolint:tparallel
	t.Parallel()

	globalPgProxyRule := new(pgtestproxy.RewriteRule)
	pgProxyOpts := pgtestproxy.Options{
		Rule: globalPgProxyRule,
	}

	// proxy per test
	pgxpool, proxy, err := setup(t.Context(), &pgProxyOpts)
	if err != nil {
		t.Fatal(fmt.Errorf("unexpected error: %w", err))
	}
	defer proxy.Shutdown()
	defer pgxpool.Close()

	serializationFailureResp := pgtestproxy.NewErrorResponse("ERROR", serializationFailure, "serialization failure")

	// subtests must be executed sequentially (not in parallel)

	t.Run("test query failure with max retries", func(t *testing.T) {
		genericRegexpExpr := pgtestproxy.CommonRegexp
		genericRegexpMatch := pgtestproxy.MustCreateRegexpMatch(t, genericRegexpExpr)
		genericMatcher := pgtestproxy.Matcher{MatchFn: genericRegexpMatch, ErrResp: serializationFailureResp}

		err = globalPgProxyRule.Init(pgtestproxy.RewriteRuleData{
			Matchers: pgtestproxy.Repeat(genericMatcher, MaxTry),
		})
		if err != nil {
			t.Fatal(fmt.Errorf("unexpected error: %w", err))
		}

		_, err = selectIdWithRetry(t.Context(), pgxpool, 1)
		if err == nil { // if NO error
			t.Fatal(errors.New("expected error here"))
		}
	})

	t.Run("test query success with one retry", func(t *testing.T) {
		commentRegexpExpr := `(?is)^-- q#c3b6f009d9c03269`
		commentRegexpMatch := pgtestproxy.MustCreateRegexpMatch(t, commentRegexpExpr)
		commentMatcher := pgtestproxy.Matcher{MatchFn: commentRegexpMatch, ErrResp: serializationFailureResp}

		err = globalPgProxyRule.Init(pgtestproxy.RewriteRuleData{
			Matchers: pgtestproxy.Repeat(commentMatcher, 1),
		})
		if err != nil {
			t.Fatal(fmt.Errorf("unexpected error: %w", err))
		}

		_, err = selectIdWithRetry(t.Context(), pgxpool, 1)
		if err != nil {
			t.Fatal(fmt.Errorf("unexpected error: %w", err))
		}
	})
}

const (
	querySelect = `
		-- q#c3b6f009d9c03269
		SELECT id FROM items WHERE id = $1
	`
	serializationFailure = "40001"
	MaxTry               = 2
)

var ErrMaxRetriesExceeded = errors.New("max retries exceeded")

func shouldRetry(err error) bool {
	if err == nil { // if NO error
		return false
	}

	var pgErr *pgconn.PgError
	if errors.As(err, &pgErr) && pgErr.Code == serializationFailure {
		return true
	}

	return false
}

func _selectId(ctx context.Context, pgxpool *pgxpool.Pool, id int) (returnedId int, err error) {
	err = pgxpool.QueryRow(ctx, querySelect, id).Scan(&returnedId)
	return returnedId, err
}

func selectIdWithRetry(ctx context.Context, pgxpool *pgxpool.Pool, id int) (returnedId int, err error) { //nolint:unparam
	for tryNo := 1; tryNo <= MaxTry; tryNo++ {
		returnedId, err = _selectId(ctx, pgxpool, id)
		if !shouldRetry(err) {
			return returnedId, err
		}
	}

	return returnedId, ErrMaxRetriesExceeded
}

const (
	db       = "example"
	user     = "user"
	password = "password"
)

var testIdx atomic.Int32 //nolint:gochecknoglobals

func setup(
	ctx context.Context, pgProxyOpts *pgtestproxy.Options,
) (pool *pgxpool.Pool, proxy *pgtestproxy.Server, err error) {
	pgPort, _, err := getPostgresPortsFromEnv()
	if err != nil {
		return nil, nil, err
	}

	db := db
	dbIdx := int(testIdx.Add(1))
	db += strconv.Itoa(dbIdx)

	err = createPostgresDbFromTemplate(ctx, db, pgPort)
	if err != nil {
		return nil, nil, err
	}

	cfg, err := pgxpool.ParseConfig(dsn(db, user, password, pgPort))
	if err != nil {
		return nil, nil, err
	}

	if pgProxyOpts != nil {
		pgProxy, pgProxyPort, err := startPostgresProxy(pgProxyOpts)
		if err != nil {
			return nil, nil, err
		}

		proxy = pgProxy
		cfg.ConnConfig.Port = uint16(pgProxyPort) //nolint:gosec
	}

	dbpool, err := pgxpool.NewWithConfig(ctx, cfg)
	if err != nil {
		return nil, nil, err
	}

	err = ping(ctx, dbpool, 3*time.Second)
	if err != nil {
		return nil, nil, err
	}

	return dbpool, proxy, nil
}

func createPostgresDbFromTemplate(ctx context.Context, db string, port int) error {
	pdb, err := pgxpool.New(ctx, dsn("postgres", user, password, port))
	if err != nil {
		return err
	}
	defer pdb.Close()

	err = ping(ctx, pdb, 3*time.Second)
	if err != nil {
		return err
	}

	createDbQuery := fmt.Sprintf(`CREATE DATABASE %s TEMPLATE template_db`, db)
	_, err = pdb.Exec(ctx, createDbQuery)

	return err
}

func dsn(db, user, password string, port int) string {
	return fmt.Sprintf(
		"dbname=%s user=%s password=%s host=127.0.0.1 port=%d sslmode=disable",
		db, user, password, port,
	)
}

func ping(ctx context.Context, pool *pgxpool.Pool, timeout time.Duration) (err error) {
	ticker := time.NewTicker(10 * time.Millisecond)
	defer ticker.Stop()

	timer := time.NewTimer(timeout)
	defer timer.Stop()

	for {
		select {
		case <-ctx.Done():
			return ctx.Err()
		case <-timer.C:
			return err
		case <-ticker.C:
			err = pool.Ping(ctx)
			if err == nil { // if NO error
				return nil
			}
		}
	}
}

var pgProxyPortDelta atomic.Int32 //nolint:gochecknoglobals

func startPostgresProxy(opts *pgtestproxy.Options) (pgProxy *pgtestproxy.Server, pgProxyPort int, err error) {
	pgPort, pgProxyLowPort, err := getPostgresPortsFromEnv()
	if err != nil {
		return nil, 0, err
	}

	localhost := "127.0.0.1"
	opts.PostgresAddr = net.JoinHostPort(localhost, strconv.Itoa(pgPort))

	for {
		pgProxyPort = pgProxyLowPort + int(pgProxyPortDelta.Add(1))
		opts.ProxyAddr = net.JoinHostPort(localhost, strconv.Itoa(pgProxyPort))

		pgProxy, err = pgtestproxy.Start(*opts)
		if err == nil { // if NO error
			return pgProxy, pgProxyPort, nil
		}

		if errors.Is(err, syscall.EADDRINUSE) { // address already in use
			continue
		}

		return nil, 0, err
	}
}

func getPostgresPortsFromEnv() (pgPort, pgProxyLowPort int, err error) {
	pgPortKey := "POSTGRES_PORT"
	pgProxyLowPortKey := "TEST_POSTGRES_PROXY_LOW_PORT"

	pgPort, err = strconv.Atoi(os.Getenv(pgPortKey))
	if err != nil {
		return 0, 0, fmt.Errorf("error parsing %s: %w", pgPortKey, err)
	}

	pgProxyLowPort, err = strconv.Atoi(os.Getenv(pgProxyLowPortKey))
	if err != nil {
		return 0, 0, fmt.Errorf("error parsing %s: %w", pgProxyLowPortKey, err)
	}

	return pgPort, pgProxyLowPort, nil
}

Commands

Inspect Makefile content before running.

  • List all commands

    make help

  • Setup a working environment

    make _setup

  • Run linters

    make lint

  • Run tests

    make test

  • Run tests with coverage report

    make coverage

License

MIT

Documentation

Overview

Package pgtestproxy provides functionality for triggering specific PostgreSQL errors in tests.

Index

Constants

View Source
const CommonRegexp = `(?is)^(?:--[^\n]*\n|/\*.*?\*/|\s*)*(select|insert|update|delete|merge|with)`

Variables

This section is empty.

Functions

func Repeat

func Repeat[V any](val V, n int) []V

Types

type ErrorResponse

type ErrorResponse struct {
	// contains filtered or unexported fields
}

func NewErrorResponse

func NewErrorResponse(severity, code, msg string) *ErrorResponse

NewErrorResponse creates custom error response. See Error Message Fields and Error Codes for more info.

Example:

NewErrorResponse("ERROR", "40001", "serialization failure")

func (*ErrorResponse) Reader

func (e *ErrorResponse) Reader() io.Reader

func (*ErrorResponse) String

func (e *ErrorResponse) String() string

type MatchFunc

type MatchFunc func(actualSql string) bool

MatchFunc determines if custom error response should be triggered based on the actual SQL query string (with all leading and trailing white space removed).

func CreateEqualMatch

func CreateEqualMatch(expectedSql string) MatchFunc

CreateEqualMatch creates a MatchFunc that will check the match between actual SQL query and expected SQL query. Beware: equality on query with placeholders does not work in Simple Protocol since all query placeholders will be replaced with actual values on the client.

func MustCreateFingerprintMatch

func MustCreateFingerprintMatch(t *testing.T, expectedSql string) MatchFunc

MustCreateFingerprintMatch creates a MatchFunc that will check the match of fingerprints (see pgquery.Fingerprint) between actual SQL query and expected SQL query.

func MustCreateNormalizedRegexpMatch

func MustCreateNormalizedRegexpMatch(t *testing.T, expr string) MatchFunc

MustCreateNormalizedRegexpMatch creates a MatchFunc that will check the match of the normalized (see pgquery.Normalize) actual SQL query based on the given regular expression.

func MustCreateRegexpMatch

func MustCreateRegexpMatch(t *testing.T, expr string) MatchFunc

MustCreateRegexpMatch creates a MatchFunc that will check the match of the actual SQL query based on the given regular expression.

type Matcher

type Matcher struct {
	MatchFn MatchFunc
	ErrResp *ErrorResponse
}

type Options

type Options struct {
	PostgresAddr string
	ProxyAddr    string
	Rule         *RewriteRule
	Exclude      []MatchFunc       // optional
	Log          func(args ...any) // optional
}

type RewriteRule

type RewriteRule struct {
	// contains filtered or unexported fields
}

RewriteRule is a per-proxy internal data store for user defined rule for rewriting/skipping PostgreSQL wire protocol messages.

func (*RewriteRule) Init

func (r *RewriteRule) Init(data RewriteRuleData) error

Init (re)initializes the RewriteRule with the user defined rule data. It must be called before each consecutive subtest (parallel subtests are not supported).

type RewriteRuleData

type RewriteRuleData struct {
	Matchers    []Matcher
	IgnoreOrder bool
}

RewriteRuleData is a rule for rewriting/skipping PostgreSQL wire protocol messages. When a Matcher matches an actual SQL query, the custom error is triggered, and the Matcher is removed from the slice. By default, with IgnoreOrder == false, SQL queries are only checked for matching the first Matcher in the slice. With IgnoreOrder == true, SQL queries are checked for matching any Matcher in the slice.

type Server

type Server struct {
	// contains filtered or unexported fields
}

Server represents test proxy instance.

func Start

func Start(opts Options) (*Server, error)

Start creates proxy server and starts it.

func (*Server) Shutdown

func (s *Server) Shutdown()

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL