Navigation

Python

Python Class Decorators vs Function Decorators: When to Use Which

Learn when to use Python class decorators vs function decorators. Compare performance, state management, and complexity with practical examples.

Decorators are one of Python's most powerful features, allowing you to modify or enhance functions and classes without changing their core implementation. While most developers are familiar with function decorators, class decorators offer unique advantages in certain scenarios. This comprehensive guide will explore both approaches, their strengths and weaknesses, and provide clear guidelines on when to use each pattern.

Table Of Contents

Understanding the Two Approaches

Function Decorators: The Traditional Approach

Function decorators are functions that take another function as an argument and return a modified version:

from functools import wraps
import time

def timing_decorator(func):
    """Traditional function decorator."""
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args, **kwargs)
        end = time.time()
        print(f"{func.__name__} took {end - start:.4f} seconds")
        return result
    return wrapper

@timing_decorator
def slow_function():
    """A function that takes some time."""
    time.sleep(0.1)
    return "Done"

slow_function()  # slow_function took 0.1001 seconds

Class Decorators: The Object-Oriented Approach

Class decorators are classes that implement the __call__ method, making instances callable like functions:

from functools import wraps
import time

class TimingDecorator:
    """Class-based decorator with state management."""
    
    def __init__(self, func):
        self.func = func
        self.call_count = 0
        self.total_time = 0
        wraps(func)(self)  # Preserve function metadata
    
    def __call__(self, *args, **kwargs):
        self.call_count += 1
        start = time.time()
        result = self.func(*args, **kwargs)
        end = time.time()
        
        execution_time = end - start
        self.total_time += execution_time
        
        print(f"{self.func.__name__} took {execution_time:.4f}s (call #{self.call_count})")
        return result
    
    def get_stats(self):
        """Get execution statistics."""
        if self.call_count == 0:
            return "No calls made yet"
        avg_time = self.total_time / self.call_count
        return f"Calls: {self.call_count}, Total: {self.total_time:.4f}s, Avg: {avg_time:.4f}s"

@TimingDecorator
def another_slow_function():
    """Another function that takes some time."""
    time.sleep(0.1)
    return "Done"

# Usage
another_slow_function()  # Call 1
another_slow_function()  # Call 2
print(another_slow_function.get_stats())  # Statistics available

When to Use Function Decorators

Function decorators are ideal when you need simple, stateless behavior modification.

Simple Behavior Enhancement

from functools import wraps
import logging

def log_calls(level=logging.INFO):
    """Simple function decorator for logging."""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            logger = logging.getLogger(func.__module__)
            logger.log(level, f"Calling {func.__name__}")
            
            try:
                result = func(*args, **kwargs)
                logger.log(level, f"{func.__name__} completed successfully")
                return result
            except Exception as e:
                logger.error(f"{func.__name__} failed: {e}")
                raise
        return wrapper
    return decorator

@log_calls(logging.DEBUG)
def calculate_sum(a, b):
    """Calculate sum of two numbers."""
    return a + b

@log_calls()
def divide_numbers(a, b):
    """Divide two numbers."""
    return a / b

Validation and Input Processing

from functools import wraps

def validate_types(**expected_types):
    """Function decorator for type validation."""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            # Get function signature
            import inspect
            sig = inspect.signature(func)
            bound_args = sig.bind(*args, **kwargs)
            bound_args.apply_defaults()
            
            # Validate types
            for param_name, expected_type in expected_types.items():
                if param_name in bound_args.arguments:
                    value = bound_args.arguments[param_name]
                    if not isinstance(value, expected_type):
                        raise TypeError(
                            f"{param_name} must be {expected_type.__name__}, "
                            f"got {type(value).__name__}"
                        )
            
            return func(*args, **kwargs)
        return wrapper
    return decorator

@validate_types(name=str, age=int, salary=float)
def create_employee(name, age, salary=0.0):
    """Create an employee record."""
    return {"name": name, "age": age, "salary": salary}

