Navigation

Go

Go Context Package for Request Scoping and Cancellation 2025

Master Go's context package with comprehensive guide covering request scoping, cancellation patterns, timeout handling, value passing, and advanced context management techniques for Go applications 2025.

Table Of Contents

Introduction

The context package is one of Go's most powerful features for managing request-scoped data, cancellation signals, and timeouts across API boundaries. Introduced in Go 1.7 and refined over subsequent versions, the context package enables elegant handling of request lifecycles, graceful shutdowns, and resource cleanup in concurrent applications.

Many developers struggle with context usage, either overusing it by passing context everywhere or underusing it by ignoring cancellation signals. Common challenges include understanding when to create new contexts, how to properly propagate cancellation, managing timeouts effectively, and avoiding context value anti-patterns.

In this comprehensive guide, we'll explore the context package from fundamental concepts to advanced patterns. You'll learn how to implement proper request scoping, handle cancellation gracefully, manage timeouts and deadlines, and use context values appropriately. These skills are essential for building production-ready Go applications that handle concurrent operations reliably.

Context Fundamentals

Understanding Context Interface

The context package revolves around the Context interface:

package main

import (
    "context"
    "fmt"
    "time"
)

func demonstrateContextBasics() {
    // Background context - root context for the application
    ctx := context.Background()
    fmt.Printf("Background context: %v\n", ctx)
    
    // TODO context - placeholder when you're not sure what context to use
    todoCtx := context.TODO()
    fmt.Printf("TODO context: %v\n", todoCtx)
    
    // Context with timeout
    timeoutCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
    defer cancel() // Always call cancel to release resources
    
    fmt.Printf("Timeout context: %v\n", timeoutCtx)
    fmt.Printf("Deadline: %v\n", timeoutCtx.Deadline())
    
    // Context with cancellation
    cancelCtx, cancelFunc := context.WithCancel(ctx)
    defer cancelFunc()
    
    fmt.Printf("Cancel context: %v\n", cancelCtx)
    
    // Check if context is canceled
    select {
    case <-cancelCtx.Done():
        fmt.Println("Context was canceled")
    default:
        fmt.Println("Context is still active")
    }
    
    // Cancel the context
    cancelFunc()
    
    // Check again
    select {
    case <-cancelCtx.Done():
        fmt.Printf("Context canceled with error: %v\n", cancelCtx.Err())
    default:
        fmt.Println("Context is still active")
    }
}

func demonstrateContextHierarchy() {
    // Create a hierarchy of contexts
    rootCtx := context.Background()
    
    // Parent context with 5-second timeout
    parentCtx, parentCancel := context.WithTimeout(rootCtx, 5*time.Second)
    defer parentCancel()
    
    // Child context with 3-second timeout (will expire first)
    childCtx, childCancel := context.WithTimeout(parentCtx, 3*time.Second)
    defer childCancel()
    
    // Grandchild context that can be canceled manually
    grandchildCtx, grandchildCancel := context.WithCancel(childCtx)
    defer grandchildCancel()
    
    fmt.Println("Context hierarchy created")
    
    // Monitor all contexts
    go func() {
        <-parentCtx.Done()
        fmt.Printf("Parent context done: %v\n", parentCtx.Err())
    }()
    
    go func() {
        <-childCtx.Done()
        fmt.Printf("Child context done: %v\n", childCtx.Err())
    }()
    
    go func() {
        <-grandchildCtx.Done()
        fmt.Printf("Grandchild context done: %v\n", grandchildCtx.Err())
    }()
    
    // Wait for child to timeout
    time.Sleep(4 * time.Second)
    fmt.Println("Finished waiting")
}

func main() {
    fmt.Println("Context basics:")
    demonstrateContextBasics()
    
    fmt.Println("\nContext hierarchy:")
    demonstrateContextHierarchy()
}

Context Creation Patterns

Different ways to create and configure contexts:

package main

import (
    "context"
    "fmt"
    "time"
)

