tel: +48 728 438 076
email: piotr.hyzy@eviden.com

Początek

Rundka - sprawdzenie mikrofonów

Kilka zasad

  1. W razie problemĂłw => chat, potem SMS i telefon, NIE mail
  2. Materiały szkoleniowe
  3. Wszystkie pytania sÄ… ok
  4. Reguła Vegas
  5. SĹ‚uchawki
  6. Kamerki
  7. Chat
  8. Zgłaszamy wyjścia na początku danego dnia, także pożary, wszystko na chacie
  9. By default mute podczas wykład
  10. Przerwy (praca 08:30 - 16:30)
    • blok 09:00 - 10:00
    • kawowa 10:00 - 10:15 (15')
    • blok 10:15 - 11:45
    • kawowa 11:45 - 12:00
    • blok 12:00 - 13:30
    • obiad 13:30 - 14:00
    • blok 14:00 - 15:00
    • kawowa 15:00 - 15:10
    • blok 15:10 - 16:00
  11. wszystkie czasy sÄ… plus/minus 10'
  12. Jak zadawać pytanie? 1) przerwanie 2) pytanie na chacie 3) podniesienie wirtualnej ręki
  13. IDE => dowolne
  14. Każde ćwiczenie w osobnym pliku/Notebooku
  15. Nie zapraszamy innych osĂłb
  16. Zaczynamy punktualnie
  17. Ćwiczenia w dwójkach, rotacje, ask for help

Pytest

pytest Fundamentals

Pytest is a powerful and easy-to-use testing framework for Python. This module will introduce the basics of using Pytest, including setting up test cases, running them, and interpreting the results. By the end of this section, you'll understand the foundational concepts of Pytest and be able to write your first test cases.

! pip install pytest
Requirement already satisfied: pytest in ./.venv/lib/python3.12/site-packages (8.2.2)
Requirement already satisfied: iniconfig in ./.venv/lib/python3.12/site-packages (from pytest) (2.0.0)
Requirement already satisfied: packaging in ./.venv/lib/python3.12/site-packages (from pytest) (24.1)
Requirement already satisfied: pluggy<2.0,>=1.5 in ./.venv/lib/python3.12/site-packages (from pytest) (1.5.0)

System Under Test (SUT): factorial(num)

The System Under Test (SUT) refers to the specific function or module being tested. In this case, it's the factorial(num) function, which calculates the factorial of a number.

# %%writefile mathutils.py
def factorial(num: int):
    if not isinstance(num, int):
        raise TypeError('Argument must be int')

    if num == 0:
        return 1
    else:
        return factorial(num-1) * num
        # n! = (n-1)! * n
        # 3! = 2! * 3 = 1! * 2 * 3 = 0! * 1 * 2 * 3 = 1 * 1 * 2 * 3 = 6
factorial(0)
1
factorial(1)
1
factorial(3)
6

Understanding the SUT is critical because test cases are built to validate its behavior against expected outcomes.

Example Project Structure

Organizing your project correctly is essential for writing maintainable and scalable tests. A typical Pytest-compatible project structure might look like this:

mathutils/
    __init__.py            # Package initialization
    factorial_utils.py     # Factorial-related utilities
    optimization.py        # Optimization-related code
domain/
    __init__.py            #  another module
    vector.py     # Vector Model
    base.py        # Base code
tests/
    __init__.py            # Package initialization for tests
    test_factorial_utils.py # Tests for factorial utilities
    test_optimization.py   # Tests for optimization
    requirements.txt           # Test Dependency management
main.py                    # Main app file
setup.py                   # Project metadata and installation setup
requirements.txt           # Dependency management

Key Points:

  • Code files are stored in a mathutils/ directory.
  • Test files are in the tests/ directory and follow the test_*.py naming convention.
  • setup.py and requirements.txt help manage project dependencies and installation.

Writing Assertions

An assertion is a condition that must evaluate to True for the test to pass:

cond = True
assert cond, 'cond is True'
# This is equivalent to:

if not cond:
    raise AssertionError('cond is True')

Basic Tests

Tests in Pytest are written as functions. Assertions are used to verify that the code behaves as expected.

%%writefile pytest_fundamentals.py

import pytest
from mathutils import factorial

def test_factorial_of_one():
    assert factorial(1) == 1

def test_factorial_of_three():
    got = factorial(3)
    expected = 7 # intentionally wrong value to show how test failing, the right value is 6
    assert expected == got, 'my error message'

def test_raises_typererror_for_float():
    with pytest.raises(TypeError):
        factorial(3.5) # Expecting a TypeError for non-integer input
Overwriting pytest_fundamentals.py

Explanation:

  1. test_factorial_of_one: Verifies the factorial of 1.
  2. test_factorial_of_three: Contains a deliberate failure (expected is incorrect) to demonstrate test results.
  3. test_raises_typeerror_for_invalid_argument: Validates that the function raises the correct exception for invalid input.

Launching Tests

Pytest provides a simple command-line interface to execute tests. To run tests in a specific file:

! pytest pytest_fundamentals.py
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 3 items                                                              

pytest_fundamentals.py .F.                                               [100%]

=================================== FAILURES ===================================
___________________________ test_factorial_of_three ____________________________

    def test_factorial_of_three():
        got = factorial(3)
        expected = 7 # intentionally wrong value to show how test failing, the right value is 6
>       assert expected == got, 'my error message'
E       AssertionError: my error message
E       assert 7 == 6

pytest_fundamentals.py:11: AssertionError
=========================== short test summary info ============================
FAILED pytest_fundamentals.py::test_factorial_of_three - AssertionError: my error message
========================= 1 failed, 2 passed in 0.57s ==========================

Launching Tests with -q

The -q flag provides a more concise output:

! pytest pytest_fundamentals.py -q
.F.                                                                      [100%]
=================================== FAILURES ===================================
___________________________ test_factorial_of_three ____________________________

    def test_factorial_of_three():
        got = factorial(3)
        expected = 7 # intentionally wrong value to show how test failing, the right value is 6
>       assert expected == got, 'my error message'
E       AssertionError: my error message
E       assert 7 == 6

pytest_fundamentals.py:11: AssertionError
=========================== short test summary info ============================
FAILED pytest_fundamentals.py::test_factorial_of_three - AssertionError: my error message
1 failed, 2 passed in 0.46s

Useful Switches

Pytest includes several useful command-line options to enhance debugging:

Stop after the first failure (-x):

! pytest pytest_fundamentals.py -x

This stops the test run as soon as a failure is encountered.

! pytest pytest_fundamentals.py -x
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 3 items                                                              

pytest_fundamentals.py .F

=================================== FAILURES ===================================
___________________________ test_factorial_of_three ____________________________

    def test_factorial_of_three():
        got = factorial(3)
        expected = 7 # intentionally wrong value to show how test failing, the right value is 6
>       assert expected == got, 'my error message'
E       AssertionError: my error message
E       assert 7 == 6

pytest_fundamentals.py:11: AssertionError
=========================== short test summary info ============================
FAILED pytest_fundamentals.py::test_factorial_of_three - AssertionError: my error message
!!!!!!!!!!!!!!!!!!!!!!!!!! stopping after 1 failures !!!!!!!!!!!!!!!!!!!!!!!!!!!
========================= 1 failed, 1 passed in 0.35s ==========================

Print local variables for failing tests (-l):

! pytest pytest_fundamentals.py -l

This displays the values of local variables to help debug.

! pytest pytest_fundamentals.py -l
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 3 items                                                              

pytest_fundamentals.py .F.                                               [100%]

=================================== FAILURES ===================================
___________________________ test_factorial_of_three ____________________________

    def test_factorial_of_three():
        got = factorial(3)
        expected = 7 # intentionally wrong value to show how test failing, the right value is 6
>       assert expected == got, 'my error message'
E       AssertionError: my error message
E       assert 7 == 6

expected   = 7
got        = 6

pytest_fundamentals.py:11: AssertionError
=========================== short test summary info ============================
FAILED pytest_fundamentals.py::test_factorial_of_three - AssertionError: my error message
========================= 1 failed, 2 passed in 0.35s ==========================

Autodiscover All Tests

Pytest can automatically discover and execute all tests in a project. Test files must follow the test_*.py naming convention.

To autodiscover and run tests:

! pytest
! pytest
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 0 items                                                              

============================ no tests ran in 0.02s =============================

Tips:

  • Ensure all test files follow the test_*.py naming pattern.
  • Use directories like tests/ to organize your test files.

Exercise: 🏠 Dividing

Write as many tests as possible for the function below.

def div(a, b):
    return a / b

Solution

%%writefile pytest_div.py
"""
### Naming Test Functions

When creating test functions in Pytest, it's a good practice to name the function in a way that clearly communicates:
1. The **action** being tested (e.g., a function call, operation, or scenario).
2. The **expected output** or result of that action.

#### Naming Convention Example
- **Test function name**: `test_dividing_two_integers_should_give_a_float`
    - **Action**: "dividing two integers" (describes the operation being tested)
    - **Expected Output**: "should give a float" (clarifies the expected result)

This naming approach improves readability, making it easier for others to understand what the test is verifying.

#### Test Name Structure
The typical structure for a test function name is:
`test_<action>_<expected_output>`

#### Examples:
- `test_dividing_4_by_2_gives_2`: Testing that dividing 4 by 2 results in 2.
- `test_dividing_two_integers_should_give_a_float`: Testing that dividing two integers always returns a float.
- `test_raises_error_on_division_by_zero`: Testing that dividing by zero raises an error.
"""

import math
import pytest

# Define a simple division function to be tested
def div(a, b):
    return a / b


# Test that dividing 4 by 2 returns 2
def test_dividing_4_by_2_gives_2():
    assert div(4, 2) == 2


# Test that dividing two integers results in a float
def test_dividing_two_integers_should_give_a_float():
    # given / arrange - set up any necessary context (not needed here)
    pass

    # when / action - perform the operation
    got = div(3, 2)

    # then / assert - check that the result matches the expectation
    expected = 1.5
    assert got == expected

# Test that dividing two integers always returns a float, even if the result is an integer
def test_dividing_two_integers_returns_a_float_even_when_result_is_integer():
    assert isinstance(div(4, 2), float)


# Test that dividing by zero raises a ZeroDivisionError
def test_raises_error_on_division_by_zero():
    with pytest.raises(ZeroDivisionError):
        div(2, 0)


# Test that dividing two negative numbers gives a positive result
def test_dividing_negative_numbers():
    assert div(-3, -2) == 1.5


# Test that dividing infinity by infinity results in NaN (not a number)
def test_dividing_infinities_gives_nan():
    # Check using math.isnan since direct comparison with NaN doesn't work
    assert math.isnan(div(float('inf'), float('inf')))


# Test that dividing infinity by a finite number still results in infinity
def test_dividing_infinity_gives_infinity():
    inf = float('inf')
    assert div(inf, 2) == inf


# Test that dividing infinity by zero raises a ZeroDivisionError
def test_dividing_infinity_by_zero_raises_an_error():
    inf = float('inf')
    with pytest.raises(ZeroDivisionError):
        div(inf, 0)


# Test that dividing a boolean value works, with True treated as 1 and False as 0
def test_dividing_boolean_works():
    assert div(True, 2) == 0.5


# Test that dividing a very small number (epsilon) by 2 results in 0
def test_dividing_epsilon_by_two_gives_zero():
    assert div(5e-324, 2) == 0.0


# Test that dividing a huge number by a small number results in infinity
def test_dividing_huge_numbers_results_in_infinity():
    assert div(1e308, 0.5) == float('inf')


# Test that attempting to divide two lists raises a TypeError
def test_dividing_two_lists_raises_an_error():
    with pytest.raises(TypeError):
        div([3, 2, 5], [1, 2, 3])


# Test that the division operator can be overridden in custom classes
def test_you_can_override_division_operator():
    class Dividable:
        def __truediv__(self, other):
            return 42  # Custom division logic

    d = Dividable()
    assert div(d, d) == 42  # Confirm overridden behavior


# Test that dividing two instances of a custom object without division logic raises a TypeError
def test_dividing_custom_objects_raises_TypeError():
    class MyClass:
        pass

    m = MyClass()
    with pytest.raises(TypeError):
        div(m, m)
Overwriting pytest_div.py
! pytest pytest_div.py -q
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 14 items                                                             

pytest_div.py ..............                                             [100%]

============================== 14 passed in 0.17s ==============================
from contextlib import contextmanager
import io, sys

@contextmanager
def suppress_output():
    stdout = sys.stdout
    sys.stdout = None
    try:
        yield stdout
    finally:
        sys.stdout = stdout


print("A")
with suppress_output() as s:
    print("B")
    s.write("asdf\n")
print("C")
A
asdf
C

Exercise: 🏠 @suppress_output

Implementation:

# %%writefile suppress_output_module.py
from contextlib import contextmanager
import io, sys


@contextmanager
def suppress_output():
    stdout = sys.stdout
    sys.stdout = None
    try:
        yield stdout
    finally:
        sys.stdout = stdout

Example Usage

print("A")
with suppress_output() as s:
    print("B")
    s.write("asdf\n")
print("C")

What to Test?