# Valid call
employee = create_employee("Alice", 30, 50000.0)
print(employee)

# This would raise TypeError
# create_employee("Bob", "thirty", 45000.0)

Caching and Memoization

from functools import wraps
import time

def simple_cache(func):
    """Simple caching decorator using function attributes."""
    cache = {}
    
    @wraps(func)
    def wrapper(*args, **kwargs):
        # Create cache key
        key = str(args) + str(sorted(kwargs.items()))
        
        if key in cache:
            print(f"Cache hit for {func.__name__}")
            return cache[key]
        
        print(f"Computing {func.__name__}")
        result = func(*args, **kwargs)
        cache[key] = result
        return result
    
    # Attach cache management
    wrapper.cache_clear = lambda: cache.clear()
    wrapper.cache_info = lambda: {"size": len(cache), "keys": list(cache.keys())}
    
    return wrapper

@simple_cache
def fibonacci(n):
    """Calculate fibonacci number with caching."""
    if n <= 1:
        return n
    return fibonacci(n-1) + fibonacci(n-2)

print(fibonacci(10))  # Computed
print(fibonacci(10))  # Cache hit
print(fibonacci.cache_info())

When to Use Class Decorators

Class decorators excel when you need state management, complex configuration, or multiple related methods.

State Management and Statistics

from functools import wraps
import time
from collections import defaultdict