func demonstrateContextCreation() {
    // 1. WithCancel - manual cancellation
    cancelCtx, cancel := context.WithCancel(context.Background())
    defer cancel()
    
    // 2. WithTimeout - automatic cancellation after duration
    timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), 1*time.Second)
    defer timeoutCancel()
    
    // 3. WithDeadline - automatic cancellation at specific time
    deadline := time.Now().Add(2 * time.Second)
    deadlineCtx, deadlineCancel := context.WithDeadline(context.Background(), deadline)
    defer deadlineCancel()
    
    // 4. WithValue - context with attached data
    valueCtx := context.WithValue(context.Background(), "userID", 12345)
    
    fmt.Println("Created different context types:")
    fmt.Printf("Cancel context: %T\n", cancelCtx)
    fmt.Printf("Timeout context: %T\n", timeoutCtx)
    fmt.Printf("Deadline context: %T\n", deadlineCtx)
    fmt.Printf("Value context: %T\n", valueCtx)
    
    // Extract value from context
    if userID, ok := valueCtx.Value("userID").(int); ok {
        fmt.Printf("User ID from context: %d\n", userID)
    }
}

func demonstrateContextChaining() {
    // Chain multiple context modifications
    baseCtx := context.Background()
    
    // Add timeout
    timeoutCtx, cancel1 := context.WithTimeout(baseCtx, 5*time.Second)
    defer cancel1()
    
    // Add user information
    userCtx := context.WithValue(timeoutCtx, "userID", "user123")
    
    // Add request ID
    requestCtx := context.WithValue(userCtx, "requestID", "req456")
    
    // Add cancellation capability
    finalCtx, cancel2 := context.WithCancel(requestCtx)
    defer cancel2()
    
    // Use the final context
    processRequest(finalCtx)
}

func processRequest(ctx context.Context) {
    fmt.Println("Processing request with context:")
    
    if userID, ok := ctx.Value("userID").(string); ok {
        fmt.Printf("  User ID: %s\n", userID)
    }
    
    if requestID, ok := ctx.Value("requestID").(string); ok {
        fmt.Printf("  Request ID: %s\n", requestID)
    }
    
    if deadline, ok := ctx.Deadline(); ok {
        fmt.Printf("  Deadline: %v\n", deadline)
        fmt.Printf("  Time remaining: %v\n", time.Until(deadline))
    }
    
    // Simulate some work
    select {
    case <-time.After(1 * time.Second):
        fmt.Println("  Request completed successfully")
    case <-ctx.Done():
        fmt.Printf("  Request canceled: %v\n", ctx.Err())
    }
}

// Context key type for type safety
type contextKey string

const (
    UserIDKey    contextKey = "userID"
    RequestIDKey contextKey = "requestID"
    SessionKey   contextKey = "session"
)

func demonstrateTypedContextKeys() {
    ctx := context.Background()
    
    // Add values with typed keys
    ctx = context.WithValue(ctx, UserIDKey, "user789")
    ctx = context.WithValue(ctx, RequestIDKey, "req101")
    ctx = context.WithValue(ctx, SessionKey, map[string]interface{}{
        "authenticated": true,
        "role":          "admin",
    })
    
    // Extract values safely
    if userID, ok := ctx.Value(UserIDKey).(string); ok {
        fmt.Printf("User ID: %s\n", userID)
    }
    
    if requestID, ok := ctx.Value(RequestIDKey).(string); ok {
        fmt.Printf("Request ID: %s\n", requestID)
    }
    
    if session, ok := ctx.Value(SessionKey).(map[string]interface{}); ok {
        fmt.Printf("Session: %v\n", session)
    }
}

func main() {
    fmt.Println("Context creation patterns:")
    demonstrateContextCreation()
    
    fmt.Println("\nContext chaining:")
    demonstrateContextChaining()
    
    fmt.Println("\nTyped context keys:")
    demonstrateTypedContextKeys()
}

Cancellation and Timeout Patterns

Implementing Graceful Cancellation

Proper cancellation handling in Go applications:

package main

import (
    "context"
    "fmt"
    "sync"
    "time"
)

// Worker that respects context cancellation
func cancellableWorker(ctx context.Context, id int, work chan int, results chan<- string) {
    for {
        select {
        case <-ctx.Done():
            fmt.Printf("Worker %d: Canceled (%v)\n", id, ctx.Err())
            return
        case job, ok := <-work:
            if !ok {
                fmt.Printf("Worker %d: No more work\n", id)
                return
            }
            
            // Simulate work with cancellation checking
            if err := doWorkWithCancellation(ctx, job); err != nil {
                fmt.Printf("Worker %d: Job %d failed (%v)\n", id, job, err)
                continue
            }
            
            results <- fmt.Sprintf("Worker %d completed job %d", id, job)
        }
    }
}

func doWorkWithCancellation(ctx context.Context, job int) error {
    // Simulate work that can be interrupted
    workDuration := time.Duration(job) * 100 * time.Millisecond
    
    select {
    case <-time.After(workDuration):
        return nil // Work completed
    case <-ctx.Done():
        return ctx.Err() // Work was canceled
    }
}