When testing the functionality involving sys.stdout and error handling, consider the following scenarios:

  1. Changing sys.stdout in Point B

    • Verify that sys.stdout is correctly changed within the scope of B.
  2. Error Handling in Point B

    • Ensure that if an error is raised in B, the suppress_output mechanism does not suppress or alter the exception.
  3. Restoring sys.stdout in Point C

    • Confirm that sys.stdout is restored to its original within the scope of C.
  4. Restoring sys.stdout in Point C even in case of an error inB

    • Test that sys.stdout is restored even if C encounters an error during execution of scope `B.
  5. Ensuring s Matches the Original sys.stdout

    • Validate that the variable s holds a reference to the original sys.stdout.

Solution

%%writefile pytest_suppressoutput.py
import sys

import pytest

from suppress_output_module import suppress_output


def test_stdout_should_change_inside_suppress_output():
    before = sys.stdout  # Save the original sys.stdout for comparison
    with suppress_output():
        # assert sys.stdout is None  # This may give a false positive (FP) because the assertion is too strong.
        assert sys.stdout is not before  # This may give a false negative (FN) because the assertion is too weak

def test_stdout_should_be_resotred_after_suppress_output():
    before = sys.stdout
    with suppress_output():
        pass
    assert sys.stdout is before

def test_suppress_output_propagates_exceptions():
    class ExampleException(Exception):
        pass

    with pytest.raises(ExampleException):
        with suppress_output():
            raise ExampleException

def test_suppress_output_restores_original_stdout_even_in_case_of_an_exception():
    class ExampleException(Exception):
        pass

    before = sys.stdout
    try:
        with suppress_output():
            raise ExampleException
    except ExampleException:
        pass
    assert sys.stdout is before

def test_suppress_output_returns_original_stdout():
    before = sys.stdout
    with suppress_output() as s:
        assert s is before

Here's a detailed explanation of the comments in the test function:

Test Function Code

def test_stdout_should_change_inside_suppress_output():
    before = sys.stdout  # Save the original sys.stdout for comparison
    with suppress_output():
        # assert sys.stdout is None  # This may give a false positive (FP) because the assertion is too strong.
        assert sys.stdout is not before  # This may give a false negative (FN) because the assertion is too weak

Explanation of Comments

Comment 1:

# assert sys.stdout is None  # This may give a false positive (FP) because the assertion is too strong.
  • Reason: The code inside suppress_output() sets sys.stdout to None. However, directly asserting that sys.stdout is None may not always be reliable in broader contexts:
    • If the implementation changes slightly (e.g., sys.stdout is replaced with a dummy io.StringIO object instead of None), this assertion would fail, even though the behavior of suppressing output remains correct.
    • This makes the assertion too strict and prone to failing unnecessarily in valid cases.

Comment 2:

# assert sys.stdout is not before  # This may give a false negative (FN) because the assertion is too weak
  • Reason: This assertion checks only that sys.stdout has changed from its original state (before). While this ensures sys.stdout is modified, it doesn't confirm how it has been changed or whether it matches the intended behavior of the suppress_output() function:
    • If sys.stdout is set to a different value (e.g., a mock object), this assertion would still pass, even if the behavior of suppressing output is incorrect.
    • This makes the assertion too lenient and may miss issues (false negatives).

Suggested Improvement

To balance the strictness and flexibility of the assertions, you could test for specific behavior rather than relying on the exact value of sys.stdout:

def test_stdout_should_change_inside_suppress_output():
    before = sys.stdout
    with suppress_output() as captured_stdout:
        assert sys.stdout is not before  # Ensure sys.stdout changes
        assert sys.stdout is None or isinstance(sys.stdout, io.TextIOWrapper)  # Allow some flexibility
        assert captured_stdout is before  # Confirm the original stdout is captured correctly

Key Points:

  • Testing specific behavior (e.g., sys.stdout changes and the original is captured) is often more robust than asserting exact values.
  • The balance between strict and lenient assertions is critical for avoiding both false positives and false negatives.

Fixtures

Fixtures are a key feature in Pytest that allow you to set up and tear down resources required for your tests. They provide a way to share setup code across multiple tests, making them more efficient and maintainable. Unlike traditional setUp and tearDown methods in unit testing frameworks like unittest, fixtures in Pytest are more flexible and reusable.

What Are Fixtures?

  • Setup Code: Fixtures are functions that prepare some state or resources required for your tests.
  • Teardown Code: Pytest ensures that resources are properly cleaned up after the test, even if it fails.
  • Reusability: Fixtures can be shared across multiple test functions, classes, or modules.
  • Scope: Fixtures can have different scopes (function, class, module, session), determining their lifespan.

Unique Temporary Directory

Pytest provides a built-in fixture called tmpdir that creates a unique temporary directory for each test function. This is useful for testing file-related operations.

%%writefile pytest_tmpdir.py
import pytest

def test_needsfiles(tmpdir):
    print(tmpdir)
    print(type(tmpdir))
    assert False
Overwriting pytest_tmpdir.py
! pytest pytest_tmpdir.py -q
F                                                                        [100%]
=================================== FAILURES ===================================
_______________________________ test_needsfiles ________________________________

tmpdir = local('/private/var/folders/w9/9hmtpfzj64v0x0j841ny9y280000gn/T/pytest-of-a563420/pytest-5/test_needsfiles0')

    def test_needsfiles(tmpdir):
        print(tmpdir)
        print(type(tmpdir))
>       assert False
E       assert False

pytest_tmpdir.py:6: AssertionError
----------------------------- Captured stdout call -----------------------------
/private/var/folders/w9/9hmtpfzj64v0x0j841ny9y280000gn/T/pytest-of-a563420/pytest-5/test_needsfiles0
<class '_pytest._py.path.LocalPath'>
=========================== short test summary info ============================
FAILED pytest_tmpdir.py::test_needsfiles - assert False
1 failed in 0.42s

Output:

  • The tmpdir fixture provides a temporary directory as a py.path.local.LocalPath object.
  • The directory is unique for each test run and is cleaned up automatically after the test.

Key Points:

  • tmpdir is isolated for each test, preventing side effects between tests.
  • It simplifies testing code that interacts with the file system.

Listing All Fixtures

Pytest allows you to view all available fixtures using the --fixtures command.

! pytest --fixtures
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 0 items                                                              
cache -- .venv/lib/python3.12/site-packages/_pytest/cacheprovider.py:560
    Return a cache object that can persist state between testing sessions.

capsysbinary -- .venv/lib/python3.12/site-packages/_pytest/capture.py:1003
    Enable bytes capturing of writes to ``sys.stdout`` and ``sys.stderr``.

capfd -- .venv/lib/python3.12/site-packages/_pytest/capture.py:1030
    Enable text capturing of writes to file descriptors ``1`` and ``2``.

capfdbinary -- .venv/lib/python3.12/site-packages/_pytest/capture.py:1057
    Enable bytes capturing of writes to file descriptors ``1`` and ``2``.

capsys -- .venv/lib/python3.12/site-packages/_pytest/capture.py:976
    Enable text capturing of writes to ``sys.stdout`` and ``sys.stderr``.

doctest_namespace [session scope] -- .venv/lib/python3.12/site-packages/_pytest/doctest.py:738
    Fixture that returns a :py:class:`dict` that will be injected into the
    namespace of doctests.

pytestconfig [session scope] -- .venv/lib/python3.12/site-packages/_pytest/fixtures.py:1338
    Session-scoped fixture that returns the session's :class:`pytest.Config`
    object.

record_property -- .venv/lib/python3.12/site-packages/_pytest/junitxml.py:284
    Add extra properties to the calling test.

record_xml_attribute -- .venv/lib/python3.12/site-packages/_pytest/junitxml.py:307
    Add extra xml attributes to the tag for the calling test.

record_testsuite_property [session scope] -- .venv/lib/python3.12/site-packages/_pytest/junitxml.py:345
    Record a new ``<property>`` tag as child of the root ``<testsuite>``.

tmpdir_factory [session scope] -- .venv/lib/python3.12/site-packages/_pytest/legacypath.py:303
    Return a :class:`pytest.TempdirFactory` instance for the test session.

tmpdir -- .venv/lib/python3.12/site-packages/_pytest/legacypath.py:310
    Return a temporary directory path object which is unique to each test
    function invocation, created as a sub directory of the base temporary
    directory.

caplog -- .venv/lib/python3.12/site-packages/_pytest/logging.py:602
    Access and control log capturing.

monkeypatch -- .venv/lib/python3.12/site-packages/_pytest/monkeypatch.py:33
    A convenient fixture for monkey-patching.

recwarn -- .venv/lib/python3.12/site-packages/_pytest/recwarn.py:32
    Return a :class:`WarningsRecorder` instance that records all warnings emitted by test functions.

tmp_path_factory [session scope] -- .venv/lib/python3.12/site-packages/_pytest/tmpdir.py:242
    Return a :class:`pytest.TempPathFactory` instance for the test session.

tmp_path -- .venv/lib/python3.12/site-packages/_pytest/tmpdir.py:257
    Return a temporary directory path object which is unique to each test
    function invocation, created as a sub directory of the base temporary
    directory.


------------------ fixtures defined from anyio.pytest_plugin -------------------
anyio_backend [module scope] -- .venv/lib/python3.12/site-packages/anyio/pytest_plugin.py:132
    no docstring available

anyio_backend_name -- .venv/lib/python3.12/site-packages/anyio/pytest_plugin.py:137
    no docstring available

anyio_backend_options -- .venv/lib/python3.12/site-packages/anyio/pytest_plugin.py:145
    no docstring available


------------------- fixtures defined from pytest_cov.plugin --------------------
no_cover -- .venv/lib/python3.12/site-packages/pytest_cov/plugin.py:429
    A pytest fixture to disable coverage.

cov -- .venv/lib/python3.12/site-packages/pytest_cov/plugin.py:434
    A pytest fixture to provide access to the underlying coverage object.


============================ no tests ran in 0.07s =============================

Common Built-in Fixtures:

  • capsys: Captures output written to sys.stdout and sys.stderr.
  • monkeypatch: Allows you to modify or replace code for testing.
  • tmpdir: Provides a unique temporary directory for each test.
  • pytestconfig: Gives access to the Pytest configuration object.

Benefits:

  • Helps discover useful fixtures provided by Pytest and third-party plugins.
  • Saves time by reusing existing functionality.

Using capsys

The capsys fixture allows you to capture text written to sys.stdout and sys.stderr during a test.

%%writefile test_capsys.py
def test_using_capsys(capsys):
    print('asdf')
    out, err = capsys.readouterr()
    print('out', out)
    assert out == 'asdf\n'
Overwriting test_capsys.py
! pytest test_capsys.py
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 1 item                                                               

test_capsys.py .                                                         [100%]

============================== 1 passed in 0.07s ===============================

Key Points:

  • Use capsys.readouterr() to capture and read the output.
  • Separate captured stdout (out) and stderr (err).
  • Useful for testing CLI tools or functions that print output.

Implementing Fixtures

You can define your own fixtures using the @pytest.fixture decorator. Fixtures encapsulate setup logic, allowing you to create reusable components for test preparation. Pytest will automatically execute fixtures before the test functions that request them.

Steps to Create Your Own Fixture

  1. Import pytest:
    • Ensure you have pytest imported in your test file.
  2. Define a Function with the Setup Logic:
    • Create a function that contains the necessary setup steps for your tests.
  3. Annotate with @pytest.fixture:
    • Use the @pytest.fixture decorator to mark the function as a fixture.
  4. Return the Required Object:
    • The function should return the object or resource that will be passed to the test.
%%writefile pytest_implementing_fixtures.py
from time import sleep
import pytest

@pytest.fixture
def empty_list():
    print('preparing database')
    sleep(2)
    return []

def test_a(empty_list):
    print(empty_list)
    empty_list.append(2)
    print(empty_list)

def test_b(empty_list):
    print(empty_list)
    empty_list.append(2)
    print(empty_list)
Overwriting pytest_implementing_fixtures.py
! pytest pytest_implementing_fixtures.py -sv
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 2 items                                                              

pytest_implementing_fixtures.py::test_a preparing database
[]
[2]
PASSED
pytest_implementing_fixtures.py::test_b preparing database
[]
[2]
PASSED

============================== 2 passed in 4.07s ===============================

Key Points:

  • Fixtures are instantiated before the test runs.
  • The return value of the fixture is passed to the test function as an argument.

Sharing Fixture Instances

You can control how often a fixture is created by setting its scope. The scope determines the lifespan of a fixture and how many times it is invoked during a test session. Pytest provides four built-in scopes:

Fixture Scopes

  1. function (Default)
    • Definition: A new instance of the fixture is created for each test function that uses it.
    • Use Case: Ideal for tests that require an isolated, fresh setup for every test.
    • Example:
%%writefile pytest_fixture_scope_function.py

import pytest

@pytest.fixture(scope='function')
def resource():
    print("Setting up resource for each test")
    return "function_resource"

def test_a(resource):
    assert resource == "function_resource"
    resource =  'something else'

def test_b(resource):
    assert resource == "function_resource"
    resource =  'something else'
Overwriting pytest_fixture_scope_function.py
! pytest pytest_fixture_scope_function.py -sv
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 2 items                                                              

pytest_fixture_scope_function.py::test_a Setting up resource for each test
PASSED
pytest_fixture_scope_function.py::test_b Setting up resource for each test
PASSED

============================== 2 passed in 0.08s ===============================
  1. class
    • Definition: A single instance of the fixture is created and shared among all tests in a class.
    • Use Case: Useful for tests within a class that share the same setup but need isolation from other classes.
    • Example:
%%writefile pytest_fixture_scope_class.py

import pytest

@pytest.fixture(scope='class')
def resource():
    print("Setting up resource for the class")
    return 'class_resource'

class TestExample:
    def test_a(self, resource):
        print(id(resource))
        assert resource == 'class_resource'
        resource = 'dddddd' # create local copy if resource variable, function scope


    def test_b(self, resource):
        assert resource == 'class_resource'
Overwriting pytest_fixture_scope_class.py
! pytest pytest_fixture_scope_class.py -sv
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 2 items                                                              

pytest_fixture_scope_class.py::TestExample::test_a Setting up resource for the class
4365607152
PASSED
pytest_fixture_scope_class.py::TestExample::test_b PASSED

============================== 2 passed in 0.07s ===============================
%%writefile pytest_fixture_scope_class1.py

import pytest

@pytest.fixture(scope='class')
def resource():
    print("Setting up resource for the class")
    return {"class_resource": "class_resource"}

class TestExample1:
    def test_a(self, resource):
        assert resource == {"class_resource": "class_resource"}
        # resource.update(new_key="new_value")

    def test_b(self, resource):
        assert resource == {"class_resource": "class_resource"}
        resource.update(new_key="new_value")

class TestExample2:
    def test_a(self, resource):
        assert resource == {"class_resource": "class_resource"}
Overwriting pytest_fixture_scope_class1.py
! pytest pytest_fixture_scope_class1.py -sv
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 3 items                                                              

pytest_fixture_scope_class1.py::TestExample1::test_a Setting up resource for the class
PASSED
pytest_fixture_scope_class1.py::TestExample1::test_b PASSED
pytest_fixture_scope_class1.py::TestExample2::test_a Setting up resource for the class
PASSED

============================== 3 passed in 0.09s ===============================
  1. module
    • Definition: A single instance of the fixture is created and shared across all tests in a module.
    • Use Case: Suitable for module-level resources that are expensive to set up but can be shared safely among tests.
    • Example:
%%writefile pytest_fixture_scope_module.py

import pytest

# fixture might be in separate file
@pytest.fixture(scope='module')
def resource():
    print("Setting up resource for the module")
    return "module_resource"

def test_a(resource):
    assert resource == "module_resource"

def test_b(resource):
    assert resource == "module_resource"
Overwriting pytest_fixture_scope_module.py
! pytest pytest_fixture_scope_module.py -sv
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 2 items                                                              

pytest_fixture_scope_module.py::test_a Setting up resource for the module
PASSED
pytest_fixture_scope_module.py::test_b PASSED

============================== 2 passed in 0.08s ===============================
  1. session
    • Definition: A single instance of the fixture is created and shared across the entire test session, spanning multiple modules.
    • Use Case: Best for global resources that need to be initialized once and reused across all tests (e.g., database connections, external services).
    • Example:
%%writefile conftest.py

import pytest
@pytest.fixture(scope='session')
def resource():
    print("Setting up resource for the session")
    return "session_resource"
Overwriting conftest.py
%%writefile test_fixture_scope_session1.py
import pytest

def test_a(resource):
    assert resource == "session_resource"

def test_b(resource):
    assert resource == "session_resource"
Overwriting test_fixture_scope_session1.py
%%writefile test_fixture_scope_session2.py

import pytest

def test_c(resource):
    assert resource == "session_resource"

def test_d(resource):
    assert resource == "session_resource"
Overwriting test_fixture_scope_session2.py
! pytest ./ -sv
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 5 items                                                              

test_capsys.py::test_using_capsys out asdf

PASSED
test_fixture_scope_session1.py::test_a Setting up resource for the session
PASSED
test_fixture_scope_session1.py::test_b PASSED
test_fixture_scope_session2.py::test_c PASSED
test_fixture_scope_session2.py::test_d PASSED

============================== 5 passed in 0.09s ===============================

Practical Example: Sharing Fixture Instances

Exercise: 🏠 DB

Tip

A fixture, can use other fixture

### Two slow!!!
create_database()
test_a()

create_database()
test_b()

create_database()
test_c()
### Optimal
create_database()
reset_database()
test_a()

reset_database()
test_b()

reset_database()
test_c()

Code:

# %%writefile pytest_db.py
import pytest

from time import sleep


# Tego kodu nie modyfiklujemy!
def create_database():
    print("create_database")
    sleep(2)
    return []


def reset_database(db):
    print("reset_database")
    del db[:]


# Wasze fixtures

....

# Wasze trzy testy: test_a, test_b i test_c
def test_one(db):
    print('test_one')
    db.append(2)
    assert db == [2]

def test_two(db):
    print('test_two')
    assert db == []

def test_three(db):
    print('test_three')
    db.append(5)
    assert db == [5]
%%writefile pytest_db.py
import pytest

from time import sleep


# Tego kodu nie modyfiklujemy!
def create_database():
    print("create_database")
    sleep(2)
    return []


def reset_database(db):
    print("reset_database")
    del db[:]


# Wasze fixtures
@pytest.fixture(scope='session')
def shared_db():
    return create_database()

@pytest.fixture
def db(shared_db):
    reset_database(shared_db)
    return shared_db

# Wasze trzy testy: test_a, test_b i test_c
def test_one(db):
    print('test_one')
    db.append(2)
    assert db == [2]

def test_two(db):
    print('test_two')
    assert db == []

def test_three(db):
    print('test_three')
    db.append(5)
    assert db == [5]
Writing pytest_db.py
! pytest pytest_db.py -qsvv
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 3 items                                                              

pytest_db.py::test_one create_database
reset_database
test_one
PASSED
pytest_db.py::test_two reset_database
test_two
PASSED
pytest_db.py::test_three reset_database
test_three
PASSED

============================== 3 passed in 2.04s ===============================

Fixture Finalization

%%writefile pytest_fixture_finalization.py

import pytest

@pytest.fixture
def csv_file_stream():
    stream = open("people.csv")
    print('stream opened')
    yield stream
    # no need for try-finally
    stream.close()
    print('stream closed')

def test_people(csv_file_stream):
    print('running test')
    assert False

# def test_people1(csv_file_stream):
#     print('running test')
#     assert False
Overwriting pytest_fixture_finalization.py
! pytest pytest_fixture_finalization.py -q
F                                                                        [100%]
=================================== FAILURES ===================================
_________________________________ test_people __________________________________

csv_file_stream = <_io.TextIOWrapper name='people.csv' mode='r' encoding='UTF-8'>

    def test_people(csv_file_stream):
        print('running test')
>       assert False
E       assert False

pytest_fixture_finalization.py:15: AssertionError
---------------------------- Captured stdout setup -----------------------------
stream opened
----------------------------- Captured stdout call -----------------------------
running test
--------------------------- Captured stdout teardown ---------------------------
stream closed
=========================== short test summary info ============================
FAILED pytest_fixture_finalization.py::test_people - assert False
1 failed in 0.38s
%%writefile pytest_fixture_finalization_2.py

import pytest

@pytest.fixture
def a(b):
    print('setup a')
    yield 42
    print('teardown a')

@pytest.fixture
def b():
    print('setup b')
    yield 22
    print('teardown b')

@pytest.fixture
def c():
    print('setup c')
    yield 22
    print('teardown c')

def test_people(a):
    print('running test')
    assert False
Overwriting pytest_fixture_finalization_2.py
! pytest pytest_fixture_finalization_2.py -q
F                                                                        [100%]
=================================== FAILURES ===================================
_________________________________ test_people __________________________________

a = 42

    def test_people(a):
        print('running test')
>       assert False
E       assert False

pytest_fixture_finalization_2.py:24: AssertionError
---------------------------- Captured stdout setup -----------------------------
setup b
setup a
----------------------------- Captured stdout call -----------------------------
running test
--------------------------- Captured stdout teardown ---------------------------
teardown a
teardown b
=========================== short test summary info ============================
FAILED pytest_fixture_finalization_2.py::test_people - assert False
1 failed in 0.50s

Exercise: 🏠 Advanced Fixture Creation and Verification

You are building a test suite for an application that processes data using a shared resource. Your task is to:

  1. Implement a custom fixture that:

    • Sets up a ResourceManager object before tests run.
    • The ResourceManager should:
      • Keep track of how many times it has been initialized.
      • Provide an increment method to increase an internal counter.
      • Provide a value method to return the current counter value.
    • Implement teardown logic that resets the counter to zero after tests are complete.
  2. Write test cases to:

    • Verify that the fixture is created only once when using session scope.
    • Confirm that the internal counter is shared across tests when using session scope.
    • Ensure that the counter value is as expected after each test.
    • Test that the teardown logic properly resets the counter after all tests have run.
  3. Challenge:

    • Modify the fixture to use function scope and update the tests accordingly to reflect the change in behavior.
    • Ensure that with function scope, the fixture is created anew for each test, and the counter does not retain its value between tests.

Initial Code

# %%writefile test_resource_manager.py
import pytest

class ResourceManager:
    initialization_count = 0

    def __init__(self):
        type(self).initialization_count += 1
        self.counter = 0

    def increment(self):
        self.counter += 1

    def value(self):
        return self.counter

    def reset(self):
        self.counter = 0

# TODO: Implement the fixture
# @pytest.fixture(scope='session')
# def resource_manager():
#     pass

# TODO: Write the tests
# def test_increment_1(resource_manager):
#     pass

# def test_increment_2(resource_manager):
#     pass

# def test_initialization_count():
#     pass

# def test_teardown(resource_manager):
#     pass

Solution

%%writefile test_resource_manager.py
import pytest

class ResourceManager:
    initialization_count = 0

    def __init__(self):
        type(self).initialization_count += 1
        self.counter = 0

    def increment(self):
        self.counter += 1

    def value(self):
        return self.counter

    def reset(self):
        self.counter = 0

# Implement the fixture with 'session' scope
@pytest.fixture(scope='session')
def resource_manager():
    """
    Fixture that provides a ResourceManager instance.
    """
    rm = ResourceManager()
    yield rm
    # Teardown logic
    rm.reset()

def test_increment_1(resource_manager):
    """
    Test that increments the counter and checks its value.
    """
    resource_manager.increment()
    assert resource_manager.value() == 1, "Counter should be 1 after first increment"

def test_increment_2(resource_manager):
    """
    Another test that increments the counter and checks its value.
    """
    resource_manager.increment()
    assert resource_manager.value() == 2, "Counter should be 2 after second increment"

def test_initialization_count():
    """
    Test that the ResourceManager was initialized only once.
    """
    assert ResourceManager.initialization_count == 1, f"Initialization count should be 1, got {ResourceManager.initialization_count}"

def test_teardown(resource_manager):
    """
    Test that verifies the teardown logic reset the counter.
    """
    # Teardown occurs after the last test using the fixture
    # Since we're in the last test, the counter should still be 2
    assert resource_manager.value() == 2, "Counter should be 2 before teardown"
    # After the test, the teardown will reset the counter
Writing test_resource_manager.py
! pytest test_resource_manager.py -qsvv
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 4 items                                                              

test_resource_manager.py::test_increment_1 PASSED
test_resource_manager.py::test_increment_2 PASSED
test_resource_manager.py::test_initialization_count PASSED
test_resource_manager.py::test_teardown PASSED

============================== 4 passed in 0.10s ===============================
%%writefile test_resource_manager.py
import pytest

class ResourceManager:
    initialization_count = 0

    def __init__(self):
        type(self).initialization_count += 1
        self.counter = 0

    def increment(self):
        self.counter += 1

    def value(self):
        return self.counter

    def reset(self):
        self.counter = 0

# Challenge: Modify the fixture to 'function' scope
@pytest.fixture(scope='function')
def resource_manager_function():
    """
    Fixture that provides a new ResourceManager instance for each test.
    """
    rm = ResourceManager()
    yield rm
    # Teardown logic
    rm.reset()

def test_increment_function_1(resource_manager_function):
    """
    Test that increments the counter and checks its value with function scope.
    """
    resource_manager_function.increment()
    assert resource_manager_function.value() == 1, "Counter should be 1 after increment in function-scoped fixture"

def test_increment_function_2(resource_manager_function):
    """
    Another test that increments the counter and checks its value with function scope.
    """
    resource_manager_function.increment()
    assert resource_manager_function.value() == 1, "Counter should be 1 in a new instance of function-scoped fixture"

def test_initialization_count_function():
    """
    Test that the ResourceManager was initialized multiple times.
    """
    # Should be 3 initializations: 1 from session-scoped fixture, 2 from function-scoped fixtures
    assert ResourceManager.initialization_count == 2, f"Initialization count should be 2, got {ResourceManager.initialization_count}"

def test_teardown_function(resource_manager_function):
    """
    Test that verifies the teardown logic reset the counter in function scope.
    """
    resource_manager_function.increment()
    assert resource_manager_function.value() == 1, "Counter should be 1 before teardown in function-scoped fixture"
    # After the test, the teardown will reset the counter
Overwriting test_resource_manager.py
! pytest test_resource_manager.py -qsvv
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 4 items                                                              

test_resource_manager.py::test_increment_1 PASSED
test_resource_manager.py::test_increment_2 PASSED
test_resource_manager.py::test_initialization_count PASSED
test_resource_manager.py::test_teardown PASSED

============================== 4 passed in 0.15s ===============================
%%writefile pytest_fixture_finalization.py

import pytest

@pytest.fixture
def csv_file_stream():
#     stream = open("people.csv")
#     print('stream opened')
#     yield stream
#     # no need for try-finally
#     stream.close()
#     print('stream closed')

    # or simpler:
    with open("people.csv") as stream:
        print("stream opened")
        yield stream
        # implicit stream.close()
    print("stream closed")

def test_people(csv_file_stream):
    print('running test')
    assert False
Overwriting pytest_fixture_finalization.py
! pytest pytest_fixture_finalization.py -q
F                                                                        [100%]
=================================== FAILURES ===================================
_________________________________ test_people __________________________________

csv_file_stream = <_io.TextIOWrapper name='people.csv' mode='r' encoding='UTF-8'>

    def test_people(csv_file_stream):
        print('running test')
>       assert False
E       assert False

pytest_fixture_finalization.py:22: AssertionError
---------------------------- Captured stdout setup -----------------------------
stream opened
----------------------------- Captured stdout call -----------------------------
running test
--------------------------- Captured stdout teardown ---------------------------
stream closed
=========================== short test summary info ============================
FAILED pytest_fixture_finalization.py::test_people - assert False
1 failed in 0.53s

Parametrizing Fixtures

Parametrizing fixtures in Pytest allows you to run the same test function multiple times with different inputs. This is especially useful when you want to test the same functionality with various configurations or data sets without writing separate tests for each case.

  • @pytest.fixture(params=[...]): The params argument in the @pytest.fixture decorator specifies a list of parameter values that the fixture will be called with.
  • request.param: Within the fixture function, Pytest provides a special request object, which has a param attribute. This attribute holds the current parameter value for each invocation of the fixture.

For each parameter value specified in params, Pytest:

  1. Calls the fixture function with request.param set to that value.
  2. Executes all tests that use this fixture once for each parameter value.
%%writefile people1.csv
first_name,last_name
John,Smith
Alice,Wilson
Overwriting people1.csv
%%writefile people2.csv
first_name,last_name
Overwriting people2.csv
%%writefile pytest_parametrized_fixtures.py
import pytest

@pytest.fixture(params=['people1.csv', 'people2.csv'])
def people_csv_stream(request):
    print(f'opening {request.param}')
    with open(request.param) as stream:
        yield stream

def test_people(people_csv_stream):
    print('test_people')
    print(people_csv_stream.read())
Overwriting pytest_parametrized_fixtures.py
! pytest pytest_parametrized_fixtures.py -qs
opening people1.csv
test_people
first_name,last_name
John,Smith
Alice,Wilson

.opening people2.csv
test_people
first_name,last_name

.
2 passed in 0.07s

Factory Fixtures

In the previous section, we explored how to use parametrized fixtures to run the same test function multiple times with different inputs. While parametrized fixtures are powerful, they have limitations when you need to generate test data dynamically or when the inputs cannot be predetermined. This is where factory fixtures come into play.

Factory fixtures allow you to pass arguments to a fixture at test time, giving you greater flexibility in setting up your test data. They are particularly useful when you need to create multiple test data instances with varying attributes within a single test function.

  • Factory Fixture: A fixture that returns a function (the factory) which can accept arguments to create test data dynamically.
  • Why Use Factory Fixtures?
    • Dynamic Data Creation: When test data cannot be predefined and needs to be generated on the fly.
    • Flexibility: Allows tests to specify exactly what data they need.
    • Reusability: Centralizes the data creation logic, making it reusable across different tests.
%%writefile pytest_factory_fixture.py
import pytest

class Customer:
    def __init__(self, first_name, last_name, email, **kwargs):
        self.first_name = first_name
        self.last_name = last_name
        self.email = email
        for key, value in kwargs.items():
            setattr(self, key, value)

# Factory fixture that creates Customer instances
@pytest.fixture
def make_customer():
    def _make_customer(
        first_name="John",
        last_name="Doe",
        email="john.doe@example.com",
        **extra_attrs
    ):
        customer = Customer(
            first_name=first_name,
            last_name=last_name,
            email=email,
            **extra_attrs
        )
        return customer
    return _make_customer

def test_create_default_customer(make_customer):
    customer = make_customer()
    assert customer.first_name == "John"
    assert customer.last_name == "Doe"
    assert customer.email == "john.doe@example.com"

def test_create_custom_customer(make_customer):
    customer = make_customer(first_name="Alice", last_name="Smith", email="alice.smith@example.com")
    assert customer.first_name == "Alice"
    assert customer.last_name == "Smith"
    assert customer.email == "alice.smith@example.com"

def test_create_customer_with_extra_attrs(make_customer):
    customer = make_customer(age=30, country="USA")
    assert customer.age == 30
    assert customer.country == "USA"
Overwriting pytest_factory_fixture.py
! pytest pytest_factory_fixture.py -qs
...
3 passed in 0.07s

Composing Fixtures

In Pytest, fixtures are a powerful way to manage test setup and teardown. One of the key strengths of fixtures is their ability to depend on other fixtures. This allows you to compose complex test scenarios by building upon simpler, reusable components. By composing fixtures, you can avoid redundant code and create more maintainable and scalable test suites.

Example: Composing Fixtures in an E-commerce Application

Suppose you're testing an e-commerce application with classes Customer, Sale, and Transaction. Each of these classes depends on the others:

  • A Transaction involves a Sale and a Customer.
  • A Sale is made by a Customer.

Instead of creating a large fixture that sets up everything at once, you can create individual fixtures for each component and compose them.

%%writefile pytest_composing_fixture.py
import pytest

class Customer:
    def __init__(self, first_name, last_name, email):
        self.first_name = first_name
        self.last_name = last_name
        self.email = email

class Sale:
    def __init__(self, amount, sku, customer):
        self.amount = amount
        self.sku = sku
        self.customer = customer

class Transaction:
    def __init__(self, transaction_id, sale, customer):
        self.transaction_id = transaction_id
        self.sale = sale
        self.customer = customer

@pytest.fixture
def make_customer():
    def _make_customer(
        first_name="John",
        last_name="Doe",
        email="john.doe@example.com",
        **extra_attrs
    ):
        customer = Customer(
            first_name=first_name,
            last_name=last_name,
            email=email,
            **extra_attrs
        )
        return customer
    return _make_customer

@pytest.fixture
def make_sale(make_customer):
    def _make_sale(amount=100.0, sku="ABC123", customer=None):
        if customer is None:
            customer = make_customer()
        sale = Sale(amount=amount, sku=sku, customer=customer)
        return sale
    return _make_sale

@pytest.fixture
def make_transaction(make_sale, make_customer):
    def _make_transaction(transaction_id, sale=None, customer=None):
        if sale is None:
            sale = make_sale(customer=customer)
        if customer is None:
            customer = sale.customer
        transaction = Transaction(
            transaction_id=transaction_id,
            sale=sale,
            customer=customer,
        )
        return transaction
    return _make_transaction

def test_transaction_creation(make_transaction):
    transaction = make_transaction(transaction_id="TXN1001")
    assert transaction.transaction_id == "TXN1001"
    assert transaction.sale.amount == 100.0
    assert transaction.customer.first_name == "John"

def test_transaction_with_custom_customer(make_transaction, make_customer):
    custom_customer = make_customer(first_name="Alice", last_name="Smith")
    transaction = make_transaction(transaction_id="TXN1002", customer=custom_customer)
    assert transaction.customer.first_name == "Alice"
    assert transaction.customer.last_name == "Smith"

Exercise: 🏠 Advanced Factory and Composed Fixtures

Description

You are tasked with testing a complex Library Management System. The system consists of several interconnected classes:

  • Book: Represents a book with attributes such as title, author, isbn, and availability status.
  • Member: Represents a library member with attributes like name, member_id, and a list of borrowed_books.
  • Library: Manages collections of books and members, and provides methods for borrowing and returning books.

Your objectives are:

  1. Create factory fixtures for Book and Member that allow dynamic creation of instances with customizable attributes.
  2. Compose fixtures to create a Library instance that depends on the Book and Member fixtures.
  3. Write tests that:
    • Verify that a member can borrow a book successfully.
    • Ensure that a book's availability status updates correctly when borrowed and returned.
    • Confirm that a member cannot borrow more books than the allowed limit.
  4. Implement fixture scopes appropriately to optimize test performance and ensure proper isolation between tests.
  5. Challenge: Modify the fixtures to handle parameterization, allowing tests to run with different configurations (e.g., varying the maximum number of books a member can borrow).

Initial Code

# %%writefile test_library_system.py
import pytest

class Book:
    def __init__(self, title, author, isbn):
        self.title = title
        self.author = author
        self.isbn = isbn
        self.is_available = True

class Member:
    def __init__(self, name, member_id, max_books=3):
        self.name = name
        self.member_id = member_id
        self.borrowed_books = []
        self.max_books = max_books

class Library:
    def __init__(self):
        self.books = []
        self.members = []

    def add_book(self, book):
        # Add book to the library collection
        pass  # TODO: Implement this method

    def register_member(self, member):
        # Register a new library member
        pass  # TODO: Implement this method

    def borrow_book(self, member_id, isbn):
        # Member borrows a book
        pass  # TODO: Implement this method

    def return_book(self, member_id, isbn):
        # Member returns a book
        pass  # TODO: Implement this method

# TODO: Implement factory fixtures for Book and Member
# @pytest.fixture
# def make_book():
#     pass

# @pytest.fixture
# def make_member():
#     pass

# TODO: Implement a composed fixture for Library
# @pytest.fixture
# def library(make_book, make_member):
#     pass

# TODO: Write tests using the fixtures
# def test_member_can_borrow_book(library, make_book, make_member):
#     pass

# def test_book_availability_updates(library, make_book, make_member):
#     pass

# def test_member_borrow_limit(library, make_book, make_member):
#     pass

Solution

%%writefile test_library_system.py
import pytest

class Book:
    def __init__(self, title, author, isbn):
        self.title = title
        self.author = author
        self.isbn = isbn
        self.is_available = True

class Member:
    def __init__(self, name, member_id, max_books=3):
        self.name = name
        self.member_id = member_id
        self.borrowed_books = []
        self.max_books = max_books

class Library:
    def __init__(self):
        self.books = {}
        self.members = {}

    def add_book(self, book):
        self.books[book.isbn] = book

    def register_member(self, member):
        self.members[member.member_id] = member

    def borrow_book(self, member_id, isbn):
        member = self.members.get(member_id)
        book = self.books.get(isbn)
        if not member or not book:
            return False # or raise exception
        if not book.is_available:
            return False # or raise exception
        if len(member.borrowed_books) >= member.max_books:
            return False # or raise exception

        book.is_available = False
        member.borrowed_books.append(book)
        return True

    def return_book(self, member_id, isbn):
        member = self.members.get(member_id)
        book = self.books.get(isbn)
        if not member or not book:
            return False # or raise exception
        if book not in member.borrowed_books:
            return False # or raise exception

        book.is_available = True
        member.borrowed_books.remove(book)
        return True


# Factory fixture for Book
@pytest.fixture(scope='session')
def make_book():
    def _make_book(title="Default Title", author="Default Author", isbn=None):
        if isbn is None:
            import uuid
            isbn = str(uuid.uuid4())
        return Book(title=title, author=author, isbn=isbn)
    return _make_book

# Factory fixture for Member
@pytest.fixture(scope='session')
def make_member():
    def _make_member(name="Default Name", member_id=None, max_books=3):
        if member_id is None:
            import uuid
            member_id = str(uuid.uuid4())
        return Member(name=name, member_id=member_id, max_books=max_books)
    return _make_member


# Composed fixture for Library
@pytest.fixture
def library(make_book, make_member):
    lib = Library()
    # Add some default books and members
    book1 = make_book(title="Book One", author="Author A")
    book2 = make_book(title="Book Two", author="Author B")
    lib.add_book(book1)
    lib.add_book(book2)

    member1 = make_member(name="Member One")
    member2 = make_member(name="Member Two")
    lib.register_member(member1)
    lib.register_member(member2)
    return lib


def test_member_can_borrow_book(library, make_book, make_member):
    member = make_member(name="Test Member")
    library.register_member(member)
    book = make_book(title="Test Book")
    library.add_book(book)
    assert library.borrow_book(member.member_id, book.isbn) == True # this line is optional in term of the test scope
    assert book in member.borrowed_books
    assert not book.is_available


def test_book_availability_updates(library, make_book, make_member):
    member = make_member(name="Test Member")
    library.register_member(member)
    book = make_book(title="Test Book")
    library.add_book(book)
    # Borrow the book
    library.borrow_book(member.member_id, book.isbn)
    assert not book.is_available
    # Return the book
    library.return_book(member.member_id, book.isbn)
    assert book.is_available
    assert book not in member.borrowed_books


def test_member_borrow_limit(library, make_book, make_member):
    member = make_member(name="Test Member", max_books=2)
    library.register_member(member)
    books = [make_book(title=f"Book {i}") for i in range(3)]
    for book in books:
        library.add_book(book)
    # Borrow first book
    assert library.borrow_book(member.member_id, books[0].isbn) == True
    # Borrow second book
    assert library.borrow_book(member.member_id, books[1].isbn) == True
    # Attempt to borrow third book should fail
    assert library.borrow_book(member.member_id, books[2].isbn) == False
    assert len(member.borrowed_books) == 2


# Challenge: Parameterize the max_books limit
@pytest.fixture(params=[1, 2, 3])
def member_with_limit(make_member, request):
    return make_member(max_books=request.param)


def test_member_borrow_limit_parametrized(library, make_book, member_with_limit):
    """
    library -> object
    make_book -> factory function
    member_with_limit -> object
    """
    library.register_member(member_with_limit)
    books = [make_book(title=f"Book {i}") for i in range(5)]
    for book in books:
        library.add_book(book)
    borrow_count = 0
    for book in books:
        success = library.borrow_book(member_with_limit.member_id, book.isbn)
        if success:
            borrow_count += 1
        else:
            break
    assert borrow_count == member_with_limit.max_books
    assert len(member_with_limit.borrowed_books) == member_with_limit.max_books
Overwriting test_library_system.py
! pytest test_library_system.py -sv
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 6 items                                                              

test_library_system.py::test_member_can_borrow_book PASSED
test_library_system.py::test_book_availability_updates PASSED
test_library_system.py::test_member_borrow_limit PASSED
test_library_system.py::test_member_borrow_limit_parametrized[1] PASSED
test_library_system.py::test_member_borrow_limit_parametrized[2] PASSED
test_library_system.py::test_member_borrow_limit_parametrized[3] PASSED

============================== 6 passed in 0.08s ===============================

More on pytest

Grouping Tests

When writing tests, it’s important to organize them in a way that makes them easy to understand, maintain, and extend. Pytest supports grouping tests by encapsulating related tests into classes. This approach helps structure your test suite, especially for large projects, by logically grouping tests that share a common purpose or context.

%%writefile pytest_grouped_tests.py

def div(a, b):
    return a / b

def mul(a, b):
    return a * b

class TestDiv:
    def test_one(self):
        assert div(5, 2) == 2.5

    def test_two(self):
        assert div(4, 2) == 2

class TestMul:
    def test_one(self):
        assert mul(2, 2) == 4
Overwriting pytest_grouped_tests.py
! pytest pytest_grouped_tests.py -q
...                                                                      [100%]
3 passed in 0.03s

Skipping Tests

In some scenarios, you may want to conditionally skip certain tests or explicitly mark tests to be skipped. Pytest provides powerful decorators to skip tests based on conditions such as the Python version, operating system, or other custom logic. This feature is particularly useful when writing platform-specific tests or handling features that depend on external environments.

%%writefile pytest_skipping_tests.py
import sys

import pytest

@pytest.mark.skipif(sys.version_info < (3, 6), reason="requires python3.6 or higher") # true -> skip, false -> continue
def test_1():
    print("Test launched only on Python 3.6 or higher")

@pytest.mark.skipif(sys.version_info > (3, 6), reason="requires python3.6 or older")
def test_2():
    print("Test launched only on Python 3.6 and older")

@pytest.mark.skipif(sys.platform != 'win32', reason="requires windows")
def test_3():
    print("Testing win32 specific features")


only_win32 = pytest.mark.skipif(sys.platform != 'win32', reason="requires windows")

@only_win32
def test_4():
    print("Testing win32 specific features")

@only_win32
def test_5():
    print("Testing win32 specific features")

@pytest.mark.skip()
def test_6():
    print("Test never launached")


@pytest.mark.skipif(sys.platform != 'darwin', reason="requires MacOs")
def test_7():
    print("Testing MacOs specific features")
Overwriting pytest_skipping_tests.py
! pytest pytest_skipping_tests.py -v
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 7 items                                                              

pytest_skipping_tests.py::test_1 PASSED                                  [ 14%]
pytest_skipping_tests.py::test_2 SKIPPED (requires python3.6 or older)   [ 28%]
pytest_skipping_tests.py::test_3 SKIPPED (requires windows)              [ 42%]
pytest_skipping_tests.py::test_4 SKIPPED (requires windows)              [ 57%]
pytest_skipping_tests.py::test_5 SKIPPED (requires windows)              [ 71%]
pytest_skipping_tests.py::test_6 SKIPPED (unconditional skip)            [ 85%]
pytest_skipping_tests.py::test_7 PASSED                                  [100%]

========================= 2 passed, 5 skipped in 0.10s =========================
 

Parametrized Tests

Parametrized tests in Pytest allow you to run the same test function multiple times with different sets of input data. This is a powerful feature for efficiently testing a wide range of input scenarios without duplicating code.

%%writefile pytest_parametrized_tests.py
import pytest

@pytest.mark.parametrize('number', [10, 20, 30])
@pytest.mark.parametrize('letter', ['a', 'b', 'c'])
def test_something(number, letter):
    print(number, letter)
Overwriting pytest_parametrized_tests.py
! pytest pytest_parametrized_tests.py -v
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 9 items                                                              

pytest_parametrized_tests.py::test_something[a-10] PASSED                [ 11%]
pytest_parametrized_tests.py::test_something[a-20] PASSED                [ 22%]
pytest_parametrized_tests.py::test_something[a-30] PASSED                [ 33%]
pytest_parametrized_tests.py::test_something[b-10] PASSED                [ 44%]
pytest_parametrized_tests.py::test_something[b-20] PASSED                [ 55%]
pytest_parametrized_tests.py::test_something[b-30] PASSED                [ 66%]
pytest_parametrized_tests.py::test_something[c-10] PASSED                [ 77%]
pytest_parametrized_tests.py::test_something[c-20] PASSED                [ 88%]
pytest_parametrized_tests.py::test_something[c-30] PASSED                [100%]

============================== 9 passed in 0.12s ===============================
%%writefile pytest_parametrized_tests_1.py
import pytest

# Function to calculate area of a rectangle
def calculate_area(length, width):
    return length * width

@pytest.mark.parametrize('length,width,expected', [
    (5, 10, 50),    # Case 1: Normal rectangle
    (0, 10, 0),     # Case 2: Zero length
    (5, 0, 0),      # Case 3: Zero width
    (5, 5, 25),     # Case 4: Square
])
def test_calculate_area(length, width, expected):
    assert calculate_area(length, width) == expected
Overwriting pytest_parametrized_tests_1.py
! pytest pytest_parametrized_tests_1.py -v
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 4 items                                                              

pytest_parametrized_tests_1.py::test_calculate_area[5-10-50] PASSED      [ 25%]
pytest_parametrized_tests_1.py::test_calculate_area[0-10-0] PASSED       [ 50%]
pytest_parametrized_tests_1.py::test_calculate_area[5-0-0] PASSED        [ 75%]
pytest_parametrized_tests_1.py::test_calculate_area[5-5-25] PASSED       [100%]

============================== 4 passed in 0.16s ===============================

Exercise: 🏠 Dividing 2

Rewrite Tests for div as Parametrized Tests

Solution

%%writefile pytest_dividing_2.py

# Best solution
from pydantic import BaseModel
import math
from typing import Union, Optional, Type

import pytest

def div(a, b):
    return a / b

class Dividable:
    def __truediv__(self, other):
        return 42

class MyClass:
    pass

inf = float('inf')
nan = float('nan')

# Test cases for returned values
@pytest.mark.parametrize('a, b, expected', [
    (4, 2, 2),
    (3, 2, 1.5),
    (-3, -2, 1.5),
    (True, 2, 0.5),
    (5e-324, 2, 0.0),
    (1e308, 0.5, inf),
    (Dividable(), Dividable(), 42),
    (inf, 2, inf),
    (inf, inf, nan),
])
def test_div_returned_values(a, b, expected):
    got = div(a, b)
    if math.isnan(expected):
        assert math.isnan(got)
    else:
        assert got == expected

# Test cases for exceptions
@pytest.mark.parametrize('a, b, exception', [
    (2, 0, ZeroDivisionError),
    (inf, 0, ZeroDivisionError),
    ([3, 2, 5], [1, 2, 3], TypeError),
    (MyClass(), MyClass(), TypeError),
])
def test_div_exceptions(a, b, exception):
    with pytest.raises(exception):
        div(a, b)

# Test cases for type of returned values
@pytest.mark.parametrize('a, b, expected_type', [
    (4, 2, float),
    (3, 2, float),
    (-3, -2, float),
    (True, 2, float),
    (5e-324, 2, float),
    (1e308, 0.5, float),
    (Dividable(), Dividable(), int),
    (inf, 2, float),
    (inf, inf, float),
])
def test_div_returned_type(a, b, expected_type):
    got = div(a, b)
    assert isinstance(got, expected_type)
Overwriting pytest_dividing_2.py
%%writefile pytest_dividing_2.py

# shows hwo to use an objects as parameters
# shows hwo to generate test description dynamically
from pydantic import BaseModel
import math
from typing import Union, Optional, Type

import pytest

def div(a, b):
    return a / b

# pydantic2 (most optimal) -> dataclass -> pydantic1
class Case(BaseModel):
    a: object
    b: object
    expected: Optional[Union[float, int]] = None
    exception: Optional[Type[Exception]] = None
    doc: Optional[str] = None

class Dividable:
    def __truediv__(self, other):
        return 42

    def __str__(self):
        return 'Dividable'

class MyClass:
    pass

inf = float('inf')
nan = float('nan')

cases = [
    Case(a=4, b=2, expected=2),
    Case(a=3, b=2, expected=1.5, doc="Dividing two integers should give a float"),
    Case(a=-3, b=-2, expected=1.5),
    Case(a=True, b=2, expected=0.5),
    Case(a=5e-324, b=2, expected=0.0),
    Case(a=1e308, b=0.5, expected=inf),
    Case(a=Dividable(), b=Dividable(), expected=42),
    Case(a=inf, b=2, expected=inf),
    Case(a=inf, b=inf, expected=nan),
    Case(a=2, b=0, exception=ZeroDivisionError),
    Case(a=inf, b=0, exception=ZeroDivisionError),
    Case(a=[3, 2, 5], b=[1, 2, 3], exception=TypeError),
    Case(a=MyClass(), b=MyClass(), exception=TypeError),
]

# @pytest.mark.parametrize('case', cases)
@pytest.mark.parametrize('case', cases, ids=lambda case: case.doc or f'div({case.a}, {case.b})')
def test_div_cases(case: Case):
    if case.exception is not None:
        with pytest.raises(case.exception):
            div(case.a, case.b)
    else:
        got = div(case.a, case.b)
        if math.isnan(case.expected):
            assert math.isnan(got)
        else:
            assert got == case.expected

def test_dividing_two_integers_returns_a_float_even_when_result_is_integer():
    assert isinstance(div(4, 2), float)
Overwriting pytest_dividing_2.py
! pytest pytest_dividing_2.py -v
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 14 items                                                             

pytest_dividing_2.py::test_div_cases[div(4, 2)] PASSED                   [  7%]
pytest_dividing_2.py::test_div_cases[Dividing two integers should give a float] PASSED [ 14%]
pytest_dividing_2.py::test_div_cases[div(-3, -2)] PASSED                 [ 21%]
pytest_dividing_2.py::test_div_cases[div(True, 2)] PASSED                [ 28%]
pytest_dividing_2.py::test_div_cases[div(5e-324, 2)] PASSED              [ 35%]
pytest_dividing_2.py::test_div_cases[div(1e+308, 0.5)] PASSED            [ 42%]
pytest_dividing_2.py::test_div_cases[div(Dividable, Dividable)] PASSED   [ 50%]
pytest_dividing_2.py::test_div_cases[div(inf, 2)] PASSED                 [ 57%]
pytest_dividing_2.py::test_div_cases[div(inf, inf)] PASSED               [ 64%]
pytest_dividing_2.py::test_div_cases[div(2, 0)] PASSED                   [ 71%]
pytest_dividing_2.py::test_div_cases[div(inf, 0)] PASSED                 [ 78%]
pytest_dividing_2.py::test_div_cases[div([3, 2, 5], [1, 2, 3])] PASSED   [ 85%]
pytest_dividing_2.py::test_div_cases[div(<pytest_dividing_2.MyClass object at 0x10e5479e0>, <pytest_dividing_2.MyClass object at 0x10e547650>)] PASSED [ 92%]
pytest_dividing_2.py::test_dividing_two_integers_returns_a_float_even_when_result_is_integer PASSED [100%]

============================== 14 passed in 0.60s ==============================

Launch Test Programatically

import os
import pytest

os.chdir('.')
r = pytest.main(args=['-s'])
print(type(r))
print(r)
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 20 items

katta/test_bowling_game.py .....
test_capsys.py out asdf

.
test_fixture_scope_session1.py Setting up resource for the session
..
test_fixture_scope_session2.py ..
test_library_system.py ......
test_resource_manager.py ..F.

=================================== FAILURES ===================================
__________________________ test_initialization_count ___________________________

    def test_initialization_count():
        """
        Test that the ResourceManager was initialized only once.
        """
>       assert ResourceManager.initialization_count == 1, f"Initialization count should be 1, got {ResourceManager.initialization_count}"
E       AssertionError: Initialization count should be 1, got 4
E       assert 4 == 1
E        +  where 4 = ResourceManager.initialization_count

test_resource_manager.py:48: AssertionError
=========================== short test summary info ============================
FAILED test_resource_manager.py::test_initialization_count - AssertionError: Initialization count should be 1, got 4
========================= 1 failed, 19 passed in 0.26s =========================
<enum 'ExitCode'>
1

Code Coverage

%%writefile abs_util.py
def absolute(x):
    if x >= 0:
        return x
    else:
        return -x
Writing abs_util.py
%%writefile pytest_coverage.py
from abs_util import absolute

def test_absolute():
    assert absolute(-2) == 2
Writing pytest_coverage.py

Coverage Package

! pip install coverage
Requirement already satisfied: coverage in ./.venv/lib/python3.12/site-packages (7.6.7)
! coverage run --source=. -m pytest pytest_coverage.py -q
.                                                                        [100%]
1 passed in 0.12s
! coverage report
Name                               Stmts   Miss  Cover
------------------------------------------------------
abs_util.py                            4      1    75%
calculator.py                         19     19     0%
conftest.py                            5      2    60%
doctests.py                           11     11     0%
mathutils.py                           6      6     0%
pytest_composing_fixture.py           50     50     0%
pytest_coverage.py                     3      0   100%
pytest_db.py                          27     27     0%
pytest_div.py                         48     48     0%
pytest_dividing_2.py                  33     33     0%
pytest_factory_fixture.py             28     28     0%
pytest_factory_fixture_1.py           28     28     0%
pytest_fixture_finalization.py        10     10     0%
pytest_fixture_finalization_2.py      19     19     0%
pytest_fixture_scope_class1.py        14     14     0%
pytest_fixture_scope_class.py         12     12     0%
pytest_fixture_scope_function.py      11     11     0%
pytest_fixture_scope_module.py         9      9     0%
pytest_fixture_scope_session1.py       6      6     0%
pytest_fixture_scope_session2.py       6      6     0%
pytest_fixture_scope_session.py        9      9     0%
pytest_fundamentals.py                11     11     0%
pytest_grouped_tests.py               14     14     0%
pytest_implementing_fixtures.py       15     15     0%
pytest_parametrized_fixtures.py        9      9     0%
pytest_parametrized_tests.py           5      5     0%
pytest_parametrized_tests_1.py         6      6     0%
pytest_shared_fixture.py              15     15     0%
pytest_skipping_tests.py              24     24     0%
pytest_suppressoutput.py              33     33     0%
pytest_tmpdir.py                       5      5     0%
recently_used_list.py                 14     14     0%
sftp.py                               13     13     0%
shape.py                              11     11     0%
suppress_output_module.py              9      9     0%
test.py                               11     11     0%
test_capsys.py                         5      5     0%
test_fixture_scope_session1.py         5      5     0%
test_fixture_scope_session2.py         5      5     0%
test_library_system.py               115    115     0%
test_resource_manager.py              27     27     0%
------------------------------------------------------
TOTAL                                710    701     1%
! coverage html
Wrote HTML report to ]8;;file:///Users/a563420/python_training/testing/htmlcov/index.htmlhtmlcov/index.html]8;;

Exercise: 🏠 Coverage of suppress_output

Measure code coverage on suppress_output.

### pytest-cov package
! pip install pytest-cov
Requirement already satisfied: pytest-cov in ./.venv/lib/python3.12/site-packages (6.0.0)
Requirement already satisfied: pytest>=4.6 in ./.venv/lib/python3.12/site-packages (from pytest-cov) (8.2.2)
Requirement already satisfied: coverage>=7.5 in ./.venv/lib/python3.12/site-packages (from coverage[toml]>=7.5->pytest-cov) (7.6.7)
Requirement already satisfied: iniconfig in ./.venv/lib/python3.12/site-packages (from pytest>=4.6->pytest-cov) (2.0.0)
Requirement already satisfied: packaging in ./.venv/lib/python3.12/site-packages (from pytest>=4.6->pytest-cov) (24.1)
Requirement already satisfied: pluggy<2.0,>=1.5 in ./.venv/lib/python3.12/site-packages (from pytest>=4.6->pytest-cov) (1.5.0)
! pytest --cov=. pytest_coverage.py -q
.                                                                        [100%]

---------- coverage: platform darwin, python 3.12.6-final-0 ----------
Name                               Stmts   Miss  Cover
------------------------------------------------------
abs_util.py                            4      1    75%
calculator.py                         19     19     0%
conftest.py                            5      2    60%
doctests.py                           11     11     0%
mathutils.py                           6      6     0%
pytest_composing_fixture.py           50     50     0%
pytest_coverage.py                     3      0   100%
pytest_db.py                          27     27     0%
pytest_div.py                         48     48     0%
pytest_dividing_2.py                  33     33     0%
pytest_factory_fixture.py             28     28     0%
pytest_factory_fixture_1.py           28     28     0%
pytest_fixture_finalization.py        10     10     0%
pytest_fixture_finalization_2.py      19     19     0%
pytest_fixture_scope_class1.py        14     14     0%
pytest_fixture_scope_class.py         12     12     0%
pytest_fixture_scope_function.py      11     11     0%
pytest_fixture_scope_module.py         9      9     0%
pytest_fixture_scope_session1.py       6      6     0%
pytest_fixture_scope_session2.py       6      6     0%
pytest_fixture_scope_session.py        9      9     0%
pytest_fundamentals.py                11     11     0%
pytest_grouped_tests.py               14     14     0%
pytest_implementing_fixtures.py       15     15     0%
pytest_parametrized_fixtures.py        9      9     0%
pytest_parametrized_tests.py           5      5     0%
pytest_parametrized_tests_1.py         6      6     0%
pytest_shared_fixture.py              15     15     0%
pytest_skipping_tests.py              24     24     0%
pytest_suppressoutput.py              33     33     0%
pytest_tmpdir.py                       5      5     0%
recently_used_list.py                 14     14     0%
sftp.py                               13     13     0%
shape.py                              11     11     0%
suppress_output_module.py              9      9     0%
test.py                               11     11     0%
test_capsys.py                         5      5     0%
test_fixture_scope_session1.py         5      5     0%
test_fixture_scope_session2.py         5      5     0%
test_library_system.py               115    115     0%
test_resource_manager.py              27     27     0%
------------------------------------------------------
TOTAL                                710    701     1%

1 passed in 0.46s
! pytest --cov=. --cov-report=html pytest_coverage.py
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 1 item                                                               

pytest_coverage.py .                                                     [100%]

---------- coverage: platform darwin, python 3.12.6-final-0 ----------
Coverage HTML written to dir htmlcov


============================== 1 passed in 1.40s ===============================
! pytest --cov=. --cov-report=term --cov-report=html pytest_coverage.py
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 1 item                                                               

pytest_coverage.py .                                                     [100%]

---------- coverage: platform darwin, python 3.12.6-final-0 ----------
Name                               Stmts   Miss  Cover
------------------------------------------------------
abs_util.py                            4      1    75%
calculator.py                         19     19     0%
conftest.py                            5      2    60%
doctests.py                           11     11     0%
mathutils.py                           6      6     0%
pytest_composing_fixture.py           50     50     0%
pytest_coverage.py                     3      0   100%
pytest_db.py                          27     27     0%
pytest_div.py                         48     48     0%
pytest_dividing_2.py                  33     33     0%
pytest_factory_fixture.py             28     28     0%
pytest_factory_fixture_1.py           28     28     0%
pytest_fixture_finalization.py        10     10     0%
pytest_fixture_finalization_2.py      19     19     0%
pytest_fixture_scope_class1.py        14     14     0%
pytest_fixture_scope_class.py         12     12     0%
pytest_fixture_scope_function.py      11     11     0%
pytest_fixture_scope_module.py         9      9     0%
pytest_fixture_scope_session1.py       6      6     0%
pytest_fixture_scope_session2.py       6      6     0%
pytest_fixture_scope_session.py        9      9     0%
pytest_fundamentals.py                11     11     0%
pytest_grouped_tests.py               14     14     0%
pytest_implementing_fixtures.py       15     15     0%
pytest_parametrized_fixtures.py        9      9     0%
pytest_parametrized_tests.py           5      5     0%
pytest_parametrized_tests_1.py         6      6     0%
pytest_shared_fixture.py              15     15     0%
pytest_skipping_tests.py              24     24     0%
pytest_suppressoutput.py              33     33     0%
pytest_tmpdir.py                       5      5     0%
recently_used_list.py                 14     14     0%
sftp.py                               13     13     0%
shape.py                              11     11     0%
suppress_output_module.py              9      9     0%
test.py                               11     11     0%
test_capsys.py                         5      5     0%
test_fixture_scope_session1.py         5      5     0%
test_fixture_scope_session2.py         5      5     0%
test_library_system.py               115    115     0%
test_resource_manager.py              27     27     0%
------------------------------------------------------
TOTAL                                710    701     1%
Coverage HTML written to dir htmlcov


============================== 1 passed in 1.76s ===============================

Doctests

Doctests are a convenient way to embed tests in the docstrings of your functions, methods, or modules. They serve both as documentation and as a way to ensure that your code behaves as expected. When you run doctests, Python executes the code examples in the docstrings and checks whether the output matches the expected results.

%%writefile doctests.py
import doctest
import math

def factorial(n):
    """Returns the factorial of n (n!).

    >>> factorial(3)
    7
    >>> factorial(30)
    265252859812191058636308480000000
    >>> factorial(-1)
    Traceback (most recent call last):
        ...
    ValueError: n must be >= 0
    >>> [factorial(n)
    ...  for n in range(6)]
    [1, 1, 2, 6, 24, 120]
    """
    if not n >= 0:
        raise ValueError("n must be >= 0")
    result = 1
    factor = 2
    while factor <= n:
        result *= factor
        factor += 1
    return result

# if __name__ == "__main__":
#     doctest.testmod()
Overwriting doctests.py
! python -m doctest doctests.py -v
Trying:
    factorial(3)
Expecting:
    6
ok
Trying:
    factorial(30)
Expecting:
    265252859812191058636308480000000
ok
Trying:
    factorial(-1)
Expecting:
    Traceback (most recent call last):
        ...
    ValueError: n must be >= 0
ok
Trying:
    [factorial(n)
     for n in range(6)]
Expecting:
    [1, 1, 2, 6, 24, 120]
ok
1 items had no tests:
    doctests
1 items passed all tests:
   4 tests in doctests.factorial
4 tests in 2 items.
4 passed and 0 failed.
Test passed.

Exercise: 🏠 Doctest: suppress_output

Rewrite tests for @suppress_output in docstring.

Solution

%%writefile suppress_output_doctest.py
from contextlib import contextmanager
import doctest
import io
import sys

@contextmanager
def suppress_output():
    """
    Example Usage:

    >>> print('this will be displayed')
    this will be displayed
    >>> with suppress_output():
    ...     print("this won't be displayed")


    >>> before = sys.stdout
    >>> print('printed')
    printed
    >>> with suppress_output() as stream:
    ...     inside = sys.stdout
    ...     print('not printed')
    >>> after = sys.stdout
    >>> print('printed again')
    printed again
    >>> inside is before
    False
    >>> after is before
    True
    >>> stream is before
    True
    >>> before = sys.stdout
    >>> with suppress_output():
    ...     print("not printed")
    ...     raise ZeroDivisionError
    Traceback (most recent call last):
    ZeroDivisionError
    >>> after = sys.stdout
    >>> after is before
    True
    >>> print('still works')
    still works
    """

    stdout = sys.stdout
    sys.stdout = io.StringIO()  # Python 3
    # sys.stdout = io.BytesIO()  # Python 2
    try:
        yield stdout
    finally:
        sys.stdout = stdout

if __name__ == "__main__":
    doctest.testmod()
Writing suppress_output_doctest.py
! python -m doctest suppress_output_doctest.py -v
Trying:
    print('this will be displayed')
Expecting:
    this will be displayed
ok
Trying:
    with suppress_output():
        print("this won't be displayed")
Expecting nothing
ok
Trying:
    before = sys.stdout
Expecting nothing
ok
Trying:
    print('printed')
Expecting:
    printed
ok
Trying:
    with suppress_output() as stream:
        inside = sys.stdout
        print('not printed')
Expecting nothing
ok
Trying:
    after = sys.stdout
Expecting nothing
ok
Trying:
    print('printed again')
Expecting:
    printed again
ok
Trying:
    inside is before
Expecting:
    False
ok
Trying:
    after is before
Expecting:
    True
ok
Trying:
    stream is before
Expecting:
    True
ok
Trying:
    before = sys.stdout
Expecting nothing
ok
Trying:
    with suppress_output():
        print("not printed")
        raise ZeroDivisionError
Expecting:
    Traceback (most recent call last):
    ZeroDivisionError
ok
Trying:
    after = sys.stdout
Expecting nothing
ok
Trying:
    after is before
Expecting:
    True
ok
Trying:
    print('still works')
Expecting:
    still works
ok
1 items had no tests:
    suppress_output_doctest
1 items passed all tests:
  15 tests in suppress_output_doctest.suppress_output
15 tests in 2 items.
15 passed and 0 failed.
Test passed.

unittest

The unittest module is Python's built-in testing framework, offering a class-based approach to test organization and execution. While Pytest is more modern and flexible, unittest serves as a reference for understanding foundational testing principles in Python. It’s especially useful for scenarios that benefit from strict class-based organization or when working within legacy codebases that already rely on this framework.

%%writefile unittest_demo.py
import unittest

class TestCase(unittest.TestCase):
    def setUp(self):
        # Setup logic executed before each test
        self.stream = open('people.csv')
        print('setUp')

    def tearDown(self):
        # Teardown logic executed after each test
        self.stream.close()
        print('tearDown')

    def test_one(self):
        print('test_one')
        self.assertEqual(2 + 2, 4)  # Basic assertion

    def test_two(self):
        print('test_two')
        self.assertEqual(2 + 2, 4)  # Basic assertion

if __name__ == "__main__":
    unittest.main()
Writing unittest_demo.py
! python unittest_demo.py
setUp
test_one
tearDown
.setUp
test_two
tearDown
.
----------------------------------------------------------------------
Ran 2 tests in 0.001s

OK
! pytest unittest_demo.py -v
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 2 items                                                              

unittest_demo.py::TestCase::test_one PASSED                              [ 50%]
unittest_demo.py::TestCase::test_two PASSED                              [100%]

============================== 2 passed in 0.10s ===============================

Key Components

  • setUp:

    • Runs before each test method in the test case.
    • Typically used to initialize resources (e.g., opening files, setting up mock data).
  • tearDown:

    • Runs after each test method in the test case.
    • Typically used to clean up resources (e.g., closing files, resetting states).
  • Test Methods:

    • Methods starting with test_ are considered test methods.
    • Use assertEqual and other assertion methods to validate test conditions.

Summary

The unittest framework provides a powerful, structured approach to testing in Python. While it requires more boilerplate code compared to Pytest, it excels in scenarios where strict class-based organization and setup/teardown logic are required.

Refactoring in Python

Refactoring is the process of restructuring existing code without changing its external behavior. It's a disciplined way to clean up code that minimizes the chances of introducing bugs. Refactoring improves the design of software, makes it easier to understand, and can help find bugs or enhance performance.

In this section, we'll explore several refactoring techniques using Python examples. Each example will illustrate how to improve code readability, maintainability, and scalability. We'll start with a piece of code "Before" refactoring, discuss its limitations, and then present the "After" refactored version along with explanations of the improvements made.

General Steps for Refactoring in Python

  1. Understand the Current Code:

    • Analyze the existing functionality to fully understand its purpose and behavior.
    • Identify any potential redundancies, code smells, or areas for improvement.
  2. Write Comprehensive Tests Before Refactoring:

    • Ensure the code is fully covered by unit tests that verify its current behavior.
    • Tests should validate both the expected outputs and any side effects.
    • Without tests, refactoring may inadvertently introduce bugs.
  3. Identify the Target for Refactoring:

    • Locate specific code fragments that can be improved, such as:
      • Long or complex methods.
      • Repeated patterns or code duplication.
      • Code with unclear intent or excessive complexity.
  4. Break Down Refactoring Tasks:

    • Plan the steps needed to refactor the identified code without altering its functionality.
    • Focus on incremental changes to minimize the risk of errors.
  5. Refactor the Code Incrementally:

    • Perform changes step by step, testing after each modification.
    • Examples of refactoring techniques include:
      • Extract Method: Break down large methods into smaller, focused methods.
      • Introduce Variables: Replace complex expressions or repeated calculations with named variables.
      • Replace Magic Numbers: Replace hard-coded values with constants or configuration options.
      • Simplify Conditional Logic: Refactor complex conditionals into more readable structures.
  6. Run Tests Frequently:

    • After each refactoring step, run the tests to ensure the code still behaves as expected.
    • If tests fail, investigate and fix the issue before proceeding.
  7. Optimize and Clean Up:

    • Review the refactored code to ensure it adheres to best practices, such as the Single Responsibility Principle or the DRY (Don't Repeat Yourself) principle.
    • Remove unused code or redundant comments, and ensure the code is well-documented.
  8. Commit Changes with Context:

    • Clearly document the purpose and scope of the refactoring in your version control system.
    • This helps future developers (and your future self) understand why the changes were made.
  9. Perform a Code Review:

    • Share the refactored code with peers for review.
    • Gather feedback to ensure the changes align with team standards and project goals.
  10. Deploy and Monitor:

    • Once the refactored code passes all tests and reviews, deploy it to production.
    • Monitor its behavior to confirm that no unexpected issues arise.

By following these general steps, you ensure that refactoring improves the code without compromising its functionality or introducing regressions. This process promotes cleaner, more maintainable, and scalable Python codebases.

Extract Method

Extract Method is a fundamental refactoring technique where a fragment of code from an existing method or function is moved into a new, separate method. This new method is given a descriptive name that clearly communicates its purpose. The original code segment is then replaced with a call to the new method. This process enhances code clarity, promotes reusability, and aligns with several SOLID principles of object-oriented design.

Purpose of Extract Method

  • Improve Readability: Breaking down complex methods into smaller, well-named methods makes code easier to read and understand.
  • Encourage Reuse: Extracting code into methods allows for reuse in other parts of the codebase, reducing duplication.
  • Simplify Maintenance: Smaller methods are easier to test, debug, and maintain. Changes are localized, minimizing the impact on the rest of the system.
  • Enhance Abstraction: High-level methods can orchestrate the flow by calling lower-level methods, promoting modularity.

Alignment with SOLID Principles

The Extract Method technique closely relates to several SOLID principles:

  1. Single Responsibility Principle (SRP):

    • Definition: A class or method should have one, and only one, reason to change.
    • Alignment: By extracting methods, you ensure that each method has a single responsibility. This reduces complexity and makes the codebase more robust to changes.
  2. Open/Closed Principle (OCP):

    • Definition: Software entities should be open for extension but closed for modification.
    • Alignment: Extracting methods allows you to extend functionality by adding new methods rather than modifying existing ones. This minimizes the risk of introducing bugs in tested code.
  3. Liskov Substitution Principle (LSP):

    • While not directly impacted by Extract Method, having well-defined methods helps in creating subclasses that can override methods without altering the expected behavior, supporting LSP.
  4. Interface Segregation Principle (ISP):

    • Definition: Many client-specific interfaces are better than one general-purpose interface.
    • Alignment: By extracting methods, you effectively create smaller, focused interfaces (methods) that clients can use without depending on larger, monolithic methods.
  5. Dependency Inversion Principle (DIP):

    • Definition: Depend upon abstractions, not concretions.
    • Alignment: Extract Method can help in isolating dependencies and promoting the use of abstractions within methods, thus supporting DIP.

When to Use Extract Method

  • Long or Complex Methods: Simplify methods that are difficult to read or understand due to their length or complexity.
  • Duplicate Code: Eliminate code duplication by extracting common code into a single method.
  • Distinct Responsibilities: Separate concerns when a method performs multiple tasks.
  • Complex Conditionals or Loops: Extract complex logic into well-named methods to clarify intent.

How to Perform Extract Method

  1. Identify the Code Fragment: Locate a section of code within a method that serves a distinct purpose or represents a logical unit.
  2. Create a New Method: Move the identified code into a new method, naming it to reflect its functionality.
  3. Replace the Original Code: In the original method, replace the code fragment with a call to the new method.
  4. Adjust Parameters and Returns: Ensure the new method has the necessary parameters and returns any required values.
  5. Test Thoroughly: Run existing tests to verify that the behavior remains unchanged, ensuring adherence to the Open/Closed Principle.

Exercise: 🏠 Calculator

Use the refactoring guideline to optimize the code. Focus on uisng Extract Method.

%%writefile  calculator.py

class Calculator:
    def calculate(self, operation, a, b):
        if operation == 'add':
            result = a + b
            print(f"The result of adding {a} and {b} is {result}")
            return result
        elif operation == 'subtract':
            result = a - b
            print(f"The result of subtracting {b} from {a} is {result}")
            return result
        elif operation == 'multiply':
            result = a * b
            print(f"The result of multiplying {a} and {b} is {result}")
            return result
        elif operation == 'divide':
            if b == 0:
                raise ValueError("Cannot divide by zero")
            result = a / b
            print(f"The result of dividing {a} by {b} is {result}")
            return result
        else:
            raise ValueError(f"Unknown operation '{operation}'")
Overwriting calculator.py
%%writefile test_calculator.py

import pytest
from calculator import Calculator

def test_add():
    calc = Calculator()
    assert calc.calculate('add', 2, 3) == 5

def test_subtract():
    calc = Calculator()
    assert calc.calculate('subtract', 5, 2) == 3

def test_multiply():
    calc = Calculator()
    assert calc.calculate('multiply', 3, 4) == 12

def test_divide():
    calc = Calculator()
    assert calc.calculate('divide', 10, 2) == 5

def test_divide_by_zero():
    calc = Calculator()
    with pytest.raises(ValueError, match="Cannot divide by zero"):
        calc.calculate('divide', 10, 0)

def test_unknown_operation():
    calc = Calculator()
    with pytest.raises(ValueError, match="Unknown operation 'modulo'"):
        calc.calculate('modulo', 10, 2)
Writing test_calculator.py
! pytest test_calculator.py
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 6 items                                                              

test_calculator.py ......                                                [100%]

============================== 6 passed in 0.10s ===============================
! coverage run -m pytest test_calculator.py -q
! coverage report
......                                                                   [100%]
6 passed in 0.26s
Name                 Stmts   Miss  Cover
----------------------------------------
calculator.py           21      0   100%
conftest.py              5      2    60%
test_calculator.py      22      0   100%
----------------------------------------
TOTAL                   48      2    96%
%%writefile  calculator.py

class Calculator:
    def calculate(self, operation, a, b):
        if operation == 'add':
            return self._add(a, b)
        elif operation == 'subtract':
            return self._subtract(a, b)
        elif operation == 'multiply':
            return self._multiply(a, b)
        elif operation == 'divide':
            return self._divide(a, b)
        else:
            raise ValueError(f"Unknown operation '{operation}'")

    def _add(self, a, b):
        result = a + b
        print(f"The result of adding {a} and {b} is {result}")
        return result

    def _subtract(self, a, b):
        result = a - b
        print(f"The result of subtracting {b} from {a} is {result}")
        return result

    def _multiply(self, a, b):
        result = a * b
        print(f"The result of multiplying {a} and {b} is {result}")
        return result

    def _divide(self, a, b):
        if b == 0:
            raise ValueError("Cannot divide by zero")
        result = a / b
        print(f"The result of dividing {a} by {b} is {result}")
        return result
Overwriting calculator.py
! coverage run -m pytest test_calculator.py -q
! coverage report
......                                                                   [100%]
6 passed in 0.17s
Name                 Stmts   Miss  Cover
----------------------------------------
calculator.py           29      0   100%
conftest.py              5      2    60%
test_calculator.py      22      0   100%
----------------------------------------
TOTAL                   56      2    96%

Example: Registration Form

import unittest

class RegistrationFormTests(unittest.TestCase):
    def test_successful_submission(self):
        form = self.app.get('/register').form

        # Fill in the form
        form['personal'] = 'John'
        form['family'] = 'Smith'
        form['email'] = 'john@smith.com'
        form['location'] = 'Cracow'
        form['country'] = 'PL'
        form['agreed_to_code_of_conduct'] = True
        form['recaptcha_response_field'] = 'PASSED'

        response = form.submit()

        self.assertEqual(response.status_code, 200)
# We've extracted the code responsible for filling in the form into a separate function `fill_in_form`.

class RegistrationFormTests(unittest.TestCase):
    def test_successful_submission(self):
        form = self.app.get('/register').form  # form is a dictionary
        fill_in_form(form)
        response = form.submit()
        self.assertEqual(response.status_code, 200)

def fill_in_form(form):
    form['personal'] = 'John'
    form['family'] = 'Smith'
    form['email'] = 'john@smith.com'
    form['location'] = 'Cracow'
    form['country'] = 'PL'
    form['agreed_to_code_of_conduct'] = True
    form['recaptcha_response_field'] = 'PASSED'

Variables

Refactoring with the "Variables" technique involves simplifying complex expressions by introducing meaningful variables to represent intermediate values or logical components. This approach enhances readability, reduces redundancy, and makes the code easier to maintain and debug. It also helps clarify the intent of calculations and logic by breaking them into smaller, understandable parts.

Steps for Refactoring with Variables

  1. Identify Complex Expressions:

    • Locate sections of the code with:
      • Nested or chained expressions.
      • Repeated calculations.
      • Hard-to-read logic.
    • These are candidates for breaking into smaller, meaningful components.
  2. Assign Intermediate Values to Variables:

    • Extract repeated or complex expressions into descriptive variables.
    • Ensure the variable names clearly convey the meaning of the value they represent.

    Example: Replace:

    total = price * (1 + tax_rate) - discount if price > threshold else price
    

    With:

    tax = price * tax_rate
    discounted_price = price - discount if price > threshold else price
    total = discounted_price + tax
    

    the first expression covers issue

  3. Replace Magic Numbers with Constants:

    • Replace hard-coded numbers with named constants to clarify their purpose.
    • This improves readability and makes future adjustments easier.

    Example:
    Replace:

    if price > 100:
        discount = price * 0.1
    

    With:

    DISCOUNT_THRESHOLD = 100
    DISCOUNT_RATE = 0.1
    if price > DISCOUNT_THRESHOLD:
        discount = price * DISCOUNT_RATE
    
  4. Simplify Conditional Logic:

    • Break down complex conditions into variables or helper functions with descriptive names.

    Example:
    Replace:

    if user.is_active and (user.age > 18 or user.has_permission):
        return "Allowed"
    

    With:

    is_adult = user.age > 18
    is_allowed = user.is_active and (is_adult or user.has_permission)
    if is_allowed:
        return "Allowed"
    
  5. Group Related Calculations:

    • Combine calculations or logical steps into sequential, meaningful variables to highlight their relationships.

    Example:
    Replace:

    net_income = revenue - (expenses + taxes)
    

    With:

    total_costs = expenses + taxes
    net_income = revenue - total_costs
    
  6. Validate Functionality with Tests:

    • Ensure the behavior remains consistent after refactoring by writing or running unit tests.
    • Cover all possible scenarios, including edge cases, to ensure the refactoring did not alter the logic.

Exercise: 🏠 Calculator 2

Use the refactoring guideline to optimize the code. Focus on uisng Variables Technique.

%%writefile  calculator.py

def calculate_subtotal():
    print('Calculate subtotal')
    return 400.0

def calculate_taxable_subtotal():
    print('Calculate taxable subtotal')
    return 300.0

def calculate_total():
    print('Calculate total')
    return (calculate_subtotal() + calculate_taxable_subtotal() * 0.15
            - (calculate_subtotal() * 0.1 if calculate_subtotal() > 100 else 0))

# if __name__ == "__main__":
#     total = calculate_total()
#     print(f'Total: {total}')
Overwriting calculator.py
%%writefile test_calculator.py

# test_calculator.py

import pytest
from calculator import (
    calculate_subtotal,
    calculate_taxable_subtotal,
    calculate_total,
)

def test_calculate_subtotal():
    subtotal = calculate_subtotal()
    assert subtotal == 400.0

def test_calculate_taxable_subtotal():
    taxable_subtotal = calculate_taxable_subtotal()
    assert taxable_subtotal == 300.0

def test_calculate_total():
    total = calculate_total()
    expected_subtotal = 400.0
    expected_taxable_subtotal = 300.0
    expected_discount = expected_subtotal * 0.1 if expected_subtotal > 100 else 0
    expected_tax = expected_taxable_subtotal * 0.15
    expected_total = expected_subtotal + expected_tax - expected_discount
    assert total == expected_total
Overwriting test_calculator.py
! coverage run -m pytest test_calculator.py -q
! coverage report
...                                                                      [100%]
3 passed in 0.06s
Name                 Stmts   Miss  Cover
----------------------------------------
calculator.py            9      0   100%
conftest.py              5      2    60%
test_calculator.py      16      0   100%
----------------------------------------
TOTAL                   30      2    93%
%%writefile  calculator.py

"""
We've introduced variables to store intermediate results:

- **Variable Assignment**: Store the result of `calculate_subtotal()` in `subtotal`.
- **Simplified Logic**: Use an `if` statement to determine the discount.
- **Named Variables**: `discount` and `tax` make the code more understandable.
"""

def calculate_subtotal():
    print('Calculate subtotal')
    return 400.0

def calculate_taxable_subtotal():
    print('Calculate taxable subtotal')
    return 300.0

# Refactored function
def calculate_total():
    print('Calculate total')
    subtotal = calculate_subtotal()
    if subtotal > 100:
        discount = subtotal * 0.1
    else:
        discount = 0
    tax = calculate_taxable_subtotal() * 0.15
    return subtotal + tax - discount

if __name__ == "__main__":
    total = calculate_total()
    print(f'Total: {total}')
Overwriting calculator.py
! coverage run -m pytest test_calculator.py -q
! coverage report
...                                                                      [100%]
3 passed in 0.06s
Name                 Stmts   Miss  Cover
----------------------------------------
calculator.py           17      3    82%
conftest.py              5      2    60%
test_calculator.py      16      0   100%
----------------------------------------
TOTAL                   38      5    87%

Exercise: 🏠 Recently Used List

%%writefile  recently_used_list.py
class RecentlyUsedList(list):
    def append(self, elem):
        if elem in self:
            self.remove(elem)
        super().append(elem)

# if __name__ == "__main__":
#     rul = RecentlyUsedList()
#     rul.append('first')
#     rul.append('second')
#     rul.append('third')
#     rul.append('second')
#     print(rul)

#     rul.extend(['a', 'b', 'first'])  # (!)
#     print(rul)
Overwriting recently_used_list.py

Explanation

The RecentlyUsedList class is intended to keep track of recently used items, moving any existing items to the end when re-added. However, inheriting from list introduces unintended behavior:

  • Unintended Methods: Methods like extend() and insert() are available but may not behave correctly.
  • Violation of Liskov Substitution Principle: The subclass does not behave consistently with the superclass.
%%writefile test_recently_used_list.py
import pytest
from recently_used_list import RecentlyUsedList

def test_initialization():
    rul = RecentlyUsedList()
    assert str(rul) == "[]"

def test_append_new_element():
    rul = RecentlyUsedList()
    rul.append('a')
    assert str(rul) == "['a']"

def test_append_existing_element():
    rul = RecentlyUsedList()
    rul.append('a')
    rul.append('b')
    rul.append('a')
    assert str(rul) == "['b', 'a']"

def test_get_item():
    rul = RecentlyUsedList()
    rul.append('a')
    rul.append('b')
    assert rul[0] == 'a'
    assert rul[1] == 'b'

def test_get_item_out_of_range():
    rul = RecentlyUsedList()
    rul.append('a')
    with pytest.raises(IndexError):
        _ = rul[1]
Overwriting test_recently_used_list.py
! coverage run  -m pytest test_recently_used_list.py -q
! coverage report
.....                                                                    [100%]
5 passed in 0.13s
Name                         Stmts   Miss  Cover
------------------------------------------------
conftest.py                      5      2    60%
recently_used_list.py            5      0   100%
test_recently_used_list.py      26      0   100%
------------------------------------------------
TOTAL                           36      2    94%
%%writefile  recently_used_list.py
class RecentlyUsedList:
    def __init__(self):
        self._list = []

    def append(self, elem):
        if elem in self._list:
            self._list.remove(elem)
        self._list.append(elem)

    def __str__(self):
        return str(self._list)

    def __getitem__(self, index):
        return self._list[index]
Overwriting recently_used_list.py
! coverage run  -m pytest test_recently_used_list.py -q
! coverage report
.....                                                                    [100%]
5 passed in 0.09s
Name                         Stmts   Miss  Cover
------------------------------------------------
conftest.py                      5      2    60%
recently_used_list.py           11      0   100%
test_recently_used_list.py      26      0   100%
------------------------------------------------
TOTAL                           42      2    95%

Explanation

By switching to composition:

  • Controlled Interface: Only the methods we define are available.
  • Encapsulation: Internal implementation details are hidden.
  • Flexibility: We can change how we store items without affecting external code.

Exercise: 🏠 Type Code

# %%writefile shape.py
class Shape:
    def __init__(self, shape):
        assert shape in ('circle', 'dot')
        self.shape = shape

    def draw(self):
        if self.shape == 'circle':
            print('o')
        elif self.shape == 'dot':
            print('.')
circle = Shape('circle')
circle.draw()
o

Explanation

The Shape class uses a type code (shape) to determine which shape to draw, using conditional logic.

Issues

  • Limited Extensibility: Adding new shapes requires modifying the Shape class.
  • Violation of Open/Closed Principle: The class is not closed for modification.
  • Inefficient Design: Uses if statements to handle different behaviors.
%%writefile test_shape.py
import pytest
from contextlib import redirect_stdout
import io
import sys
from shape import Shape

def _test_output(shape, expected_output):
    f = io.StringIO()
    with redirect_stdout(f):
        shape.draw()
    output = f.getvalue()
    assert output == expected_output + '\n'

def test_circle():
    _test_output(shape=Shape('circle'), expected_output='o')

def test_dot():
    _test_output(shape=Shape('dot'), expected_output='.')
Overwriting test_shape.py
! coverage run  -m pytest test_shape.py -q
! coverage report
..                                                                       [100%]
2 passed in 0.09s
Name            Stmts   Miss  Cover
-----------------------------------
conftest.py         5      2    60%
shape.py            9      0   100%
test_shape.py      15      0   100%
-----------------------------------
TOTAL              29      2    93%
%%writefile shape.py
class Circle:
    def draw(self):
        print('o')

class Dot:
    def draw(self):
        print('.')


# Shape factory function
shapes = {
    'circle': Circle,
    'dot': Dot,
}

def Shape(shape):
    factory = shapes[shape]
    return factory()
Overwriting shape.py
! coverage run  -m pytest test_shape.py -q
! coverage report
..                                                                       [100%]
2 passed in 0.11s
Name            Stmts   Miss  Cover
-----------------------------------
conftest.py         5      2    60%
shape.py           10      0   100%
test_shape.py      15      0   100%
-----------------------------------
TOTAL              30      2    93%
%%writefile shape.py
class Shape:
    def __new__(cls, shape):
        factory = shapes[shape]
        return super().__new__(factory)

class Circle(Shape):
    def draw(self):
        print('o')

class Dot(Shape):
    def draw(self):
        print('.')

# Shapes dictionary
shapes = {
    'circle': Circle,
    'dot': Dot,
}
Overwriting shape.py
! coverage run  -m pytest test_shape.py -q
! coverage report
..                                                                       [100%]
2 passed in 0.12s
Name            Stmts   Miss  Cover
-----------------------------------
conftest.py         5      2    60%
shape.py           11      0   100%
test_shape.py      15      0   100%
-----------------------------------
TOTAL              31      2    94%
%%writefile shape.py
class Shape:
    _shapes = {}

    def __new__(cls, shape):
        cls._register_subclasses()
        if shape in cls._shapes:
            factory = cls._shapes[shape]
            return super().__new__(factory)
        raise ValueError(f"Shape '{shape}' is not registered.")

    @classmethod
    def _register_subclasses(cls):
        for subclass in cls.__subclasses__():
            shape_name = subclass.__name__.lower()
            cls._shapes[shape_name] = subclass

# bad usage
# shapes = { subclass.__name__.lower(), subclass for subclass in Shape.__subclasses__() }

class Circle(Shape):
    def draw(self):
        print('o')

class Dot(Shape):
    def draw(self):
        print('.')
Overwriting shape.py
! coverage run  -m pytest test_shape.py -q
! coverage report
..                                                                       [100%]
2 passed in 0.14s
Name            Stmts   Miss  Cover
-----------------------------------
conftest.py         5      2    60%
shape.py           19      1    95%
test_shape.py      15      0   100%
-----------------------------------
TOTAL              39      3    92%

Explanation

  • Overriding __new__: The Shape class overrides the __new__ method to instantiate the correct subclass.
  • Inheritance Hierarchy: Circle and Dot inherit from Shape, allowing isinstance checks.
  • Unified Interface: All shapes are instances of Shape, maintaining consistency.

Conclusion

Refactoring is an essential practice for improving code quality, readability, and maintainability. By applying refactoring techniques, developers can transform complex or inefficient code into cleaner, more efficient, and more understandable versions.

In this training material, we've covered several key refactoring techniques:

  • Extract Method: Simplifies code by moving repeated or complex code into separate methods.
  • Variable Extraction: Improves readability and efficiency by storing intermediate results in variables.
  • Composition Over Inheritance: Enhances flexibility and reduces unintended behaviors by using composition instead of inheritance.
  • Replacing Type Code with Classes: Utilizes polymorphism to eliminate type codes and conditional logic, making the code more extensible and maintainable.

Understanding and applying these techniques allows developers to write better Python code that is easier to understand, test, and maintain. Refactoring should be a continuous part of the development process, helping teams deliver high-quality software.


Key Takeaways for Trainers

  • Highlight the Importance of Refactoring: Emphasize how refactoring leads to better code quality without changing external behavior.
  • Use Practical Examples: The provided before-and-after code snippets serve as concrete examples of how refactoring improves code.
  • Encourage Best Practices: Stress the importance of readability, maintainability, and adherence to design principles like the Open/Closed Principle.
  • Demonstrate Incremental Changes: Show how small, incremental refactorings can lead to significant improvements.
  • Promote Continuous Refactoring: Encourage developers to make refactoring a regular part of their workflow.

By focusing on these points, trainers can effectively convey the value of refactoring and equip learners with practical skills to improve their codebases.

Test-Driven Development (TDD)

Introduction to Test-Driven Development (TDD)

Test-Driven Development (TDD) is a software development process where tests are written before writing the bare minimum of code required for the test to pass. It emphasizes writing small, incremental tests and code in cycles.

Why Use TDD?

  • Improves Code Quality: By writing tests first, developers think about the design and requirements before implementation.
  • Simplifies Debugging: When a test fails, it's easy to locate the issue since it likely lies in the code added since the last passing test.
  • Documentation: Tests serve as documentation of the code's intended behavior.
  • Refactoring Support: Since tests are in place, code can be refactored with confidence, ensuring existing functionality remains intact.

TDD Theory

The TDD Cycle

The TDD process can be broken down into the following steps:

  1. Write a Simplest Failing Test: Start by writing a test that fails because the functionality isn't implemented yet.
  2. Write the Simplest Code to Pass the Test: Write just enough code to make the failing test pass.
  3. Refactor Code and Tests: Improve the code's structure and readability without changing its external behavior, ensuring all tests still pass.

Writing the Simplest Failing Test

  • Purpose: To define a new functionality or requirement.
  • Approach:
    • Write a test for a specific behavior you want to implement.
    • Ensure the test fails initially to verify it's testing the right thing.

Writing the Simplest Code to Pass the Test

  • Purpose: To implement just enough code to make the failing test pass.
  • Approach:
    • Focus on minimal implementation.
    • Avoid adding any additional functionality beyond what's necessary for the test to pass.

Refactoring Code and Tests

  • Purpose: To improve the codebase without altering its external behavior.
  • Approach:
    • Clean up code, eliminate duplication, and improve design.
    • Refactor tests if necessary to improve clarity and maintainability.
    • Ensure all tests continue to pass after refactoring.

TDD in Practice: The Bowling Game Kata (Step-by-Step Guide)

Overview

The Bowling Game Kata is a programming exercise designed to practice TDD. We'll implement a BowlingGame class that can calculate the score of a bowling game based on rolls provided.

Bowling Scoring Rules

  • A game consists of 10 frames.
  • Each frame can be one of the following:
    • Normal frame: The player has two attempts to knock down 10 pins.
    • Spare: The player knocks down all 10 pins in two attempts. The score for that frame is 10 plus the number of pins knocked down in the next roll.
    • Strike: The player knocks down all 10 pins on the first attempt. The score for that frame is 10 plus the number of pins knocked down in the next two rolls.
  • In the 10th frame, if the player scores a spare or strike, they get extra rolls accordingly.

Step 1: Setting Up the Project

Before we start, ensure you have the following setup:

  • Python 3.x installed.
  • pytest installed (pip install pytest).

We'll create two files:

  • bowling_game.py: Contains the BowlingGame class.
  • test_bowling_game.py: Contains the tests for the BowlingGame class.

Step 2: Write the First Failing Test

Test: Gutter Game (All Zeros)

Objective: Ensure that when the player rolls all zeros, the total score is zero.

Writing the Test

Create test_bowling_game.py and add the following code:

# test_bowling_game.py
import pytest
from bowling_game import BowlingGame

@pytest.fixture
def game():
    return BowlingGame()

def test_gitter_gane(game):

    for _ in range(20):
        game.roll(0)
    assert game.score() == 0

Explanation:

  • We use a fixture game to create a new instance of BowlingGame before each test.
  • In test_gutter_game, we simulate 20 rolls of 0 pins.
  • We assert that the final score should be 0.

Running the Test

Run the test with:

pytest test_bowling_game.py

Expected Result: The test should fail because BowlingGame is not implemented yet.

!pytest katta/test_bowling_game.py
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 0 items / 1 error                                                    

==================================== ERRORS ====================================
_________________ ERROR collecting katta/test_bowling_game.py __________________
ImportError while importing test module '/Users/a563420/python_training/testing/katta/test_bowling_game.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
../../.pyenv/versions/3.12.6/lib/python3.12/importlib/__init__.py:90: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
katta/test_bowling_game.py:2: in <module>
    from bowling_game import BowlingGame
E   ImportError: cannot import name 'BowlingGame' from 'bowling_game' (/Users/a563420/python_training/testing/katta/bowling_game.py)
=========================== short test summary info ============================
ERROR katta/test_bowling_game.py
!!!!!!!!!!!!!!!!!!!! Interrupted: 1 error during collection !!!!!!!!!!!!!!!!!!!!
=============================== 1 error in 0.54s ===============================

Step 3: Implement Minimal Code to Pass the Test

Implementing BowlingGame

Create bowling_game.py with the minimal implementation:

class BowlingGame:
    def roll(self, pins):
        pass

    def score(self):
        return 0
``

**Explanation**:

- We define the `BowlingGame` class with two methods:
  - `roll(pins)`: Records a roll. Currently, it does nothing.
  - `score()`: Returns the total score. Currently, it always returns 0.

#### Running the Test Again

Run:

```bash
pytest test_bowling_game.py

