Table Of Contents
- Introduction
- Context Fundamentals
- Cancellation and Timeout Patterns
- Context Values and Best Practices
- FAQ
- Conclusion
Introduction
The context
package is one of Go's most powerful features for managing request-scoped data, cancellation signals, and timeouts across API boundaries. Introduced in Go 1.7 and refined over subsequent versions, the context package enables elegant handling of request lifecycles, graceful shutdowns, and resource cleanup in concurrent applications.
Many developers struggle with context usage, either overusing it by passing context everywhere or underusing it by ignoring cancellation signals. Common challenges include understanding when to create new contexts, how to properly propagate cancellation, managing timeouts effectively, and avoiding context value anti-patterns.
In this comprehensive guide, we'll explore the context package from fundamental concepts to advanced patterns. You'll learn how to implement proper request scoping, handle cancellation gracefully, manage timeouts and deadlines, and use context values appropriately. These skills are essential for building production-ready Go applications that handle concurrent operations reliably.
Context Fundamentals
Understanding Context Interface
The context package revolves around the Context interface:
package main
import (
"context"
"fmt"
"time"
)
func demonstrateContextBasics() {
// Background context - root context for the application
ctx := context.Background()
fmt.Printf("Background context: %v\n", ctx)
// TODO context - placeholder when you're not sure what context to use
todoCtx := context.TODO()
fmt.Printf("TODO context: %v\n", todoCtx)
// Context with timeout
timeoutCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel() // Always call cancel to release resources
fmt.Printf("Timeout context: %v\n", timeoutCtx)
fmt.Printf("Deadline: %v\n", timeoutCtx.Deadline())
// Context with cancellation
cancelCtx, cancelFunc := context.WithCancel(ctx)
defer cancelFunc()
fmt.Printf("Cancel context: %v\n", cancelCtx)
// Check if context is canceled
select {
case <-cancelCtx.Done():
fmt.Println("Context was canceled")
default:
fmt.Println("Context is still active")
}
// Cancel the context
cancelFunc()
// Check again
select {
case <-cancelCtx.Done():
fmt.Printf("Context canceled with error: %v\n", cancelCtx.Err())
default:
fmt.Println("Context is still active")
}
}
func demonstrateContextHierarchy() {
// Create a hierarchy of contexts
rootCtx := context.Background()
// Parent context with 5-second timeout
parentCtx, parentCancel := context.WithTimeout(rootCtx, 5*time.Second)
defer parentCancel()
// Child context with 3-second timeout (will expire first)
childCtx, childCancel := context.WithTimeout(parentCtx, 3*time.Second)
defer childCancel()
// Grandchild context that can be canceled manually
grandchildCtx, grandchildCancel := context.WithCancel(childCtx)
defer grandchildCancel()
fmt.Println("Context hierarchy created")
// Monitor all contexts
go func() {
<-parentCtx.Done()
fmt.Printf("Parent context done: %v\n", parentCtx.Err())
}()
go func() {
<-childCtx.Done()
fmt.Printf("Child context done: %v\n", childCtx.Err())
}()
go func() {
<-grandchildCtx.Done()
fmt.Printf("Grandchild context done: %v\n", grandchildCtx.Err())
}()
// Wait for child to timeout
time.Sleep(4 * time.Second)
fmt.Println("Finished waiting")
}
func main() {
fmt.Println("Context basics:")
demonstrateContextBasics()
fmt.Println("\nContext hierarchy:")
demonstrateContextHierarchy()
}
Context Creation Patterns
Different ways to create and configure contexts:
package main
import (
"context"
"fmt"
"time"
)
func demonstrateContextCreation() {
// 1. WithCancel - manual cancellation
cancelCtx, cancel := context.WithCancel(context.Background())
defer cancel()
// 2. WithTimeout - automatic cancellation after duration
timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), 1*time.Second)
defer timeoutCancel()
// 3. WithDeadline - automatic cancellation at specific time
deadline := time.Now().Add(2 * time.Second)
deadlineCtx, deadlineCancel := context.WithDeadline(context.Background(), deadline)
defer deadlineCancel()
// 4. WithValue - context with attached data
valueCtx := context.WithValue(context.Background(), "userID", 12345)
fmt.Println("Created different context types:")
fmt.Printf("Cancel context: %T\n", cancelCtx)
fmt.Printf("Timeout context: %T\n", timeoutCtx)
fmt.Printf("Deadline context: %T\n", deadlineCtx)
fmt.Printf("Value context: %T\n", valueCtx)
// Extract value from context
if userID, ok := valueCtx.Value("userID").(int); ok {
fmt.Printf("User ID from context: %d\n", userID)
}
}
func demonstrateContextChaining() {
// Chain multiple context modifications
baseCtx := context.Background()
// Add timeout
timeoutCtx, cancel1 := context.WithTimeout(baseCtx, 5*time.Second)
defer cancel1()
// Add user information
userCtx := context.WithValue(timeoutCtx, "userID", "user123")
// Add request ID
requestCtx := context.WithValue(userCtx, "requestID", "req456")
// Add cancellation capability
finalCtx, cancel2 := context.WithCancel(requestCtx)
defer cancel2()
// Use the final context
processRequest(finalCtx)
}
func processRequest(ctx context.Context) {
fmt.Println("Processing request with context:")
if userID, ok := ctx.Value("userID").(string); ok {
fmt.Printf(" User ID: %s\n", userID)
}
if requestID, ok := ctx.Value("requestID").(string); ok {
fmt.Printf(" Request ID: %s\n", requestID)
}
if deadline, ok := ctx.Deadline(); ok {
fmt.Printf(" Deadline: %v\n", deadline)
fmt.Printf(" Time remaining: %v\n", time.Until(deadline))
}
// Simulate some work
select {
case <-time.After(1 * time.Second):
fmt.Println(" Request completed successfully")
case <-ctx.Done():
fmt.Printf(" Request canceled: %v\n", ctx.Err())
}
}
// Context key type for type safety
type contextKey string
const (
UserIDKey contextKey = "userID"
RequestIDKey contextKey = "requestID"
SessionKey contextKey = "session"
)
func demonstrateTypedContextKeys() {
ctx := context.Background()
// Add values with typed keys
ctx = context.WithValue(ctx, UserIDKey, "user789")
ctx = context.WithValue(ctx, RequestIDKey, "req101")
ctx = context.WithValue(ctx, SessionKey, map[string]interface{}{
"authenticated": true,
"role": "admin",
})
// Extract values safely
if userID, ok := ctx.Value(UserIDKey).(string); ok {
fmt.Printf("User ID: %s\n", userID)
}
if requestID, ok := ctx.Value(RequestIDKey).(string); ok {
fmt.Printf("Request ID: %s\n", requestID)
}
if session, ok := ctx.Value(SessionKey).(map[string]interface{}); ok {
fmt.Printf("Session: %v\n", session)
}
}
func main() {
fmt.Println("Context creation patterns:")
demonstrateContextCreation()
fmt.Println("\nContext chaining:")
demonstrateContextChaining()
fmt.Println("\nTyped context keys:")
demonstrateTypedContextKeys()
}
Cancellation and Timeout Patterns
Implementing Graceful Cancellation
Proper cancellation handling in Go applications:
package main
import (
"context"
"fmt"
"sync"
"time"
)
// Worker that respects context cancellation
func cancellableWorker(ctx context.Context, id int, work chan int, results chan<- string) {
for {
select {
case <-ctx.Done():
fmt.Printf("Worker %d: Canceled (%v)\n", id, ctx.Err())
return
case job, ok := <-work:
if !ok {
fmt.Printf("Worker %d: No more work\n", id)
return
}
// Simulate work with cancellation checking
if err := doWorkWithCancellation(ctx, job); err != nil {
fmt.Printf("Worker %d: Job %d failed (%v)\n", id, job, err)
continue
}
results <- fmt.Sprintf("Worker %d completed job %d", id, job)
}
}
}
func doWorkWithCancellation(ctx context.Context, job int) error {
// Simulate work that can be interrupted
workDuration := time.Duration(job) * 100 * time.Millisecond
select {
case <-time.After(workDuration):
return nil // Work completed
case <-ctx.Done():
return ctx.Err() // Work was canceled
}
}
func demonstrateWorkerCancellation() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
work := make(chan int, 10)
results := make(chan string, 10)
// Start workers
const numWorkers = 3
var wg sync.WaitGroup
for i := 1; i <= numWorkers; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
cancellableWorker(ctx, id, work, results)
}(i)
}
// Send work
go func() {
defer close(work)
for i := 1; i <= 20; i++ {
select {
case work <- i:
fmt.Printf("Sent job %d\n", i)
case <-ctx.Done():
fmt.Println("Stopped sending work due to cancellation")
return
}
}
}()
// Collect results
go func() {
defer close(results)
wg.Wait()
}()
// Process results until context is done or all work is finished
for {
select {
case result, ok := <-results:
if !ok {
fmt.Println("All workers finished")
return
}
fmt.Printf("Result: %s\n", result)
case <-ctx.Done():
fmt.Printf("Context canceled: %v\n", ctx.Err())
return
}
}
}
// HTTP client with context cancellation
func httpClientWithCancellation(ctx context.Context, url string) error {
// Create a cancellable HTTP request (simulation)
fmt.Printf("Starting HTTP request to %s\n", url)
// Simulate network delay
requestDuration := 1500 * time.Millisecond
select {
case <-time.After(requestDuration):
fmt.Printf("HTTP request to %s completed\n", url)
return nil
case <-ctx.Done():
fmt.Printf("HTTP request to %s canceled: %v\n", url, ctx.Err())
return ctx.Err()
}
}
func demonstrateHTTPCancellation() {
// Short timeout to force cancellation
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
urls := []string{
"https://api.example.com/users",
"https://api.example.com/orders",
"https://api.example.com/products",
}
var wg sync.WaitGroup
for _, url := range urls {
wg.Add(1)
go func(u string) {
defer wg.Done()
httpClientWithCancellation(ctx, u)
}(url)
}
wg.Wait()
fmt.Println("All HTTP requests completed or canceled")
}
// Database operation with context
func databaseOperationWithContext(ctx context.Context, query string) error {
fmt.Printf("Executing query: %s\n", query)
// Simulate database operation
operationTime := 800 * time.Millisecond
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
timeout := time.After(operationTime)
for {
select {
case <-timeout:
fmt.Printf("Query completed: %s\n", query)
return nil
case <-ticker.C:
// Check for cancellation periodically during long operation
select {
case <-ctx.Done():
fmt.Printf("Query canceled: %s (%v)\n", query, ctx.Err())
return ctx.Err()
default:
fmt.Printf("Query in progress: %s\n", query)
}
case <-ctx.Done():
fmt.Printf("Query canceled: %s (%v)\n", query, ctx.Err())
return ctx.Err()
}
}
}
func demonstrateDatabaseCancellation() {
ctx, cancel := context.WithTimeout(context.Background(), 1500*time.Millisecond)
defer cancel()
queries := []string{
"SELECT * FROM users WHERE active = true",
"UPDATE orders SET status = 'processed' WHERE date < '2023-01-01'",
"INSERT INTO logs (message, timestamp) VALUES ('Operation completed', NOW())",
}
for _, query := range queries {
if err := databaseOperationWithContext(ctx, query); err != nil {
fmt.Printf("Query failed: %v\n", err)
break // Stop executing remaining queries
}
}
}
func main() {
fmt.Println("Worker cancellation:")
demonstrateWorkerCancellation()
fmt.Println("\nHTTP request cancellation:")
demonstrateHTTPCancellation()
fmt.Println("\nDatabase operation cancellation:")
demonstrateDatabaseCancellation()
}
Advanced Timeout Handling
Sophisticated timeout and deadline management:
package main
import (
"context"
"fmt"
"sync"
"time"
)
// MultiStage operation with different timeouts for each stage
func multiStageOperation(ctx context.Context) error {
fmt.Println("Starting multi-stage operation")
// Stage 1: Authentication (500ms timeout)
authCtx, authCancel := context.WithTimeout(ctx, 500*time.Millisecond)
defer authCancel()
if err := authenticateUser(authCtx); err != nil {
return fmt.Errorf("authentication failed: %w", err)
}
// Stage 2: Data fetching (1s timeout)
fetchCtx, fetchCancel := context.WithTimeout(ctx, 1*time.Second)
defer fetchCancel()
if err := fetchUserData(fetchCtx); err != nil {
return fmt.Errorf("data fetch failed: %w", err)
}
// Stage 3: Processing (2s timeout)
processCtx, processCancel := context.WithTimeout(ctx, 2*time.Second)
defer processCancel()
if err := processData(processCtx); err != nil {
return fmt.Errorf("processing failed: %w", err)
}
fmt.Println("Multi-stage operation completed successfully")
return nil
}
func authenticateUser(ctx context.Context) error {
select {
case <-time.After(300 * time.Millisecond):
fmt.Println(" User authenticated")
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func fetchUserData(ctx context.Context) error {
select {
case <-time.After(800 * time.Millisecond):
fmt.Println(" User data fetched")
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func processData(ctx context.Context) error {
select {
case <-time.After(1500 * time.Millisecond):
fmt.Println(" Data processed")
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// Retry mechanism with exponential backoff and context
func retryWithContext(ctx context.Context, operation func() error, maxRetries int) error {
var lastErr error
for attempt := 0; attempt < maxRetries; attempt++ {
// Check if context is already canceled
if err := ctx.Err(); err != nil {
return err
}
// Try the operation
if err := operation(); err != nil {
lastErr = err
fmt.Printf("Attempt %d failed: %v\n", attempt+1, err)
// Calculate backoff duration
backoff := time.Duration(attempt) * 100 * time.Millisecond
// Wait with context cancellation
select {
case <-time.After(backoff):
continue
case <-ctx.Done():
return ctx.Err()
}
}
return nil // Success
}
return fmt.Errorf("operation failed after %d attempts: %w", maxRetries, lastErr)
}
func unreliableOperation() error {
// Simulate 70% failure rate
if time.Now().UnixNano()%10 < 7 {
return fmt.Errorf("simulated failure")
}
return nil
}
func demonstrateRetryWithTimeout() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
err := retryWithContext(ctx, unreliableOperation, 5)
if err != nil {
fmt.Printf("Retry operation failed: %v\n", err)
} else {
fmt.Println("Retry operation succeeded")
}
}
// Circuit breaker pattern with context
type CircuitBreaker struct {
mutex sync.Mutex
failures int
maxFailures int
timeout time.Duration
nextAttempt time.Time
state string // "closed", "open", "half-open"
}
func NewCircuitBreaker(maxFailures int, timeout time.Duration) *CircuitBreaker {
return &CircuitBreaker{
maxFailures: maxFailures,
timeout: timeout,
state: "closed",
}
}
func (cb *CircuitBreaker) Execute(ctx context.Context, operation func(context.Context) error) error {
cb.mutex.Lock()
defer cb.mutex.Unlock()
switch cb.state {
case "open":
if time.Now().Before(cb.nextAttempt) {
return fmt.Errorf("circuit breaker is open")
}
cb.state = "half-open"
fallthrough
case "half-open", "closed":
err := operation(ctx)
if err != nil {
cb.failures++
if cb.failures >= cb.maxFailures {
cb.state = "open"
cb.nextAttempt = time.Now().Add(cb.timeout)
fmt.Printf("Circuit breaker opened after %d failures\n", cb.failures)
}
return err
}
// Success - reset or keep closed
if cb.state == "half-open" {
cb.state = "closed"
fmt.Println("Circuit breaker closed")
}
cb.failures = 0
return nil
}
return fmt.Errorf("unknown circuit breaker state: %s", cb.state)
}
func demonstrateCircuitBreaker() {
cb := NewCircuitBreaker(3, 1*time.Second)
flakyService := func(ctx context.Context) error {
select {
case <-time.After(100 * time.Millisecond):
// Simulate 80% failure rate
if time.Now().UnixNano()%10 < 8 {
return fmt.Errorf("service unavailable")
}
return nil
case <-ctx.Done():
return ctx.Err()
}
}
ctx := context.Background()
// Make several calls to trigger circuit breaker
for i := 1; i <= 10; i++ {
callCtx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
err := cb.Execute(callCtx, flakyService)
if err != nil {
fmt.Printf("Call %d failed: %v\n", i, err)
} else {
fmt.Printf("Call %d succeeded\n", i)
}
cancel()
time.Sleep(200 * time.Millisecond)
}
}
func main() {
fmt.Println("Multi-stage operation with timeouts:")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := multiStageOperation(ctx); err != nil {
fmt.Printf("Operation failed: %v\n", err)
}
fmt.Println("\nRetry with context:")
demonstrateRetryWithTimeout()
fmt.Println("\nCircuit breaker with context:")
demonstrateCircuitBreaker()
}
Context Values and Best Practices
Proper Context Value Usage
Guidelines for using context values appropriately:
package main
import (
"context"
"fmt"
"log"
"time"
)
// Define typed keys for context values
type ctxKey int
const (
requestIDKey ctxKey = iota
userIDKey
correlationIDKey
traceIDKey
)
// RequestContext wraps common request-scoped data
type RequestContext struct {
RequestID string
UserID string
CorrelationID string
TraceID string
StartTime time.Time
}
// NewRequestContext creates a context with request metadata
func NewRequestContext(ctx context.Context, requestID, userID string) context.Context {
requestCtx := &RequestContext{
RequestID: requestID,
UserID: userID,
CorrelationID: generateCorrelationID(),
TraceID: generateTraceID(),
StartTime: time.Now(),
}
ctx = context.WithValue(ctx, requestIDKey, requestCtx.RequestID)
ctx = context.WithValue(ctx, userIDKey, requestCtx.UserID)
ctx = context.WithValue(ctx, correlationIDKey, requestCtx.CorrelationID)
ctx = context.WithValue(ctx, traceIDKey, requestCtx.TraceID)
return ctx
}
func generateCorrelationID() string {
return fmt.Sprintf("corr-%d", time.Now().UnixNano())
}
func generateTraceID() string {
return fmt.Sprintf("trace-%d", time.Now().UnixNano())
}
// Helper functions to extract values safely
func GetRequestID(ctx context.Context) string {
if id, ok := ctx.Value(requestIDKey).(string); ok {
return id
}
return "unknown"
}
func GetUserID(ctx context.Context) string {
if id, ok := ctx.Value(userIDKey).(string); ok {
return id
}
return "anonymous"
}
func GetCorrelationID(ctx context.Context) string {
if id, ok := ctx.Value(correlationIDKey).(string); ok {
return id
}
return "unknown"
}
func GetTraceID(ctx context.Context) string {
if id, ok := ctx.Value(traceIDKey).(string); ok {
return id
}
return "unknown"
}
// Structured logger that uses context values
type ContextLogger struct {
ctx context.Context
}
func NewContextLogger(ctx context.Context) *ContextLogger {
return &ContextLogger{ctx: ctx}
}
func (cl *ContextLogger) Info(message string) {
log.Printf("[INFO] [req:%s] [user:%s] [trace:%s] %s",
GetRequestID(cl.ctx),
GetUserID(cl.ctx),
GetTraceID(cl.ctx),
message)
}
func (cl *ContextLogger) Error(message string, err error) {
log.Printf("[ERROR] [req:%s] [user:%s] [trace:%s] %s: %v",
GetRequestID(cl.ctx),
GetUserID(cl.ctx),
GetTraceID(cl.ctx),
message, err)
}
func demonstrateContextValues() {
ctx := context.Background()
// Create request context
requestCtx := NewRequestContext(ctx, "req-123", "user-456")
// Use throughout the request lifecycle
processUserRequest(requestCtx)
}
func processUserRequest(ctx context.Context) {
logger := NewContextLogger(ctx)
logger.Info("Processing user request")
// Simulate some processing steps
if err := validateUser(ctx); err != nil {
logger.Error("User validation failed", err)
return
}
if err := processData(ctx); err != nil {
logger.Error("Data processing failed", err)
return
}
logger.Info("Request processed successfully")
}
func validateUser(ctx context.Context) error {
logger := NewContextLogger(ctx)
userID := GetUserID(ctx)
logger.Info(fmt.Sprintf("Validating user: %s", userID))
// Simulate validation
time.Sleep(100 * time.Millisecond)
if userID == "user-456" {
logger.Info("User validation successful")
return nil
}
return fmt.Errorf("invalid user: %s", userID)
}
// Anti-pattern examples (what NOT to do)
func demonstrateAntiPatterns() {
fmt.Println("\n=== Context Anti-Patterns (DON'T DO THIS) ===")
// DON'T: Store large objects in context
largeData := make([]byte, 1024*1024) // 1MB
badCtx := context.WithValue(context.Background(), "largeData", largeData)
fmt.Printf("Bad: Stored %d bytes in context\n", len(largeData))
_ = badCtx
// DON'T: Use context for optional parameters
badConfigCtx := context.WithValue(context.Background(), "config", map[string]interface{}{
"timeout": 30,
"retries": 3,
"debug": true,
})
_ = badConfigCtx
// DON'T: Pass nil context
// processWithNilContext(nil) // This would panic
fmt.Println("See code comments for anti-patterns to avoid")
}
// Good patterns for context usage
type ServiceConfig struct {
Timeout time.Duration
MaxRetries int
Debug bool
}
// GOOD: Pass configuration as parameters, not context values
func processWithConfig(ctx context.Context, config ServiceConfig) error {
logger := NewContextLogger(ctx)
logger.Info("Processing with configuration")
// Use the config explicitly
timeoutCtx, cancel := context.WithTimeout(ctx, config.Timeout)
defer cancel()
return doProcessing(timeoutCtx)
}
func doProcessing(ctx context.Context) error {
select {
case <-time.After(50 * time.Millisecond):
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// GOOD: Use context for request-scoped data only
type AuthInfo struct {
UserID string
Permissions []string
Token string
}
const authInfoKey ctxKey = 999
func WithAuth(ctx context.Context, auth AuthInfo) context.Context {
return context.WithValue(ctx, authInfoKey, auth)
}
func GetAuth(ctx context.Context) (AuthInfo, bool) {
auth, ok := ctx.Value(authInfoKey).(AuthInfo)
return auth, ok
}
func demonstrateGoodPatterns() {
ctx := context.Background()
// Add authentication information
auth := AuthInfo{
UserID: "user-123",
Permissions: []string{"read", "write"},
Token: "bearer-token-xyz",
}
authCtx := WithAuth(ctx, auth)
// Add request tracing
requestCtx := NewRequestContext(authCtx, "req-789", auth.UserID)
// Pass configuration explicitly, not through context
config := ServiceConfig{
Timeout: 2 * time.Second,
MaxRetries: 3,
Debug: true,
}
if err := processWithConfig(requestCtx, config); err != nil {
logger := NewContextLogger(requestCtx)
logger.Error("Processing failed", err)
}
// Demonstrate auth extraction
if extractedAuth, ok := GetAuth(requestCtx); ok {
fmt.Printf("User %s has permissions: %v\n",
extractedAuth.UserID, extractedAuth.Permissions)
}
}
func main() {
fmt.Println("Context values demonstration:")
demonstrateContextValues()
demonstrateAntiPatterns()
fmt.Println("\nGood context patterns:")
demonstrateGoodPatterns()
}
Context Middleware and Interceptors
Building reusable context-aware middleware:
package main
import (
"context"
"fmt"
"log"
"sync"
"time"
)
// Middleware type definition
type Middleware func(next HandlerFunc) HandlerFunc
type HandlerFunc func(ctx context.Context) error
// Request metrics middleware
func MetricsMiddleware() Middleware {
return func(next HandlerFunc) HandlerFunc {
return func(ctx context.Context) error {
start := time.Now()
requestID := GetRequestID(ctx)
log.Printf("[METRICS] Request %s started", requestID)
err := next(ctx)
duration := time.Since(start)
status := "success"
if err != nil {
status = "error"
}
log.Printf("[METRICS] Request %s completed in %v (status: %s)",
requestID, duration, status)
return err
}
}
}
// Timeout middleware
func TimeoutMiddleware(timeout time.Duration) Middleware {
return func(next HandlerFunc) HandlerFunc {
return func(ctx context.Context) error {
timeoutCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
done := make(chan error, 1)
go func() {
done <- next(timeoutCtx)
}()
select {
case err := <-done:
return err
case <-timeoutCtx.Done():
return timeoutCtx.Err()
}
}
}
}
// Rate limiting middleware
type RateLimiter struct {
requests map[string][]time.Time
mutex sync.Mutex
limit int
window time.Duration
}
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
return &RateLimiter{
requests: make(map[string][]time.Time),
limit: limit,
window: window,
}
}
func (rl *RateLimiter) isAllowed(userID string) bool {
rl.mutex.Lock()
defer rl.mutex.Unlock()
now := time.Now()
cutoff := now.Add(-rl.window)
// Clean old requests
requests := rl.requests[userID]
var validRequests []time.Time
for _, reqTime := range requests {
if reqTime.After(cutoff) {
validRequests = append(validRequests, reqTime)
}
}
// Check if under limit
if len(validRequests) >= rl.limit {
rl.requests[userID] = validRequests
return false
}
// Add current request
validRequests = append(validRequests, now)
rl.requests[userID] = validRequests
return true
}
func RateLimitMiddleware(limiter *RateLimiter) Middleware {
return func(next HandlerFunc) HandlerFunc {
return func(ctx context.Context) error {
userID := GetUserID(ctx)
if !limiter.isAllowed(userID) {
return fmt.Errorf("rate limit exceeded for user %s", userID)
}
return next(ctx)
}
}
}
// Authentication middleware
func AuthMiddleware() Middleware {
return func(next HandlerFunc) HandlerFunc {
return func(ctx context.Context) error {
userID := GetUserID(ctx)
if userID == "anonymous" {
return fmt.Errorf("authentication required")
}
// Simulate auth validation
if userID == "invalid-user" {
return fmt.Errorf("invalid user credentials")
}
log.Printf("[AUTH] User %s authenticated successfully", userID)
return next(ctx)
}
}
}
// Logging middleware
func LoggingMiddleware() Middleware {
return func(next HandlerFunc) HandlerFunc {
return func(ctx context.Context) error {
logger := NewContextLogger(ctx)
logger.Info("Request received")
err := next(ctx)
if err != nil {
logger.Error("Request failed", err)
} else {
logger.Info("Request completed successfully")
}
return err
}
}
}
// Recovery middleware
func RecoveryMiddleware() Middleware {
return func(next HandlerFunc) HandlerFunc {
return func(ctx context.Context) (err error) {
defer func() {
if r := recover(); r != nil {
logger := NewContextLogger(ctx)
logger.Error("Panic recovered", fmt.Errorf("panic: %v", r))
err = fmt.Errorf("internal server error")
}
}()
return next(ctx)
}
}
}
// Chain multiple middlewares
func Chain(middlewares ...Middleware) Middleware {
return func(next HandlerFunc) HandlerFunc {
for i := len(middlewares) - 1; i >= 0; i-- {
next = middlewares[i](next)
}
return next
}
}
// Sample business logic handlers
func userProfileHandler(ctx context.Context) error {
time.Sleep(100 * time.Millisecond) // Simulate processing
userID := GetUserID(ctx)
fmt.Printf("Fetched profile for user: %s\n", userID)
return nil
}
func dataProcessingHandler(ctx context.Context) error {
time.Sleep(200 * time.Millisecond) // Simulate processing
fmt.Println("Data processing completed")
return nil
}
func panicHandler(ctx context.Context) error {
panic("simulated panic for testing recovery")
}
func demonstrateMiddlewareChain() {
rateLimiter := NewRateLimiter(3, time.Minute)
// Create middleware chain
middleware := Chain(
RecoveryMiddleware(),
MetricsMiddleware(),
LoggingMiddleware(),
TimeoutMiddleware(1*time.Second),
AuthMiddleware(),
RateLimitMiddleware(rateLimiter),
)
// Test cases
testCases := []struct {
name string
userID string
handler HandlerFunc
}{
{"Valid User Profile", "user-123", userProfileHandler},
{"Valid Data Processing", "user-456", dataProcessingHandler},
{"Anonymous User", "anonymous", userProfileHandler},
{"Invalid User", "invalid-user", userProfileHandler},
{"Panic Handler", "user-789", panicHandler},
}
for _, tc := range testCases {
fmt.Printf("\n=== Testing: %s ===\n", tc.name)
ctx := context.Background()
ctx = NewRequestContext(ctx, fmt.Sprintf("req-%d", time.Now().UnixNano()), tc.userID)
// Wrap handler with middleware
wrappedHandler := middleware(tc.handler)
if err := wrappedHandler(ctx); err != nil {
fmt.Printf("Handler error: %v\n", err)
}
}
}
// Demonstrate rate limiting
func demonstrateRateLimiting() {
fmt.Println("\n=== Rate Limiting Test ===")
rateLimiter := NewRateLimiter(2, 5*time.Second) // 2 requests per 5 seconds
middleware := RateLimitMiddleware(rateLimiter)
handler := middleware(userProfileHandler)
userID := "rate-test-user"
// Make multiple requests
for i := 1; i <= 5; i++ {
ctx := NewRequestContext(context.Background(),
fmt.Sprintf("req-%d", i), userID)
fmt.Printf("Request %d: ", i)
if err := handler(ctx); err != nil {
fmt.Printf("Failed - %v\n", err)
} else {
fmt.Println("Success")
}
time.Sleep(500 * time.Millisecond)
}
}
func main() {
demonstrateMiddlewareChain()
demonstrateRateLimiting()
}
FAQ
Q: When should I use context.Background() vs context.TODO()? A: Use context.Background() as the root context for your application, typically in main(), init(), or tests. Use context.TODO() when you're unsure what context to use or as a placeholder during development.
Q: Should I pass context as the first parameter to every function? A: Not every function needs context. Only pass context to functions that perform I/O operations, may need cancellation, or require request-scoped data. Don't pass context to pure functions or simple utilities.
Q: How do I handle context cancellation in long-running operations? A: Check ctx.Done() periodically in loops or long operations. Use select statements to check for cancellation alongside your main work. Consider breaking large operations into smaller chunks.
Q: Is it safe to store database connections in context? A: No, avoid storing connection objects in context. Context values should be request-scoped data like user IDs, trace IDs, or auth tokens. Pass connections as explicit parameters or use dependency injection.
Q: What happens if I don't call the cancel function returned by WithCancel/WithTimeout? A: You'll leak goroutines and resources. Always defer cancel() immediately after creating a context, even if you don't explicitly cancel it. The defer ensures cleanup happens.
Q: Can I modify context values after creation? A: No, context values are immutable. To "modify" values, create a new context with context.WithValue() using the existing context as the parent.
Conclusion
The Go context package is essential for building robust, concurrent applications that handle request lifecycles gracefully. Key insights from this comprehensive guide include:
- Understanding context creation patterns and proper hierarchy management
- Implementing graceful cancellation and timeout handling across your application
- Using context values appropriately for request-scoped data while avoiding anti-patterns
- Building reusable middleware that leverages context for cross-cutting concerns
- Following best practices for context propagation and resource cleanup
- Designing applications that respect cancellation signals and handle timeouts elegantly
Context enables you to build applications that are responsive to cancellation, respect timeouts, and maintain request scope throughout complex operations. By mastering these patterns, you can create Go applications that handle concurrent operations reliably and provide excellent user experiences even under load or failure conditions.
Ready to improve your context usage? Start by auditing your current code for missing context propagation and timeout handling. Implement proper cancellation checking in long-running operations and consider adding middleware for common cross-cutting concerns. Share your experiences and questions in the comments below!
Add Comment
No comments yet. Be the first to comment!