func demonstrateWorkerCancellation() {
    ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
    defer cancel()
    
    work := make(chan int, 10)
    results := make(chan string, 10)
    
    // Start workers
    const numWorkers = 3
    var wg sync.WaitGroup
    
    for i := 1; i <= numWorkers; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()
            cancellableWorker(ctx, id, work, results)
        }(i)
    }
    
    // Send work
    go func() {
        defer close(work)
        for i := 1; i <= 20; i++ {
            select {
            case work <- i:
                fmt.Printf("Sent job %d\n", i)
            case <-ctx.Done():
                fmt.Println("Stopped sending work due to cancellation")
                return
            }
        }
    }()
    
    // Collect results
    go func() {
        defer close(results)
        wg.Wait()
    }()
    
    // Process results until context is done or all work is finished
    for {
        select {
        case result, ok := <-results:
            if !ok {
                fmt.Println("All workers finished")
                return
            }
            fmt.Printf("Result: %s\n", result)
        case <-ctx.Done():
            fmt.Printf("Context canceled: %v\n", ctx.Err())
            return
        }
    }
}

// HTTP client with context cancellation
func httpClientWithCancellation(ctx context.Context, url string) error {
    // Create a cancellable HTTP request (simulation)
    fmt.Printf("Starting HTTP request to %s\n", url)
    
    // Simulate network delay
    requestDuration := 1500 * time.Millisecond
    
    select {
    case <-time.After(requestDuration):
        fmt.Printf("HTTP request to %s completed\n", url)
        return nil
    case <-ctx.Done():
        fmt.Printf("HTTP request to %s canceled: %v\n", url, ctx.Err())
        return ctx.Err()
    }
}

func demonstrateHTTPCancellation() {
    // Short timeout to force cancellation
    ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
    defer cancel()
    
    urls := []string{
        "https://api.example.com/users",
        "https://api.example.com/orders",
        "https://api.example.com/products",
    }
    
    var wg sync.WaitGroup
    for _, url := range urls {
        wg.Add(1)
        go func(u string) {
            defer wg.Done()
            httpClientWithCancellation(ctx, u)
        }(url)
    }
    
    wg.Wait()
    fmt.Println("All HTTP requests completed or canceled")
}

// Database operation with context
func databaseOperationWithContext(ctx context.Context, query string) error {
    fmt.Printf("Executing query: %s\n", query)
    
    // Simulate database operation
    operationTime := 800 * time.Millisecond
    
    ticker := time.NewTicker(100 * time.Millisecond)
    defer ticker.Stop()
    
    timeout := time.After(operationTime)
    
    for {
        select {
        case <-timeout:
            fmt.Printf("Query completed: %s\n", query)
            return nil
        case <-ticker.C:
            // Check for cancellation periodically during long operation
            select {
            case <-ctx.Done():
                fmt.Printf("Query canceled: %s (%v)\n", query, ctx.Err())
                return ctx.Err()
            default:
                fmt.Printf("Query in progress: %s\n", query)
            }
        case <-ctx.Done():
            fmt.Printf("Query canceled: %s (%v)\n", query, ctx.Err())
            return ctx.Err()
        }
    }
}

func demonstrateDatabaseCancellation() {
    ctx, cancel := context.WithTimeout(context.Background(), 1500*time.Millisecond)
    defer cancel()
    
    queries := []string{
        "SELECT * FROM users WHERE active = true",
        "UPDATE orders SET status = 'processed' WHERE date < '2023-01-01'",
        "INSERT INTO logs (message, timestamp) VALUES ('Operation completed', NOW())",
    }
    
    for _, query := range queries {
        if err := databaseOperationWithContext(ctx, query); err != nil {
            fmt.Printf("Query failed: %v\n", err)
            break // Stop executing remaining queries
        }
    }
}

func main() {
    fmt.Println("Worker cancellation:")
    demonstrateWorkerCancellation()
    
    fmt.Println("\nHTTP request cancellation:")
    demonstrateHTTPCancellation()
    
    fmt.Println("\nDatabase operation cancellation:")
    demonstrateDatabaseCancellation()
}

Advanced Timeout Handling

Sophisticated timeout and deadline management:

package main

import (
    "context"
    "fmt"
    "sync"
    "time"
)