Expected Result: The test should now pass.

!pytest katta/test_bowling_game.py -sv
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 1 item                                                               

katta/test_bowling_game.py::test_gitter_gane PASSED

============================== 1 passed in 0.08s ===============================

Step 4: Write the Next Failing Test

Test: All Ones

Objective: Ensure that when the player rolls all ones, the total score is 20.

Writing the Test

Add the following test to test_bowling_game.py:

def test_all_ones(game):
    for _ in range(20):
        game.roll(1)
    assert game.score() == 20

Explanation:

  • We simulate 20 rolls of 1 pin each.
  • We assert that the final score should be 20.

Running the Test

Run:

pytest test_bowling_game.py

Expected Result: The test should fail because score() always returns 0.

!pytest katta/test_bowling_game.py -sv
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 2 items                                                              

katta/test_bowling_game.py::test_gitter_gane PASSED
katta/test_bowling_game.py::test_all_ones FAILED

=================================== FAILURES ===================================
________________________________ test_all_ones _________________________________

game = <bowling_game.BowlingGame object at 0x10be6e390>

    def test_all_ones(game):
        for _ in range(20):
            game.roll(1)
>       assert game.score() == 20
E       assert 0 == 20
E        +  where 0 = <bound method BowlingGame.score of <bowling_game.BowlingGame object at 0x10be6e390>>()
E        +    where <bound method BowlingGame.score of <bowling_game.BowlingGame object at 0x10be6e390>> = <bowling_game.BowlingGame object at 0x10be6e390>.score