class PerformanceMonitor:
    """Class decorator for comprehensive performance monitoring."""
    
    def __init__(self, track_memory=False):
        self.track_memory = track_memory
        self.stats = defaultdict(list)
        self.call_counts = defaultdict(int)
        
        if track_memory:
            import psutil
            import os
            self.process = psutil.Process(os.getpid())
    
    def __call__(self, func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            func_name = func.__qualname__
            self.call_counts[func_name] += 1
            
            # Memory before (if tracking)
            memory_before = None
            if self.track_memory:
                memory_before = self.process.memory_info().rss / 1024 / 1024  # MB
            
            # Time execution
            start_time = time.time()
            result = func(*args, **kwargs)
            end_time = time.time()
            
            execution_time = end_time - start_time
            self.stats[func_name].append(execution_time)
            
            # Memory after (if tracking)
            if self.track_memory and memory_before is not None:
                memory_after = self.process.memory_info().rss / 1024 / 1024  # MB
                memory_delta = memory_after - memory_before
                print(f"{func_name}: {execution_time:.4f}s, Memory: {memory_delta:+.2f}MB")
            else:
                print(f"{func_name}: {execution_time:.4f}s")
            
            return result
        
        # Store reference to decorator instance
        wrapper._monitor = self
        return wrapper
    
    def get_summary(self):
        """Get comprehensive performance summary."""
        summary = {}
        for func_name, times in self.stats.items():
            summary[func_name] = {
                'calls': self.call_counts[func_name],
                'total_time': sum(times),
                'avg_time': sum(times) / len(times),
                'min_time': min(times),
                'max_time': max(times)
            }
        return summary
    
    def reset_stats(self):
        """Reset all statistics."""
        self.stats.clear()
        self.call_counts.clear()

# Create monitor instance
monitor = PerformanceMonitor(track_memory=True)

@monitor
def memory_intensive_task(size):
    """Task that uses memory."""
    data = list(range(size))
    return len(data)

@monitor
def cpu_intensive_task(iterations):
    """Task that uses CPU."""
    total = 0
    for i in range(iterations):
        total += i ** 2
    return total

# Usage
memory_intensive_task(1000000)
cpu_intensive_task(100000)
memory_intensive_task(500000)

print("\nPerformance Summary:")
for func, stats in monitor.get_summary().items():
    print(f"{func}: {stats}")

Complex Configuration Management

from functools import wraps
import json
from datetime import datetime, timedelta

class ConfigurableCache:
    """Highly configurable caching decorator."""
    
    def __init__(self, max_size=100, ttl_seconds=300, 
                 serialize_keys=False, cache_exceptions=False):
        self.max_size = max_size
        self.ttl_seconds = ttl_seconds
        self.serialize_keys = serialize_keys
        self.cache_exceptions = cache_exceptions
        
        self.cache = {}
        self.access_times = {}
        self.hit_count = 0
        self.miss_count = 0
    
    def __call__(self, func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            # Generate cache key
            if self.serialize_keys:
                key = json.dumps({'args': args, 'kwargs': kwargs}, 
                               sort_keys=True, default=str)
            else:
                key = str(args) + str(sorted(kwargs.items()))
            
            current_time = datetime.now()
            
            # Check cache
            if key in self.cache:
                cached_result, cached_time = self.cache[key]
                if current_time - cached_time < timedelta(seconds=self.ttl_seconds):
                    self.hit_count += 1
                    self.access_times[key] = current_time
                    print(f"Cache hit for {func.__name__}")
                    return cached_result
                else:
                    # Expired
                    del self.cache[key]
                    del self.access_times[key]
            
            # Cache miss
            self.miss_count += 1
            print(f"Cache miss for {func.__name__}")
            
            try:
                result = func(*args, **kwargs)
                
                # Store in cache
                self._evict_if_needed()
                self.cache[key] = (result, current_time)
                self.access_times[key] = current_time
                
                return result
                
            except Exception as e:
                if self.cache_exceptions:
                    # Cache the exception
                    self._evict_if_needed()
                    self.cache[key] = (e, current_time)
                    self.access_times[key] = current_time
                raise
        
        # Attach management methods
        wrapper.cache_info = self.get_cache_info
        wrapper.cache_clear = self.clear_cache
        wrapper.cache_stats = self.get_stats
        
        return wrapper
    
    def _evict_if_needed(self):
        """Evict least recently used items if cache is full."""
        if len(self.cache) >= self.max_size:
            # Find least recently used key
            lru_key = min(self.access_times.keys(), 
                         key=lambda k: self.access_times[k])
            del self.cache[lru_key]
            del self.access_times[lru_key]
    
    def get_cache_info(self):
        """Get cache information."""
        return {
            'size': len(self.cache),
            'max_size': self.max_size,
            'hit_rate': self.hit_count / (self.hit_count + self.miss_count) * 100 
                       if (self.hit_count + self.miss_count) > 0 else 0,
            'hits': self.hit_count,
            'misses': self.miss_count
        }
    
    def clear_cache(self):
        """Clear the cache."""
        self.cache.clear()
        self.access_times.clear()
    
    def get_stats(self):
        """Get detailed statistics."""
        return {
            'config': {
                'max_size': self.max_size,
                'ttl_seconds': self.ttl_seconds,
                'serialize_keys': self.serialize_keys,
                'cache_exceptions': self.cache_exceptions
            },
            'cache_info': self.get_cache_info()
        }

# Create different cache configurations
fast_cache = ConfigurableCache(max_size=50, ttl_seconds=60)
persistent_cache = ConfigurableCache(max_size=200, ttl_seconds=3600, 
                                   serialize_keys=True)

@fast_cache
def quick_calculation(x, y):
    """Fast calculation with short-term caching."""
    time.sleep(0.1)  # Simulate work
    return x * y + x / y

@persistent_cache
def expensive_calculation(data):
    """Expensive calculation with long-term caching."""
    time.sleep(0.5)  # Simulate expensive work
    return sum(data) / len(data)

# Usage
print(quick_calculation(10, 5))     # Cache miss
print(quick_calculation(10, 5))     # Cache hit
print("Fast cache stats:", quick_calculation.cache_stats())

print(expensive_calculation([1, 2, 3, 4, 5]))  # Cache miss
print(expensive_calculation([1, 2, 3, 4, 5]))  # Cache hit
print("Persistent cache stats:", expensive_calculation.cache_stats())

Multiple Related Functionalities

from functools import wraps
import time
import logging
from contextlib import contextmanager

class ComprehensiveDecorator:
    """Class decorator with multiple related functionalities."""
    
    def __init__(self, log_calls=True, time_execution=True, 
                 retry_on_failure=False, max_retries=3):
        self.log_calls = log_calls
        self.time_execution = time_execution
        self.retry_on_failure = retry_on_failure
        self.max_retries = max_retries
        
        self.logger = logging.getLogger(__name__)
        self.execution_times = []
        self.call_count = 0
        self.failure_count = 0
    
    def __call__(self, func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            self.call_count += 1
            
            if self.log_calls:
                self.logger.info(f"Calling {func.__name__} (call #{self.call_count})")
            
            # Retry logic
            for attempt in range(self.max_retries + 1):
                try:
                    with self._timing_context() as timer:
                        result = func(*args, **kwargs)
                    
                    if self.time_execution and timer.elapsed is not None:
                        self.execution_times.append(timer.elapsed)
                        if self.log_calls:
                            self.logger.info(f"{func.__name__} completed in {timer.elapsed:.4f}s")
                    
                    return result
                    
                except Exception as e:
                    self.failure_count += 1
                    
                    if not self.retry_on_failure or attempt >= self.max_retries:
                        if self.log_calls:
                            self.logger.error(f"{func.__name__} failed: {e}")
                        raise
                    
                    if self.log_calls:
                        self.logger.warning(
                            f"{func.__name__} failed (attempt {attempt + 1}): {e}. Retrying..."
                        )
                    time.sleep(0.1 * (attempt + 1))  # Exponential backoff
        
        # Attach utility methods
        wrapper.get_stats = self.get_stats
        wrapper.reset_stats = self.reset_stats
        wrapper.configure = self.configure
        
        return wrapper
    
    @contextmanager
    def _timing_context(self):
        """Context manager for timing execution."""
        class Timer:
            def __init__(self):
                self.elapsed = None
        
        timer = Timer()
        if self.time_execution:
            start = time.time()
            try:
                yield timer
            finally:
                timer.elapsed = time.time() - start
        else:
            yield timer
    
    def get_stats(self):
        """Get comprehensive statistics."""
        stats = {
            'call_count': self.call_count,
            'failure_count': self.failure_count,
            'success_rate': ((self.call_count - self.failure_count) / self.call_count * 100) 
                           if self.call_count > 0 else 0
        }
        
        if self.execution_times:
            stats['timing'] = {
                'total_time': sum(self.execution_times),
                'avg_time': sum(self.execution_times) / len(self.execution_times),
                'min_time': min(self.execution_times),
                'max_time': max(self.execution_times)
            }
        
        return stats
    
    def reset_stats(self):
        """Reset all statistics."""
        self.execution_times.clear()
        self.call_count = 0
        self.failure_count = 0
    
    def configure(self, **kwargs):
        """Dynamically reconfigure the decorator."""
        for key, value in kwargs.items():
            if hasattr(self, key):
                setattr(self, key, value)

# Configure logging
logging.basicConfig(level=logging.INFO)

@ComprehensiveDecorator(log_calls=True, time_execution=True, 
                       retry_on_failure=True, max_retries=2)
def unreliable_function(fail_rate=0.3):
    """Function that might fail randomly."""
    import random
    if random.random() < fail_rate:
        raise Exception("Random failure")
    time.sleep(0.1)
    return "Success"

# Usage
try:
    for i in range(5):
        result = unreliable_function(0.4)
        print(f"Call {i+1}: {result}")
except Exception as e:
    print(f"Final failure: {e}")

print("\nStatistics:")
print(unreliable_function.get_stats())

# Reconfigure the decorator
unreliable_function.configure(retry_on_failure=False)
print("Retry disabled for future calls")

Performance Comparison

import time
from functools import wraps

# Function decorator
def function_timer(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args, **kwargs)
        end = time.time()
        return result
    return wrapper

# Class decorator
class ClassTimer:
    def __init__(self, func):
        self.func = func
        wraps(func)(self)
    
    def __call__(self, *args, **kwargs):
        start = time.time()
        result = self.func(*args, **kwargs)
        end = time.time()
        return result

# Test functions
@function_timer
def func_with_function_decorator():
    return sum(range(1000))

@ClassTimer
def func_with_class_decorator():
    return sum(range(1000))

# Performance test
def test_performance(func, iterations=100000):
    start = time.time()
    for _ in range(iterations):
        func()
    end = time.time()
    return end - start

print("Performance Comparison:")
func_time = test_performance(func_with_function_decorator)
class_time = test_performance(func_with_class_decorator)

print(f"Function decorator: {func_time:.4f} seconds")
print(f"Class decorator: {class_time:.4f} seconds")
print(f"Class decorator overhead: {((class_time - func_time) / func_time * 100):.1f}%")

Decision Guidelines

Use Function Decorators When:

  1. Simple, stateless modifications are needed
  2. Performance is critical (slightly faster)
  3. Quick prototyping or one-off decorators
  4. Functional programming style is preferred
  5. The decorator logic is straightforward
# Good use cases for function decorators
@wraps(func)
def simple_logger(func):
    def wrapper(*args, **kwargs):
        print(f"Calling {func.__name__}")
        return func(*args, **kwargs)
    return wrapper

def validate_positive(func):
    @wraps(func)
    def wrapper(x):
        if x <= 0:
            raise ValueError("Value must be positive")
        return func(x)
    return wrapper

Use Class Decorators When:

  1. State management is required
  2. Complex configuration is needed
  3. Multiple related methods should be available
  4. Statistics or monitoring functionality is desired
  5. The decorator needs to be customizable at runtime
# Good use cases for class decorators
class RateLimiter:
    def __init__(self, calls_per_minute=60):
        self.calls_per_minute = calls_per_minute
        # ... state management code

class CacheWithEviction:
    def __init__(self, max_size=100, strategy='lru'):
        self.max_size = max_size
        self.strategy = strategy
        # ... complex caching logic

class PerformanceProfiler:
    def __init__(self, track_memory=True, track_cpu=True):
        # ... comprehensive monitoring setup
        pass

Hybrid Approach: Decorator Factories

Sometimes the best solution combines both approaches:

def create_decorator(use_class=False, **config):
    """Factory function that creates either function or class decorator."""
    
    if use_class:
        class ConfigurableDecorator:
            def __init__(self, func):
                self.func = func
                self.config = config
                wraps(func)(self)
            
            def __call__(self, *args, **kwargs):
                # Implementation using self.config
                return self.func(*args, **kwargs)
            
            def update_config(self, **new_config):
                self.config.update(new_config)
        
        return ConfigurableDecorator
    
    else:
        def decorator(func):
            @wraps(func)
            def wrapper(*args, **kwargs):
                # Implementation using config
                return func(*args, **kwargs)
            wrapper.config = config
            return wrapper
        return decorator

# Usage
@create_decorator(use_class=True, log_level='INFO')
def stateful_function():
    pass

@create_decorator(use_class=False, timeout=30)
def simple_function():
    pass

Conclusion

The choice between class and function decorators depends on your specific needs:

Function Decorators are ideal for:

  • Simple behavior modifications
  • Performance-critical applications
  • Stateless transformations
  • Quick implementations

Class Decorators excel at:

  • State management and statistics
  • Complex configuration requirements
  • Multiple related functionalities
  • Runtime customization needs

Key considerations:

  • Complexity: Start with function decorators, move to class decorators when complexity grows
  • Performance: Function decorators have slightly less overhead
  • Maintainability: Class decorators provide better organization for complex logic
  • Reusability: Class decorators offer more flexible configuration options

By understanding these trade-offs, you can choose the right decorator pattern for each situation, leading to more maintainable and efficient Python code.

Share this article

Add Comment

No comments yet. Be the first to comment!

More from Python