// MultiStage operation with different timeouts for each stage
func multiStageOperation(ctx context.Context) error {
    fmt.Println("Starting multi-stage operation")
    
    // Stage 1: Authentication (500ms timeout)
    authCtx, authCancel := context.WithTimeout(ctx, 500*time.Millisecond)
    defer authCancel()
    
    if err := authenticateUser(authCtx); err != nil {
        return fmt.Errorf("authentication failed: %w", err)
    }
    
    // Stage 2: Data fetching (1s timeout)
    fetchCtx, fetchCancel := context.WithTimeout(ctx, 1*time.Second)
    defer fetchCancel()
    
    if err := fetchUserData(fetchCtx); err != nil {
        return fmt.Errorf("data fetch failed: %w", err)
    }
    
    // Stage 3: Processing (2s timeout)
    processCtx, processCancel := context.WithTimeout(ctx, 2*time.Second)
    defer processCancel()
    
    if err := processData(processCtx); err != nil {
        return fmt.Errorf("processing failed: %w", err)
    }
    
    fmt.Println("Multi-stage operation completed successfully")
    return nil
}

func authenticateUser(ctx context.Context) error {
    select {
    case <-time.After(300 * time.Millisecond):
        fmt.Println("  User authenticated")
        return nil
    case <-ctx.Done():
        return ctx.Err()
    }
}

func fetchUserData(ctx context.Context) error {
    select {
    case <-time.After(800 * time.Millisecond):
        fmt.Println("  User data fetched")
        return nil
    case <-ctx.Done():
        return ctx.Err()
    }
}

func processData(ctx context.Context) error {
    select {
    case <-time.After(1500 * time.Millisecond):
        fmt.Println("  Data processed")
        return nil
    case <-ctx.Done():
        return ctx.Err()
    }
}

// Retry mechanism with exponential backoff and context
func retryWithContext(ctx context.Context, operation func() error, maxRetries int) error {
    var lastErr error
    
    for attempt := 0; attempt < maxRetries; attempt++ {
        // Check if context is already canceled
        if err := ctx.Err(); err != nil {
            return err
        }
        
        // Try the operation
        if err := operation(); err != nil {
            lastErr = err
            fmt.Printf("Attempt %d failed: %v\n", attempt+1, err)
            
            // Calculate backoff duration
            backoff := time.Duration(attempt) * 100 * time.Millisecond
            
            // Wait with context cancellation
            select {
            case <-time.After(backoff):
                continue
            case <-ctx.Done():
                return ctx.Err()
            }
        }
        
        return nil // Success
    }
    
    return fmt.Errorf("operation failed after %d attempts: %w", maxRetries, lastErr)
}

func unreliableOperation() error {
    // Simulate 70% failure rate
    if time.Now().UnixNano()%10 < 7 {
        return fmt.Errorf("simulated failure")
    }
    return nil
}

func demonstrateRetryWithTimeout() {
    ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
    defer cancel()
    
    err := retryWithContext(ctx, unreliableOperation, 5)
    if err != nil {
        fmt.Printf("Retry operation failed: %v\n", err)
    } else {
        fmt.Println("Retry operation succeeded")
    }
}

// Circuit breaker pattern with context
type CircuitBreaker struct {
    mutex       sync.Mutex
    failures    int
    maxFailures int
    timeout     time.Duration
    nextAttempt time.Time
    state       string // "closed", "open", "half-open"
}

func NewCircuitBreaker(maxFailures int, timeout time.Duration) *CircuitBreaker {
    return &CircuitBreaker{
        maxFailures: maxFailures,
        timeout:     timeout,
        state:       "closed",
    }
}

func (cb *CircuitBreaker) Execute(ctx context.Context, operation func(context.Context) error) error {
    cb.mutex.Lock()
    defer cb.mutex.Unlock()
    
    switch cb.state {
    case "open":
        if time.Now().Before(cb.nextAttempt) {
            return fmt.Errorf("circuit breaker is open")
        }
        cb.state = "half-open"
        fallthrough
    case "half-open", "closed":
        err := operation(ctx)
        if err != nil {
            cb.failures++
            if cb.failures >= cb.maxFailures {
                cb.state = "open"
                cb.nextAttempt = time.Now().Add(cb.timeout)
                fmt.Printf("Circuit breaker opened after %d failures\n", cb.failures)
            }
            return err
        }
        
        // Success - reset or keep closed
        if cb.state == "half-open" {
            cb.state = "closed"
            fmt.Println("Circuit breaker closed")
        }
        cb.failures = 0
        return nil
    }
    
    return fmt.Errorf("unknown circuit breaker state: %s", cb.state)
}