katta/test_bowling_game.py:17: AssertionError
=========================== short test summary info ============================
FAILED katta/test_bowling_game.py::test_all_ones - assert 0 == 20
========================= 1 failed, 1 passed in 0.44s ==========================

Step 5: Update the Code to Pass the Test

Modifying BowlingGame

Update bowling_game.py:

class BowlingGame:
    def __init__(self):
        self.rolls = []

    def roll(self, pins):
        self.rolls.append(pins)

    def score(self):
        return sum(self.rolls)

Explanation:

  • In __init__, we initialize an empty list rolls to keep track of all the rolls.
  • In roll(pins), we append each roll to the rolls list.
  • In score(), we return the sum of all rolls.

Running the Test Again

Run:

pytest test_bowling_game.py

Expected Result: Both tests (test_gutter_game and test_all_ones) should now pass.

!pytest katta/test_bowling_game.py -sv
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 2 items                                                              

katta/test_bowling_game.py::test_gitter_gane PASSED
katta/test_bowling_game.py::test_all_ones PASSED

============================== 2 passed in 0.09s ===============================

Step 6: Write a Test for Spares

Test: One Spare

Objective: Ensure that a spare is scored correctly.

  • When a spare is rolled, the frame score is 10 plus the number of pins knocked down in the next roll.

Writing the Test

Add the following test to test_bowling_game.py:

def test_one_spare(game):
    game.roll(5)
    game.roll(5)

    for _ in range(18):
        game.roll(1)

    assert game.score() == 29

Explanation:

  • Rolls: 5, 5 (spare), then 3.
  • The spare bonus is the next roll, which is 3.
  • Frame score: 10 + 3 = 13.
  • Total score: 13 (first frame) + 3 (second frame) = 16.

Running the Test

Run:

pytest test_bowling_game.py

Expected Result: The test should fail because the current implementation doesn't handle spares.

!pytest katta/test_bowling_game.py -sv
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 3 items                                                              

katta/test_bowling_game.py::test_gitter_gane PASSED
katta/test_bowling_game.py::test_all_ones PASSED
katta/test_bowling_game.py::test_one_spare FAILED

=================================== FAILURES ===================================
________________________________ test_one_spare ________________________________

game = <bowling_game.BowlingGame object at 0x110ed8b90>

    def test_one_spare(game):
        game.roll(5)
        game.roll(5)
    
        for _ in range(18):
            game.roll(1)
    
>       assert game.score() == 29
E       assert 28 == 29
E        +  where 28 = <bound method BowlingGame.score of <bowling_game.BowlingGame object at 0x110ed8b90>>()
E        +    where <bound method BowlingGame.score of <bowling_game.BowlingGame object at 0x110ed8b90>> = <bowling_game.BowlingGame object at 0x110ed8b90>.score

katta/test_bowling_game.py:27: AssertionError
=========================== short test summary info ============================
FAILED katta/test_bowling_game.py::test_one_spare - assert 28 == 29
========================= 1 failed, 2 passed in 0.33s ==========================

Step 7: Update Code to Handle Spares

Modifying BowlingGame

Update bowling_game.py:

class BowlingGame:
    def __init__(self):
        self.rolls = []

    def roll(self, pins):
        self.rolls.append(pins)

    def score(self):

        total = 0
        roll_index = 0

        for frame in range(10):
            if self.is_spare(roll_index):
                total += 10 + self.rolls[roll_index +2]
                roll_index += 2
            else:
                total += self.rolls[roll_index] + self.rolls[roll_index +1]
                roll_index += 2

        return total


    def is_spare(self, roll_index):
        return self.rolls[roll_index] + self.rolls[roll_index + 1] == 10

Explanation:

  • We iterate over 10 frames.
  • is_spare(roll_index): Checks if the sum of two rolls is 10.
  • If it's a spare, we add the spare bonus (next roll).
  • Otherwise, we sum the two rolls.
  • We increment roll_index accordingly.

Running the Test Again

Run:

pytest test_bowling_game.py

Expected Result: The spare test should pass, along with the previous tests.

!pytest katta/test_bowling_game.py -sv
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 3 items                                                              

katta/test_bowling_game.py::test_gitter_gane PASSED
katta/test_bowling_game.py::test_all_ones PASSED
katta/test_bowling_game.py::test_one_spare PASSED

============================== 3 passed in 0.09s ===============================

Step 8: Write a Test for Strikes

Test: One Strike

Objective: Ensure that a strike is scored correctly.

  • When a strike is rolled, the frame score is 10 plus the number of pins knocked down in the next two rolls.

Writing the Test

Add the following test to test_bowling_game.py:

def test_one_strike(game):

    game.roll(10)
    game.roll(3)
    game.roll(2)

    for _ in range(16):
        game.roll(1)

    assert game.score() == 36

Explanation:

  • Rolls: 10 (strike), then 3, 4.
  • The strike bonus is the next two rolls: 3 + 4.
  • Frame score: 10 + 3 + 4 = 17.
  • Second frame: 3 + 4 = 7.
  • Total score: 17 + 7 = 24.

Running the Test

Run:

pytest test_bowling_game.py

Expected Result: The test should fail because the current implementation doesn't handle strikes.

!pytest katta/test_bowling_game.py -sv
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 4 items                                                              

katta/test_bowling_game.py::test_gitter_gane PASSED
katta/test_bowling_game.py::test_all_ones PASSED
katta/test_bowling_game.py::test_one_spare PASSED
katta/test_bowling_game.py::test_one_strike FAILED

=================================== FAILURES ===================================
_______________________________ test_one_strike ________________________________

game = <bowling_game.BowlingGame object at 0x103facbc0>

    def test_one_strike(game):
    
        game.roll(10)
        game.roll(3)
        game.roll(2)
    
        for _ in range(16):
            game.roll(1)
    
>       assert game.score() == 36

katta/test_bowling_game.py:38: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
katta/bowling_game.py:15: in score
    if self.is_spare(roll_index):
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <bowling_game.BowlingGame object at 0x103facbc0>, roll_index = 18

    def is_spare(self, roll_index):
>       return self.rolls[roll_index] + self.rolls[roll_index + 1] == 10
E       IndexError: list index out of range

katta/bowling_game.py:26: IndexError
=========================== short test summary info ============================
FAILED katta/test_bowling_game.py::test_one_strike - IndexError: list index out of range
========================= 1 failed, 3 passed in 0.38s ==========================

Step 9: Update Code to Handle Strikes

Modifying BowlingGame

Update bowling_game.py:

class BowlingGame:
    def __init__(self):
        self.rolls = []

    def roll(self, pins):
        self.rolls.append(pins)


    def score(self):

        total = 0
        roll_index = 0
        for frame in range(10):
            if self.is_strike(roll_index):
                total += 10 + self.rolls[roll_index +1] + self.rolls[roll_index +2]
                roll_index += 1

            elif self.is_spare(roll_index):
                total += 10 + self.rolls[roll_index + 2]
                roll_index += 2
            else:
                total += self.rolls[roll_index] + self.rolls[roll_index + 1]
                roll_index += 2
        return total


    def is_spare(self, roll_index):
        return self.rolls[roll_index] + self.rolls[roll_index + 1] == 10


    def is_strike(self, roll_index):
        return self.rolls[roll_index] == 10

Explanation:

  • Added is_strike(roll_index) to check if the roll is a strike.
  • If it's a strike, we calculate the strike bonus (next two rolls) and increment roll_index by 1.
  • The rest remains the same.

Running the Test Again

Run:

pytest test_bowling_game.py

Expected Result: The strike test should pass, along with the previous tests.

!pytest katta/test_bowling_game.py -sv
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 4 items                                                              

katta/test_bowling_game.py::test_gitter_gane PASSED
katta/test_bowling_game.py::test_all_ones PASSED
katta/test_bowling_game.py::test_one_spare PASSED
katta/test_bowling_game.py::test_one_strike PASSED

============================== 4 passed in 0.16s ===============================

Step 10: Refactor the Code

At this point, our code handles basic scoring, spares, and strikes. We can refactor to improve readability.

Refactoring BowlingGame

Update bowling_game.py:

class BowlingGame:
    def __init__(self):
        self.rolls = []

    def roll(self, pins):
        self.rolls.append(pins)

    def score(self):
        total = 0
        roll_index = 0
        for frame in range(10):
            if self.is_strike(roll_index):
                total += self.strike_score(roll_index)
                roll_index += 1
            elif self.is_spare(roll_index):
                total += self.spare_score(roll_index)
                roll_index += 2
            else:
                total += self.frame_score(roll_index)
                roll_index += 2
        return total

    def is_spare(self, roll_index):
        return self.frame_score(roll_index) == 10

    def is_strike(self, roll_index):
        return self.rolls[roll_index] == 10

    def frame_score(self, roll_index):
        return self.rolls[roll_index] + self.rolls[roll_index + 1]

    def spare_score(self, roll_index):
        return 10 + self.rolls[roll_index + 2]

    def strike_score(self, roll_index):
        return 10 + self.rolls[roll_index + 1] + self.rolls[roll_index + 2]

Explanation:

  • Extracted methods frame_score, spare_score, and strike_score for clarity.
  • This makes the score() method cleaner and easier to understand.
!pytest katta/test_bowling_game.py -sv
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 4 items                                                              

katta/test_bowling_game.py::test_gitter_gane PASSED
katta/test_bowling_game.py::test_all_ones PASSED
katta/test_bowling_game.py::test_one_spare PASSED
katta/test_bowling_game.py::test_one_strike PASSED

============================== 4 passed in 0.10s ===============================

Step 11: Write a Test for a Perfect Game

Test: Perfect Game

Objective: Ensure that a perfect game (12 strikes) scores 300.

Writing the Test

Add the following test to test_bowling_game.py:

def test_perfect_game(game):
    for _ in range(12):
        game.roll(10)
    assert game.score() == 300

Explanation:

  • In a perfect game, the player rolls 12 strikes (the extra two strikes are for the 10th frame bonus).
  • The total score should be 300.

Running the Test

Run:

pytest test_bowling_game.py

Expected Result: The test should success because our current code handle the 10th frame correctly.

!pytest katta/test_bowling_game.py -sv
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 5 items                                                              

katta/test_bowling_game.py::test_gitter_gane PASSED
katta/test_bowling_game.py::test_all_ones PASSED
katta/test_bowling_game.py::test_one_spare PASSED
katta/test_bowling_game.py::test_one_strike PASSED
katta/test_bowling_game.py::test_perfect_game PASSED

============================== 5 passed in 0.10s ===============================

Step 12: Additional Tests and Edge Cases

Test: All Spares

Objective: Ensure that a game of all spares scores correctly.

Writing the Test

Add to test_bowling_game.py:

def test_all_spares(game):
    for _ in range(21):
        game.roll(5)
    assert game.score() == 150

Explanation:

  • 21 rolls of 5 pins each (the extra roll is for the 10th frame spare).
  • Each frame score: 10 + next roll (5).
  • Total score: 10 + 5 per frame, over 10 frames: (15 * 10) = 150.
!pytest katta/test_bowling_game.py -sv
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 6 items                                                              

katta/test_bowling_game.py::test_gitter_gane PASSED
katta/test_bowling_game.py::test_all_ones PASSED
katta/test_bowling_game.py::test_one_spare PASSED
katta/test_bowling_game.py::test_one_strike PASSED
katta/test_bowling_game.py::test_perfect_game PASSED
katta/test_bowling_game.py::test_all_spares PASSED

============================== 6 passed in 0.12s ===============================
import pytest
from bowling_game import BowlingGame

@pytest.fixture
def game():
    return BowlingGame()

def test_gitter_gane(game):

    for _ in range(20):
        game.roll(0)
    assert game.score() == 0

def test_all_ones(game):
    for _ in range(20):
        game.roll(1)
    assert game.score() == 20


def test_one_spare(game):
    game.roll(5)
    game.roll(5)

    for _ in range(18):
        game.roll(1)

    assert game.score() == 29

def test_one_strike(game):

    game.roll(10)
    game.roll(3)
    game.roll(2)

    for _ in range(16):
        game.roll(1)

    assert game.score() == 36

def test_perfect_game(game):
    for _ in range(12):
        game.roll(10)
    assert game.score() == 300


def test_all_spares(game):
    for _ in range(21):
        game.roll(5)
    assert game.score() == 150

🏠 Exercises

  1. Handle Invalid Inputs:

    • Modify the roll method to handle invalid inputs (e.g., negative pins, pins greater than 10).
    • Write tests to ensure that invalid inputs raise appropriate exceptions.
  1. Write a parametrized test and test a different paths of the game
  1. Random Game Simulation:

    • Write a function to simulate a random bowling game.
    • Ensure that the total score calculated matches the expected score.
  2. Handle Incomplete Game:

    • Modify the code to handle cases where roll_index may go out of bounds:
    def test_incomplete_game_with_spare_and_strike(game):
      # Frame 1: Strike (10 points)
      game.roll(10)
      # Frame 2: Spare (5 + 5 = 10 points)
      game.roll(5)
      game.roll(5)
      # Frame 3: Partial frame (only one roll)
      game.roll(3)
      # No further rolls provided, leaving the game incomplete
    
      # Expected score:
      # Frame 1: Strike = 10 + 5 + 5 = 20
      # Frame 2: Spare = 10 + 3 = 13
      # Frame 3: 3 points
      # Total = 20 + 13 + 3 = 36
      assert game.score() == 36
    
  3. Implement a Command-Line Interface:

    • Allow users to input rolls and display the score after each frame.
    • Use the BowlingGame class as the backend.

Discussion

Pros:

  • Tests serve as documentation, are readable, and independent of each other (orthogonal).
  • Design focuses on interfaces rather than implementation, leading to more convenient interfaces.
  • It’s easy to revert to the last working version because tests are consistently maintained, and progress is made in small steps.
  • TDD is a methodology that is language-agnostic and tool-agnostic.

Cons:

  • It takes more time when starting from scratch. On the other hand, the difference is negligible in large projects. => It's not particularly useful for prototypes or short projects (in such cases, testing in general may not make much sense).
  • You cannot follow TDD rules 100% of the time; they need to be filtered through personal experience.
  • TDD doesn’t answer all questions: What about E2E tests? What about multi-layered architecture? (Hint: BDD).

The FIRST Principles of Testing

In this chapter, we delve into the FIRST principles of testing—Fast, Independent, Repeatable, Self-validating, and Timely/Thorough. Understanding and applying these principles will help you write effective and efficient tests for your Python applications. We'll explore each principle in detail, providing practical examples and best practices to enhance your testing methodologies.

Introduction to FIRST Principles

The FIRST principles serve as guidelines to create high-quality tests that are maintainable and reliable. They ensure that your test suite is:

  • Fast: Encourages frequent execution of tests without hesitation.
  • Independent: Ensures tests do not rely on each other's state or the environment.
  • Repeatable: Guarantees consistent results every time the tests are run.
  • Self-validating: Eliminates the need for manual result inspection.
  • Timely/Thorough: Promotes writing tests at the right time and covering all necessary scenarios.

Fast

Importance of Speed in Testing

Tests should execute swiftly to encourage developers to run them frequently. If tests are slow, developers may hesitate to run them often, leading to less frequent integration and delayed detection of issues.

Best Practices for Fast Tests

  • Optimize Setup and Teardown: Minimize the time spent in setting up and tearing down tests.
  • Mock External Dependencies: Use mocking to simulate external systems or services.
  • Run Tests in Memory: Avoid disk I/O operations if possible.

Python Example

import unittest
from unittest.mock import MagicMock

class FastTest(unittest.TestCase):
    def test_fast_execution(self):
        # Arrange
        external_service = MagicMock()
        external_service.get_data.return_value = {'key': 'value'}

        # Act
        result = external_service.get_data()

        # Assert
        self.assertEqual(result, {'key': 'value'})

Independent

The 3 As: Arrange, Act, Assert

  • Arrange: Set up the data and environment for the test.
  • Act: Execute the functionality under test.
  • Assert: Verify the outcome.

Ensuring Independence

  • Isolate Test Data: Each test should create its own data.
  • No Order Dependency: Tests should pass or fail regardless of the execution order.
  • Avoid Shared State: Do not rely on data modified by other tests.

Best Practices

  • Use fixtures or setup methods to initialize data.
  • Clean up any changes made during the test in the teardown phase.
  • Use mocking to isolate external dependencies.

Python Example

import pytest
def process_data(data):
    # Sample function to double the input value
    return data['input'] * 2

def test_independent_behavior():
    # Arrange
    data = {'input': 10}

    # Act
    result = process_data(data)

    # Assert
    assert result == 20

Repeatable

Deterministic Results

Tests should produce the same results every time they are run, regardless of the environment or timing.

Avoiding Environmental Dependencies

  • Control Randomness: Seed random number generators.
  • Mock Time-dependent Functions: Replace actual time functions with fixed values.

Best Practices

  • Use Mocking: Replace non-deterministic functions with mocks.
  • Data Helpers: Use helper functions or classes to set up consistent data.

Python Examples

Controlling Randomness:

import random

def test_random_number():
    random.seed(0)
    assert random.random() == 0.8444218515250481

The random module in Python uses a pseudo-random number generator (PRNG) algorithm. When you set the seed using random.seed(0), it initializes the PRNG to a specific state. This ensures that the sequence of random numbers generated is reproducible.

The number 0.8444218515250481 is the first number in the sequence generated by the PRNG when the seed is set to 0. The PRNG algorithm used by Python's random module is deterministic, meaning that given the same initial seed, it will always produce the same sequence of numbers.

Mocking Date and Time:

from unittest.mock import patch
import datetime

def get_current_time():
    return datetime.datetime.now()

@patch('datetime.datetime')
def test_time_dependent_function(mock_datetime):
    mock_datetime.now.return_value = datetime.datetime(2020, 1, 1)
    assert get_current_time() == datetime.datetime(2020, 1, 1)

Self-Validating

Automated Validation

Tests should automatically determine whether they pass or fail without requiring manual inspection.

Best Practices

  • Use Assertions: Rely on assertions to validate outcomes.
  • Avoid Print Statements: Do not use print statements for validation.
  • Consistent Assertion Messages: Provide clear messages for assertion failures.

Python Example

def test_self_validating():
    result = perform_calculation(5, 5)
    assert result == 10, "The calculation result should be 10."

Timely/Thorough

Writing Tests at the Right Time

  • Test-Driven Development (TDD): Write tests before the actual code.
  • Behavior-Driven Development (BDD): Focus on the behavior and write tests accordingly.

Ensuring Thoroughness

  • Cover Edge Cases: Test boundary conditions and corner cases.
  • Test with Large Data Sets: Assess performance and scalability.
  • Security Testing: Validate behavior with different user roles and permissions.
  • Error Handling: Test exceptions, errors, and invalid inputs.

Best Practices

  • Aim Beyond 100% Coverage: Focus on meaningful test cases rather than just coverage metrics.
  • Use Parameterized Tests: Test multiple inputs with the same test logic.
  • Include Negative Tests: Ensure the system behaves correctly with invalid inputs.

Python Examples

Testing Edge Cases:

def test_edge_case_zero():
    result = divide(10, 0)
    assert result is None, "Division by zero should return None."

Parameterized Tests with pytest:

import pytest

@pytest.mark.parametrize("input,expected", [
    (0, 1),
    (1, 1),
    (5, 120),
])
def test_factorial(input, expected):
    assert factorial(input) == expected

Summary

Applying the FIRST principles ensures that your tests are efficient, reliable, and maintainable. By writing tests that are fast, independent, repeatable, self-validating, and timely/thorough, you improve the quality of your software and the confidence in your codebase.


Further Reading

  • Test-Driven Development by Example by Kent Beck
  • Python Testing with pytest by Brian Okken
  • Clean Code by Robert C. Martin (Uncle Bob)

Mocking in Python Testing

In this chapter, we explore the concept of mocking in Python testing. Mocking allows you to isolate the system under test (SUT) by replacing parts of the system that are external or not easily testable with mock objects. This is especially useful when the SUT interacts with external resources like databases, file systems, or network services.

Introduction to Mocking

What is Mocking?

Mocking is a technique used in unit testing where you replace real objects with mock objects that simulate the behavior of the real ones. This helps in isolating the code under test and controlling the environment in which the tests run.

Why Use Mocking?

  • Isolation: Test components in isolation without dependencies.
  • Control: Simulate specific scenarios and edge cases.
  • Performance: Avoid slow operations like network calls or file I/O.
  • Reliability: Eliminate flakiness due to external factors.

System Under Test (SUT)

Let's consider a simple module as our system under test.

The my_remove_module.py Module

%%writefile my_remove_module.py
# my_remove_module.py
import os

DEFAULT_EXTENSION = '.txt'

def my_remove(filename):
    if '.' not in filename:
        filename += DEFAULT_EXTENSION
    os.remove(filename)
