Type checking decorators

Decorators are functions that usually wrap other functions with some extra code.

How they do it is by replacing a function with a wrapper function. This "wrapper function" calls the original function inside it, but adds extra code above and below it.

For example, take this decorator that just logs everytime you call a specific function:

def log(function):
    """Log everytime you call the function"""
    def wrapper():
        print(f"Calling the function!")
        function()

    return wrapper

@log
def fortytwo():
    print(42)

fortytwo()

The @log decorator wraps the fortytwo function such that everytime you call it, it prints out Calling the function!. And the best part is that you can use the same @log decorator on many functions, and it'll work on all of them.

So, how do we add type hints to this function? We've already seen one method to do this: use a Protocol:

from typing import Protocol

class LogFunction(Protocol):
    def __call__(self) -> None: ...

def log(function: LogFunction) -> LogFunction:
    """Log everytime you call the function"""
    def wrapper() -> None:
        print(f"Calling the function!")
        function()

    return wrapper

@log
def fortytwo() -> None:
    print(42)

fortytwo()

But it becomes clear very quickly that this code needs some fixes, when you try to call it with a function that takes arguments:

from typing import Protocol

class LogFunction(Protocol):
    def __call__(self) -> None: ...

def log(function: LogFunction) -> LogFunction:
    """Log everytime you call the function"""
    def wrapper() -> None:
        print(f"Calling the function!")
        function()

    return wrapper

@log
def get_greeting(name: str, age: int, location: str) -> str:
    return f"Hi, I am {name}, {age}, from {location}."

greeting = get_greeting("Steve", 27, "London")
print(greeting)

Running the code (and also while checking through mypy), we realise that we have hard-coded the function being passed-in to have no arguments, and to have no return value:

    def wrapper() -> None:
        print(f"Calling the function!")
        function()

Meanwhile right now, we want to pass a function that can take 3 arguments and also return a string.

We could update our protocol to handle this:

class LogFunction(Protocol):
    def __call__(self, name: str, age: int, location: str) -> str: ...

Or we could ditch the protocol completely and use the generic typing.Callable:

from typing import Callable

def log(function: Callable[[str, int, str], str]) -> Callable[[str, int, str], str]:
    def wrapper(name, age, location) -> str:
        print(f"Calling the function!")
        return function(name, age, location)

    return wrapper

But the main problem still remains: We want to be able to pass functions with any function signature, not any pre-defined one.

You can do that by replacing function(name, age, location) with function(*args, **kwargs), and then making a Callable type that can accept any number of any arguments and return anything. This is how you do it:

from typing import Any, Callable

def log(function: Callable[..., Any]) -> Callable[..., Any]:
    def wrapper(*args, **kwargs) -> Any:
        print(f"Calling the function!")
        return function(*args, **kwargs)

    return wrapper


@log
def fortytwo() -> None:
    print(42)

@log
def get_greeting(name: str, age: int, location: str) -> str:
    return f"Hi, I am {name}, {age}, from {location}."


fortytwo()
greeting = get_greeting("Steve", 27, "London")
print(greeting)

The ... used as the first argument of Callable[..., Any] tells mypy that you want the function to be able to accept as many arguments of any type as needed.

Though there's just one problem with this, note that we're saying the decorated function returns Any. This means that if you return something from a decorated function, even if that function had types, the types are now gone. This is a phenomenon called "type erasure", and it can creep up in your applications if you start using Any in a lot of places.

For decorators however, usually they don't change the signature of a function. So we can just make the decorator generic, to allow it to work with any callable type:

from typing import Any, Callable, TypeVar

T = TypeVar('T', bound=Callable[..., Any])

def log(function: T) -> T:
    def wrapper(*args: Any, **kwargs: Any) -> Any:
        print(f"Calling the function!")
        return function(*args, **kwargs)

    return wrapper  # type: ignore


@log
def fortytwo() -> None:
    print(42)

@log
def get_greeting(name: str, age: int, location: str) -> str:
    return f"Hi, I am {name}, {age}, from {location}."


fortytwo()
greeting = get_greeting("Steve", 27, "London")
print(greeting)

Doing reveal_type(greeting) here should successfully tell the type int.

T = TypeVar('T', bound=Callable[..., Any]) tells mypy that we want T to be any type, as long as it falls within being Callable[..., Any]. Essentially, we are saying we want T to accept any callable type.

Notice that wrapper() still has to use Any, and because of that we had to suppress a warning from mypy on line 10, using a # type: ignore. If you remove that comment, mypy will say "we expected to return a T, but this wrapper function can be anything". However we know that we're not changing the signature so it is safe to ignore this warning. To do that, we add a # type: ignore.

But with that, our decorator is fully typed. It can take in functions of any kind, and wrap them without erasing any type information.