func demonstrateCircuitBreaker() {
    cb := NewCircuitBreaker(3, 1*time.Second)
    
    flakyService := func(ctx context.Context) error {
        select {
        case <-time.After(100 * time.Millisecond):
            // Simulate 80% failure rate
            if time.Now().UnixNano()%10 < 8 {
                return fmt.Errorf("service unavailable")
            }
            return nil
        case <-ctx.Done():
            return ctx.Err()
        }
    }
    
    ctx := context.Background()
    
    // Make several calls to trigger circuit breaker
    for i := 1; i <= 10; i++ {
        callCtx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
        
        err := cb.Execute(callCtx, flakyService)
        if err != nil {
            fmt.Printf("Call %d failed: %v\n", i, err)
        } else {
            fmt.Printf("Call %d succeeded\n", i)
        }
        
        cancel()
        time.Sleep(200 * time.Millisecond)
    }
}

func main() {
    fmt.Println("Multi-stage operation with timeouts:")
    ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
    defer cancel()
    
    if err := multiStageOperation(ctx); err != nil {
        fmt.Printf("Operation failed: %v\n", err)
    }
    
    fmt.Println("\nRetry with context:")
    demonstrateRetryWithTimeout()
    
    fmt.Println("\nCircuit breaker with context:")
    demonstrateCircuitBreaker()
}

Context Values and Best Practices

Proper Context Value Usage

Guidelines for using context values appropriately:

package main

import (
    "context"
    "fmt"
    "log"
    "time"
)

// Define typed keys for context values
type ctxKey int

const (
    requestIDKey ctxKey = iota
    userIDKey
    correlationIDKey
    traceIDKey
)

// RequestContext wraps common request-scoped data
type RequestContext struct {
    RequestID     string
    UserID        string
    CorrelationID string
    TraceID       string
    StartTime     time.Time
}

// NewRequestContext creates a context with request metadata
func NewRequestContext(ctx context.Context, requestID, userID string) context.Context {
    requestCtx := &RequestContext{
        RequestID:     requestID,
        UserID:        userID,
        CorrelationID: generateCorrelationID(),
        TraceID:       generateTraceID(),
        StartTime:     time.Now(),
    }
    
    ctx = context.WithValue(ctx, requestIDKey, requestCtx.RequestID)
    ctx = context.WithValue(ctx, userIDKey, requestCtx.UserID)
    ctx = context.WithValue(ctx, correlationIDKey, requestCtx.CorrelationID)
    ctx = context.WithValue(ctx, traceIDKey, requestCtx.TraceID)
    
    return ctx
}

func generateCorrelationID() string {
    return fmt.Sprintf("corr-%d", time.Now().UnixNano())
}

func generateTraceID() string {
    return fmt.Sprintf("trace-%d", time.Now().UnixNano())
}

// Helper functions to extract values safely
func GetRequestID(ctx context.Context) string {
    if id, ok := ctx.Value(requestIDKey).(string); ok {
        return id
    }
    return "unknown"
}

func GetUserID(ctx context.Context) string {
    if id, ok := ctx.Value(userIDKey).(string); ok {
        return id
    }
    return "anonymous"
}

func GetCorrelationID(ctx context.Context) string {
    if id, ok := ctx.Value(correlationIDKey).(string); ok {
        return id
    }
    return "unknown"
}

func GetTraceID(ctx context.Context) string {
    if id, ok := ctx.Value(traceIDKey).(string); ok {
        return id
    }
    return "unknown"
}

// Structured logger that uses context values
type ContextLogger struct {
    ctx context.Context
}

func NewContextLogger(ctx context.Context) *ContextLogger {
    return &ContextLogger{ctx: ctx}
}

func (cl *ContextLogger) Info(message string) {
    log.Printf("[INFO] [req:%s] [user:%s] [trace:%s] %s",
        GetRequestID(cl.ctx),
        GetUserID(cl.ctx),
        GetTraceID(cl.ctx),
        message)
}

func (cl *ContextLogger) Error(message string, err error) {
    log.Printf("[ERROR] [req:%s] [user:%s] [trace:%s] %s: %v",
        GetRequestID(cl.ctx),
        GetUserID(cl.ctx),
        GetTraceID(cl.ctx),
        message, err)
}

func demonstrateContextValues() {
    ctx := context.Background()
    
    // Create request context
    requestCtx := NewRequestContext(ctx, "req-123", "user-456")
    
    // Use throughout the request lifecycle
    processUserRequest(requestCtx)
}