Overwriting my_remove_module.py

Explanation:

  • Imports: The module imports the os module to interact with the operating system.
  • Constants: DEFAULT_EXTENSION is set to '.txt'.
  • Function: my_remove takes a filename and removes it from the file system.
    • If the filename does not contain an extension (no . character), it appends the default extension.

Testing Without Mocking

Initially, we'll write tests without using mocking to understand the limitations.

Test Cases Without Mocking

%%writefile test_my_remove_module.py
# test_my_remove_module.py
import os
from my_remove_module import my_remove

def test_provided_extension_should_be_used():
    filename = 'file.md'
    # Create the file
    open(filename, 'w').close()
    assert os.path.isfile(filename)
    # Call the function
    my_remove(filename)
    assert not os.path.isfile(filename)

def test_when_extension_is_missing_then_use_default_one():
    filename = 'file.txt'
    # Create the file
    open(filename, 'w').close()
    assert os.path.isfile(filename)
    # Call the function with filename without extension
    my_remove('file')
    assert not os.path.isfile(filename)
Overwriting test_my_remove_module.py
!pytest test_my_remove_module.py -v
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 2 items                                                              

test_my_remove_module.py::test_provided_extension_should_be_used PASSED  [ 50%]
test_my_remove_module.py::test_when_extension_is_missing_then_use_default_one PASSED [100%]

============================== 2 passed in 0.12s ===============================

Limitations Without Mocking:

  • Dependency on File System: The tests interact with the actual file system.
  • Side Effects: Creates and deletes real files, which can be risky.
  • Performance: File I/O operations can slow down the tests.
  • Environment Sensitivity: Tests may fail if the environment doesn't permit file operations.

Testing With Mocking

To overcome the limitations, we'll use mocking to simulate the file system operations.

Using unittest.mock

Python's unittest.mock library provides tools for mocking objects in tests.

Refactoring Tests with Mocking

%%writefile test_my_remove_module_mock.py
# test_my_remove_module_mock.py
from unittest import mock
from my_remove_module import my_remove

@mock.patch('my_remove_module.os')
def test_provided_extension_should_be_used(mock_os):
    filename = 'file.md'
    my_remove(filename)
    mock_os.remove.assert_called_once_with(filename)

@mock.patch('my_remove_module.os')
def test_when_extension_is_missing_then_use_default_one(mock_os):
    my_remove('file')
    # my_remove('file')
    mock_os.remove.assert_called_once_with('file.txt')
Overwriting test_my_remove_module_mock.py
!pytest test_my_remove_module_mock.py -v
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 2 items                                                              

test_my_remove_module_mock.py::test_provided_extension_should_be_used PASSED [ 50%]
test_my_remove_module_mock.py::test_when_extension_is_missing_then_use_default_one PASSED [100%]

============================== 2 passed in 0.19s ===============================

Explanation:

  • Decorators: Use @mock.patch('my_remove_module.os') to replace the os module in my_remove_module with a mock.
  • Mock Objects:
    • mock_os is the mocked version of the os module.
  • Assertions:
    • assert_called_once_with verifies that os.remove was called exactly once with the specified argument.

Understanding mock.patch

How mock.patch Works

  • Target: The string 'my_remove_module.os' specifies the exact location to patch.
    • It means, "In the module my_remove_module, replace os with a mock."
  • Replacement: The original os module is temporarily replaced with a mock during the test.

Common Pitfall

It's crucial to patch the object in the module where it's used, not where it's imported from. Patching os directly won't affect my_remove_module.os because my_remove_module has its own reference to os.

Manual Patching

Sometimes, you might need to patch objects manually without decorators.

Manual Patching Example

%%writefile test_my_remove_module_manual.py
# test_my_remove_module_manual.py
from unittest import mock
import my_remove_module
from my_remove_module import my_remove

def test_when_extension_is_missing_then_use_default_one():
    # Save the real os module
    real_os = my_remove_module.os
    # Replace os with a mock
    my_remove_module.os = mock.MagicMock()
    try:
        filename = 'file.md'
        my_remove(filename)
        # Assert os.remove was called correctly
        my_remove_module.os.remove.assert_called_once_with(filename)
    finally:
        # Restore the real os module
        my_remove_module.os = real_os
Overwriting test_my_remove_module_manual.py
!pytest test_my_remove_module_manual.py -v
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 1 item                                                               

test_my_remove_module_manual.py::test_when_extension_is_missing_then_use_default_one PASSED [100%]

============================== 1 passed in 0.18s ===============================

Explanation:

  • Manual Replacement:
    • Save the real os module.
    • Replace my_remove_module.os with a mock object.
  • Try-Finally Block:
    • Ensure the real os module is restored after the test, even if an exception occurs.
  • Assertion:
    • Check that os.remove was called with the expected filename.

When to Use Manual Patching

  • When you need more control over the patching process.
  • When decorators or context managers are not suitable.

Mocking Functions

Essentials

The unittest.mock library provides the MagicMock class, which is a powerful and flexible mock object that can mimic any Python object.

Creating a MagicMock instance:

from unittest import mock

m = mock.MagicMock()
print(m)
print(m())
print(m().revert(0).delete())
<MagicMock id='4398390144'>
<MagicMock name='mock()' id='4397114688'>
<MagicMock name='mock().revert().delete()' id='4393699280'>
  • m: A MagicMock instance.
  • m(): Calling the mock as if it were a function returns another MagicMock instance.

Setting Return Values:

You can specify the return value of a mock function using the return_value attribute.

m = mock.MagicMock()
m.return_value = 42

result = m(84, foo=3)

assert result == 42
print(result)
42

Asserting Calls:

You can assert that a mock was called with specific arguments.

m.assert_called_once_with(84, foo=3)

If the mock was not called with the specified arguments exactly once, an AssertionError is raised.

Checking Call Arguments:

You can inspect how a mock was called using call_args.

print(m.call_args)
call(84, foo=3)
assert m.call_args == mock.call(84, foo=3)

Example: Mocking a Function in a Module

Let's consider a module power_reset.py that performs a POST request to reset power.

power_reset.py:

%%writefile power_reset.py
# power_reset.py
import requests

def power_reset():
    ret = requests.post('https://127.0.0.1:8000/power_reset', json={})
    print(ret)
    if ret != 42:
        raise Exception("Power reset failed")
Overwriting power_reset.py

Explanation:

  • The power_reset function sends a POST request to a local URL.
  • It prints the response and raises an exception if the response is not 42.

Writing a Test with Mocking:

We can mock the requests module to simulate the POST request without actually performing it.

test_power_reset.py:

%%writefile test_power_reset.py
# test_power_reset.py
from unittest import mock
from power_reset import power_reset

@mock.patch('power_reset.requests')
def test_power_reset(mock_requests):
    # Arrange
    mock_requests.post.return_value = 42  # Mock the return value of requests.post

    # Act
    power_reset()

    # Assert
    mock_requests.post.assert_called_once_with(
        'https://127.0.0.1:8000/power_reset', json={}
    )
Overwriting test_power_reset.py

*Explanation:**

  • We use @mock.patch('power_reset.requests') to replace the requests module in power_reset with a mock.
  • We set the return_value of mock_requests.post to 42.
  • We call power_reset() and assert that requests.post was called once with the specified URL and JSON data.

Running the Test:

!pytest test_power_reset.py -vs
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 1 item                                                               

test_power_reset.py::test_power_reset 42
PASSED

============================== 1 passed in 0.61s ===============================

Raising Exceptions with Mocks

You can configure a mock to raise an exception when called, using the side_effect attribute.

Example:

m = mock.MagicMock()
m.side_effect = ZeroDivisionError

try:
    m()
except ZeroDivisionError:
    print("Caught ZeroDivisionError as expected.")
Caught ZeroDivisionError as expected.

Explanation:

  • Setting m.side_effect = ZeroDivisionError causes the mock to raise ZeroDivisionError when called.
  • This is useful for testing how your code handles exceptions from dependencies.

Returning Multiple Values

You can use side_effect to specify a list of return values or exceptions for successive calls.

Example:

m = mock.MagicMock()
m.side_effect = [1, 2, KeyError, 3]

print(m())  # Outputs: 1
print(m())  # Outputs: 2

try:
    m()      # Raises KeyError
except KeyError:
    print("Caught KeyError as expected.")

print(m())  # Outputs: 3
1
2
Caught KeyError as expected.
3

Explanation:

  • Each call to m() returns the next value in the side_effect list.
  • If an exception is encountered in the list, it is raised when the mock is called.

Mocking with Functions and Lambdas

You can set side_effect to a function or lambda to dynamically determine the return value based on the input arguments.

Using a Lambda Function:

m = mock.MagicMock()
m.side_effect = lambda x: x + 2

print(m(5))  # Outputs: 7
7

Using a Defined Function:

def add_two(x):
    return x + 2

m = mock.MagicMock()
m.side_effect = add_two

print(m(5))  # Outputs: 7
7
def add_two(x):
    return x + 2

m = mock.MagicMock()
m.side_effect = add_two

print(m(5,6))  # Outputs: 7
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[82], line 7
      4 m = mock.MagicMock()
      5 m.side_effect = add_two
----> 7 print(m(5,6))  # Outputs: 7

File ~/.pyenv/versions/3.12.6/lib/python3.12/unittest/mock.py:1137, in CallableMixin.__call__(self, *args, **kwargs)
   1135 self._mock_check_sig(*args, **kwargs)
   1136 self._increment_mock_call(*args, **kwargs)
-> 1137 return self._mock_call(*args, **kwargs)

File ~/.pyenv/versions/3.12.6/lib/python3.12/unittest/mock.py:1141, in CallableMixin._mock_call(self, *args, **kwargs)
   1140 def _mock_call(self, /, *args, **kwargs):
-> 1141     return self._execute_mock_call(*args, **kwargs)

File ~/.pyenv/versions/3.12.6/lib/python3.12/unittest/mock.py:1202, in CallableMixin._execute_mock_call(self, *args, **kwargs)
   1200         raise result
   1201 else:
-> 1202     result = effect(*args, **kwargs)
   1204 if result is not DEFAULT:
   1205     return result

TypeError: add_two() takes 1 positional argument but 2 were given
def add_two(x, y):
    return x + y + 2

m = mock.MagicMock()
m.side_effect = add_two

print(m(5, 6))  # Outputs: 13
13

Explanation:

  • The mock will call the function specified in side_effect when invoked.
  • This allows for more complex logic in determining the return value.

Assertions on Mocks

Mocks provide several assertion methods to check how they were called.

Example:

m = mock.MagicMock()
m(10)
m(20)

# Assert that the last call was with 20
m.assert_called_with(20)

# This will raise an AssertionError because the mock was called twice
try:
    m.assert_called_once_with(20)
except AssertionError:
    print("AssertionError as expected: The mock was called more than once.")

# Assert that the mock was called at least once with 10
m.assert_any_call(10)

# Assert that the mock was called at least once with specific arguments
m.assert_any_call(20)
AssertionError as expected: The mock was called more than once.

Common Assertion Methods:

  • `assert_called_with(*args, kwargs)`**: Asserts the last call's arguments.
  • `assert_called_once_with(*args, kwargs)`**: Asserts the mock was called exactly once with the specified arguments.
  • `assert_any_call(*args, kwargs)`**: Asserts the mock was called with the specified arguments at least once.
  • assert_not_called(): Asserts the mock was never called.

Inspecting Mock Attributes

Mocks keep track of how they were called, providing useful attributes for inspection.

Example:

m = mock.MagicMock()
m(5)
m(7)

print(m.called)         # Outputs: True
print(m.call_count)     # Outputs: 2
print(m.call_args)      # Outputs: call(7)
print(m.call_args_list) # Outputs: [call(5), call(7)]

# Assert the call history
assert m.call_args_list == [mock.call(5), mock.call(7)]
True
2
call(7)
[call(5), call(7)]

Attributes Explained:

  • called: True if the mock has been called at least once.
  • call_count: Number of times the mock has been called.
  • call_args: Arguments of the most recent call.
  • call_args_list: List of all call arguments in order.

Practical Example with Assertions

Let's revisit the power_reset example and add more assertions.

Updated test_power_reset.py:

from unittest import mock
from power_reset import power_reset

@mock.patch('power_reset.requests')
def test_power_reset(mock_requests):
    # Arrange
    mock_requests.post.return_value = 42

    # Act
    power_reset()

    # Assert
    assert mock_requests.post.called
    assert mock_requests.post.call_count == 1
    mock_requests.post.assert_called_once_with(
        'https://127.0.0.1:8000/power_reset', json={}
    )
    # extra
    assert mock_requests.post.call_args == mock.call(
        'https://127.0.0.1:8000/power_reset', json={}
    )

Explanation:

  • We assert that requests.post was called.
  • We check the call count to ensure it was called exactly once.
  • We use assert_called_once_with to verify the arguments.
  • We inspect call_args to see the details of the call.

Common Pitfalls

  • Patching the Wrong Target: Always patch the object in the module where it is looked up, not where it is defined.
  • Not Restoring Original Objects: If you manually replace objects, ensure you restore them after the test.
  • Mocks Return Mocks: By default, calling a mock returns another mock. Set return_value or side_effect to control this behavior.

🏠 Exercise: Mocking the open Function

Objective

Write a test that mocks the built-in open function to verify the behavior of the print_file function. Your test should:

  1. Ensure that print_file('file.txt') calls open('file.txt').
  2. Verify that the print function is called (without checking its arguments).
  3. Confirm that the print function is called with whatever open().__enter__().read() returns.

Instructions

  • Use the unittest.mock library to mock the open and print functions.
  • Implement the test in such a way that it passes all the assertions.
  • Remember to handle the context manager methods __enter__ and __exit__ when mocking open.

Starting Code

Here is the code for the print_file function:

# mocking_open.py

def print_file(filename):
    with open(filename) as stream:
        content = stream.read()
    # Equivalent to:
    # context_manager = open(filename)
    # stream = context_manager.__enter__()
    # content = stream.read()
    # context_manager.__exit__()

    print(content)

Your task is to write a test for the print_file function that satisfies the objectives listed above.

Solution

%%writefile mocking_open.py
# mocking_open.py

def print_file(filename):
    with open(filename) as stream:
        content = stream.read()
    # Equivalent to:
    # context_manager = open(filename)
    # stream = context_manager.__enter__()
    # content = stream.read()
    # context_manager.__exit__()

    print(content)
Overwriting mocking_open.py
%%writefile test_mocking_open.py
# test_mocking_open.py

from unittest import mock
from mocking_open import print_file

@mock.patch('mocking_open.open')
@mock.patch('mocking_open.print')
def test_print_file(print_mock, open_mock):
    content = '42'

    # Arrange
    # Mock the context manager returned by open()
    context_manager = open_mock.return_value
    # Mock the stream returned by __enter__()
    stream = context_manager.__enter__.return_value
    # Mock the return value of stream.read()
    stream.read.return_value = content

    # Bad Arrange
    # open_mock.return_value.__enter__.return_value.read.return_value = content => stream .read.return_value = content

    # Act
    print_file('file.txt')

    # Assert
    # Ensure open was called once with 'file.txt'
    open_mock.assert_called_once_with('file.txt')
    # Ensure print was called (without checking arguments)
    assert print_mock.called
    # Ensure print was called with the content read from the file
    print_mock.assert_called_once_with(content)
    # Optionally, check that __enter__ and __exit__ were called
    context_manager.__enter__.assert_called_once_with()
    context_manager.__exit__.assert_called_once_with(None, None, None)
Overwriting test_mocking_open.py
!pytest test_mocking_open.py -v
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 1 item                                                               

test_mocking_open.py::test_print_file PASSED                             [100%]

============================== 1 passed in 0.18s ===============================
m = mock.MagicMock()

print(m.foo(42, 'bar'))
print(m.foo(42, 'bar'))
print(m.foo())
print(m.boo())

print(m.foo().bar())

a = m.foo()

print(a.bar())
<MagicMock name='mock.foo()' id='4506124160'>
<MagicMock name='mock.foo()' id='4506124160'>
<MagicMock name='mock.foo()' id='4506124160'>
<MagicMock name='mock.boo()' id='4398380624'>
<MagicMock name='mock.foo().bar()' id='4398737760'>
<MagicMock name='mock.foo().bar()' id='4398737760'>

Explanation

Import Statements

from unittest import mock
from mocking_open import print_file
  • unittest.mock: Provides a library for mocking objects in tests.
  • print_file: The function we're testing.

Mocking with @mock.patch

@mock.patch('mocking_open.open')
@mock.patch('mocking_open.print')
def test_print_file(print_mock, open_mock):
  • @mock.patch('mocking_open.open'): Mocks the open function in the mocking_open module.
  • @mock.patch('mocking_open.print'): Mocks the print function in the mocking_open module.
  • The order of decorators is important; mocks are passed into the test function in reverse order.

Arrange

content = '42'

# Mock the context manager returned by open()
context_manager = open_mock.return_value
# Mock the stream returned by __enter__()
stream = context_manager.__enter__.return_value
# Mock the return value of stream.read()
stream.read.return_value = content
  • content: The mocked content to be read from the file.
  • open_mock.return_value: Represents the context manager returned by open().
  • context_manager.__enter__.return_value: Represents the file stream returned by the context manager's __enter__ method.
  • stream.read.return_value = content: Specifies the return value when read() is called on the stream.

Act

print_file('file.txt')
  • Calls the function under test with the specified filename.

Assert

# Ensure open was called once with 'file.txt'
open_mock.assert_called_once_with('file.txt')

# Ensure print was called (without checking arguments)
assert print_mock.called

# Ensure print was called with the content read from the file
print_mock.assert_called_once_with(content)

# Optionally, check that __enter__ and __exit__ were called
context_manager.__enter__.assert_called_once_with()
context_manager.__exit__.assert_called_once_with(None, None, None)
  • open_mock.assert_called_once_with('file.txt'): Verifies that open was called exactly once with 'file.txt'.
  • assert print_mock.called: Verifies that print was called.
  • print_mock.assert_called_once_with(content): Verifies that print was called with the expected content.
  • Context Manager Assertions:
    • context_manager.__enter__.assert_called_once_with(): Ensures the __enter__ method was called.
    • context_manager.__exit__.assert_called_once_with(None, None, None): Ensures the __exit__ method was called with the correct arguments (no exceptions).

Running the Test

To run the test, use the following command:

pytest test_mocking_open.py

Make sure you have pytest installed and that both mocking_open.py and test_mocking_open.py are in the same directory.

Additional Notes

  • Handling the Context Manager:

    • When mocking open, you need to handle the fact that it returns a context manager.
    • The context manager's __enter__ method returns the file stream, which you can mock to control the return value of read().
  • Order of Mocks:

    • In the test function parameters, print_mock comes before open_mock because the mocks are applied in the reverse order of the decorators.
  • Why Mock print?:

    • Mocking print allows you to verify that it's called and to check the arguments it was called with, without actually printing to the console during the test.
  • Verifying Calls Without Arguments:

    • If you only want to verify that a function was called, you can use assert mock.called or mock.assert_called().

Advanced Mocking Techniques

In this section, we'll delve deeper into Python's unittest.mock library, exploring more advanced features and nuances of mocks. Understanding these concepts will help you write more precise and effective tests.

Returning a Mock from a Mock

When you call a MagicMock instance, it returns another MagicMock instance by default. This allows you to chain calls or access attributes on the returned mock.

Example:

from unittest import mock

m = mock.MagicMock()
print(m)      # Outputs: <MagicMock id='...'>
print(m())    # Outputs: <MagicMock name='mock()' id='...'>
print(m())    # Outputs the same as above
print(m().m())    # Outputs: MagicMock name='mock().m()' id='4458159264'
print(m().m())   # Outputs the same as above
<MagicMock id='4506071936'>
<MagicMock name='mock()' id='4506068384'>
<MagicMock name='mock()' id='4506068384'>
<MagicMock name='mock().m()' id='4512817200'>
<MagicMock name='mock().m()' id='4512817200'>
  • Explanation:
    • m: A MagicMock instance.
    • m(): Calling m returns another MagicMock instance named 'mock()'.
    • Each call to m() returns the same mock object unless configured otherwise.

Setting Return Values:

You can specify what a mock should return when called.

m = mock.MagicMock()
m.return_value = 42

print(m)      # Outputs: <MagicMock id='...'>
print(m())    # Outputs: 42
print(m())    # Outputs: 42
<MagicMock id='4506117296'>
42
42
  • Explanation:
    • m.return_value = 42: Sets the return value when m is called.
    • Subsequent calls to m() return 42.

Mock Attributes and Methods

Mocks can have attributes and methods that you can configure and inspect.

Example with Attributes:

m = mock.MagicMock()
print(m)
print(m.field)
<MagicMock id='4512848768'>
<MagicMock name='mock.field' id='4512919536'>
m = mock.MagicMock()
m.field = 42



print(m)        # Outputs: <MagicMock id='...'>
print(m.field)  # Outputs: 42
print(m.field)  # Outputs: 42
<MagicMock id='4513093712'>
42
42
m = mock.MagicMock()
# m.field = 42

m.field.return_value = 42

print(m)        # Outputs: <MagicMock id='...'>
print(m.field())  # Outputs: 42
print(m.field)  # Outputs: 42
<MagicMock id='4512935872'>
42
<MagicMock name='mock.field' id='4506073520'>
  • Explanation:
    • You can set attributes on a mock as you would on a regular object.

Mocking Methods:

If you access an attribute or method that hasn't been set, the mock will create it on the fly as another MagicMock instance.

m = mock.Mock()
# No arrangement for m.field or m.method

print(m)             # Outputs: <Mock id='...'>
print(m.field)       # Outputs: <Mock name='mock.field' id='...'>
print(m.method)      # Outputs: <Mock name='mock.method' id='...'>
print(m.method(3))   # Outputs: <Mock name='mock.method()' id='...'>
print(m())           # Outputs: <Mock name='mock()' id='...'>
print(m.mock_calls)  # Outputs the list of calls made on the mock
<Mock id='4399320048'>
<Mock name='mock.field' id='4513092560'>
<Mock name='mock.method' id='4512935680'>
<Mock name='mock.method()' id='4397134528'>
<Mock name='mock()' id='4398737856'>
[call.method(3), call()]

MagicMock vs. Mock

The MagicMock class is a subclass of Mock that includes default implementations of the magic methods (dunder methods) in Python, such as __len__, __getitem__, etc.

Differences in Behavior:

mm = mock.MagicMock()
m = mock.Mock()

# Using magic methods on MagicMock
print(mm['a'])

print(m['a'])
<MagicMock name='mock.__getitem__()' id='4513010688'>
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[107], line 7
      4 # Using magic methods on MagicMock
      5 print(mm['a'])
----> 7 print(m['a'])

TypeError: 'Mock' object is not subscriptable
mm = mock.MagicMock()
m = mock.Mock()

# Using magic methods on MagicMock
print(len(mm))

print(len(m))
0
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[109], line 7
      4 # Using magic methods on MagicMock
      5 print(len(mm))
----> 7 print(len(m))

TypeError: object of type 'Mock' has no len()
mm = mock.MagicMock()
m = mock.Mock()

# Using magic methods on MagicMock
print( 42 in mm)

print(42 in m)
False
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[111], line 7
      4 # Using magic methods on MagicMock
      5 print( 42 in mm)
----> 7 print(42 in m)

TypeError: argument of type 'Mock' is not iterable
  • Explanation:
    • MagicMock supports magic methods by default, making it suitable for mocking objects that use them.
    • Mock does not support magic methods unless you specify them.

Using spec and autospec

The spec parameter in mocks restricts the mock to only have attributes and methods that exist on the specified object.

Example without Specifying spec:

dumb_os_mock = mock.MagicMock()
dumb_os_mock.remoev()  # Typo in method name, but no error
dumb_os_mock.remove.asser_called_once_with('another file')  # Typo in method name, but no error
<MagicMock name='mock.remove.asser_called_once_with()' id='4506093424'>
  • Issue:
    • Since we didn't specify a spec, the mock allows any attribute or method, leading to potential silent failures.

Using spec to Enforce the Interface:

import os

os_mock = mock.MagicMock(spec=os)
os_mock.remove('file')  # Correct usage

# Accessing a non-existent attribute raises AttributeError
# os_mock.remoev  # Raises AttributeError

# Typo in method name during assertion is still not caught
os_mock.remove.asser_called_once_with('another_file')  # Typo in method name
<MagicMock name='mock.remove.asser_called_once_with()' id='4505391360'>
  • Explanation:
    • spec=os restricts the mock to have only attributes and methods that exist in the os module.
    • Accessing os_mock.remoev (with a typo) raises an AttributeError.

Limitation:

  • Using spec does not prevent typos in methods of the mock's methods (like asser_called_once_with instead of assert_called_once_with).

spec and Instance Attributes

Using spec or autospec with classes can sometimes lead to unexpected behavior, especially with instance attributes.

Example:

class Something:
    foo: int
    def __init__(self):
        self.foo = 42

something_mock = mock.Mock(spec=Something)
something_mock.foo  # Raises AttributeError
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[122], line 7
      4         self.foo = 42
      6 something_mock = mock.Mock(spec=Something)
----> 7 something_mock.foo  # Raises AttributeError

File ~/.pyenv/versions/3.12.6/lib/python3.12/unittest/mock.py:658, in NonCallableMock.__getattr__(self, name)
    656 elif self._mock_methods is not None:
    657     if name not in self._mock_methods or name in _all_magics:
--> 658         raise AttributeError("Mock object has no attribute %r" % name)
    659 elif _is_magic(name):
    660     raise AttributeError(name)

AttributeError: Mock object has no attribute 'foo'
  • Issue:
    • The mock does not have the foo attribute because spec only sets up methods and attributes at the class level, not instance attributes set in __init__.

Solution:

  • Manually set the attribute on the mock:
# test_my_remove_module_mock.py
from unittest import mock
from my_remove_module import my_remove
import my_remove_module


@mock.patch('my_remove_module.os', spec=my_remove_module.os)
def test_provided_extension_should_be_used(mock_os):
    filename = 'file.md'
    my_remove(filename)
    mock_os.remove.assert_called_once_with(filename)

@mock.patch('my_remove_module.os', autospec=True)
def test_when_extension_is_missing_then_use_default_one(mock_os):
    my_remove('file')
    # my_remove('file')
    mock_os.remove.assert_called_once_with('file.txt')
something_mock.foo = 42
print(something_mock.foo)
42

Using mock.patch as a Context Manager

The mock.patch function can be used as a context manager to temporarily replace an object during a specific block of code.

Example:

# import builtins

with mock.patch('builtins.sum') as sum_mock:
    print(sum)       # Outputs: <MagicMock name='sum' id='...'>
    print(sum_mock)  # Outputs: <MagicMock name='sum' id='...'>

print(sum)  # Outputs: <built-in function sum>
<MagicMock name='sum' id='4399299376'>
<MagicMock name='sum' id='4399299376'>
<built-in function sum>

Using mock.ANY

When you need to assert a call was made with certain arguments but want to ignore some arguments, you can use mock.ANY.

mock.ANY
# is equivalent to:
class ANYClass:
    def __eq__(self, other):
        return True

ANY = ANYClass()
with mock.patch('__main__.open') as open_mock:
    with open('qwer', 'r') as s:
        s.read()
open_mock.assert_called_once_with(mock.ANY, 'r')

Using wraps in Mocks

The wraps parameter allows a mock to pass calls through to the original object while still recording how it was used.

Example:

import requests
from unittest import mock

requests_mock = mock.MagicMock(wraps=requests)
response = requests_mock.get('https://api.github.com')

print(response.status_code)       # Outputs: 200 (assuming the request succeeds)
print(requests_mock.mock_calls)   # Records the calls made on the mock
print(requests_mock.get.call_args)
200
[call.get('https://api.github.com')]
call('https://api.github.com')
  • Explanation:
    • wraps=requests means the mock will delegate method calls to the real requests module.
    • The mock still records the calls, allowing you to assert how it was used.

Use Cases for wraps:

  • When you want to monitor interactions with an object without completely mocking out its behavior.
  • Useful for spying on real objects.

Further Reading

Exercise: Implementing and Testing an NBP API Wrapper with Mocking

Objective

In this exercise, you will implement the get_exchange_rate function, which retrieves currency exchange rates from the National Bank of Poland (NBP) API. You will also write unit tests for this function using pytest and unittest.mock to ensure it behaves correctly under various scenarios without making actual HTTP requests.

Description

You are provided with a partial implementation of the get_exchange_rate function in the nbp_api.py module. This function is designed to fetch the exchange rate for a specified currency and date from the NBP API. Your tasks are:

  1. Implement the get_exchange_rate Function:

    • Complete the function to make an HTTP GET request to the NBP API.
    • Parse the JSON response to extract the exchange rate for the specified currency.
    • Handle cases where no data is available for the given date by raising a NoData exception.
    • Handle invalid currency codes by raising a ValueError.
  2. Write Unit Tests for get_exchange_rate:

    • Use pytest and unittest.mock to mock external dependencies (requests.get).
    • Test the function's behavior under different scenarios:
      • Successful retrieval of exchange rates.
      • No data available for the given date.
      • Invalid currency codes.

Initial Code

nbp_api.py

# nbp_api.py

from datetime import date
import requests

class NoData(Exception):
    """Exception raised when no data is available for the given date."""
    pass

API_ENDPOINT = 'http://api.nbp.pl/api/exchangerates/tables/a/{date}/?format=json'
DATE_FORMAT = '%Y-%m-%d'

def get_exchange_rate(currency: str, date: date) -> float:
    """Fetches the exchange rate for a given currency and date from the NBP API.

    Args:
        currency (str): The currency code (e.g., 'USD', 'EUR').
        date (date): The date for which to retrieve the exchange rate.

    Returns:
        float: The exchange rate.

    Raises:
        NoData: If no data is available for the given date.
        ValueError: If the provided currency code is invalid.
    """
    currency = currency.upper()
    date_as_str = date.strftime(DATE_FORMAT)
    url = API_ENDPOINT.format(date=date_as_str)

    # TODO: Make an HTTP GET request to the API_ENDPOINT.

    # TODO: Check if the response status code indicates success.

    # TODO: Parse the JSON response.


    # TODO: Extract the exchange rate for the specified currency.

    # TODO: If the currency code is not found, raise ValueError.

``

 `test_nbp_api.py`