func processUserRequest(ctx context.Context) {
    logger := NewContextLogger(ctx)
    logger.Info("Processing user request")
    
    // Simulate some processing steps
    if err := validateUser(ctx); err != nil {
        logger.Error("User validation failed", err)
        return
    }
    
    if err := processData(ctx); err != nil {
        logger.Error("Data processing failed", err)
        return
    }
    
    logger.Info("Request processed successfully")
}

func validateUser(ctx context.Context) error {
    logger := NewContextLogger(ctx)
    userID := GetUserID(ctx)
    
    logger.Info(fmt.Sprintf("Validating user: %s", userID))
    
    // Simulate validation
    time.Sleep(100 * time.Millisecond)
    
    if userID == "user-456" {
        logger.Info("User validation successful")
        return nil
    }
    
    return fmt.Errorf("invalid user: %s", userID)
}

// Anti-pattern examples (what NOT to do)
func demonstrateAntiPatterns() {
    fmt.Println("\n=== Context Anti-Patterns (DON'T DO THIS) ===")
    
    // DON'T: Store large objects in context
    largeData := make([]byte, 1024*1024) // 1MB
    badCtx := context.WithValue(context.Background(), "largeData", largeData)
    fmt.Printf("Bad: Stored %d bytes in context\n", len(largeData))
    _ = badCtx
    
    // DON'T: Use context for optional parameters
    badConfigCtx := context.WithValue(context.Background(), "config", map[string]interface{}{
        "timeout": 30,
        "retries": 3,
        "debug":   true,
    })
    _ = badConfigCtx
    
    // DON'T: Pass nil context
    // processWithNilContext(nil) // This would panic
    
    fmt.Println("See code comments for anti-patterns to avoid")
}

// Good patterns for context usage
type ServiceConfig struct {
    Timeout    time.Duration
    MaxRetries int
    Debug      bool
}

// GOOD: Pass configuration as parameters, not context values
func processWithConfig(ctx context.Context, config ServiceConfig) error {
    logger := NewContextLogger(ctx)
    logger.Info("Processing with configuration")
    
    // Use the config explicitly
    timeoutCtx, cancel := context.WithTimeout(ctx, config.Timeout)
    defer cancel()
    
    return doProcessing(timeoutCtx)
}

func doProcessing(ctx context.Context) error {
    select {
    case <-time.After(50 * time.Millisecond):
        return nil
    case <-ctx.Done():
        return ctx.Err()
    }
}

// GOOD: Use context for request-scoped data only
type AuthInfo struct {
    UserID      string
    Permissions []string
    Token       string
}

const authInfoKey ctxKey = 999

func WithAuth(ctx context.Context, auth AuthInfo) context.Context {
    return context.WithValue(ctx, authInfoKey, auth)
}

func GetAuth(ctx context.Context) (AuthInfo, bool) {
    auth, ok := ctx.Value(authInfoKey).(AuthInfo)
    return auth, ok
}

func demonstrateGoodPatterns() {
    ctx := context.Background()
    
    // Add authentication information
    auth := AuthInfo{
        UserID:      "user-123",
        Permissions: []string{"read", "write"},
        Token:       "bearer-token-xyz",
    }
    authCtx := WithAuth(ctx, auth)
    
    // Add request tracing
    requestCtx := NewRequestContext(authCtx, "req-789", auth.UserID)
    
    // Pass configuration explicitly, not through context
    config := ServiceConfig{
        Timeout:    2 * time.Second,
        MaxRetries: 3,
        Debug:      true,
    }
    
    if err := processWithConfig(requestCtx, config); err != nil {
        logger := NewContextLogger(requestCtx)
        logger.Error("Processing failed", err)
    }
    
    // Demonstrate auth extraction
    if extractedAuth, ok := GetAuth(requestCtx); ok {
        fmt.Printf("User %s has permissions: %v\n", 
            extractedAuth.UserID, extractedAuth.Permissions)
    }
}

func main() {
    fmt.Println("Context values demonstration:")
    demonstrateContextValues()
    
    demonstrateAntiPatterns()
    
    fmt.Println("\nGood context patterns:")
    demonstrateGoodPatterns()
}

Context Middleware and Interceptors

Building reusable context-aware middleware:

package main

import (
    "context"
    "fmt"
    "log"
    "sync"
    "time"
)

// Middleware type definition
type Middleware func(next HandlerFunc) HandlerFunc
type HandlerFunc func(ctx context.Context) error

// Request metrics middleware
func MetricsMiddleware() Middleware {
    return func(next HandlerFunc) HandlerFunc {
        return func(ctx context.Context) error {
            start := time.Now()
            requestID := GetRequestID(ctx)
            
            log.Printf("[METRICS] Request %s started", requestID)
            
            err := next(ctx)
            
            duration := time.Since(start)
            status := "success"
            if err != nil {
                status = "error"
            }
            
            log.Printf("[METRICS] Request %s completed in %v (status: %s)", 
                requestID, duration, status)
            
            return err
        }
    }
}

// Timeout middleware
func TimeoutMiddleware(timeout time.Duration) Middleware {
    return func(next HandlerFunc) HandlerFunc {
        return func(ctx context.Context) error {
            timeoutCtx, cancel := context.WithTimeout(ctx, timeout)
            defer cancel()
            
            done := make(chan error, 1)
            go func() {
                done <- next(timeoutCtx)
            }()
            
            select {
            case err := <-done:
                return err
            case <-timeoutCtx.Done():
                return timeoutCtx.Err()
            }
        }
    }
}

// Rate limiting middleware
type RateLimiter struct {
    requests map[string][]time.Time
    mutex    sync.Mutex
    limit    int
    window   time.Duration
}

func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
    return &RateLimiter{
        requests: make(map[string][]time.Time),
        limit:    limit,
        window:   window,
    }
}

func (rl *RateLimiter) isAllowed(userID string) bool {
    rl.mutex.Lock()
    defer rl.mutex.Unlock()
    
    now := time.Now()
    cutoff := now.Add(-rl.window)
    
    // Clean old requests
    requests := rl.requests[userID]
    var validRequests []time.Time
    for _, reqTime := range requests {
        if reqTime.After(cutoff) {
            validRequests = append(validRequests, reqTime)
        }
    }
    
    // Check if under limit
    if len(validRequests) >= rl.limit {
        rl.requests[userID] = validRequests
        return false
    }
    
    // Add current request
    validRequests = append(validRequests, now)
    rl.requests[userID] = validRequests
    return true
}

func RateLimitMiddleware(limiter *RateLimiter) Middleware {
    return func(next HandlerFunc) HandlerFunc {
        return func(ctx context.Context) error {
            userID := GetUserID(ctx)
            
            if !limiter.isAllowed(userID) {
                return fmt.Errorf("rate limit exceeded for user %s", userID)
            }
            
            return next(ctx)
        }
    }
}

// Authentication middleware
func AuthMiddleware() Middleware {
    return func(next HandlerFunc) HandlerFunc {
        return func(ctx context.Context) error {
            userID := GetUserID(ctx)
            
            if userID == "anonymous" {
                return fmt.Errorf("authentication required")
            }
            
            // Simulate auth validation
            if userID == "invalid-user" {
                return fmt.Errorf("invalid user credentials")
            }
            
            log.Printf("[AUTH] User %s authenticated successfully", userID)
            return next(ctx)
        }
    }
}

// Logging middleware
func LoggingMiddleware() Middleware {
    return func(next HandlerFunc) HandlerFunc {
        return func(ctx context.Context) error {
            logger := NewContextLogger(ctx)
            logger.Info("Request received")
            
            err := next(ctx)
            
            if err != nil {
                logger.Error("Request failed", err)
            } else {
                logger.Info("Request completed successfully")
            }
            
            return err
        }
    }
}

// Recovery middleware
func RecoveryMiddleware() Middleware {
    return func(next HandlerFunc) HandlerFunc {
        return func(ctx context.Context) (err error) {
            defer func() {
                if r := recover(); r != nil {
                    logger := NewContextLogger(ctx)
                    logger.Error("Panic recovered", fmt.Errorf("panic: %v", r))
                    err = fmt.Errorf("internal server error")
                }
            }()
            
            return next(ctx)
        }
    }
}

// Chain multiple middlewares
func Chain(middlewares ...Middleware) Middleware {
    return func(next HandlerFunc) HandlerFunc {
        for i := len(middlewares) - 1; i >= 0; i-- {
            next = middlewares[i](next)
        }
        return next
    }
}

// Sample business logic handlers
func userProfileHandler(ctx context.Context) error {
    time.Sleep(100 * time.Millisecond) // Simulate processing
    
    userID := GetUserID(ctx)
    fmt.Printf("Fetched profile for user: %s\n", userID)
    return nil
}

func dataProcessingHandler(ctx context.Context) error {
    time.Sleep(200 * time.Millisecond) // Simulate processing
    
    fmt.Println("Data processing completed")
    return nil
}

func panicHandler(ctx context.Context) error {
    panic("simulated panic for testing recovery")
}