```python
# test_nbp_api.py

import pytest
from unittest import mock
from datetime import date
from nbp_api import get_exchange_rate, NoData

# You will write your tests here.

Advice for Takers

  • Understand the Functionality: Before writing tests, ensure you understand what the get_exchange_rate function is supposed to do, including its inputs, outputs, and how it handles errors.

  • Mock External Calls: Use unittest.mock to mock the requests.get method. This prevents actual HTTP requests during testing and allows you to simulate different responses.

  • Use autospec: When mocking requests.get, use the autospec=True parameter. This ensures that the mock object matches the signature of the real requests.get method, helping catch errors like incorrect arguments.

  • Test Different Scenarios:

    • Successful Response: Simulate a successful API response with valid exchange rate data.
    • No Data Available: Simulate a response with a non-200 status code to trigger the NoData exception.
    • Invalid Currency Code: Simulate a successful response that does not include the requested currency, triggering a ValueError.
    • Network Exceptions: Simulate network-related exceptions (e.g., connection errors) to ensure your function handles them gracefully.
  • Use mock.ANY: When asserting calls where some arguments are irrelevant or variable, use mock.ANY to ignore those specific arguments.

  • Handle Exceptions in Tests: Use pytest.raises to verify that your function raises the appropriate exceptions under error conditions.

  • Keep Tests Isolated: Ensure that each test is independent and does not rely on the state or side effects from other tests.

  • Run Tests Frequently: As you implement the function and write tests, run them frequently to catch and fix issues early.

Expected Behaviour

Successful Retrieval

get_exchange_rate('USD', date(2021, 1, 13))
3.7142

No Data Available

get_exchange_rate('USD', date(2021, 1, 10))
Traceback (most recent call last):

  File "<ipython-input-164-c1f5ecb2786e>", line 1, in <module>
    get_exchange_rate('USD', date(2021, 1, 10))

  File "<ipython-input-162-71c2eb3d0b2f>", line 24, in get_exchange_rate
    raise NoData

NoData

Invalid Currency Code

get_exchange_rate('XXX', date(2021, 1, 13))
Traceback (most recent call last):

  File "<ipython-input-165-0d9eb2e858cd>", line 1, in <module>
    get_exchange_rate('XXX', date(2021, 1, 13))

  File "<ipython-input-162-71c2eb3d0b2f>", line 37, in get_exchange_rate
    raise ValueError("Invalid currency")

ValueError: Invalid currency

Solution

# %%writefile nbp_api.py

from datetime import date
import os
import sys

import requests


class NoData(Exception):
    pass


API_ENDPOINT = 'http://api.nbp.pl/api/exchangerates/tables/a/{date}/?format=json'
DATE_FORMAT = '%Y-%m-%d'


def get_exchange_rate(currency: str, date: date) -> float:
    """ May raise NoData. """

    currency = currency.upper()
    date_as_str = date.strftime(DATE_FORMAT)
    url = API_ENDPOINT.format(date=date_as_str)
    response = requests.get(url)
    if response.status_code != 200:
        raise NoData
    json = response.json()

#     rates = (r['mid'] for r in json[0]['rates'] if r['code'] == currency)
#     try:
#         rate = next(rates)
#     except StopIteration:
#         raise ValueError("Invalid currency")
#     else:
#         return rate

    for rate in json[0]['rates']:
        if rate['code'] == currency:
            return rate['mid']
    raise ValueError("Invalid currency")
get_exchange_rate('USD', date(2021, 1, 13))
3.7142
%%writefile nbp_api_tests.py
import datetime
import json
from unittest import mock

import pytest

from nbp_api import get_exchange_rate, NoData

RAW_JSON: str = """[{"table":"A","no":"217/A/NBP/2019","effectiveDate":"2019-11-08","rates":[{"currency":"bat (Tajlandia)","code":"THB","mid":0.1272},{"currency":"dolar amerykański","code":"USD","mid":3.8625},{"currency":"dolar australijski","code":"AUD","mid":2.6533},{"currency":"dolar Hongkongu","code":"HKD","mid":0.4935},{"currency":"dolar kanadyjski","code":"CAD","mid":2.9263},{"currency":"dolar nowozelandzki","code":"NZD","mid":2.4520},{"currency":"dolar singapurski","code":"SGD","mid":2.8406},{"currency":"euro","code":"EUR","mid":4.2638},{"currency":"forint (Węgry)","code":"HUF","mid":0.01278},{"currency":"frank szwajcarski","code":"CHF","mid":3.8797},{"currency":"funt szterling","code":"GBP","mid":4.9476},{"currency":"hrywna (Ukraina)","code":"UAH","mid":0.1577},{"currency":"jen (Japonia)","code":"JPY","mid":0.035324},{"currency":"korona czeska","code":"CZK","mid":0.1669},{"currency":"korona duńska","code":"DKK","mid":0.5706},{"currency":"korona islandzka","code":"ISK","mid":0.030964},{"currency":"korona norweska","code":"NOK","mid":0.4221},{"currency":"korona szwedzka","code":"SEK","mid":0.3991},{"currency":"kuna (Chorwacja)","code":"HRK","mid":0.5739},{"currency":"lej rumuński","code":"RON","mid":0.8956},{"currency":"lew (Bułgaria)","code":"BGN","mid":2.1800},{"currency":"lira turecka","code":"TRY","mid":0.6713},{"currency":"nowy izraelski szekel","code":"ILS","mid":1.1053},{"currency":"peso chilijskie","code":"CLP","mid":0.005206},{"currency":"peso filipińskie","code":"PHP","mid":0.0764},{"currency":"peso meksykańskie","code":"MXN","mid":0.2012},{"currency":"rand (Republika Południowej Afryki)","code":"ZAR","mid":0.2603},{"currency":"real (Brazylia)","code":"BRL","mid":0.9419},{"currency":"ringgit (Malezja)","code":"MYR","mid":0.9344},{"currency":"rubel rosyjski","code":"RUB","mid":0.0605},{"currency":"rupia indonezyjska","code":"IDR","mid":0.00027562},{"currency":"rupia indyjska","code":"INR","mid":0.054185},{"currency":"won południowokoreański","code":"KRW","mid":0.003337},{"currency":"yuan renminbi (Chiny)","code":"CNY","mid":0.5523},{"currency":"SDR (MFW)","code":"XDR","mid":5.2969}]}]"""
JSON: list = json.loads(RAW_JSON)


@mock.patch('nbp_api.requests.get')
def test_should_return_exchange_rate_for_a_working_date(requests_get_mock):
    response_mock = requests_get_mock.return_value
    response_mock.status_code = 200
    response_mock.json.return_value = JSON

    date = datetime.date(2019, 11, 8)
    rate = get_exchange_rate('USD', date)

    assert rate == 3.8625
    requests_get_mock.assert_called_once_with(
        'http://api.nbp.pl/api/exchangerates/tables/a/2019-11-08/?format=json')

@mock.patch('nbp_api.requests')
def test_should_raise_NoData_for_a_nonworking_date(requests_mock):
    response_mock = requests_mock.get.return_value
    response_mock.status_code = 404

    date = datetime.date(2019, 11, 9)
    with pytest.raises(NoData):
        get_exchange_rate('USD', date)

    requests_mock.get.assert_called_once_with(
        'http://api.nbp.pl/api/exchangerates/tables/a/2019-11-09/?format=json')
    response_mock.json.assert_not_called()

@mock.patch('nbp_api.requests')
def test_should_raise_ValueError_when_currency_not_found(requests_mock):
    response_mock = requests_mock.get.return_value
    response_mock.status_code = 200
    response_mock.json.return_value = JSON

    date = datetime.date(2019, 11, 8)
    with pytest.raises(ValueError):
        get_exchange_rate('XXX', date)
Overwriting nbp_api_tests.py
!pytest nbp_api_tests.py -v
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 3 items                                                              

nbp_api_tests.py::test_should_return_exchange_rate_for_a_working_date PASSED [ 33%]
nbp_api_tests.py::test_should_raise_NoData_for_a_nonworking_date PASSED  [ 66%]
nbp_api_tests.py::test_should_raise_ValueError_when_currency_not_found PASSED [100%]

============================== 3 passed in 0.40s ===============================
%%writefile nbp_api_tests.py
import datetime
import json
from unittest import mock

import pytest

from nbp_api import get_exchange_rate, NoData

RAW_JSON: str = """[{"table":"A","no":"217/A/NBP/2019","effectiveDate":"2019-11-08","rates":[{"currency":"bat (Tajlandia)","code":"THB","mid":0.1272},{"currency":"dolar amerykański","code":"USD","mid":3.8625},{"currency":"dolar australijski","code":"AUD","mid":2.6533},{"currency":"dolar Hongkongu","code":"HKD","mid":0.4935},{"currency":"dolar kanadyjski","code":"CAD","mid":2.9263},{"currency":"dolar nowozelandzki","code":"NZD","mid":2.4520},{"currency":"dolar singapurski","code":"SGD","mid":2.8406},{"currency":"euro","code":"EUR","mid":4.2638},{"currency":"forint (Węgry)","code":"HUF","mid":0.01278},{"currency":"frank szwajcarski","code":"CHF","mid":3.8797},{"currency":"funt szterling","code":"GBP","mid":4.9476},{"currency":"hrywna (Ukraina)","code":"UAH","mid":0.1577},{"currency":"jen (Japonia)","code":"JPY","mid":0.035324},{"currency":"korona czeska","code":"CZK","mid":0.1669},{"currency":"korona duńska","code":"DKK","mid":0.5706},{"currency":"korona islandzka","code":"ISK","mid":0.030964},{"currency":"korona norweska","code":"NOK","mid":0.4221},{"currency":"korona szwedzka","code":"SEK","mid":0.3991},{"currency":"kuna (Chorwacja)","code":"HRK","mid":0.5739},{"currency":"lej rumuński","code":"RON","mid":0.8956},{"currency":"lew (Bułgaria)","code":"BGN","mid":2.1800},{"currency":"lira turecka","code":"TRY","mid":0.6713},{"currency":"nowy izraelski szekel","code":"ILS","mid":1.1053},{"currency":"peso chilijskie","code":"CLP","mid":0.005206},{"currency":"peso filipińskie","code":"PHP","mid":0.0764},{"currency":"peso meksykańskie","code":"MXN","mid":0.2012},{"currency":"rand (Republika Południowej Afryki)","code":"ZAR","mid":0.2603},{"currency":"real (Brazylia)","code":"BRL","mid":0.9419},{"currency":"ringgit (Malezja)","code":"MYR","mid":0.9344},{"currency":"rubel rosyjski","code":"RUB","mid":0.0605},{"currency":"rupia indonezyjska","code":"IDR","mid":0.00027562},{"currency":"rupia indyjska","code":"INR","mid":0.054185},{"currency":"won południowokoreański","code":"KRW","mid":0.003337},{"currency":"yuan renminbi (Chiny)","code":"CNY","mid":0.5523},{"currency":"SDR (MFW)","code":"XDR","mid":5.2969}]}]"""
JSON: list = json.loads(RAW_JSON)

@pytest.fixture
def preconfigured_requests_mock():
    with mock.patch('nbp_api.requests') as requests_mock:
        response_mock = requests_mock.get.return_value
        response_mock.status_code = 200
        response_mock.json.return_value = JSON
        yield requests_mock

def test_should_return_exchange_rate_for_a_working_date(preconfigured_requests_mock):
    date = datetime.date(2019, 11, 8)
    rate = get_exchange_rate('USD', date)

    assert rate == 3.8625
    preconfigured_requests_mock.get.assert_called_once_with(
        'http://api.nbp.pl/api/exchangerates/tables/a/2019-11-08/?format=json')

def test_should_raise_NoData_for_a_nonworking_date(preconfigured_requests_mock):
    response_mock = preconfigured_requests_mock.get.return_value
    response_mock.status_code = 404

    date = datetime.date(2019, 11, 9)
    with pytest.raises(NoData):
        get_exchange_rate('USD', date)

    preconfigured_requests_mock.get.assert_called_once_with(
        'http://api.nbp.pl/api/exchangerates/tables/a/2019-11-09/?format=json')
    preconfigured_requests_mock.json.assert_not_called()

@mock.patch('nbp_api.os')
@mock.patch('nbp_api.sys')
def test_should_raise_ValueError_when_currency_not_found(sys_mock, os_mock, preconfigured_requests_mock):
    date = datetime.date(2019, 11, 8)
    with pytest.raises(ValueError):
        get_exchange_rate('XXX', date)
Overwriting nbp_api_tests.py
!pytest nbp_api_tests.py -v
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0 -- /Users/a563420/python_training/testing/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 3 items                                                              

nbp_api_tests.py::test_should_return_exchange_rate_for_a_working_date PASSED [ 33%]
nbp_api_tests.py::test_should_raise_NoData_for_a_nonworking_date PASSED  [ 66%]
nbp_api_tests.py::test_should_raise_ValueError_when_currency_not_found PASSED [100%]

============================== 3 passed in 0.18s ===============================

Behaviour-Driven Development (BDD)

Behaviour-Driven Development (BDD) is an Agile software development process that encourages collaboration among developers, quality assurance, and non-technical or business participants in a software project. It extends Test-Driven Development (TDD) by writing test cases in natural language that non-programmers can read.

Key Concepts

  • User Stories: Descriptions of features from the perspective of the end-user.
  • Scenarios: Specific situations that demonstrate how a feature should behave.
  • Acceptance Tests: Tests written in a way that describes the desired behavior of the system from the user's perspective.

BDD Workflow

The BDD process involves several iterative steps:

BDD Workflow

Planning for BDD

Before diving into coding, it's essential to plan out the features and how they will be tested. This includes:

  • Identifying user stories and scenarios.
  • Planning the structure of your application.
  • Deciding on the tools and frameworks you'll use.

BDD Demo: Calculator Application

requirements.txt

%%writefile calculator/requirements.txt
Flask
WebTest
behave
pytest
Overwriting calculator/requirements.txt

backend.py

%%writefile calculator/backend.py
def add(a, b):
    return a + b
Overwriting calculator/backend.py

frontend.py

%%writefile calculator/frontend.py
from flask import request, Flask

import backend


app = Flask(__name__)

HOME_PAGE = """
<h1>Home Page</h1>
<form method="POST">
<input name="first" /> +
<input name="second" /> =
<input type="submit" value="?" />
</form>
"""
FORM_SENT_TEMPLATE = "{first} + {second} = {sum}"

@app.route('/', methods=['GET', 'POST'])
def home():
    if request.method == 'GET':
        return HOME_PAGE
    else:
        first = float(request.form['first'])
        second = float(request.form['second'])
        sum_ = backend.add(first, second)
        return FORM_SENT_TEMPLATE.format(
            first=first, second=second, sum=sum_)
Overwriting calculator/frontend.py

test_backend.py

%%writefile calculator/test_backend.py
from backend import add

def test_addition():
    assert add(50.0, 70.0) == 120.0
Overwriting calculator/test_backend.py
!pytest calculator/test_backend.py
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 1 item                                                               

calculator/test_backend.py .                                             [100%]

============================== 1 passed in 0.16s ===============================

test_frontend.py

%%writefile calculator/test_frontend.py
from unittest import mock

import frontend

@mock.patch('frontend.request', spec=object())
def test_should_display_home_page(request_mock):
    request_mock.method = 'GET'
    html = frontend.home()
    assert 'Home Page' in html

@mock.patch('frontend.request', spec=object())
def test_should_display_addition_form(request_mock):
    request_mock.method = 'GET'
    html = frontend.home()
    assert '<form' in html
    assert 'name="first"' in html
    assert 'name="second"' in html
    assert 'type="submit"' in html

@mock.patch('frontend.request', spec=object())
@mock.patch('frontend.backend')
def test_should_delegate_request_parameters_to_backend(
        backend_mock, request_mock):
    # arrange (given)
    request_mock.method = 'POST'
    request_mock.form = {
        'first': '50',
        'second': '70',
    }
    backend_mock.add.return_value = 120

    # action (when)
    html = frontend.home()

    # assert (then)
    backend_mock.add.assert_called_once_with(50.0, 70.0)
    assert '120' in html
Overwriting calculator/test_frontend.py
!pytest calculator/test_frontend.py
============================= test session starts ==============================
platform darwin -- Python 3.12.6, pytest-8.2.2, pluggy-1.5.0
rootdir: /Users/a563420/python_training/testing
plugins: cov-6.0.0, anyio-4.4.0
collected 3 items                                                              

calculator/test_frontend.py ...                                          [100%]

============================== 3 passed in 0.64s ===============================

features/addition.feature

%%writefile calculator/features/addition.feature
Feature: add numbers
    In order to avoid silly mistakes
    I want to be told the sum of two numbers

    Scenario: Navigation to Home Page
        When a user navigates to home page
        Then Home Page should be displayed

    Scenario: Add two numbers
        Given a user navigates to home page
        And I have entered 50 as first number
        And I have entered 70 as second number
        When I press add
        Then 120 should be displayed

features/steps/addition_steps.py

%%writefile calculator/features/steps/addition_steps.py
from behave import *

@step('a user navigates to home page')
def step_impl(context):
    context.resp = context.client.get('/')

@then('{text} should be displayed')
def step_impl(context, text):
    assert text in context.resp

@given('I have entered {value} as {field} number')
def step_impl(context, value, field):
    context.resp.form[field] = value

@when('I press add')
def step_impl(context):
    context.resp = context.resp.form.submit()

features/environment.py

%%writefile calculator/features/environment.py
from webtest import TestApp

from frontend import app


def before_scenario(context, scenario):
    context.client = TestApp(app)


def after_scenario(context, scenario):
    del context.client
Overwriting calculator/features/environment.py
! cd calculator && behave --no-source
Feature: add numbers
  In order to avoid silly mistakes
  I want to be told the sum of two numbers
  Scenario: Navigation to Home Page 
    When a user navigates to home page
    When a user navigates to home page # 0.005s
    Then Home Page should be displayed
    Then Home Page should be displayed # 0.000s

  Scenario: Add two numbers 
    Given a user navigates to home page
    Given a user navigates to home page    # 0.001s
    And I have entered 50 as first number
    And I have entered 50 as first number  # 0.002s
    And I have entered 70 as second number
    And I have entered 70 as second number # 0.000s
    When I press add
    When I press add                       # 0.003s
    Then 120 should be displayed
    Then 120 should be displayed           # 0.000s

1 feature passed, 0 failed, 0 skipped
2 scenarios passed, 0 failed, 0 skipped
7 steps passed, 0 failed, 0 skipped, 0 undefined
Took 0m0.012s

Exercise: 🏠 NBP

If you're using Python 3.8+, specify the spec parameter when mocking flask.request as a workaround for a bug:

@mock.patch(..., spec=object())

NBP API Description

Everyone! Experiment with the calculator. Run unit and acceptance tests.

Everyone! Create the directory and file structure for NBP from scratch, using the same structure as the calculator.

BDD for NBP: Change the person after the first acceptance test.


Explanation

  • Mocking flask.request in Python 3.8+:

    • Due to a known bug in Python 3.8 and above, when mocking flask.request, it's necessary to specify the spec parameter to ensure the mock object behaves correctly. This can be done using the @mock.patch decorator with spec=object().
  • NBP API Description:

    • Hands-On Practice: Participants are encouraged to engage with the calculator application by running both unit tests and acceptance tests. This hands-on approach reinforces understanding of testing methodologies.

    • Project Setup: Participants should create the necessary directory and file structure for the NBP (National Bank of Poland) API wrapper from scratch. The structure should mirror that of the calculator application to maintain consistency and leverage established best practices.

    • Behaviour-Driven Development (BDD) for NBP:

      • Team Rotation: After completing the first acceptance test, participants should rotate roles. This ensures that different team members gain experience in various aspects of BDD, fostering a more collaborative and well-rounded development process.

backend.py

%%writefile nbp/backend.py
import datetime

import requests


class NoData(Exception):
    pass


API_ENDPOINT = 'http://api.nbp.pl/api/exchangerates/tables/a/{}/?format=json'
DATE_FORMAT = '%Y-%m-%d'

def get_exchange_rate(currency, date):
    currency = currency.upper()
    date_as_str = date.strftime(DATE_FORMAT)
    url = API_ENDPOINT.format(date_as_str)
    response = requests.get(url)
    if response.status_code == 404:
        raise NoData
    json = response.json()
    rates = [rate['mid'] for rate in json[0]['rates']
             if rate['code'] == currency]
    try:
        rate = rates[0]
    except IndexError:
        raise ValueError("Invalid currency")
    else:
        return rate

frontend.py

%%writefile nbp/frontend.py
from datetime import datetime

from flask import Flask, request

import backend

app = Flask(__name__)

HOME_PAGE_TEMPLATE = """
<h1>Home Page</h1>
<p>{msg}</p>
<form method="POST">
  <p>Date: <input type="text" name="date" /></p>
  <p>Currency
    <select name="currency">
      <option value="USD">dolar amerykanski USD</option>
      <option value="THB">bat (Tajlandia) THB</option>
      <option value="ISK">korona islandzka ISK</option>
    </select>
  </p>
  <p><input type="submit" value="Get exchange rate!"></p>
</form>
"""
EXCHANGE_RATE_TEMPLATE = \
    '1.00 {currency} = {rate} PLN'
NO_DATA_MSG = 'No data for this day.'
DATE_FORMAT = '%Y/%m/%d'

@app.route('/', methods=['GET', 'POST'])
def home():
    # todo

if __name__ == "__main__":
    app.run()

test_backend.py

%%writefile nbp/test_backend.py
import datetime
import json
from unittest import mock

import pytest

from backend import get_exchange_rate, NoData

JSON = """[{"table":"A","no":"217/A/NBP/2019","effectiveDate":"2019-11-08","rates":[{"currency":"bat (Tajlandia)","code":"THB","mid":0.1272},{"currency":"dolar amerykański","code":"USD","mid":3.8625},{"currency":"dolar australijski","code":"AUD","mid":2.6533},{"currency":"dolar Hongkongu","code":"HKD","mid":0.4935},{"currency":"dolar kanadyjski","code":"CAD","mid":2.9263},{"currency":"dolar nowozelandzki","code":"NZD","mid":2.4520},{"currency":"dolar singapurski","code":"SGD","mid":2.8406},{"currency":"euro","code":"EUR","mid":4.2638},{"currency":"forint (Węgry)","code":"HUF","mid":0.01278},{"currency":"frank szwajcarski","code":"CHF","mid":3.8797},{"currency":"funt szterling","code":"GBP","mid":4.9476},{"currency":"hrywna (Ukraina)","code":"UAH","mid":0.1577},{"currency":"jen (Japonia)","code":"JPY","mid":0.035324},{"currency":"korona czeska","code":"CZK","mid":0.1669},{"currency":"korona duńska","code":"DKK","mid":0.5706},{"currency":"korona islandzka","code":"ISK","mid":0.030964},{"currency":"korona norweska","code":"NOK","mid":0.4221},{"currency":"korona szwedzka","code":"SEK","mid":0.3991},{"currency":"kuna (Chorwacja)","code":"HRK","mid":0.5739},{"currency":"lej rumuński","code":"RON","mid":0.8956},{"currency":"lew (Bułgaria)","code":"BGN","mid":2.1800},{"currency":"lira turecka","code":"TRY","mid":0.6713},{"currency":"nowy izraelski szekel","code":"ILS","mid":1.1053},{"currency":"peso chilijskie","code":"CLP","mid":0.005206},{"currency":"peso filipińskie","code":"PHP","mid":0.0764},{"currency":"peso meksykańskie","code":"MXN","mid":0.2012},{"currency":"rand (Republika Południowej Afryki)","code":"ZAR","mid":0.2603},{"currency":"real (Brazylia)","code":"BRL","mid":0.9419},{"currency":"ringgit (Malezja)","code":"MYR","mid":0.9344},{"currency":"rubel rosyjski","code":"RUB","mid":0.0605},{"currency":"rupia indonezyjska","code":"IDR","mid":0.00027562},{"currency":"rupia indyjska","code":"INR","mid":0.054185},{"currency":"won południowokoreański","code":"KRW","mid":0.003337},{"currency":"yuan renminbi (Chiny)","code":"CNY","mid":0.5523},{"currency":"SDR (MFW)","code":"XDR","mid":5.2969}]}]"""


def preconfigure_requests_mock(requests_mock):
    response = requests_mock.get.return_value
    response.status_code = 200
    response.json.return_value = json.loads(JSON)

@mock.patch('backend.requests')
def test_should_return_exchange_rate_for_a_working_date(requests_mock):
    preconfigure_requests_mock(requests_mock)
    date = datetime.date(2019, 11, 8)
    rate = get_exchange_rate('USD', date)
    assert rate == 3.8625
    requests_mock.get.assert_called_once_with(
        'http://api.nbp.pl/api/exchangerates/tables/a/2019-11-08/?format=json')

@mock.patch('backend.requests')
def test_should_raise_NoData_for_a_nonworking_date(requests_mock):
    response = requests_mock.get.return_value
    response.status_code = 404
    date = datetime.date(2019, 11, 9)
    with pytest.raises(NoData):
        get_exchange_rate('USD', date)
    requests_mock.get.assert_called_once_with(
        'http://api.nbp.pl/api/exchangerates/tables/a/2019-11-09/?format=json')

@mock.patch('backend.requests')
def test_should_raise_ValueError_when_currency_not_found(requests_mock):
    preconfigure_requests_mock(requests_mock)
    date = datetime.date(2019, 11, 8)
    with pytest.raises(ValueError):
        get_exchange_rate('XYZ', date)
    requests_mock.get.assert_called_once_with(
        'http://api.nbp.pl/api/exchangerates/tables/a/2019-11-08/?format=json')

test_frontend.py

%%writefile nbp/test_frontend.py
import datetime
from unittest import mock

import frontend
import backend

@mock.patch # todo
class TestHomePage:
    def test_home_page_should_have_exchange_rate_form(self, request_mock):
        # todo

    @mock.patch # todo
    def test_home_page_should_work_for_working_days(self, backend_mock, request_mock):
        # todo

    @mock.patch # todo
    def test_home_page_should_work_for_non_working_days(self, get_exchange_rate_mock, request_mock):
        # todo

features/exchange_rate.feature

%%writefile nbp/features/exchange_rate.feature
Feature: Exchange rate form
    In order to become rich
    As a future currency speculator
    I'd like to know currency exchange rates

    Scenario: Navigation to Home Page
        When I navigate to Home Page
        Then Home Page should be displayed

    Scenario Outline: Exchange rate form works
    # todo


Examples:
    | date       | currency | expected output      |
    | 2017/02/03 | USD      | 1.00 USD = 4.0014 PLN   |
    | 2017/02/03 | ISK      | 1.00 ISK = 0.035263 PLN |
    | 2017/02/05 | USD      | No data for this day |

features/environment.py

%%writefile nbp/features/environment.py
from webtest import TestApp
from frontend import app

def before_scenario(context, scenario):
    context.client = TestApp(app)

features/steps/exchange_rate_steps.py

%%writefile nbp/features/steps/exchange_rate_steps.py

# todo

SOLUTION

frontend.py

%%writefile nbp/frontend.py
from datetime import datetime

from flask import Flask, request

import backend

app = Flask(__name__)

HOME_PAGE_TEMPLATE = """
<h1>Home Page</h1>
<p>{msg}</p>
<form method="POST">
  <p>Date: <input type="text" name="date" /></p>
  <p>Currency
    <select name="currency">
      <option value="USD">dolar amerykanski USD</option>
      <option value="THB">bat (Tajlandia) THB</option>
      <option value="ISK">korona islandzka ISK</option>
    </select>
  </p>
  <p><input type="submit" value="Get exchange rate!"></p>
</form>
"""
EXCHANGE_RATE_TEMPLATE = \
    '1.00 {currency} = {rate} PLN'
NO_DATA_MSG = 'No data for this day.'
DATE_FORMAT = '%Y/%m/%d'

@app.route('/', methods=['GET', 'POST'])
def home():
    if request.method == 'POST':
        date_as_str = request.form['date']
        date = datetime.strptime(date_as_str, DATE_FORMAT).date()
        currency = request.form['currency']

        try:
            rate = backend.get_exchange_rate(
                date=date, currency=currency)
        except backend.NoData:
            msg = NO_DATA_MSG
        else:
            msg = EXCHANGE_RATE_TEMPLATE.format(
                currency=currency,
                rate=rate)
    else:
        msg = ''
    return HOME_PAGE_TEMPLATE.format(msg=msg)

if __name__ == "__main__":
    app.run()

test_frontend.py

%%writefile nbp/test_frontend.py
import datetime
from unittest import mock

import frontend
import backend

@mock.patch('frontend.request', spec=object())
class TestHomePage:
    def test_home_page_should_have_exchange_rate_form(self, request_mock):
        request_mock.method = 'GET'
        html = frontend.home()
        assert '<form' in html
        assert 'name="date"' in html

    @mock.patch('frontend.backend')
    def test_home_page_should_work_for_working_days(self, backend_mock, request_mock):
        request_mock.method = 'POST'
        request_mock.form = dict(date='2019/11/08', currency='USD')
        backend_mock.get_exchange_rate.return_value = 4.321
        html = frontend.home()
        assert '1.00 USD = 4.321 PLN' in html
        backend_mock.get_exchange_rate.assert_called_once_with(
            date=datetime.date(2019, 11, 8),
            currency='USD')

    @mock.patch('frontend.backend.get_exchange_rate')
    def test_home_page_should_work_for_non_working_days(self, get_exchange_rate_mock, request_mock):
        request_mock.method = 'POST'
        request_mock.form = dict(date='2019/11/09', currency='USD')
        get_exchange_rate_mock.side_effect = backend.NoData
        html = frontend.home()
        assert 'No data for this day.' in html
        get_exchange_rate_mock.assert_called_once_with(
            date=datetime.date(2019, 11, 9),
            currency='USD')

features/exchange_rate.feature

%%writefile nbp/features/exchange_rate.feature
Feature: Exchange rate form
    In order to become rich
    As a future currency speculator
    I'd like to know currency exchange rates

    Scenario: Navigation to Home Page
        When I navigate to Home Page
        Then Home Page should be displayed

    Scenario Outline: Exchange rate form works
        Given I navigate to Home Page
          And I enter <date> as date
          And I enter <currency> as currency
         When I sent the form
         Then <expected output> should be displayed

    Examples:
      | date       | currency | expected output      |
      | 2017/02/03 | USD      | 1.00 USD = 4.0014 PLN   |
      | 2017/02/03 | ISK      | 1.00 ISK = 0.035263 PLN |
      | 2017/02/05 | USD      | No data for this day |

features/steps/exchange_rate_steps.py

%%writefile nbp/features/steps/exchange_rate_steps.py
from behave import *

# @given('I navigate to Home Page')
# @when('I navigate to Home Page')
@step('I navigate to Home Page')
def step_impl(ctx):
    ctx.resp = ctx.client.get('/')

@given('I enter {value} as {field}')
def step_impl(ctx, value, field):
    ctx.resp.form[field] = value

@when('I sent the form')
def step_impl(ctx):
    ctx.resp = ctx.resp.form.submit()

@then('{text} should be displayed')
def step_impl(ctx, text):
    assert text in ctx.resp

Redis

This module is designed to provide you with a comprehensive understanding of Redis, from installation to practical usage in Python. By the end of this module, you'll be equipped to leverage Redis in your applications effectively. This module is structured to be completed within half a day, leading up to the start of the twitter.py exercise.

Redis (Remote Dictionary Server) is an open-source, in-memory data structure store that functions as a versatile database, cache, and message broker. Renowned for its speed and efficiency, Redis is widely used in modern applications to enhance performance, manage real-time data, and facilitate complex data operations. Its ability to handle various data structures with atomic operations makes it a powerful tool for developers seeking both simplicity and performance.

Key Features of Redis

Redis boasts a rich set of features that cater to a wide range of use cases. Below are its core functionalities:

In-Memory Data Storage

  • Speed: By storing data in RAM, Redis provides ultra-fast read and write operations, often completing them in sub-millisecond times.
  • Latency: Minimal latency makes Redis ideal for applications requiring real-time data processing and quick response times.

Persistence Options

While Redis is primarily an in-memory store, it offers multiple persistence mechanisms to ensure data durability:

  • Snapshotting (RDB): Periodically saves the dataset to disk, allowing for point-in-time recovery.
  • Append-Only File (AOF): Logs every write operation received by the server, enabling a more granular recovery of data.
  • Hybrid Approach: Combines both RDB and AOF for optimal performance and data safety.

Rich Data Structures

Redis supports a variety of data structures, each optimized for specific operations:

  • Strings: Simple key-value pairs, suitable for caching and simple data storage.
  • Hashes: Maps between string fields and string values, ideal for representing objects.
  • Lists: Ordered collections of strings, useful for queues and stacks.
  • Sets: Unordered collections of unique strings, perfect for membership testing and uniqueness constraints.
  • Sorted Sets (ZSets): Similar to sets but with an associated score for each member, enabling ranking and ordering.
  • Bitmaps, HyperLogLogs, and Geospatial Indexes: Specialized data structures for advanced use cases like counting unique items, tracking bits, and handling location-based data.

Atomic Operations

Redis ensures that all operations are atomic, meaning each operation is completed entirely without interference from other operations. This guarantees data consistency, especially in concurrent environments.

High Availability and Scalability

  • Replication: Redis supports master-slave replication, allowing data to be duplicated across multiple servers for redundancy and load balancing.
  • Sentinel: Provides high availability by monitoring Redis instances and performing automatic failover in case of failures.
  • Redis Cluster: Enables horizontal scaling by partitioning data across multiple Redis nodes, ensuring seamless scalability and fault tolerance.

Pub/Sub Messaging

Redis includes a Publish/Subscribe messaging paradigm, allowing messages to be broadcast to multiple subscribers. This feature is useful for real-time messaging, event notification systems, and inter-service communication.

Transactions

Redis transactions allow multiple commands to be executed in a single, atomic operation. Commands within a transaction are queued and executed sequentially, ensuring that either all commands are processed or none are, maintaining data integrity.

Lightweight and Portable

Redis is lightweight, easy to install, and runs on various platforms, including Linux, macOS, and Windows (via WSL or Docker). Its portability makes it accessible for diverse development environments.

Common Use Cases for Redis

Redis's flexibility and performance make it suitable for a wide array of applications:

  1. Caching:

    • Purpose: Reduce latency and offload database queries by storing frequently accessed data in Redis.
    • Example: Caching API responses, session data, or user profiles.
  2. Real-Time Analytics:

    • Purpose: Handle high-speed data ingestion and provide immediate insights.
    • Example: Tracking website metrics, monitoring application performance, or analyzing user behavior in real-time.
  3. Session Management:

    • Purpose: Store user session information efficiently with quick access times.
    • Example: Managing user authentication tokens, shopping cart data, or user preferences.
  4. Message Queues:

    • Purpose: Implement reliable and scalable message queuing systems for asynchronous processing.
    • Example: Task scheduling, background job processing, or inter-service communication in microservices architectures.
  5. Leaderboard and Counting:

    • Purpose: Maintain dynamic leaderboards and counters with real-time updates.
    • Example: Gaming leaderboards, vote counts, or view counters.
  6. Geospatial Applications:

    • Purpose: Store and query location-based data efficiently.
    • Example: Finding nearby users, tracking delivery vehicles, or mapping services.
  7. Full-Page Caching:

    • Purpose: Cache entire web pages to accelerate web application performance.
    • Example: E-commerce websites serving cached product pages to reduce server load.
  8. Rate Limiting:

    • Purpose: Control the rate of requests to APIs or services to prevent abuse and ensure fair usage.
    • Example: Limiting the number of login attempts or API calls per user within a specific timeframe.

Fundamentals

Redis Installation

Before diving into Redis, it's essential to install and set it up correctly on your system. Below are the steps to install Redis, particularly focusing on Windows users utilizing the Windows Subsystem for Linux (WSL).

Installation Steps:

  1. Set Up WSL (Windows Subsystem for Linux):

    If you're on Windows, it's recommended to use WSL for a seamless Redis experience. Follow the official guide to install WSL.

  2. Install Redis:

    Open your WSL terminal and execute the following commands:

    sudo apt update
    sudo apt install redis-server
    
  3. Start Redis Server:

    After installation, start the Redis server using:

    sudo service redis-server start
    
  4. Verify Installation:

    Test if Redis is running correctly by pinging the server:

    redis-cli ping
    

    You should receive a response:

    PONG
  5. Install redis-py for Python Integration:

    To interact with Redis using Python, install the redis library:

    pip install redis
    

Basic Python Connection:

Here's a simple Python script to connect to your local Redis server and perform basic operations:

import redis

# Create a connection to the localhost Redis server instance
# By default, Redis runs on port 6379
redis_db = redis.Redis(host="localhost", port=6379, db=0)

# Retrieve all keys (should be empty initially)
print(redis_db.keys())  # Output: []

Explanation:

  • redis.Redis: Initializes a connection to the Redis server.
  • host & port: Specify the server address and port.
  • db: Specifies the database number (default is 0).

Redis

This module is designed to provide you with a comprehensive understanding of Redis, from installation to practical usage in Python. By the end of this module, you'll be equipped to leverage Redis in your applications effectively. This module is structured to be completed within half a day, leading up to the start of the twitter.py exercise.

Redis (Remote Dictionary Server) is an open-source, in-memory data structure store that functions as a versatile database, cache, and message broker. Renowned for its speed and efficiency, Redis is widely used in modern applications to enhance performance, manage real-time data, and facilitate complex data operations. Its ability to handle various data structures with atomic operations makes it a powerful tool for developers seeking both simplicity and performance.

Key Features of Redis

Redis boasts a rich set of features that cater to a wide range of use cases. Below are its core functionalities:

In-Memory Data Storage

  • Speed: By storing data in RAM, Redis provides ultra-fast read and write operations, often completing them in sub-millisecond times.
  • Latency: Minimal latency makes Redis ideal for applications requiring real-time data processing and quick response times.

Persistence Options

While Redis is primarily an in-memory store, it offers multiple persistence mechanisms to ensure data durability:

  • Snapshotting (RDB): Periodically saves the dataset to disk, allowing for point-in-time recovery.
  • Append-Only File (AOF): Logs every write operation received by the server, enabling a more granular recovery of data.
  • Hybrid Approach: Combines both RDB and AOF for optimal performance and data safety.

Rich Data Structures

Redis supports a variety of data structures, each optimized for specific operations:

  • Strings: Simple key-value pairs, suitable for caching and simple data storage.
  • Hashes: Maps between string fields and string values, ideal for representing objects.
  • Lists: Ordered collections of strings, useful for queues and stacks.
  • Sets: Unordered collections of unique strings, perfect for membership testing and uniqueness constraints.
  • Sorted Sets (ZSets): Similar to sets but with an associated score for each member, enabling ranking and ordering.
  • Bitmaps, HyperLogLogs, and Geospatial Indexes: Specialized data structures for advanced use cases like counting unique items, tracking bits, and handling location-based data.

Atomic Operations

Redis ensures that all operations are atomic, meaning each operation is completed entirely without interference from other operations. This guarantees data consistency, especially in concurrent environments.

High Availability and Scalability

  • Replication: Redis supports master-slave replication, allowing data to be duplicated across multiple servers for redundancy and load balancing.
  • Sentinel: Provides high availability by monitoring Redis instances and performing automatic failover in case of failures.
  • Redis Cluster: Enables horizontal scaling by partitioning data across multiple Redis nodes, ensuring seamless scalability and fault tolerance.

Pub/Sub Messaging

Redis includes a Publish/Subscribe messaging paradigm, allowing messages to be broadcast to multiple subscribers. This feature is useful for real-time messaging, event notification systems, and inter-service communication.

Transactions

Redis transactions allow multiple commands to be executed in a single, atomic operation. Commands within a transaction are queued and executed sequentially, ensuring that either all commands are processed or none are, maintaining data integrity.

Lightweight and Portable

Redis is lightweight, easy to install, and runs on various platforms, including Linux, macOS, and Windows (via WSL or Docker). Its portability makes it accessible for diverse development environments.

Common Use Cases for Redis

Redis's flexibility and performance make it suitable for a wide array of applications:

  1. Caching:

    • Purpose: Reduce latency and offload database queries by storing frequently accessed data in Redis.
    • Example: Caching API responses, session data, or user profiles.
  2. Real-Time Analytics:

    • Purpose: Handle high-speed data ingestion and provide immediate insights.
    • Example: Tracking website metrics, monitoring application performance, or analyzing user behavior in real-time.
  3. Session Management:

    • Purpose: Store user session information efficiently with quick access times.
    • Example: Managing user authentication tokens, shopping cart data, or user preferences.
  4. Message Queues:

    • Purpose: Implement reliable and scalable message queuing systems for asynchronous processing.
    • Example: Task scheduling, background job processing, or inter-service communication in microservices architectures.
  5. Leaderboard and Counting:

    • Purpose: Maintain dynamic leaderboards and counters with real-time updates.
    • Example: Gaming leaderboards, vote counts, or view counters.
  6. Geospatial Applications:

    • Purpose: Store and query location-based data efficiently.
    • Example: Finding nearby users, tracking delivery vehicles, or mapping services.
  7. Full-Page Caching:

    • Purpose: Cache entire web pages to accelerate web application performance.
    • Example: E-commerce websites serving cached product pages to reduce server load.
  8. Rate Limiting:

    • Purpose: Control the rate of requests to APIs or services to prevent abuse and ensure fair usage.
    • Example: Limiting the number of login attempts or API calls per user within a specific timeframe.

Fundamentals

Redis Installation

Before diving into Redis, it's essential to install and set it up correctly on your system. Below are the steps to install Redis, particularly focusing on Windows users utilizing the Windows Subsystem for Linux (WSL).

Installation Steps:

  1. Set Up WSL (Windows Subsystem for Linux):

    If you're on Windows, it's recommended to use WSL for a seamless Redis experience. Follow the official guide to install WSL.

  2. Install Redis:

    Open your WSL terminal and execute the following commands:

    sudo apt update
    sudo apt install redis-server
    
  3. Start Redis Server:

    After installation, start the Redis server using:

    sudo service redis-server start
    
  4. Verify Installation:

    Test if Redis is running correctly by pinging the server:

    redis-cli ping
    

    You should receive a response:

    PONG
  5. Install redis-py for Python Integration:

    To interact with Redis using Python, install the redis library:

    pip install redis
    

Basic Python Connection:

Here's a simple Python script to connect to your local Redis server and perform basic operations:

import redis

# Create a connection to the localhost Redis server instance
# By default, Redis runs on port 6379
redis_db = redis.Redis(host="localhost", port=6379, db=0)

# Retrieve all keys (should be empty initially)
print(redis_db.keys())  # Output: []

Explanation:

  • redis.Redis: Initializes a connection to the Redis server.
  • host & port: Specify the server address and port.
  • db: Specifies the database number (default is 0).
import redis

# Create a connection to the localhost Redis server instance
# By default, Redis runs on port 6379
redis_db = redis.Redis(host="localhost", port=6379, db=0)

# Retrieve all keys (should be empty initially)
print(redis_db.keys())  # Output: []
[b'user:1000', b'user:1000:active', b'tags2', b'tags', b'tags_union', b'unique_visitors', b'hackers']

Using the Redis CLI

The Redis Command-Line Interface (CLI) is a powerful tool for interacting with your Redis server directly. Below are common commands and their usages.

Common Redis CLI Commands:

redis
127.0.0.1:6379> FLUSHALL
OK

127.0.0.1:6379> KEYS
(error) ERR wrong number of arguments for 'keys' command

127.0.0.1:6379> KEYS *
(empty array)

127.0.0.1:6379> SET db_url "asgdfhdsaghfgghdsa"
OK

127.0.0.1:6379> KEYS *
1) "db_url"

127.0.0.1:6379> GET db_url
"asgdfhdsaghfgghdsa"

127.0.0.1:6379> SET counter 0
OK

127.0.0.1:6379> INCR counter
(integer) 1

127.0.0.1:6379> GET counter
"1"

Explanation of Commands:

  • FLUSHALL: Removes all keys from all databases.
  • KEYS *: Lists all keys in the current database.
  • SET key value: Sets the value of a key.
  • GET key: Retrieves the value of a key.
  • INCR key: Increments the integer value of a key by one.

Understanding Race Conditions

Race conditions occur when multiple clients access and modify the same data concurrently, leading to unexpected results. Redis operations are atomic, but understanding potential race conditions is crucial for maintaining data integrity.

 

Example Without Race Condition:

redis
# Initialize the counter
SET counter 0

# Client A retrieves the counter
GET counter  # Returns 0

# Client A increments the counter
SET counter 1

# Client B retrieves the updated counter
GET counter  # Returns 1

# Client B increments the counter
SET counter 2

Outcome: The counter increments correctly from 0 to 2 without any issues.

Example With Race Condition:

redis
# Initialize the counter
SET counter 0

# Client A retrieves the counter
GET counter  # Returns 0

# Client B retrieves the counter
GET counter  # Also returns 0

# Both clients increment the counter based on the retrieved value
SET counter 1  # Client A
SET counter 1  # Client B

Outcome: Instead of the counter being 2, it remains at 1 due to both clients reading the same initial value and setting it independently.

Mitigating Race Conditions:

To prevent race conditions, use Redis's atomic operations such as INCR:

redis
# Initialize the counter
SET counter 0

# Client A increments the counter
INCR counter  # Counter becomes 1

# Client B increments the counter
INCR counter  # Counter becomes 2

Outcome: The counter correctly increments to 2, ensuring data integrity.

Interacting with Redis in Python

Leveraging Redis within Python applications allows for efficient data storage and retrieval. Below are examples of common operations using the redis-py library.

Setting Up the Connection:

import redis

# Create a connection to the localhost Redis server instance
# By default, Redis runs on port 6379
redis_db = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True)

# Flush all existing data
print(redis_db.flushall())  # Output: True

# Retrieve all keys (should be empty)
print(redis_db.keys())  # Output: []

# Set a key-value pair
print(redis_db.set('db_url', 'asgdfhdsaghfgghdsa'))  # Output: True
# Flush all existing data
print(redis_db.flushall())  # Output: True

# Retrieve all keys (should be empty)
print(redis_db.keys())  # Output: []

# Set a key-value pair
print(redis_db.set('db_url', 'asgdfhdsaghfgghdsa'))  # Output: True

# Retrieve all keys (should be empty)
print(redis_db.keys())  # Output: []
True
[]
True
[b'db_url']

Performing Basic Operations:

# Retrieve all keys
print(redis_db.keys())  # Output: ['db_url']

# Get the value of a key
print(redis_db.get('db_url'))  # Output: 'asgdfhdsaghfgghdsa'

# Initialize a counter
redis_db.set('counter', 0)

# Increment the counter
print(redis_db.incr('counter'))  # Output: 1

# Get the updated counter value
print(redis_db.get('counter'))  # Output: '1'

Explanation:

  • flushall(): Clears all data from Redis.
  • keys(): Retrieves all keys in the current database.
  • set(key, value): Sets the value for a specified key.
  • get(key): Retrieves the value of a specified key.
  • incr(key): Atomically increments the integer value of a key by one.
# Get the value of a key
print(redis_db.get('db_url'))  # Output: 'asgdfhdsaghfgghdsa'
b'asgdfhdsaghfgghdsa'
# Initialize a counter
redis_db.set('counter', 0)

# Increment the counter
print(redis_db.incr('counter'))  # Output: 1

# Get the updated counter value
print(redis_db.get('counter'))  # Output: '1'
1
b'1'

Advanced Operations:

To handle more complex scenarios and ensure atomicity, consider using transactions or Redis pipelines.

# Using a pipeline for atomic operations
pipeline = redis_db.pipeline()

pipeline.set('user:1000', 'John Doe')
pipeline.incr('counter')
pipeline.get('user:1000')
pipeline.get('counter')

results = pipeline.execute()
print(results)
# Output: [True, 2, 'John Doe', '2']
pipeline = redis_db.pipeline()

pipeline.set('user:1000', 'John Doe')
pipeline.incr('counter')
pipeline.get('user:1000')
pipeline.get('counter')

results = pipeline.execute()
print(results)
[True, 2, b'John Doe', b'2']
print(redis_db.keys())
[b'user:1000', b'counter', b'db_url']

*Benefits of Pipelines:**

  • Performance: Reduces the number of round-trip times by batching commands.
  • Atomicity: Ensures that a series of commands are executed as a single transaction.

Redis Data Types

Understanding Redis data types is fundamental to leveraging Redis effectively in your applications. Redis supports a variety of data types, each optimized for specific use cases and operations. This section will explore these data types, their associated commands, and practical examples of how to use them both in the Redis CLI and within Python using the redis-py library.

Overview of Redis Data Types

Redis categorizes its data structures into the following types:

  1. Strings
  2. Lists
  3. Hashes
  4. Sets
  5. Sorted Sets
  6. Bitmaps
  7. HyperLogLogs

Additionally, Redis provides various commands for managing key spaces and other operations such as expiration and flushing databases.

For a comprehensive overview of Redis data types, refer to the Redis Data Types Introduction.

Querying the Key Space

Before diving into specific data types, it's essential to understand how to interact with the Redis key space. The key space is where all keys and their associated values are stored.

Common Key Space Commands

  • EXISTS key: Checks if a key exists.
  • SET key value: Sets the value of a key.
  • DEL key: Deletes a key.
  • TYPE key: Returns the data type of the value stored at the key.

Examples in Redis CLI

redis
127.0.0.1:6379> SET greeting "Hello, Redis!"
OK

127.0.0.1:6379> EXISTS greeting
(integer) 1

127.0.0.1:6379> TYPE greeting
string

127.0.0.1:6379> DEL greeting
(integer) 1

127.0.0.1:6379> EXISTS greeting
(integer) 0

Examples in Python

import redis

# Connect to Redis
redis_db = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True)

# Set a key
redis_db.set('greeting', 'Hello, Redis!')

# Check if the key exists
print(redis_db.exists('greeting'))  # Output: 1

# Get the type of the key
print(redis_db.type('greeting'))    # Output: 'string'

# Delete the key
redis_db.delete('greeting')

# Verify deletion
print(redis_db.exists('greeting'))  # Output: 0
# Set a key
redis_db.set('greeting', 'Hello, Redis!')

# Check if the key exists
print(redis_db.exists('greeting'))  # Output: 1

# Get the type of the key
print(redis_db.type('greeting'))    # Output: 'string'

# Delete the key
redis_db.delete('greeting')

# Verify deletion
print(redis_db.exists('greeting'))  # Output: 0
1
b'string'
0

Strings

Description:
Strings are the simplest Redis data type, representing a sequence of bytes. They can store text, numbers, or binary data.

Common Commands:

  • SET key value: Sets the value of a key.
  • GET key: Retrieves the value of a key.
  • MSET key1 value1 key2 value2 ...: Sets multiple keys to multiple values.
  • MGET key1 key2 ...: Retrieves the values of multiple keys.
  • INCR key: Increments the integer value of a key by one.
  • INCRBY key increment: Increments the integer value of a key by a specified amount.

Examples in Redis CLI:

redis
127.0.0.1:6379> SET user:1000 "John Doe"
OK

127.0.0.1:6379> GET user:1000
"John Doe"

127.0.0.1:6379> MSET user:1001 "Jane Smith" user:1002 "Alice Johnson"
OK

127.0.0.1:6379> MGET user:1000 user:1001 user:1002
1) "John Doe"
2) "Jane Smith"
3) "Alice Johnson"

127.0.0.1:6379> SET counter 10
OK

127.0.0.1:6379> INCR counter
(integer) 11

127.0.0.1:6379> INCRBY counter 5
(integer) 16

Examples in Python:

# Set a single key
redis_db.set('user:1000', 'John Doe')

# Get the value of a key
print(redis_db.get('user:1000'))  # Output: 'John Doe'

# Set multiple keys
redis_db.mset({'user:1001': 'Jane Smith', 'user:1002': 'Alice Johnson'})

# Get multiple keys
print(redis_db.mget('user:1000', 'user:1001', 'user:1002'))
# Output: ['John Doe', 'Jane Smith', 'Alice Johnson']

# Increment a counter
redis_db.set('counter', 10)
redis_db.incr('counter')
print(redis_db.get('counter'))  # Output: '11'

# Increment by a specific value
redis_db.incrby('counter', 5)
print(redis_db.get('counter'))  # Output: '16'
redis_db.set('user:1000', 'John Doe')

# Get the value of a key
print(redis_db.get('user:1000'))  # Output: 'John Doe'

# Set multiple keys
redis_db.mset({'user:1001': 'Jane Smith', 'user:1002': 'Alice Johnson'})

# Get multiple keys
print(redis_db.mget('user:1000', 'user:1001', 'user:1002'))
# Output: ['John Doe', 'Jane Smith', 'Alice Johnson']

# Increment a counter
redis_db.set('counter', 10)
redis_db.incr('counter')
print(redis_db.get('counter'))  # Output: '11'

# Increment by a specific value
redis_db.incrby('counter', 5)
print(redis_db.get('counter'))  # Output: '16'
b'John Doe'
[b'John Doe', b'Jane Smith', b'Alice Johnson']
b'11'
b'16'

Lists

Description: Lists are ordered collections of strings, implemented as linked lists. They are ideal for implementing queues, stacks, and other ordered data structures.

Common Commands:

  • RPUSH key value: Appends a value to the end of a list.
  • LPUSH key value: Prepends a value to the beginning of a list.
  • RPOP key: Removes and returns the last element of the list.
  • LPOP key: Removes and returns the first element of the list.
  • LRANGE key start stop: Retrieves a range of elements from the list.
  • LLEN key: Returns the length of the list.
  • LTRIM key start stop: Trims the list to the specified range.
  • BRPOP key timeout: Removes and returns the last element of the list, blocking until an element is available or timeout is reached.

Examples in Redis CLI:

redis
127.0.0.1:6379> RPUSH tasks "task1"
(integer) 1

127.0.0.1:6379> RPUSH tasks "task2" "task3"
(integer) 3

127.0.0.1:6379> LPUSH tasks "task0"
(integer) 4

127.0.0.1:6379> LRANGE tasks 0 -1
1) "task0"
2) "task1"
3) "task2"
4) "task3"

127.0.0.1:6379> LLEN tasks
(integer) 4

127.0.0.1:6379> RPOP tasks
"task3"

127.0.0.1:6379> LRANGE tasks 0 -1
1) "task0"
2) "task1"
3) "task2"

Examples in Python:

# Push elements to the right of the list
redis_db.rpush('tasks', 'task1')
redis_db.rpush('tasks', 'task2', 'task3')

# Push an element to the left of the list
redis_db.lpush('tasks', 'task0')

# Retrieve all elements in the list
print(redis_db.lrange('tasks', 0, -1))
# Output: ['task0', 'task1', 'task2', 'task3']

# Get the length of the list
print(redis_db.llen('tasks'))  # Output: 4

# Pop an element from the right
print(redis_db.rpop('tasks'))  # Output: 'task3'

# Retrieve the updated list
print(redis_db.lrange('tasks', 0, -1))
# Output: ['task0', 'task1', 'task2']
# Push elements to the right of the list
redis_db.rpush('tasks', 'task1')
redis_db.rpush('tasks', 'task2', 'task3')

# Push an element to the left of the list
redis_db.lpush('tasks', 'task0')

# Retrieve all elements in the list
print(redis_db.lrange('tasks', 0, -1))
# Output: ['task0', 'task1', 'task2', 'task3']

# Get the length of the list
print(redis_db.llen('tasks'))  # Output: 4

# Pop an element from the right
print(redis_db.rpop('tasks'))  # Output: 'task3'

# Retrieve the updated list
print(redis_db.lrange('tasks', 0, -1))
# Output: ['task0', 'task1', 'task2']
[b'task0', b'task0', b'task0', b'task1', b'task2', b'task1', b'task2', b'task1', b'task2', b'task3']
10
b'task3'
[b'task0', b'task0', b'task0', b'task1', b'task2', b'task1', b'task2', b'task1', b'task2']

Hashes

Description:
Hashes are maps between string fields and string values, allowing you to store multiple key-value pairs under a single Redis key. They are ideal for representing objects with multiple attributes.

Common Commands:

  • HSET key field value: Sets the value of a field in a hash.
  • HGET key field: Retrieves the value of a field in a hash.
  • HGETALL key: Retrieves all fields and values in a hash.
  • HMSET key field1 value1 field2 value2 ...: Sets multiple fields in a hash.
  • HDEL key field: Deletes a field from a hash.
  • HINCRBY key field increment: Increments the integer value of a field by a specified amount.

Examples in Redis CLI:

redis
127.0.0.1:6379> HSET user:1000 username "jan"
(integer) 1

127.0.0.1:6379> HSET user:1000 password "1234"
(integer) 1

127.0.0.1:6379> HSET user:1000 birthyear "1977"
(integer) 1

127.0.0.1:6379> HSET user:1000 verified "1"
(integer) 1

127.0.0.1:6379> HGETALL user:1000
1) "username"
2) "jan"
3) "password"
4) "1234"
5) "birthyear"
6) "1977"
7) "verified"
8) "1"

127.0.0.1:6379> HINCRBY user:1000 verified 1
(integer) 2

127.0.0.1:6379> HGET user:1000 verified
"2"

Examples in Python:

# Set fields in a hash
redis_db.hset('user:1000', 'username', 'jan')
redis_db.hset('user:1000', 'password', '1234')
redis_db.hset('user:1000', 'birthyear', '1977')
redis_db.hset('user:1000', 'verified', '1')

# Retrieve all fields and values
print(redis_db.hgetall('user:1000'))
# Output: {'username': 'jan', 'password': '1234', 'birthyear': '1977', 'verified': '1'}

# Increment a field
redis_db.hincrby('user:1000', 'verified', 1)
print(redis_db.hget('user:1000', 'verified'))  # Output: '2'
# Set fields in a hash
redis_db.hset('user:1000', 'username', 'jan')
redis_db.hset('user:1000', 'password', '1234')
redis_db.hset('user:1000', 'birthyear', '1977')
redis_db.hset('user:1000', 'verified', '1')

# Retrieve all fields and values
print(redis_db.hgetall('user:1000'))
# Output: {'username': 'jan', 'password': '1234', 'birthyear': '1977', 'verified': '1'}

# Increment a field
redis_db.hincrby('user:1000', 'verified', 1)
print(redis_db.hget('user:1000', 'verified'))  # Output: '2'
{b'username': b'jan', b'password': b'1234', b'birthyear': b'1977', b'verified': b'1'}
b'2'

Sets

Description:
Sets are unordered collections of unique strings. They are useful for membership testing, ensuring uniqueness, and performing set operations like unions and intersections.

Common Commands:

  • SADD key member1 member2 ...: Adds one or more members to a set.
  • SREM key member: Removes a member from a set.
  • SISMEMBER key member: Checks if a member exists in a set.
  • SMEMBERS key: Retrieves all members of a set.
  • SUNION key1 key2 ...: Returns the union of multiple sets.
  • SUNIONSTORE destination key1 key2 ...: Stores the union of multiple sets in a destination set.
  • SCARD key: Returns the number of members in a set.

Examples in Redis CLI:

redis
127.0.0.1:6379> SADD tags "python" "redis" "database"
(integer) 3

127.0.0.1:6379> SADD tags "cache" "memory"
(integer) 2

127.0.0.1:6379> SMEMBERS tags
1) "python"
2) "redis"
3) "database"
4) "cache"
5) "memory"

127.0.0.1:6379> SISMEMBER tags "redis"
(integer) 1

127.0.0.1:6379> SREM tags "cache"
(integer) 1

127.0.0.1:6379> SMEMBERS tags
1) "python"
2) "redis"
3) "database"
4) "memory"

127.0.0.1:6379> SCARD tags
(integer) 4

127.0.0.1:6379> SADD tags2 "redis" "scaling"
(integer) 2

127.0.0.1:6379> SUNION tags tags2
1) "python"
2) "redis"
3) "database"
4) "memory"
5) "scaling"

127.0.0.1:6379> SUNIONSTORE tags_union tags tags2
(integer) 5

127.0.0.1:6379> SMEMBERS tags_union
1) "python"
2) "redis"
3) "database"
4) "memory"
5) "scaling"

Examples in Python:

# Add members to a set
redis_db.sadd('tags', 'python', 'redis', 'database')
redis_db.sadd('tags', 'cache', 'memory')

# Retrieve all members
print(redis_db.smembers('tags'))
# Output: {'python', 'redis', 'database', 'cache', 'memory'}

# Check membership
print(redis_db.sismember('tags', 'redis'))  # Output: True
print(redis_db.sismember('tags', 'flask'))  # Output: False

# Remove a member
redis_db.srem('tags', 'cache')
print(redis_db.smembers('tags'))
# Output: {'python', 'redis', 'database', 'memory'}

# Get the number of members
print(redis_db.scard('tags'))  # Output: 4

# Perform set union
redis_db.sadd('tags2', 'redis', 'scaling')
union = redis_db.sunion('tags', 'tags2')
print(union)
# Output: {'python', 'redis', 'database', 'memory', 'scaling'}

# Store the union in a new set
redis_db.sunionstore('tags_union', 'tags', 'tags2')
print(redis_db.smembers('tags_union'))
# Output: {'python', 'redis', 'database', 'memory', 'scaling'}
# Add members to a set
redis_db.sadd('tags', 'python', 'redis', 'database')
redis_db.sadd('tags', 'cache', 'memory')

# Retrieve all members
print(redis_db.smembers('tags'))
# Output: {'python', 'redis', 'database', 'cache', 'memory'}

# Check membership
print(redis_db.sismember('tags', 'redis'))  # Output: True
print(redis_db.sismember('tags', 'flask'))  # Output: False

# Remove a member
redis_db.srem('tags', 'cache')
print(redis_db.smembers('tags'))
# Output: {'python', 'redis', 'database', 'memory'}

# Get the number of members
print(redis_db.scard('tags'))  # Output: 4

# Perform set union
redis_db.sadd('tags2', 'redis', 'scaling')
union = redis_db.sunion('tags', 'tags2')
print(union)
# Output: {'python', 'redis', 'database', 'memory', 'scaling'}

# Store the union in a new set
redis_db.sunionstore('tags_union', 'tags', 'tags2')
print(redis_db.smembers('tags_union'))
# Output: {'python', 'redis', 'database', 'memory', 'scaling'}
{b'redis', b'memory', b'cache', b'python', b'database'}
1
0
{b'memory', b'redis', b'database', b'python'}
4
{b'redis', b'scaling', b'memory', b'python', b'database'}
{b'redis', b'scaling', b'memory', b'python', b'database'}

Sorted Sets

Description:
Sorted sets are similar to sets but with an associated score for each member. Members are ordered based on their scores, allowing for efficient range queries and ranking.

Common Commands:

  • ZADD key score1 member1 score2 member2 ...: Adds one or more members with associated scores to a sorted set.
  • ZRANGE key start stop [WITHSCORES]: Retrieves a range of members by index.
  • ZREVRANGE key start stop [WITHSCORES]: Retrieves a range of members by index in reverse order.
  • ZRANGEBYSCORE key min max [WITHSCORES]: Retrieves members within a score range.
  • ZRANK key member: Returns the rank of a member.
  • ZREM key member: Removes a member from a sorted set.
  • ZREMRANGEBYSCORE key min max: Removes all members within a score range.

Examples in Redis CLI:

redis
127.0.0.1:6379> ZADD hackers 1912 "Alan Turing"
(integer) 1

127.0.0.1:6379> ZADD hackers 1940 "Alan Key"
(integer) 1

127.0.0.1:6379> ZADD hackers 1957 "Sophie Wilson"
(integer) 1

127.0.0.1:6379> ZRANGE hackers 0 -1 WITHSCORES
1) "Alan Turing"
2) "1912"
3) "Alan Key"
4) "1940"
5) "Sophie Wilson"
6) "1957"

127.0.0.1:6379> ZRANK hackers "Alan Key"
(integer) 1

127.0.0.1:6379> ZREM hackers "Alan Turing"
(integer) 1

127.0.0.1:6379> ZRANGE hackers 0 -1 WITHSCORES
1) "Alan Key"
2) "1940"
3) "Sophie Wilson"
4) "1957"

Examples in Python:

# Add members with scores to a sorted set
redis_db.zadd('hackers', {'Alan Turing': 1912, 'Alan Key': 1940, 'Sophie Wilson': 1957})

# Retrieve all members with scores
print(redis_db.zrange('hackers', 0, -1, withscores=True))
# Output: [('Alan Turing', 1912.0), ('Alan Key', 1940.0), ('Sophie Wilson', 1957.0)]

# Get the rank of a member
print(redis_db.zrank('hackers', 'Alan Key'))  # Output: 1

# Remove a member
redis_db.zrem('hackers', 'Alan Turing')
print(redis_db.zrange('hackers', 0, -1, withscores=True))
# Output: [('Alan Key', 1940.0), ('Sophie Wilson', 1957.0)]
# Add members with scores to a sorted set
redis_db.zadd('hackers', {'Alan Turing': 1912, 'Alan Key': 1940, 'Sophie Wilson': 1957})

# Retrieve all members with scores
print(redis_db.zrange('hackers', 0, -1, withscores=True))
# Output: [('Alan Turing', 1912.0), ('Alan Key', 1940.0), ('Sophie Wilson', 1957.0)]

# Get the rank of a member
print(redis_db.zrank('hackers', 'Alan Key'))  # Output: 1

# Remove a member
redis_db.zrem('hackers', 'Alan Turing')
print(redis_db.zrange('hackers', 0, -1, withscores=True))
# Output: [('Alan Key', 1940.0), ('Sophie Wilson', 1957.0)]
[(b'Alan Turing', 1912.0), (b'Alan Key', 1940.0), (b'Sophie Wilson', 1957.0)]
1
[(b'Alan Key', 1940.0), (b'Sophie Wilson', 1957.0)]

Bitmaps

Description:
Bitmaps allow you to manipulate individual bits within a string value. They are useful for scenarios like tracking user activities, feature flags, and storing compact data.

Common Commands:

  • SETBIT key offset value: Sets or clears the bit at the specified offset.
  • GETBIT key offset: Retrieves the bit value at the specified offset.
  • BITCOUNT key [start end]: Counts the number of set bits (1s) in a string.

Examples in Redis CLI:

redis
127.0.0.1:6379> SETBIT user:1000:active 7 1
(integer) 0

127.0.0.1:6379> GETBIT user:1000:active 7
(integer) 1

127.0.0.1:6379> SETBIT user:1000:active 0 1
(integer) 0

127.0.0.1:6379> GETBIT user:1000:active 0
(integer) 1

127.0.0.1:6379> BITCOUNT user:1000:active
(integer) 2

Examples in Python:

# Set bits at specific offsets
redis_db.setbit('user:1000:active', 7, 1)
redis_db.setbit('user:1000:active', 0, 1)

# Get bits at specific offsets
print(redis_db.getbit('user:1000:active', 7))  # Output: 1
print(redis_db.getbit('user:1000:active', 0))  # Output: 1

# Count the number of set bits
print(redis_db.bitcount('user:1000:active'))  # Output: 2
# Set bits at specific offsets
redis_db.setbit('user:1000:active', 7, 1)
redis_db.setbit('user:1000:active', 0, 1)

# Get bits at specific offsets
print(redis_db.getbit('user:1000:active', 7))  # Output: 1
print(redis_db.getbit('user:1000:active', 0))  # Output: 1

# Count the number of set bits
print(redis_db.bitcount('user:1000:active'))  # Output: 2
1
1
2

HyperLogLogs

Description:
HyperLogLogs are probabilistic data structures used to estimate the cardinality (number of unique elements) of a dataset with minimal memory usage. They are ideal for large-scale unique counts, such as website visitors.

Common Commands:

  • PFADD key element1 element2 ...: Adds elements to the HyperLogLog.
  • PFCOUNT key: Returns the approximate cardinality of the HyperLogLog.

Examples in Redis CLI:

redis
127.0.0.1:6379> PFADD unique_visitors user1
(integer) 1

127.0.0.1:6379> PFADD unique_visitors user2 user3 user4
(integer) 1

127.0.0.1:6379> PFCOUNT unique_visitors
(integer) 4

127.0.0.1:6379> PFADD unique_visitors user2 user5
(integer) 1

127.0.0.1:6379> PFCOUNT unique_visitors
(integer) 5

Examples in Python:

# Add elements to HyperLogLog
redis_db.pfadd('unique_visitors', 'user1')
redis_db.pfadd('unique_visitors', 'user2', 'user3', 'user4')

# Get the approximate count
print(redis_db.pfcount('unique_visitors'))  # Output: 4

# Add more elements, including duplicates
redis_db.pfadd('unique_visitors', 'user2', 'user5')

# Get the updated count
print(redis_db.pfcount('unique_visitors'))  # Output: 5
# Add elements to HyperLogLog
redis_db.pfadd('unique_visitors', 'user1')
redis_db.pfadd('unique_visitors', 'user2', 'user3', 'user4')

# Get the approximate count
print(redis_db.pfcount('unique_visitors'))  # Output: 4

# Add more elements, including duplicates
redis_db.pfadd('unique_visitors', 'user2', 'user5')

# Get the updated count
print(redis_db.pfcount('unique_visitors'))  # Output: 5
4
5

Other Essential Commands

In addition to data type-specific commands, Redis provides several other commands for managing keys and the database.

Expiration and Time-to-Live (TTL)

  • EXPIRE key seconds: Sets a timeout on a key.
  • TTL key: Retrieves the remaining time-to-live of a key.

Examples in Redis CLI:

redis
127.0.0.1:6379> SET session:abc123 "user data"
OK

127.0.0.1:6379> EXPIRE session:abc123 3600
(integer) 1

127.0.0.1:6379> TTL session:abc123
(integer) 3600

127.0.0.1:6379> TTL non_existent_key
(integer) -2

Examples in Python:

# Set a key with expiration
redis_db.set('session:abc123', 'user data')
redis_db.expire('session:abc123', 3600)

# Get the TTL of a key
print(redis_db.ttl('session:abc123'))  # Output: 3600

# Get TTL of a non-existent key
print(redis_db.ttl('non_existent_key'))  # Output: -2
 

Flushing the Database

  • FLUSHALL: Removes all keys from all databases.
  • FLUSHDB: Removes all keys from the current database.

Examples in Redis CLI:

redis
127.0.0.1:6379> FLUSHDB
OK

127.0.0.1:6379> FLUSHALL
OK

Examples in Python:

# Flush the current database
redis_db.flushdb()

# Flush all databases
redis_db.flushall()
 

Exercise: 🏠 Build a Redis-Powered Twitter Clone

In this exercise, you will implement basic operations of a Twitter-like microblogging platform using Redis as a backend. The operations include:

  • Registering a new user: Provide a username and password and store them in Redis.
  • Logging in: Validate user credentials and generate an authentication token.
  • Following a user: Add a user to the following/follower lists.
  • Creating a new post: Insert a new message and propagate it to the followers’ newsfeeds and the global timeline.
  • Getting your newsfeed: Retrieve the 10 most recent posts from users you follow.
  • Getting the global timeline: Retrieve recent posts from all users, 10 at a time, possibly with pagination.
  • Logging out: Invalidate the user’s authentication token.

Data Layout

  • next_user_id and next_post_id: Integer counters for assigning unique IDs to new users and posts.
  • user:<user_id>: A hash storing username, password, and auth (token).
  • users: A hash mapping username to user_id.
  • followers:<user_id>: A sorted set of all followers of <user_id>, sorted by the time they started following.
  • following:<user_id>: A sorted set of all users that <user_id> follows, sorted by the time.
  • post:<post_id>: A hash storing user_id, username, time, and body.
  • timeline: A global capped list (trimmed to 1000 latest posts).
  • posts:<user_id>: A list of post IDs for the user’s personal newsfeed.
  • auths: A hash mapping tokens to user_ids.

Auth Logic

  • Register: Check if a user exists by username. If not, increment next_user_id, create user:<id>, and map username -> user_id.
  • Login: Given username and password, verify credentials. If correct, generate a token, store it in user:<id> and auths.
  • is_logged_in: Verify that the provided token matches the one stored for the user.
  • Logout: Remove the token from user:<id> and auths.

Other Logic

  • Create Post: Ensure the user is authenticated. Create a new post:<id> hash. Push the post ID onto the author’s and all their followers’ posts:<id> lists, as well as onto the global timeline.
  • Follow: Ensure authentication. Update followers:<followed_id> and following:<follower_id> with timestamps.

CLI Usage Example

Commands such as register username password, login username password, post message, timeline, newsfeed, follow username, logout, and bye will interact with Redis through the provided CLI.

plaintext
(Cmd) help register
Creates a new user:  REGISTER username password

(Cmd) register jan 1234
(Cmd) post asdf!
:-( Log in first

(Cmd) login jan 123456
:-( Invalid password

(Cmd) login jan 1234
(Cmd) post Cześć, jestem Jan!
(Cmd) post A to jest mĂłj drugi post!

(Cmd) timeline
Fri Jan 15 12:11:48 2021  by jan            : A to jest mĂłj drugi post!
Fri Jan 15 12:11:45 2021  by jan            : Cześć, jestem Jan!

(Cmd) register ala 1234
(Cmd) login ala 1234
(Cmd) post Cześć, a ja jestem Ala!

(Cmd) timeline
Fri Jan 15 12:12:26 2021  by ala            : Cześć, a ja jestem Ala!
Fri Jan 15 12:11:48 2021  by jan            : A to jest mĂłj drugi post!
Fri Jan 15 12:11:45 2021  by jan            : Cześć, jestem Jan!

(Cmd) newsfeed
Fri Jan 15 12:12:26 2021  by ala            : Cześć, a ja jestem Ala!

(Cmd) follow jan
(Cmd) logout

(Cmd) register kamil 12345
(Cmd) login kamil 12345
(Cmd) post A tutaj Kamil

(Cmd) login ala 1234
(Cmd) newsfeed
Fri Jan 15 12:12:26 2021  by ala            : Cześć, a ja jestem Ala!

(Cmd) login jan 1234
(Cmd) post O, widzę, Ala, że zaczęłaś mnie follować.

(Cmd) login ala 1234
(Cmd) newsfeed
Fri Jan 15 12:13:56 2021  by jan            : O, widzę, Ala, że zaczęłaś mnie follować.
Fri Jan 15 12:12:26 2021  by ala            : Cześć, a ja jestem Ala!

(Cmd) timeline
Fri Jan 15 12:13:56 2021  by jan            : O, widzę, Ala, że zaczęłaś mnie follować.
Fri Jan 15 12:13:17 2021  by kamil          : A tutaj Kamil
Fri Jan 15 12:12:26 2021  by ala            : Cześć, a ja jestem Ala!
Fri Jan 15 12:11:48 2021  by jan            : A to jest mĂłj drugi post!
Fri Jan 15 12:11:45 2021  by jan            : Cześć, jestem Jan!

(Cmd) bye
Bye!

Initial Code

You are given the following starting point:

  • twitter_cli.py provides a command-line interface using Python’s cmd module.
  • twitter.py provides data structures and some helper functions but leaves the core logic (login, require_login, follow, create_post, get_newsfeed, get_timeline, logout) unimplemented.

Your task: Fill in the NotImplementedError sections in twitter.py.

twitter_cli.py (Initial Code)

# twitter_cli.py
import cmd
from datetime import datetime
import redis
from typing import Optional

import twitter


class TwitterCLI(cmd.Cmd):
    def __init__(self, db: redis.Redis):
        self.db = db
        self.token: Optional[twitter.Token] = None
        self.user_id: Optional[twitter.UserId] = None
        super().__init__()

    def do_register(self, arg: str):
        """Creates a new user:  REGISTER username password """

        try:
            username, password = arg.split()
        except ValueError:
            print("Invalid usage. Use HELP REGISTER")
        else:
            twitter.register(self.db, username, password)

    def do_login(self, arg: str):
        """Logges in as a user:  LOGIN username password """

        try:
            username, password = arg.split()
        except ValueError:
            print("Invalid usage. Use HELP LOGIN")
        else:
            try:
                self.token, self.user_id = twitter.login(self.db, username, password)
            except twitter.NoSuchUser:
                print(":-( No such user")
            except twitter.InvalidPassword:
                print(":-( Invalid password")

    def do_follow(self, username: str):
        """Starts following a user:  FOLLOW obama """

        if self.token is None:
            print(f":-( Log in first")
        else:
            try:
                twitter.follow(self.db, self.user_id, self.token, username)
            except twitter.NoSuchUser:
                print(f":-( Looks like there is no {username}")
            except twitter.NotAuthenticated:
                print(f":-( Auth error -- login again")

    def do_post(self, body: str):
        """Posts a new update as the logged in user:  POST This is your message! """

        if self.token is None:
            print(f":-( Log in first")
        else:
            try:
                twitter.create_post(self.db, self.user_id, self.token, body)
            except twitter.NotAuthenticated:
                print(":-( Auth error - login again")

    def do_newsfeed(self, body: str):
        """Lists 10 most recent posts on your timeline:  NEWSFEED """

        if self.token is None:
            print(f":-( Log in first")
        else:
            try:
                posts = twitter.get_newsfeed(self.db, self.user_id, self.token)
            except twitter.NotAuthenticated:
                print(":-( Auth error - login again")
            else:
                for post in posts:
                    date = datetime.fromtimestamp(post.time)
                    print(f"{date.ctime():25} by {post.username:15}: {post.body}")

    def do_timeline(self, body: str):
        """Lists 10 most recent posts on the global timeline:  TIMELINE.

        You can specify the page, so that you can paginate to next 10 posts:  TIMELINE 1 """

        try:
            page = int(body.strip())
        except ValueError:
            page = 0

        if self.token is None:
            print(f":-( Log in first")
        else:
            posts = twitter.get_timeline(self.db, offset=page*10)
            for post in posts:
                date = datetime.fromtimestamp(post.time)
                print(f"{date.ctime():25} by {post.username:15}: {post.body}")

    def do_logout(self, body: str):
        """Logouts you. """

        if self.token is None:
            print(":-( First log in ")
        else:
            try:
                twitter.logout(self.db, self.user_id, self.token)
            except twitter.NotAuthenticated:
                print(":-( Auth error - login again, then logout")

            self.token = None
            self.user_id = None

    def do_bye(self, arg):
        """Exits """

        print("Bye!")
        return True


if __name__ == "__main__":
    redis_db = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True)
    cli = TwitterCLI(redis_db)
    cli.cmdloop()

twitter.py (Initial Code - You must complete the logic)

# twitter.py - Initial Code
from dataclasses import dataclass, asdict
from datetime import datetime
import random
from typing import NewType, Tuple, List

import redis


TimeStamp = NewType('TimeStamp', int)
Token = NewType('Token', str)
UserId = NewType('UserId', int)
PostId = NewType('PostId', int)


class NoSuchUser(Exception):
    pass

class InvalidPassword(Exception):
    pass

class NotAuthenticated(Exception):
    pass


@dataclass
class User:
    username: str
    password: str

@dataclass
class Post:
    user_id: UserId
    username: str
    time: TimeStamp
    body: str


def current_timestamp() -> int:
    return int(datetime.now().timestamp())

def generate_random_token() -> Token:
    return str(random.randint(0, 10000000000000))


def register(db: redis.Redis, username: str, password: str) -> None:
    user_id = db.incr('next_user_id')
    user = User(username=username, password=password)
    db.hset(f'user:{user_id}', mapping=asdict(user))
    db.hset('users', username, user_id)


def login(db: redis.Redis, username: str, password: str) -> Tuple[Token, UserId]:
    """ May raise NoSuchUser or InvalidPassword. """
    raise NotImplementedError

def require_login(db: redis.Redis, user_id: UserId, token: Token) -> None:
    """ May raise NotAuthenticated. """
    raise NotImplementedError

def follow(db: redis.Redis, follower_id: UserId, token: Token, followed_username: str) -> None:
    """ May raise NoSuchUser or NotAuthenticated """
    require_login(db, follower_id, token)
    raise NotImplementedError

def create_post(db: redis.Redis, user_id: UserId, token: Token, body: str) -> None:
    """ May raise NotAuthenticated """
    require_login(db, user_id, token)
    raise NotImplementedError

def get_newsfeed(db: redis.Redis, user_id: UserId, token: Token) -> List[Post]:
    require_login(db, user_id, token)
    raise NotImplementedError

def get_timeline(db: redis.Redis, offset: int = 0, limit: int = 10) -> List[Post]:
    raise NotImplementedError

def logout(db: redis.Redis, user_id: UserId, token: Token) -> None:
    """ May raise NotAuthenticated. """
    raise NotImplementedError

Solution

from dataclasses import dataclass, asdict
from datetime import datetime
import random
from typing import NewType, Tuple, List

import redis


TimeStamp = NewType('TimeStamp', int)
Token = NewType('Token', str)
UserId = NewType('UserId', int)
PostId = NewType('PostId', int)

class NoSuchUser(Exception):
    pass

class InvalidPassword(Exception):
    pass

class NotAuthenticated(Exception):
    pass


@dataclass
class User:
    username: str
    password: str

@dataclass
class Post:
    user_id: UserId
    username: str
    time: TimeStamp
    body: str


def current_timestamp() -> int:
    return int(datetime.now().timestamp())

def generate_random_token() -> Token:
    return str(random.randint(0, 10000000000000))

def register(db: redis.Redis, username: str, password: str) -> None:
    user_id = db.incr('next_user_id')
    user = User(username=username, password=password)
    db.hset(f'user:{user_id}', mapping=asdict(user))
    db.hset('users', username, user_id)

def login(db: redis.Redis, username: str, password: str) -> Tuple[Token, UserId]:
    user_id = db.hget('users', username)
    if user_id is None:
        raise NoSuchUser
    correct_password = db.hget(f'user:{user_id}', 'password')
    if password != correct_password:
        raise InvalidPassword
    token = generate_random_token()
    db.hset(f'user:{user_id}', 'auth', token)
    db.hset('auths', token, user_id)
    return token, user_id

def require_login(db: redis.Redis, user_id: UserId, token: Token) -> None:
    stored_user_id = db.hget('auths', token)
    if stored_user_id is None:
        raise NotAuthenticated
    real_auth = db.hget(f'user:{user_id}', 'auth')
    if real_auth != token:
        raise NotAuthenticated

def follow(db: redis.Redis, follower_id: UserId, token: Token, followed_username: str) -> None:
    require_login(db, follower_id, token)
    followed_user_id = db.hget('users', followed_username)
    if followed_user_id is None:
        raise NoSuchUser
    now = current_timestamp()
    db.zadd(f'followers:{followed_user_id}', {follower_id: now})
    db.zadd(f'following:{follower_id}', {followed_user_id: now})

def create_post(db: redis.Redis, user_id: UserId, token: Token, body: str) -> None:
    require_login(db, user_id, token)
    post_id = db.incr('next_post_id')
    username = db.hget(f'user:{user_id}', 'username')
    post = Post(
        user_id=int(user_id),
        username=username,
        time=current_timestamp(),
        body=body,
    )
    db.hset(f'post:{post_id}', mapping=asdict(post))
    # Push to all followers
    followers = db.zrange(f'followers:{user_id}', 0, -1)
    for fid in followers:
        db.lpush(f'posts:{fid}', post_id)
    # Also push to author's own feed
    db.lpush(f'posts:{user_id}', post_id)
    # Push to global timeline
    db.lpush('timeline', post_id)
    db.ltrim('timeline', 0, 1000)

def _get_posts_by_ids(db: redis.Redis, post_ids: List[PostId]) -> List[Post]:
    posts = []
    for pid in post_ids:
        data = db.hgetall(f'post:{pid}')
        posts.append(Post(
            user_id=int(data['user_id']),
            username=data['username'],
            time=int(data['time']),
            body=data['body'],
        ))
    return posts

def get_newsfeed(db: redis.Redis, user_id: UserId, token: Token) -> List[Post]:
    require_login(db, user_id, token)
    post_ids = db.lrange(f'posts:{user_id}', 0, 9) # 10 posts
    return _get_posts_by_ids(db, post_ids)

def get_timeline(db: redis.Redis, offset: int = 0, limit: int = 10) -> List[Post]:
    post_ids = db.lrange('timeline', offset, offset+limit-1)
    return _get_posts_by_ids(db, post_ids)

def logout(db: redis.Redis, user_id: UserId, token: Token) -> None:
    require_login(db, user_id, token)
    db.hdel(f'user:{user_id}', 'auth')
    db.hdel('auths', token)

Publish - Subscribe

Redis provides a built-in Publish/Subscribe (Pub/Sub) messaging paradigm. This allows you to broadcast messages to multiple subscribers listening on different channels. Subscribers can listen for new messages as they arrive, enabling real-time notifications, chat systems, streaming logs, and more.

Python: get_message

Below is an example of using Redis Pub/Sub with Python and the redis-py library. We first create a Redis connection and subscribe to a channel. By calling get_message(), we can fetch the next available message in a non-blocking manner.

import redis

r = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True)
p = r.pubsub()
p.subscribe('mychannel')
# Initially, we get a subscribe message:
msg = p.get_message()
print(msg)
{'type': 'subscribe', 'pattern': None, 'channel': 'mychannel', 'data': 1}
print(p.get_message())  # Non-blocking attempt to get another message
# None (since no new messages are available)
None
# Publish a message from another client:
r2 = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True)
r2.publish('mychannel', 'my message content')
4
# Now, retrieving the message:
msg = p.get_message()
print(msg)
{'type': 'message', 'pattern': None, 'channel': 'mychannel', 'data': 'my message content'}

Python: listen

If you want to listen continuously to incoming messages, you can use the listen() method. This method blocks and yields new messages as soon as they arrive.

r3 = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True)
p3 = r3.pubsub()
p3.subscribe('mychannel')

for msg in p3.listen():
    print(msg)
{'type': 'subscribe', 'pattern': None, 'channel': 'mychannel', 'data': 1}
{'type': 'message', 'pattern': None, 'channel': 'mychannel', 'data': 'message'}
{'type': 'message', 'pattern': None, 'channel': 'mychannel', 'data': 'message 1'}
{'type': 'message', 'pattern': None, 'channel': 'mychannel', 'data': 'message 2'}
{'type': 'message', 'pattern': None, 'channel': 'mychannel', 'data': 'message 2'}
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[6], line 5
      2 p3 = r3.pubsub()
      3 p3.subscribe('mychannel')
----> 5 for msg in p3.listen():
      6     print(msg)

File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/client.py:1026, in PubSub.listen(self)
   1024 "Listen for messages on channels this client has been subscribed to"
   1025 while self.subscribed:
-> 1026     response = self.handle_message(self.parse_response(block=True))
   1027     if response is not None:
   1028         yield response

File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/client.py:865, in PubSub.parse_response(self, block, timeout)
    862         conn.connect()
    863     return conn.read_response(disconnect_on_error=False, push_request=True)
--> 865 response = self._execute(conn, try_read)
    867 if self.is_health_check_response(response):
    868     # ignore the health check message as user might not expect it
    869     self.health_check_response_counter -= 1

File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/client.py:841, in PubSub._execute(self, conn, command, *args, **kwargs)
    833 def _execute(self, conn, command, *args, **kwargs):
    834     """
    835     Connect manually upon disconnection. If the Redis server is down,
    836     this will fail and raise a ConnectionError as desired.
   (...)
    839     patterns we were previously listening to
    840     """
--> 841     return conn.retry.call_with_retry(
    842         lambda: command(*args, **kwargs),
    843         lambda error: self._disconnect_raise_connect(conn, error),
    844     )

File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/retry.py:62, in Retry.call_with_retry(self, do, fail)
     60 while True:
     61     try:
---> 62         return do()
     63     except self._supported_errors as error:
     64         failures += 1

File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/client.py:842, in PubSub._execute.<locals>.<lambda>()
    833 def _execute(self, conn, command, *args, **kwargs):
    834     """
    835     Connect manually upon disconnection. If the Redis server is down,
    836     this will fail and raise a ConnectionError as desired.
   (...)
    839     patterns we were previously listening to
    840     """
    841     return conn.retry.call_with_retry(
--> 842         lambda: command(*args, **kwargs),
    843         lambda error: self._disconnect_raise_connect(conn, error),
    844     )

File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/client.py:863, in PubSub.parse_response.<locals>.try_read()
    861 else:
    862     conn.connect()
--> 863 return conn.read_response(disconnect_on_error=False, push_request=True)

File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/connection.py:592, in AbstractConnection.read_response(self, disable_decoding, disconnect_on_error, push_request)
    588         response = self._parser.read_response(
    589             disable_decoding=disable_decoding, push_request=push_request
    590         )
    591     else:
--> 592         response = self._parser.read_response(disable_decoding=disable_decoding)
    593 except socket.timeout:
    594     if disconnect_on_error:

File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/_parsers/resp2.py:15, in _RESP2Parser.read_response(self, disable_decoding)
     13 pos = self._buffer.get_pos() if self._buffer else None
     14 try:
---> 15     result = self._read_response(disable_decoding=disable_decoding)
     16 except BaseException:
     17     if self._buffer:

File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/_parsers/resp2.py:25, in _RESP2Parser._read_response(self, disable_decoding)
     24 def _read_response(self, disable_decoding=False):
---> 25     raw = self._buffer.readline()
     26     if not raw:
     27         raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)

File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/_parsers/socket.py:115, in SocketBuffer.readline(self)
    112 data = buf.readline()
    113 while not data.endswith(SYM_CRLF):
    114     # there's more data in the socket that we need
--> 115     self._read_from_socket()
    116     data += buf.readline()
    118 return data[:-2]

File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/_parsers/socket.py:65, in SocketBuffer._read_from_socket(self, length, timeout, raise_on_timeout)
     63 try:
     64     while True:
---> 65         data = self._sock.recv(socket_read_size)
     66         # an empty string indicates the server shutdown the socket
     67         if isinstance(data, bytes) and len(data) == 0:

KeyboardInterrupt: 

Note: listen() is blocking. To stop it, you’ll need to interrupt the execution (e.g., KeyboardInterrupt).

Python: Registered Callbacks

When subscribing, you can also register callbacks for specific channels. Messages published to those channels will trigger the callback function.

def callback(msg):
    print('inside callback', msg)

r4 = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True)
p4 = r4.pubsub()

# Subscribe to 'myanother' and assign a callback to 'mychannel'
p4.subscribe('myanother', myanother=callback)

for msg in p4.listen():
    print('inside listen', msg)
inside listen {'type': 'subscribe', 'pattern': None, 'channel': 'myanother', 'data': 1}
inside callback {'type': 'message', 'pattern': None, 'channel': 'myanother', 'data': 'message 1'}
inside callback {'type': 'message', 'pattern': None, 'channel': 'myanother', 'data': 'message 2'}
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[15], line 10
      7 # Subscribe to 'myanother' and assign a callback to 'mychannel'
      8 p4.subscribe('myanother', myanother=callback)
---> 10 for msg in p4.listen():
     11     print('inside listen', msg)

File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/client.py:1026, in PubSub.listen(self)
   1024 "Listen for messages on channels this client has been subscribed to"
   1025 while self.subscribed:
-> 1026     response = self.handle_message(self.parse_response(block=True))
   1027     if response is not None:
   1028         yield response

File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/client.py:865, in PubSub.parse_response(self, block, timeout)
    862         conn.connect()
    863     return conn.read_response(disconnect_on_error=False, push_request=True)
--> 865 response = self._execute(conn, try_read)
    867 if self.is_health_check_response(response):
    868     # ignore the health check message as user might not expect it
    869     self.health_check_response_counter -= 1

File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/client.py:841, in PubSub._execute(self, conn, command, *args, **kwargs)
    833 def _execute(self, conn, command, *args, **kwargs):
    834     """
    835     Connect manually upon disconnection. If the Redis server is down,
    836     this will fail and raise a ConnectionError as desired.
   (...)
    839     patterns we were previously listening to
    840     """
--> 841     return conn.retry.call_with_retry(
    842         lambda: command(*args, **kwargs),
    843         lambda error: self._disconnect_raise_connect(conn, error),
    844     )

File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/retry.py:62, in Retry.call_with_retry(self, do, fail)
     60 while True:
     61     try:
---> 62         return do()
     63     except self._supported_errors as error:
     64         failures += 1

File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/client.py:842, in PubSub._execute.<locals>.<lambda>()
    833 def _execute(self, conn, command, *args, **kwargs):
    834     """
    835     Connect manually upon disconnection. If the Redis server is down,
    836     this will fail and raise a ConnectionError as desired.
   (...)
    839     patterns we were previously listening to
    840     """
    841     return conn.retry.call_with_retry(
--> 842         lambda: command(*args, **kwargs),
    843         lambda error: self._disconnect_raise_connect(conn, error),
    844     )

File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/client.py:863, in PubSub.parse_response.<locals>.try_read()
    861 else:
    862     conn.connect()
--> 863 return conn.read_response(disconnect_on_error=False, push_request=True)

File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/connection.py:592, in AbstractConnection.read_response(self, disable_decoding, disconnect_on_error, push_request)
    588         response = self._parser.read_response(
    589             disable_decoding=disable_decoding, push_request=push_request
    590         )
    591     else:
--> 592         response = self._parser.read_response(disable_decoding=disable_decoding)
    593 except socket.timeout:
    594     if disconnect_on_error:

File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/_parsers/resp2.py:15, in _RESP2Parser.read_response(self, disable_decoding)
     13 pos = self._buffer.get_pos() if self._buffer else None
     14 try:
---> 15     result = self._read_response(disable_decoding=disable_decoding)
     16 except BaseException:
     17     if self._buffer:

File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/_parsers/resp2.py:25, in _RESP2Parser._read_response(self, disable_decoding)
     24 def _read_response(self, disable_decoding=False):
---> 25     raw = self._buffer.readline()
     26     if not raw:
     27         raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)

File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/_parsers/socket.py:115, in SocketBuffer.readline(self)
    112 data = buf.readline()
    113 while not data.endswith(SYM_CRLF):
    114     # there's more data in the socket that we need
--> 115     self._read_from_socket()
    116     data += buf.readline()
    118 return data[:-2]

File ~/python_training/testing/.venv/lib/python3.12/site-packages/redis/_parsers/socket.py:65, in SocketBuffer._read_from_socket(self, length, timeout, raise_on_timeout)
     63 try:
     64     while True:
---> 65         data = self._sock.recv(socket_read_size)
     66         # an empty string indicates the server shutdown the socket
     67         if isinstance(data, bytes) and len(data) == 0:

KeyboardInterrupt: 

If the callback is long-running, it will delay processing subsequent messages. You can introduce a delay and observe the blocking behavior.

Python: run_in_thread

To avoid blocking the main thread, you can use run_in_thread() to handle messages in a separate thread. This makes your main code non-blocking while still processing incoming messages asynchronously.

def callback(msg):
    print('inside callback', msg)
    sleep(5)  # Simulate a long-running task

r5 = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True)
p5 = r5.pubsub()
p5.subscribe(mychannel=callback)

# run_in_thread() will start a separate thread to listen to messages
thread = p5.run_in_thread()
print('This line executes immediately because run_in_thread is non-blocking.')

! add muti threading example

Working with FastAPI

In this chapter, we'll explore the core concepts of building APIs with FastAPI, a modern, high-performance, Python web framework for building APIs with Python 3.7+ based on standard Python type hints. By the end of this chapter, you'll understand how to define your API routes, handle parameters and request bodies, validate data, and return responses in a clean, efficient, and Pythonic way.

First Steps

Objective: Set up a basic FastAPI application and understand the fundamental building blocks.

  • Installation:
    pip install fastapi uvicorn
    
  • Hello World Example:

    from fastapi import FastAPI
    
    app = FastAPI()
    
    @app.get("/")
    def read_root():
        return {"message": "Hello World"}
    
  • Running the Server:
    uvicorn fast:app --reload # fast as the file with server is named fast.py if a file is named main.py then change to main
    

By visiting http://127.0.0.1:8000/ in your browser, you'll see your API in action. FastAPI also generates interactive documentation at http://127.0.0.1:8000/docs.

Path Parameters

Objective: Learn how to define dynamic URL paths.

  • Defining Path Parameters:
    @app.get("/items/{item_id}")
    def read_item(item_id: int):
        return {"item_id": item_id}
    
  • Type Hints for Validation:
    Specifying item_id: int automatically converts and validates the parameter.

Query Parameters

Objective: Retrieve optional parameters from the query string.

  • Defining Query Parameters:
    @app.get("/items")
    def read_items(skip: int = 0, limit: int = 10):
        return {"skip": skip, "limit": limit}
    
  • Optional Parameters and Defaults:
    If not provided, skip defaults to 0 and limit to 10.

Request Body

Objective: Send JSON data in the request body and parse it into Python objects.

  • Pydantic Models:

    from pydantic import BaseModel
    
    class Item(BaseModel):
        name: str
        description: str = None
    
    @app.post("/items/")
    def create_item(item: Item):
        return {"item": item}
    
  • Automatic Data Parsing and Validation:
    FastAPI validates request data and converts it into Item instances.

Query Parameters and String Validations

Objective: Add validations to query parameters, such as length constraints.

  • String Validations:

    from fastapi import Query
    
    @app.get("/users")
    def read_users(q: str = Query(None, min_length=3, max_length=50)):
        return {"q": q}
    
  • Using Query for Metadata and Validation: You can set default values, add regex validations, and provide descriptions.

Path Parameters and Numeric Validations

Objective: Validate numeric path parameters with constraints like minimum and maximum values.

  • Integer Validations:

    from fastapi import Path
    
    @app.get("/items/{item_id}")
    def read_items(item_id: int = Path(..., ge=1, le=1000)):
        return {"item_id": item_id}
    

Query Parameter Models

requires fast api 0.115

Objective: Use pydantic models to handle complex query parameters.

  • Complex Query Parameters:

    from pydantic import BaseModel
    
    class UserFilter(BaseModel):
        name: str
        age: int
    
    @app.get("/search")
    def search_users(filters: UserFilter):
        return {"filters": filters}
    
  • Dependency Injection for Parsing:
    You can also use dependencies to parse complex query parameters into models.

Body - Multiple Parameters

Objective: Combine query parameters, path parameters, and request bodies in a single endpoint.

  • Mixing Parameters:
    @app.post("/orders/{order_id}")
    def update_order(order_id: int, item: Item, confirm: bool = True):
        return {"order_id": order_id, "item": item, "confirm": confirm}
    

Body - Fields

Objective: Add metadata and validations to individual fields in a request body model.

  • Using Field:

    from pydantic import Field
    
    class Product(BaseModel):
        name: str = Field(..., example="Laptop")
        price: float = Field(..., gt=0, example=999.99)
    

Body - Nested Models

Objective: Create nested data structures and validate them using Pydantic models.

  • Nested Models:

    class Manufacturer(BaseModel):
        name: str
        country: str
    
    class Product(BaseModel):
        name: str
        price: float
        manufacturer: Manufacturer
    
    @app.post("/products/")
    def create_product(product: Product):
        return {"product": product}
    

Declare Request Example Data

Objective: Provide examples for request bodies to display in the interactive docs.

  • Examples in Pydantic Models:

    class Item(BaseModel):
        name: str
        description: str = None
    
        model_config = {
          "json_schema_extra": {
              "examples": [
                  {
                  "name": "Example Item",
                  "description": "A sample item."
                  }
              ]
          }
      }
    

Header Parameters

Objective: Extract headers from the request.

  • Header Parameter: ```python from fastapi import Header

    @app.get("/headers") def read_headers(user_agent: str = Header(None)):

    return {"User-Agent": user_agent}

Response Model - Return Type

Objective: Define the response model to ensure consistent response formats.

  • Response Model:
    @app.get("/users/{user_id}", response_model=User)
    def get_user(user_id: int):
        return User(id=user_id, name="John Doe")
    

FastAPI automatically filters and validates the response.

Response Status Code

Objective: Set HTTP status codes for your responses.

  • Specifying Status Codes: ```python from fastapi import status

    @app.post("/items/", status_code=status.HTTP_201_CREATED) def create_item(item: Item):

    return item

Unit Tests

from fastapi import FastAPI
from fastapi.testclient import TestClient

app = FastAPI()


@app.get("/")
async def read_main():
    return {"msg": "Hello World"}


client = TestClient(app)


def test_read_main():
    response = client.get("/")
    assert response.status_code == 200
    assert response.json() == {"msg": "Hello World"}