func demonstrateMiddlewareChain() {
    rateLimiter := NewRateLimiter(3, time.Minute)
    
    // Create middleware chain
    middleware := Chain(
        RecoveryMiddleware(),
        MetricsMiddleware(),
        LoggingMiddleware(),
        TimeoutMiddleware(1*time.Second),
        AuthMiddleware(),
        RateLimitMiddleware(rateLimiter),
    )
    
    // Test cases
    testCases := []struct {
        name    string
        userID  string
        handler HandlerFunc
    }{
        {"Valid User Profile", "user-123", userProfileHandler},
        {"Valid Data Processing", "user-456", dataProcessingHandler},
        {"Anonymous User", "anonymous", userProfileHandler},
        {"Invalid User", "invalid-user", userProfileHandler},
        {"Panic Handler", "user-789", panicHandler},
    }
    
    for _, tc := range testCases {
        fmt.Printf("\n=== Testing: %s ===\n", tc.name)
        
        ctx := context.Background()
        ctx = NewRequestContext(ctx, fmt.Sprintf("req-%d", time.Now().UnixNano()), tc.userID)
        
        // Wrap handler with middleware
        wrappedHandler := middleware(tc.handler)
        
        if err := wrappedHandler(ctx); err != nil {
            fmt.Printf("Handler error: %v\n", err)
        }
    }
}

// Demonstrate rate limiting
func demonstrateRateLimiting() {
    fmt.Println("\n=== Rate Limiting Test ===")
    
    rateLimiter := NewRateLimiter(2, 5*time.Second) // 2 requests per 5 seconds
    middleware := RateLimitMiddleware(rateLimiter)
    handler := middleware(userProfileHandler)
    
    userID := "rate-test-user"
    
    // Make multiple requests
    for i := 1; i <= 5; i++ {
        ctx := NewRequestContext(context.Background(), 
            fmt.Sprintf("req-%d", i), userID)
        
        fmt.Printf("Request %d: ", i)
        if err := handler(ctx); err != nil {
            fmt.Printf("Failed - %v\n", err)
        } else {
            fmt.Println("Success")
        }
        
        time.Sleep(500 * time.Millisecond)
    }
}

func main() {
    demonstrateMiddlewareChain()
    demonstrateRateLimiting()
}

FAQ

Q: When should I use context.Background() vs context.TODO()? A: Use context.Background() as the root context for your application, typically in main(), init(), or tests. Use context.TODO() when you're unsure what context to use or as a placeholder during development.

Q: Should I pass context as the first parameter to every function? A: Not every function needs context. Only pass context to functions that perform I/O operations, may need cancellation, or require request-scoped data. Don't pass context to pure functions or simple utilities.

Q: How do I handle context cancellation in long-running operations? A: Check ctx.Done() periodically in loops or long operations. Use select statements to check for cancellation alongside your main work. Consider breaking large operations into smaller chunks.

Q: Is it safe to store database connections in context? A: No, avoid storing connection objects in context. Context values should be request-scoped data like user IDs, trace IDs, or auth tokens. Pass connections as explicit parameters or use dependency injection.

Q: What happens if I don't call the cancel function returned by WithCancel/WithTimeout? A: You'll leak goroutines and resources. Always defer cancel() immediately after creating a context, even if you don't explicitly cancel it. The defer ensures cleanup happens.

Q: Can I modify context values after creation? A: No, context values are immutable. To "modify" values, create a new context with context.WithValue() using the existing context as the parent.

Conclusion

The Go context package is essential for building robust, concurrent applications that handle request lifecycles gracefully. Key insights from this comprehensive guide include:

  • Understanding context creation patterns and proper hierarchy management
  • Implementing graceful cancellation and timeout handling across your application
  • Using context values appropriately for request-scoped data while avoiding anti-patterns
  • Building reusable middleware that leverages context for cross-cutting concerns
  • Following best practices for context propagation and resource cleanup
  • Designing applications that respect cancellation signals and handle timeouts elegantly

Context enables you to build applications that are responsive to cancellation, respect timeouts, and maintain request scope throughout complex operations. By mastering these patterns, you can create Go applications that handle concurrent operations reliably and provide excellent user experiences even under load or failure conditions.

Ready to improve your context usage? Start by auditing your current code for missing context propagation and timeout handling. Implement proper cancellation checking in long-running operations and consider adding middleware for common cross-cutting concerns. Share your experiences and questions in the comments below!

Share this article

Add Comment

No comments yet. Be the first to comment!

More from Go